Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub mod connect_compute;
5 : mod copy_bidirectional;
6 : pub mod handshake;
7 : pub mod passthrough;
8 : pub mod retry;
9 : pub mod wake_compute;
10 : pub use copy_bidirectional::copy_bidirectional_client_compute;
11 :
12 : use crate::{
13 : auth,
14 : cancellation::{self, CancellationHandlerMain, CancellationHandlerMainInternal},
15 : compute,
16 : config::{ProxyConfig, TlsConfig},
17 : context::RequestMonitoring,
18 : error::ReportableError,
19 : metrics::{Metrics, NumClientConnectionsGuard},
20 : protocol2::read_proxy_protocol,
21 : proxy::handshake::{handshake, HandshakeData},
22 : rate_limiter::EndpointRateLimiter,
23 : stream::{PqStream, Stream},
24 : EndpointCacheKey,
25 : };
26 : use futures::TryFutureExt;
27 : use itertools::Itertools;
28 : use once_cell::sync::OnceCell;
29 : use pq_proto::{BeMessage as Be, StartupMessageParams};
30 : use regex::Regex;
31 : use smol_str::{format_smolstr, SmolStr};
32 : use std::sync::Arc;
33 : use thiserror::Error;
34 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
35 : use tokio_util::sync::CancellationToken;
36 : use tracing::{error, info, Instrument};
37 :
38 : use self::{
39 : connect_compute::{connect_to_compute, TcpMechanism},
40 : passthrough::ProxyPassthrough,
41 : };
42 :
43 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
44 :
45 0 : pub async fn run_until_cancelled<F: std::future::Future>(
46 0 : f: F,
47 0 : cancellation_token: &CancellationToken,
48 0 : ) -> Option<F::Output> {
49 0 : match futures::future::select(
50 0 : std::pin::pin!(f),
51 0 : std::pin::pin!(cancellation_token.cancelled()),
52 0 : )
53 0 : .await
54 : {
55 0 : futures::future::Either::Left((f, _)) => Some(f),
56 0 : futures::future::Either::Right(((), _)) => None,
57 : }
58 0 : }
59 :
60 0 : pub async fn task_main(
61 0 : config: &'static ProxyConfig,
62 0 : listener: tokio::net::TcpListener,
63 0 : cancellation_token: CancellationToken,
64 0 : cancellation_handler: Arc<CancellationHandlerMain>,
65 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
66 0 : ) -> anyhow::Result<()> {
67 : scopeguard::defer! {
68 : info!("proxy has shut down");
69 : }
70 :
71 : // When set for the server socket, the keepalive setting
72 : // will be inherited by all accepted client sockets.
73 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
74 :
75 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
76 :
77 0 : while let Some(accept_result) =
78 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
79 : {
80 0 : let (socket, peer_addr) = accept_result?;
81 :
82 0 : let conn_gauge = Metrics::get()
83 0 : .proxy
84 0 : .client_connections
85 0 : .guard(crate::metrics::Protocol::Tcp);
86 0 :
87 0 : let session_id = uuid::Uuid::new_v4();
88 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
89 0 :
90 0 : tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
91 0 : let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
92 0 :
93 0 : connections.spawn(async move {
94 0 : let (socket, peer_addr) = match read_proxy_protocol(socket).await {
95 0 : Ok((socket, Some(addr))) => (socket, addr.ip()),
96 0 : Err(e) => {
97 0 : error!("per-client task finished with an error: {e:#}");
98 0 : return;
99 : }
100 0 : Ok((_socket, None)) if config.require_client_ip => {
101 0 : error!("missing required client IP");
102 0 : return;
103 : }
104 0 : Ok((socket, None)) => (socket, peer_addr.ip()),
105 : };
106 :
107 0 : match socket.inner.set_nodelay(true) {
108 0 : Ok(()) => {}
109 0 : Err(e) => {
110 0 : error!("per-client task finished with an error: failed to set socket option: {e:#}");
111 0 : return;
112 : }
113 : };
114 :
115 0 : let mut ctx = RequestMonitoring::new(
116 0 : session_id,
117 0 : peer_addr,
118 0 : crate::metrics::Protocol::Tcp,
119 0 : &config.region,
120 0 : );
121 0 : let span = ctx.span.clone();
122 0 :
123 0 : let startup = Box::pin(
124 0 : handle_client(
125 0 : config,
126 0 : &mut ctx,
127 0 : cancellation_handler,
128 0 : socket,
129 0 : ClientMode::Tcp,
130 0 : endpoint_rate_limiter2,
131 0 : conn_gauge,
132 0 : )
133 0 : .instrument(span.clone()),
134 0 : );
135 0 : let res = startup.await;
136 :
137 0 : match res {
138 0 : Err(e) => {
139 0 : // todo: log and push to ctx the error kind
140 0 : ctx.set_error_kind(e.get_error_kind());
141 : error!(parent: &span, "per-client task finished with an error: {e:#}");
142 : }
143 0 : Ok(None) => {
144 0 : ctx.set_success();
145 0 : }
146 0 : Ok(Some(p)) => {
147 0 : ctx.set_success();
148 0 : ctx.log_connect();
149 0 : match p.proxy_pass().instrument(span.clone()).await {
150 0 : Ok(()) => {}
151 0 : Err(e) => {
152 0 : error!(parent: &span, "per-client task finished with an error: {e:#}");
153 0 : }
154 : }
155 : }
156 : }
157 0 : });
158 : }
159 :
160 0 : connections.close();
161 0 : drop(listener);
162 0 :
163 0 : // Drain connections
164 0 : connections.wait().await;
165 :
166 0 : Ok(())
167 0 : }
168 :
169 : pub enum ClientMode {
170 : Tcp,
171 : Websockets { hostname: Option<String> },
172 : }
173 :
174 : /// Abstracts the logic of handling TCP vs WS clients
175 : impl ClientMode {
176 0 : pub fn allow_cleartext(&self) -> bool {
177 0 : match self {
178 0 : ClientMode::Tcp => false,
179 0 : ClientMode::Websockets { .. } => true,
180 : }
181 0 : }
182 :
183 0 : pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
184 0 : match self {
185 0 : ClientMode::Tcp => config.allow_self_signed_compute,
186 0 : ClientMode::Websockets { .. } => false,
187 : }
188 0 : }
189 :
190 0 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
191 0 : match self {
192 0 : ClientMode::Tcp => s.sni_hostname(),
193 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
194 : }
195 0 : }
196 :
197 0 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
198 0 : match self {
199 0 : ClientMode::Tcp => tls,
200 : // TLS is None here if using websockets, because the connection is already encrypted.
201 0 : ClientMode::Websockets { .. } => None,
202 : }
203 0 : }
204 : }
205 :
206 0 : #[derive(Debug, Error)]
207 : // almost all errors should be reported to the user, but there's a few cases where we cannot
208 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
209 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
210 : // we cannot be sure the client even understands our error message
211 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
212 : pub enum ClientRequestError {
213 : #[error("{0}")]
214 : Cancellation(#[from] cancellation::CancelError),
215 : #[error("{0}")]
216 : Handshake(#[from] handshake::HandshakeError),
217 : #[error("{0}")]
218 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
219 : #[error("{0}")]
220 : PrepareClient(#[from] std::io::Error),
221 : #[error("{0}")]
222 : ReportedError(#[from] crate::stream::ReportedError),
223 : }
224 :
225 : impl ReportableError for ClientRequestError {
226 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
227 0 : match self {
228 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
229 0 : ClientRequestError::Handshake(e) => e.get_error_kind(),
230 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
231 0 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
232 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
233 : }
234 0 : }
235 : }
236 :
237 0 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
238 0 : config: &'static ProxyConfig,
239 0 : ctx: &mut RequestMonitoring,
240 0 : cancellation_handler: Arc<CancellationHandlerMain>,
241 0 : stream: S,
242 0 : mode: ClientMode,
243 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
244 0 : conn_gauge: NumClientConnectionsGuard<'static>,
245 0 : ) -> Result<Option<ProxyPassthrough<CancellationHandlerMainInternal, S>>, ClientRequestError> {
246 0 : info!(
247 : protocol = %ctx.protocol,
248 0 : "handling interactive connection from client"
249 : );
250 :
251 0 : let metrics = &Metrics::get().proxy;
252 0 : let proto = ctx.protocol;
253 0 : let _request_gauge = metrics.connection_requests.guard(proto);
254 0 :
255 0 : let tls = config.tls_config.as_ref();
256 0 :
257 0 : let record_handshake_error = !ctx.has_private_peer_addr();
258 0 : let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Client);
259 0 : let do_handshake = handshake(stream, mode.handshake_tls(tls), record_handshake_error);
260 0 : let (mut stream, params) =
261 0 : match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
262 0 : HandshakeData::Startup(stream, params) => (stream, params),
263 0 : HandshakeData::Cancel(cancel_key_data) => {
264 0 : return Ok(cancellation_handler
265 0 : .cancel_session(cancel_key_data, ctx.session_id)
266 0 : .await
267 0 : .map(|()| None)?)
268 : }
269 : };
270 0 : drop(pause);
271 0 :
272 0 : ctx.set_db_options(params.clone());
273 0 :
274 0 : let hostname = mode.hostname(stream.get_ref());
275 0 :
276 0 : let common_names = tls.map(|tls| &tls.common_names);
277 0 :
278 0 : // Extract credentials which we're going to use for auth.
279 0 : let result = config
280 0 : .auth_backend
281 0 : .as_ref()
282 0 : .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
283 0 : .transpose();
284 :
285 0 : let user_info = match result {
286 0 : Ok(user_info) => user_info,
287 0 : Err(e) => stream.throw_error(e).await?,
288 : };
289 :
290 0 : let user = user_info.get_user().to_owned();
291 0 : let user_info = match user_info
292 0 : .authenticate(
293 0 : ctx,
294 0 : &mut stream,
295 0 : mode.allow_cleartext(),
296 0 : &config.authentication_config,
297 0 : endpoint_rate_limiter,
298 0 : )
299 0 : .await
300 : {
301 0 : Ok(auth_result) => auth_result,
302 0 : Err(e) => {
303 0 : let db = params.get("database");
304 0 : let app = params.get("application_name");
305 0 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
306 :
307 0 : return stream.throw_error(e).instrument(params_span).await?;
308 : }
309 : };
310 :
311 0 : let mut node = connect_to_compute(
312 0 : ctx,
313 0 : &TcpMechanism {
314 0 : params: ¶ms,
315 0 : locks: &config.connect_compute_locks,
316 0 : },
317 0 : &user_info,
318 0 : mode.allow_self_signed_compute(config),
319 0 : config.wake_compute_retry_config,
320 0 : config.connect_to_compute_retry_config,
321 0 : )
322 0 : .or_else(|e| stream.throw_error(e))
323 0 : .await?;
324 :
325 0 : let session = cancellation_handler.get_session();
326 0 : prepare_client_connection(&node, &session, &mut stream).await?;
327 :
328 : // Before proxy passing, forward to compute whatever data is left in the
329 : // PqStream input buffer. Normally there is none, but our serverless npm
330 : // driver in pipeline mode sends startup, password and first query
331 : // immediately after opening the connection.
332 0 : let (stream, read_buf) = stream.into_inner();
333 0 : node.stream.write_all(&read_buf).await?;
334 :
335 0 : Ok(Some(ProxyPassthrough {
336 0 : client: stream,
337 0 : aux: node.aux.clone(),
338 0 : compute: node,
339 0 : req: _request_gauge,
340 0 : conn: conn_gauge,
341 0 : cancel: session,
342 0 : }))
343 0 : }
344 :
345 : /// Finish client connection initialization: confirm auth success, send params, etc.
346 0 : #[tracing::instrument(skip_all)]
347 : async fn prepare_client_connection<P>(
348 : node: &compute::PostgresConnection,
349 : session: &cancellation::Session<P>,
350 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
351 : ) -> Result<(), std::io::Error> {
352 : // Register compute's query cancellation token and produce a new, unique one.
353 : // The new token (cancel_key_data) will be sent to the client.
354 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
355 :
356 : // Forward all postgres connection params to the client.
357 : // Right now the implementation is very hacky and inefficent (ideally,
358 : // we don't need an intermediate hashmap), but at least it should be correct.
359 : for (name, value) in &node.params {
360 : // TODO: Theoretically, this could result in a big pile of params...
361 : stream.write_message_noflush(&Be::ParameterStatus {
362 : name: name.as_bytes(),
363 : value: value.as_bytes(),
364 : })?;
365 : }
366 :
367 : stream
368 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
369 : .write_message(&Be::ReadyForQuery)
370 : .await?;
371 :
372 : Ok(())
373 : }
374 :
375 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
376 : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
377 :
378 : impl NeonOptions {
379 22 : pub fn parse_params(params: &StartupMessageParams) -> Self {
380 22 : params
381 22 : .options_raw()
382 22 : .map(Self::parse_from_iter)
383 22 : .unwrap_or_default()
384 22 : }
385 14 : pub fn parse_options_raw(options: &str) -> Self {
386 14 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
387 14 : }
388 :
389 4 : pub fn is_ephemeral(&self) -> bool {
390 4 : // Currently, neon endpoint options are all reserved for ephemeral endpoints.
391 4 : !self.0.is_empty()
392 4 : }
393 :
394 26 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
395 26 : let mut options = options
396 26 : .filter_map(neon_option)
397 26 : .map(|(k, v)| (k.into(), v.into()))
398 26 : .collect_vec();
399 26 : options.sort();
400 26 : Self(options)
401 26 : }
402 :
403 8 : pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
404 8 : // prefix + format!(" {k}:{v}")
405 8 : // kinda jank because SmolStr is immutable
406 8 : std::iter::once(prefix)
407 8 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
408 8 : .collect::<SmolStr>()
409 8 : .into()
410 8 : }
411 :
412 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
413 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
414 0 : pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
415 0 : self.0
416 0 : .iter()
417 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
418 0 : .collect()
419 0 : }
420 : }
421 :
422 64 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
423 64 : static RE: OnceCell<Regex> = OnceCell::new();
424 64 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
425 :
426 64 : let cap = re.captures(bytes)?;
427 8 : let (_, [k, v]) = cap.extract();
428 8 : Some((k, v))
429 64 : }
|