Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub mod connect_compute;
5 : pub mod handshake;
6 : pub mod passthrough;
7 : pub mod retry;
8 : pub mod wake_compute;
9 :
10 : use crate::{
11 : auth,
12 : cancellation::{self, CancelMap},
13 : compute,
14 : config::{ProxyConfig, TlsConfig},
15 : context::RequestMonitoring,
16 : metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
17 : protocol2::WithClientIp,
18 : proxy::{handshake::handshake, passthrough::proxy_pass},
19 : rate_limiter::EndpointRateLimiter,
20 : stream::{PqStream, Stream},
21 : EndpointCacheKey,
22 : };
23 : use anyhow::{bail, Context};
24 : use futures::TryFutureExt;
25 : use itertools::Itertools;
26 : use once_cell::sync::OnceCell;
27 : use pq_proto::{BeMessage as Be, StartupMessageParams};
28 : use regex::Regex;
29 : use smol_str::{format_smolstr, SmolStr};
30 : use std::sync::Arc;
31 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
32 : use tokio_util::sync::CancellationToken;
33 : use tracing::{error, info, info_span, Instrument};
34 :
35 : use self::connect_compute::{connect_to_compute, TcpMechanism};
36 :
37 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
38 : const ERR_PROTO_VIOLATION: &str = "protocol violation";
39 :
40 87 : pub async fn run_until_cancelled<F: std::future::Future>(
41 87 : f: F,
42 87 : cancellation_token: &CancellationToken,
43 87 : ) -> Option<F::Output> {
44 87 : match futures::future::select(
45 87 : std::pin::pin!(f),
46 87 : std::pin::pin!(cancellation_token.cancelled()),
47 87 : )
48 86 : .await
49 : {
50 63 : futures::future::Either::Left((f, _)) => Some(f),
51 24 : futures::future::Either::Right(((), _)) => None,
52 : }
53 87 : }
54 :
55 23 : pub async fn task_main(
56 23 : config: &'static ProxyConfig,
57 23 : listener: tokio::net::TcpListener,
58 23 : cancellation_token: CancellationToken,
59 23 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
60 23 : ) -> anyhow::Result<()> {
61 23 : scopeguard::defer! {
62 23 : info!("proxy has shut down");
63 23 : }
64 23 :
65 23 : // When set for the server socket, the keepalive setting
66 23 : // will be inherited by all accepted client sockets.
67 23 : socket2::SockRef::from(&listener).set_keepalive(true)?;
68 :
69 23 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
70 23 : let cancel_map = Arc::new(CancelMap::default());
71 :
72 61 : while let Some(accept_result) =
73 84 : run_until_cancelled(listener.accept(), &cancellation_token).await
74 : {
75 61 : let (socket, peer_addr) = accept_result?;
76 :
77 61 : let session_id = uuid::Uuid::new_v4();
78 61 : let cancel_map = Arc::clone(&cancel_map);
79 61 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
80 :
81 61 : let session_span = info_span!(
82 61 : "handle_client",
83 61 : ?session_id,
84 61 : peer_addr = tracing::field::Empty,
85 61 : ep = tracing::field::Empty,
86 61 : );
87 :
88 61 : connections.spawn(
89 61 : async move {
90 61 : info!("accepted postgres client connection");
91 :
92 61 : let mut socket = WithClientIp::new(socket);
93 61 : let mut peer_addr = peer_addr.ip();
94 61 : if let Some(addr) = socket.wait_for_addr().await? {
95 0 : peer_addr = addr.ip();
96 0 : tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
97 61 : } else if config.require_client_ip {
98 0 : bail!("missing required client IP");
99 61 : }
100 :
101 61 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
102 61 :
103 61 : socket
104 61 : .inner
105 61 : .set_nodelay(true)
106 61 : .context("failed to set socket option")?;
107 :
108 61 : handle_client(
109 61 : config,
110 61 : &mut ctx,
111 61 : cancel_map,
112 61 : socket,
113 61 : ClientMode::Tcp,
114 61 : endpoint_rate_limiter,
115 61 : )
116 1010 : .await
117 61 : }
118 61 : .unwrap_or_else(move |e| {
119 61 : // Acknowledge that the task has finished with an error.
120 61 : error!("per-client task finished with an error: {e:#}");
121 61 : })
122 61 : .instrument(session_span),
123 61 : );
124 : }
125 :
126 23 : connections.close();
127 23 : drop(listener);
128 23 :
129 23 : // Drain connections
130 23 : connections.wait().await;
131 :
132 23 : Ok(())
133 23 : }
134 :
135 : pub enum ClientMode {
136 : Tcp,
137 : Websockets { hostname: Option<String> },
138 : }
139 :
140 : /// Abstracts the logic of handling TCP vs WS clients
141 : impl ClientMode {
142 50 : fn allow_cleartext(&self) -> bool {
143 50 : match self {
144 50 : ClientMode::Tcp => false,
145 0 : ClientMode::Websockets { .. } => true,
146 : }
147 50 : }
148 :
149 39 : fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
150 39 : match self {
151 39 : ClientMode::Tcp => config.allow_self_signed_compute,
152 0 : ClientMode::Websockets { .. } => false,
153 : }
154 39 : }
155 :
156 50 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
157 50 : match self {
158 50 : ClientMode::Tcp => s.sni_hostname(),
159 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
160 : }
161 50 : }
162 :
163 61 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
164 61 : match self {
165 61 : ClientMode::Tcp => tls,
166 : // TLS is None here if using websockets, because the connection is already encrypted.
167 0 : ClientMode::Websockets { .. } => None,
168 : }
169 61 : }
170 : }
171 :
172 61 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
173 61 : config: &'static ProxyConfig,
174 61 : ctx: &mut RequestMonitoring,
175 61 : cancel_map: Arc<CancelMap>,
176 61 : stream: S,
177 61 : mode: ClientMode,
178 61 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
179 61 : ) -> anyhow::Result<()> {
180 61 : info!(
181 61 : protocol = ctx.protocol,
182 61 : "handling interactive connection from client"
183 61 : );
184 :
185 61 : let proto = ctx.protocol;
186 61 : let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
187 61 : .with_label_values(&[proto])
188 61 : .guard();
189 61 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
190 61 : .with_label_values(&[proto])
191 61 : .guard();
192 61 :
193 61 : let tls = config.tls_config.as_ref();
194 61 :
195 61 : let pause = ctx.latency_timer.pause();
196 61 : let do_handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map);
197 105 : let (mut stream, params) = match do_handshake.await? {
198 50 : Some(x) => x,
199 0 : None => return Ok(()), // it's a cancellation request
200 : };
201 50 : drop(pause);
202 50 :
203 50 : let hostname = mode.hostname(stream.get_ref());
204 50 :
205 50 : let common_names = tls.map(|tls| &tls.common_names);
206 50 :
207 50 : // Extract credentials which we're going to use for auth.
208 50 : let result = config
209 50 : .auth_backend
210 50 : .as_ref()
211 50 : .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
212 50 : .transpose();
213 :
214 50 : let user_info = match result {
215 50 : Ok(user_info) => user_info,
216 0 : Err(e) => stream.throw_error(e).await?,
217 : };
218 :
219 : // check rate limit
220 50 : if let Some(ep) = user_info.get_endpoint() {
221 47 : if !endpoint_rate_limiter.check(ep) {
222 0 : return stream
223 0 : .throw_error(auth::AuthError::too_many_connections())
224 0 : .await;
225 47 : }
226 3 : }
227 :
228 50 : let user = user_info.get_user().to_owned();
229 50 : let (mut node_info, user_info) = match user_info
230 50 : .authenticate(
231 50 : ctx,
232 50 : &mut stream,
233 50 : mode.allow_cleartext(),
234 50 : &config.authentication_config,
235 50 : )
236 637 : .await
237 : {
238 39 : Ok(auth_result) => auth_result,
239 11 : Err(e) => {
240 11 : let db = params.get("database");
241 11 : let app = params.get("application_name");
242 11 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
243 :
244 11 : return stream.throw_error(e).instrument(params_span).await;
245 : }
246 : };
247 :
248 39 : node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
249 39 :
250 39 : let aux = node_info.aux.clone();
251 39 : let mut node = connect_to_compute(
252 39 : ctx,
253 39 : &TcpMechanism { params: ¶ms },
254 39 : node_info,
255 39 : &user_info,
256 39 : )
257 39 : .or_else(|e| stream.throw_error(e))
258 114 : .await?;
259 :
260 39 : let session = cancel_map.get_session();
261 39 : prepare_client_connection(&node, &session, &mut stream).await?;
262 :
263 : // Before proxy passing, forward to compute whatever data is left in the
264 : // PqStream input buffer. Normally there is none, but our serverless npm
265 : // driver in pipeline mode sends startup, password and first query
266 : // immediately after opening the connection.
267 39 : let (stream, read_buf) = stream.into_inner();
268 39 : node.stream.write_all(&read_buf).await?;
269 :
270 154 : proxy_pass(ctx, stream, node.stream, aux).await
271 61 : }
272 :
273 : /// Finish client connection initialization: confirm auth success, send params, etc.
274 39 : #[tracing::instrument(skip_all)]
275 : async fn prepare_client_connection(
276 : node: &compute::PostgresConnection,
277 : session: &cancellation::Session,
278 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
279 : ) -> anyhow::Result<()> {
280 : // Register compute's query cancellation token and produce a new, unique one.
281 : // The new token (cancel_key_data) will be sent to the client.
282 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
283 :
284 : // Forward all postgres connection params to the client.
285 : // Right now the implementation is very hacky and inefficent (ideally,
286 : // we don't need an intermediate hashmap), but at least it should be correct.
287 : for (name, value) in &node.params {
288 : // TODO: Theoretically, this could result in a big pile of params...
289 : stream.write_message_noflush(&Be::ParameterStatus {
290 : name: name.as_bytes(),
291 : value: value.as_bytes(),
292 : })?;
293 : }
294 :
295 : stream
296 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
297 : .write_message(&Be::ReadyForQuery)
298 : .await?;
299 :
300 : Ok(())
301 : }
302 :
303 85 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
304 : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
305 :
306 : impl NeonOptions {
307 69 : pub fn parse_params(params: &StartupMessageParams) -> Self {
308 69 : params
309 69 : .options_raw()
310 69 : .map(Self::parse_from_iter)
311 69 : .unwrap_or_default()
312 69 : }
313 14 : pub fn parse_options_raw(options: &str) -> Self {
314 14 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
315 14 : }
316 :
317 73 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
318 73 : let mut options = options
319 73 : .filter_map(neon_option)
320 73 : .map(|(k, v)| (k.into(), v.into()))
321 73 : .collect_vec();
322 73 : options.sort();
323 73 : Self(options)
324 73 : }
325 :
326 260 : pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
327 260 : // prefix + format!(" {k}:{v}")
328 260 : // kinda jank because SmolStr is immutable
329 260 : std::iter::once(prefix)
330 260 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
331 260 : .collect::<SmolStr>()
332 260 : .into()
333 260 : }
334 :
335 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
336 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
337 0 : pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
338 0 : self.0
339 0 : .iter()
340 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
341 0 : .collect()
342 0 : }
343 : }
344 :
345 181 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
346 181 : static RE: OnceCell<Regex> = OnceCell::new();
347 181 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
348 :
349 181 : let cap = re.captures(bytes)?;
350 8 : let (_, [k, v]) = cap.extract();
351 8 : Some((k, v))
352 181 : }
|