Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub mod connect_compute;
5 : mod copy_bidirectional;
6 : pub mod handshake;
7 : pub mod passthrough;
8 : pub mod retry;
9 : pub mod wake_compute;
10 :
11 : use crate::{
12 : auth,
13 : cancellation::{self, CancelMap},
14 : compute,
15 : config::{ProxyConfig, TlsConfig},
16 : context::RequestMonitoring,
17 : error::ReportableError,
18 : metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
19 : protocol2::WithClientIp,
20 : proxy::handshake::{handshake, HandshakeData},
21 : rate_limiter::EndpointRateLimiter,
22 : stream::{PqStream, Stream},
23 : EndpointCacheKey,
24 : };
25 : use anyhow::{bail, Context};
26 : use futures::TryFutureExt;
27 : use itertools::Itertools;
28 : use once_cell::sync::OnceCell;
29 : use pq_proto::{BeMessage as Be, StartupMessageParams};
30 : use regex::Regex;
31 : use smol_str::{format_smolstr, SmolStr};
32 : use std::sync::Arc;
33 : use thiserror::Error;
34 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
35 : use tokio_util::sync::CancellationToken;
36 : use tracing::{error, info, info_span, Instrument};
37 :
38 : use self::{
39 : connect_compute::{connect_to_compute, TcpMechanism},
40 : passthrough::ProxyPassthrough,
41 : };
42 :
43 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
44 :
45 91 : pub async fn run_until_cancelled<F: std::future::Future>(
46 91 : f: F,
47 91 : cancellation_token: &CancellationToken,
48 91 : ) -> Option<F::Output> {
49 91 : match futures::future::select(
50 91 : std::pin::pin!(f),
51 91 : std::pin::pin!(cancellation_token.cancelled()),
52 91 : )
53 90 : .await
54 : {
55 65 : futures::future::Either::Left((f, _)) => Some(f),
56 26 : futures::future::Either::Right(((), _)) => None,
57 : }
58 91 : }
59 :
60 25 : pub async fn task_main(
61 25 : config: &'static ProxyConfig,
62 25 : listener: tokio::net::TcpListener,
63 25 : cancellation_token: CancellationToken,
64 25 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
65 25 : ) -> anyhow::Result<()> {
66 25 : scopeguard::defer! {
67 25 : info!("proxy has shut down");
68 : }
69 :
70 : // When set for the server socket, the keepalive setting
71 : // will be inherited by all accepted client sockets.
72 25 : socket2::SockRef::from(&listener).set_keepalive(true)?;
73 :
74 25 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
75 25 : let cancel_map = Arc::new(CancelMap::default());
76 :
77 63 : while let Some(accept_result) =
78 88 : run_until_cancelled(listener.accept(), &cancellation_token).await
79 : {
80 63 : let (socket, peer_addr) = accept_result?;
81 :
82 63 : let session_id = uuid::Uuid::new_v4();
83 63 : let cancel_map = Arc::clone(&cancel_map);
84 63 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
85 :
86 63 : let session_span = info_span!(
87 63 : "handle_client",
88 63 : ?session_id,
89 63 : peer_addr = tracing::field::Empty,
90 63 : ep = tracing::field::Empty,
91 63 : );
92 :
93 63 : connections.spawn(
94 63 : async move {
95 63 : info!("accepted postgres client connection");
96 :
97 63 : let mut socket = WithClientIp::new(socket);
98 63 : let mut peer_addr = peer_addr.ip();
99 63 : if let Some(addr) = socket.wait_for_addr().await? {
100 0 : peer_addr = addr.ip();
101 0 : tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
102 63 : } else if config.require_client_ip {
103 0 : bail!("missing required client IP");
104 63 : }
105 :
106 63 : socket
107 63 : .inner
108 63 : .set_nodelay(true)
109 63 : .context("failed to set socket option")?;
110 :
111 63 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
112 :
113 63 : let res = handle_client(
114 63 : config,
115 63 : &mut ctx,
116 63 : cancel_map,
117 63 : socket,
118 63 : ClientMode::Tcp,
119 63 : endpoint_rate_limiter,
120 63 : )
121 894 : .await;
122 :
123 41 : match res {
124 22 : Err(e) => {
125 22 : // todo: log and push to ctx the error kind
126 22 : ctx.set_error_kind(e.get_error_kind());
127 22 : ctx.log();
128 22 : Err(e.into())
129 : }
130 : Ok(None) => {
131 0 : ctx.set_success();
132 0 : ctx.log();
133 0 : Ok(())
134 : }
135 41 : Ok(Some(p)) => {
136 41 : ctx.set_success();
137 41 : ctx.log();
138 121 : p.proxy_pass().await
139 : }
140 : }
141 63 : }
142 63 : .unwrap_or_else(move |e| {
143 63 : // Acknowledge that the task has finished with an error.
144 63 : error!("per-client task finished with an error: {e:#}");
145 63 : })
146 63 : .instrument(session_span),
147 63 : );
148 : }
149 :
150 25 : connections.close();
151 25 : drop(listener);
152 25 :
153 25 : // Drain connections
154 25 : connections.wait().await;
155 :
156 25 : Ok(())
157 25 : }
158 :
159 : pub enum ClientMode {
160 : Tcp,
161 : Websockets { hostname: Option<String> },
162 : }
163 :
164 : /// Abstracts the logic of handling TCP vs WS clients
165 : impl ClientMode {
166 52 : fn allow_cleartext(&self) -> bool {
167 52 : match self {
168 52 : ClientMode::Tcp => false,
169 0 : ClientMode::Websockets { .. } => true,
170 : }
171 52 : }
172 :
173 41 : fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
174 41 : match self {
175 41 : ClientMode::Tcp => config.allow_self_signed_compute,
176 0 : ClientMode::Websockets { .. } => false,
177 : }
178 41 : }
179 :
180 52 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
181 52 : match self {
182 52 : ClientMode::Tcp => s.sni_hostname(),
183 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
184 : }
185 52 : }
186 :
187 63 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
188 63 : match self {
189 63 : ClientMode::Tcp => tls,
190 : // TLS is None here if using websockets, because the connection is already encrypted.
191 0 : ClientMode::Websockets { .. } => None,
192 : }
193 63 : }
194 : }
195 :
196 88 : #[derive(Debug, Error)]
197 : // almost all errors should be reported to the user, but there's a few cases where we cannot
198 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
199 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
200 : // we cannot be sure the client even understands our error message
201 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
202 : pub enum ClientRequestError {
203 : #[error("{0}")]
204 : Cancellation(#[from] cancellation::CancelError),
205 : #[error("{0}")]
206 : Handshake(#[from] handshake::HandshakeError),
207 : #[error("{0}")]
208 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
209 : #[error("{0}")]
210 : PrepareClient(#[from] std::io::Error),
211 : #[error("{0}")]
212 : ReportedError(#[from] crate::stream::ReportedError),
213 : }
214 :
215 : impl ReportableError for ClientRequestError {
216 22 : fn get_error_kind(&self) -> crate::error::ErrorKind {
217 22 : match self {
218 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
219 11 : ClientRequestError::Handshake(e) => e.get_error_kind(),
220 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
221 11 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
222 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
223 : }
224 22 : }
225 : }
226 :
227 63 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
228 63 : config: &'static ProxyConfig,
229 63 : ctx: &mut RequestMonitoring,
230 63 : cancel_map: Arc<CancelMap>,
231 63 : stream: S,
232 63 : mode: ClientMode,
233 63 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
234 63 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
235 63 : info!(
236 63 : protocol = ctx.protocol,
237 63 : "handling interactive connection from client"
238 63 : );
239 :
240 63 : let proto = ctx.protocol;
241 63 : let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
242 63 : .with_label_values(&[proto])
243 63 : .guard();
244 63 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
245 63 : .with_label_values(&[proto])
246 63 : .guard();
247 63 :
248 63 : let tls = config.tls_config.as_ref();
249 63 :
250 63 : let pause = ctx.latency_timer.pause();
251 63 : let do_handshake = handshake(stream, mode.handshake_tls(tls));
252 52 : let (mut stream, params) =
253 108 : match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
254 52 : HandshakeData::Startup(stream, params) => (stream, params),
255 0 : HandshakeData::Cancel(cancel_key_data) => {
256 0 : return Ok(cancel_map
257 0 : .cancel_session(cancel_key_data)
258 0 : .await
259 0 : .map(|()| None)?)
260 : }
261 : };
262 52 : drop(pause);
263 52 :
264 52 : let hostname = mode.hostname(stream.get_ref());
265 52 :
266 52 : let common_names = tls.map(|tls| &tls.common_names);
267 52 :
268 52 : // Extract credentials which we're going to use for auth.
269 52 : let result = config
270 52 : .auth_backend
271 52 : .as_ref()
272 52 : .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
273 52 : .transpose();
274 :
275 52 : let user_info = match result {
276 52 : Ok(user_info) => user_info,
277 0 : Err(e) => stream.throw_error(e).await?,
278 : };
279 :
280 : // check rate limit
281 52 : if let Some(ep) = user_info.get_endpoint() {
282 49 : if !endpoint_rate_limiter.check(ep) {
283 0 : return stream
284 0 : .throw_error(auth::AuthError::too_many_connections())
285 0 : .await?;
286 49 : }
287 3 : }
288 :
289 52 : let user = user_info.get_user().to_owned();
290 52 : let (mut node_info, user_info) = match user_info
291 52 : .authenticate(
292 52 : ctx,
293 52 : &mut stream,
294 52 : mode.allow_cleartext(),
295 52 : &config.authentication_config,
296 52 : )
297 666 : .await
298 : {
299 41 : Ok(auth_result) => auth_result,
300 11 : Err(e) => {
301 11 : let db = params.get("database");
302 11 : let app = params.get("application_name");
303 11 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
304 :
305 11 : return stream.throw_error(e).instrument(params_span).await?;
306 : }
307 : };
308 :
309 41 : node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
310 41 :
311 41 : let aux = node_info.aux.clone();
312 41 : let mut node = connect_to_compute(
313 41 : ctx,
314 41 : &TcpMechanism { params: ¶ms },
315 41 : node_info,
316 41 : &user_info,
317 41 : )
318 41 : .or_else(|e| stream.throw_error(e))
319 120 : .await?;
320 :
321 41 : let session = cancel_map.get_session();
322 41 : prepare_client_connection(&node, &session, &mut stream).await?;
323 :
324 : // Before proxy passing, forward to compute whatever data is left in the
325 : // PqStream input buffer. Normally there is none, but our serverless npm
326 : // driver in pipeline mode sends startup, password and first query
327 : // immediately after opening the connection.
328 41 : let (stream, read_buf) = stream.into_inner();
329 41 : node.stream.write_all(&read_buf).await?;
330 :
331 41 : Ok(Some(ProxyPassthrough {
332 41 : client: stream,
333 41 : compute: node,
334 41 : aux,
335 41 : req: _request_gauge,
336 41 : conn: _client_gauge,
337 41 : }))
338 63 : }
339 :
340 : /// Finish client connection initialization: confirm auth success, send params, etc.
341 82 : #[tracing::instrument(skip_all)]
342 : async fn prepare_client_connection(
343 : node: &compute::PostgresConnection,
344 : session: &cancellation::Session,
345 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
346 : ) -> Result<(), std::io::Error> {
347 : // Register compute's query cancellation token and produce a new, unique one.
348 : // The new token (cancel_key_data) will be sent to the client.
349 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
350 :
351 : // Forward all postgres connection params to the client.
352 : // Right now the implementation is very hacky and inefficent (ideally,
353 : // we don't need an intermediate hashmap), but at least it should be correct.
354 : for (name, value) in &node.params {
355 : // TODO: Theoretically, this could result in a big pile of params...
356 : stream.write_message_noflush(&Be::ParameterStatus {
357 : name: name.as_bytes(),
358 : value: value.as_bytes(),
359 : })?;
360 : }
361 :
362 : stream
363 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
364 : .write_message(&Be::ReadyForQuery)
365 : .await?;
366 :
367 : Ok(())
368 : }
369 :
370 248 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
371 : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
372 :
373 : impl NeonOptions {
374 71 : pub fn parse_params(params: &StartupMessageParams) -> Self {
375 71 : params
376 71 : .options_raw()
377 71 : .map(Self::parse_from_iter)
378 71 : .unwrap_or_default()
379 71 : }
380 14 : pub fn parse_options_raw(options: &str) -> Self {
381 14 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
382 14 : }
383 :
384 75 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
385 75 : let mut options = options
386 75 : .filter_map(neon_option)
387 75 : .map(|(k, v)| (k.into(), v.into()))
388 75 : .collect_vec();
389 75 : options.sort();
390 75 : Self(options)
391 75 : }
392 :
393 328 : pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
394 328 : // prefix + format!(" {k}:{v}")
395 328 : // kinda jank because SmolStr is immutable
396 328 : std::iter::once(prefix)
397 328 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
398 328 : .collect::<SmolStr>()
399 328 : .into()
400 328 : }
401 :
402 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
403 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
404 0 : pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
405 0 : self.0
406 0 : .iter()
407 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
408 0 : .collect()
409 0 : }
410 : }
411 :
412 185 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
413 185 : static RE: OnceCell<Regex> = OnceCell::new();
414 185 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
415 :
416 185 : let cap = re.captures(bytes)?;
417 8 : let (_, [k, v]) = cap.extract();
418 8 : Some((k, v))
419 185 : }
|