TLA Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : use crate::{
5 : auth::{self, backend::AuthSuccess},
6 : cancellation::{self, CancelMap},
7 : compute::{self, PostgresConnection},
8 : config::{ProxyConfig, TlsConfig},
9 : console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api},
10 : http::StatusCode,
11 : metrics::{Ids, USAGE_METRICS},
12 : protocol2::WithClientIp,
13 : stream::{PqStream, Stream},
14 : };
15 : use anyhow::{bail, Context};
16 : use async_trait::async_trait;
17 : use futures::TryFutureExt;
18 : use metrics::{exponential_buckets, register_int_counter_vec, IntCounterVec};
19 : use once_cell::sync::Lazy;
20 : use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
21 : use prometheus::{register_histogram_vec, HistogramVec};
22 : use std::{error::Error, io, ops::ControlFlow, sync::Arc, time::Instant};
23 : use tokio::{
24 : io::{AsyncRead, AsyncWrite, AsyncWriteExt},
25 : time,
26 : };
27 : use tokio_util::sync::CancellationToken;
28 : use tracing::{error, info, info_span, warn, Instrument};
29 : use utils::measured_stream::MeasuredStream;
30 :
31 : /// Number of times we should retry the `/proxy_wake_compute` http request.
32 : /// Retry duration is BASE_RETRY_WAIT_DURATION * RETRY_WAIT_EXPONENT_BASE ^ n, where n starts at 0
33 : pub const NUM_RETRIES_CONNECT: u32 = 16;
34 : const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
35 : const BASE_RETRY_WAIT_DURATION: time::Duration = time::Duration::from_millis(25);
36 : const RETRY_WAIT_EXPONENT_BASE: f64 = std::f64::consts::SQRT_2;
37 :
38 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
39 : const ERR_PROTO_VIOLATION: &str = "protocol violation";
40 :
41 CBC 15 : pub static NUM_DB_CONNECTIONS_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
42 15 : register_int_counter_vec!(
43 15 : "proxy_opened_db_connections_total",
44 15 : "Number of opened connections to a database.",
45 15 : &["protocol"],
46 15 : )
47 15 : .unwrap()
48 15 : });
49 :
50 15 : pub static NUM_DB_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
51 15 : register_int_counter_vec!(
52 15 : "proxy_closed_db_connections_total",
53 15 : "Number of closed connections to a database.",
54 15 : &["protocol"],
55 15 : )
56 15 : .unwrap()
57 15 : });
58 :
59 15 : pub static NUM_CLIENT_CONNECTION_OPENED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
60 15 : register_int_counter_vec!(
61 15 : "proxy_opened_client_connections_total",
62 15 : "Number of opened connections from a client.",
63 15 : &["protocol"],
64 15 : )
65 15 : .unwrap()
66 15 : });
67 :
68 15 : pub static NUM_CLIENT_CONNECTION_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
69 15 : register_int_counter_vec!(
70 15 : "proxy_closed_client_connections_total",
71 15 : "Number of closed connections from a client.",
72 15 : &["protocol"],
73 15 : )
74 15 : .unwrap()
75 15 : });
76 :
77 15 : pub static NUM_CONNECTIONS_ACCEPTED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
78 15 : register_int_counter_vec!(
79 15 : "proxy_accepted_connections_total",
80 15 : "Number of client connections accepted.",
81 15 : &["protocol"],
82 15 : )
83 15 : .unwrap()
84 15 : });
85 :
86 15 : pub static NUM_CONNECTIONS_CLOSED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
87 15 : register_int_counter_vec!(
88 15 : "proxy_closed_connections_total",
89 15 : "Number of client connections closed.",
90 15 : &["protocol"],
91 15 : )
92 15 : .unwrap()
93 15 : });
94 :
95 16 : static COMPUTE_CONNECTION_LATENCY: Lazy<HistogramVec> = Lazy::new(|| {
96 16 : register_histogram_vec!(
97 16 : "proxy_compute_connection_latency_seconds",
98 16 : "Time it took for proxy to establish a connection to the compute endpoint",
99 16 : &["protocol", "cache_miss", "pool_miss"],
100 16 : // largest bucket = 2^16 * 0.5ms = 32s
101 16 : exponential_buckets(0.0005, 2.0, 16).unwrap(),
102 16 : )
103 16 : .unwrap()
104 16 : });
105 :
106 : pub struct LatencyTimer {
107 : start: Instant,
108 : pool_miss: bool,
109 : cache_miss: bool,
110 : protocol: &'static str,
111 : }
112 :
113 : impl LatencyTimer {
114 70 : pub fn new(protocol: &'static str) -> Self {
115 70 : Self {
116 70 : start: Instant::now(),
117 70 : cache_miss: false,
118 70 : // by default we don't do pooling
119 70 : pool_miss: true,
120 70 : protocol,
121 70 : }
122 70 : }
123 :
124 8 : pub fn cache_miss(&mut self) {
125 8 : self.cache_miss = true;
126 8 : }
127 :
128 3 : pub fn pool_hit(&mut self) {
129 3 : self.pool_miss = false;
130 3 : }
131 : }
132 :
133 : impl Drop for LatencyTimer {
134 70 : fn drop(&mut self) {
135 70 : let duration = self.start.elapsed().as_secs_f64();
136 70 : COMPUTE_CONNECTION_LATENCY
137 70 : .with_label_values(&[
138 70 : self.protocol,
139 70 : bool_to_str(self.cache_miss),
140 70 : bool_to_str(self.pool_miss),
141 70 : ])
142 70 : .observe(duration)
143 70 : }
144 : }
145 :
146 2 : static NUM_CONNECTION_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
147 2 : register_int_counter_vec!(
148 2 : "proxy_connection_failures_total",
149 2 : "Number of connection failures (per kind).",
150 2 : &["kind"],
151 2 : )
152 2 : .unwrap()
153 2 : });
154 :
155 1 : static NUM_WAKEUP_FAILURES: Lazy<IntCounterVec> = Lazy::new(|| {
156 1 : register_int_counter_vec!(
157 1 : "proxy_connection_failures_breakdown",
158 1 : "Number of wake-up failures (per kind).",
159 1 : &["retry", "kind"],
160 1 : )
161 1 : .unwrap()
162 1 : });
163 :
164 16 : static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
165 16 : register_int_counter_vec!(
166 16 : "proxy_io_bytes_per_client",
167 16 : "Number of bytes sent/received between client and backend.",
168 16 : crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
169 16 : )
170 16 : .unwrap()
171 16 : });
172 :
173 16 : pub async fn task_main(
174 16 : config: &'static ProxyConfig,
175 16 : listener: tokio::net::TcpListener,
176 16 : cancellation_token: CancellationToken,
177 16 : ) -> anyhow::Result<()> {
178 16 : scopeguard::defer! {
179 16 : info!("proxy has shut down");
180 : }
181 :
182 : // When set for the server socket, the keepalive setting
183 : // will be inherited by all accepted client sockets.
184 16 : socket2::SockRef::from(&listener).set_keepalive(true)?;
185 :
186 16 : let mut connections = tokio::task::JoinSet::new();
187 16 : let cancel_map = Arc::new(CancelMap::default());
188 :
189 : loop {
190 53 : tokio::select! {
191 37 : accept_result = listener.accept() => {
192 : let (socket, _) = accept_result?;
193 :
194 : let session_id = uuid::Uuid::new_v4();
195 : let cancel_map = Arc::clone(&cancel_map);
196 : connections.spawn(
197 37 : async move {
198 37 : info!("accepted postgres client connection");
199 :
200 37 : let mut socket = WithClientIp::new(socket);
201 37 : if let Some(ip) = socket.wait_for_addr().await? {
202 UBC 0 : tracing::Span::current().record("peer_addr", &tracing::field::display(ip));
203 CBC 37 : } else if config.require_client_ip {
204 UBC 0 : bail!("missing required client IP");
205 CBC 37 : }
206 :
207 37 : socket
208 37 : .inner
209 37 : .set_nodelay(true)
210 37 : .context("failed to set socket option")?;
211 :
212 485 : handle_client(config, &cancel_map, session_id, socket, ClientMode::Tcp).await
213 37 : }
214 : .instrument(info_span!("handle_client", ?session_id, peer_addr = tracing::field::Empty))
215 37 : .unwrap_or_else(move |e| {
216 37 : // Acknowledge that the task has finished with an error.
217 37 : error!(?session_id, "per-client task finished with an error: {e:#}");
218 37 : }),
219 : );
220 : }
221 UBC 0 : Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
222 : if !e.is_panic() && !e.is_cancelled() {
223 0 : warn!("unexpected error from joined connection task: {e:?}");
224 : }
225 : }
226 : _ = cancellation_token.cancelled() => {
227 : drop(listener);
228 : break;
229 : }
230 : }
231 : }
232 : // Drain connections
233 CBC 21 : while let Some(res) = connections.join_next().await {
234 5 : if let Err(e) = res {
235 UBC 0 : if !e.is_panic() && !e.is_cancelled() {
236 0 : warn!("unexpected error from joined connection task: {e:?}");
237 0 : }
238 CBC 5 : }
239 : }
240 16 : Ok(())
241 16 : }
242 :
243 : pub enum ClientMode {
244 : Tcp,
245 : Websockets { hostname: Option<String> },
246 : }
247 :
248 : /// Abstracts the logic of handling TCP vs WS clients
249 : impl ClientMode {
250 136 : fn protocol_label(&self) -> &'static str {
251 136 : match self {
252 136 : ClientMode::Tcp => "tcp",
253 UBC 0 : ClientMode::Websockets { .. } => "ws",
254 : }
255 CBC 136 : }
256 :
257 33 : fn allow_cleartext(&self) -> bool {
258 33 : match self {
259 33 : ClientMode::Tcp => false,
260 UBC 0 : ClientMode::Websockets { .. } => true,
261 : }
262 CBC 33 : }
263 :
264 33 : fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
265 33 : match self {
266 33 : ClientMode::Tcp => config.allow_self_signed_compute,
267 UBC 0 : ClientMode::Websockets { .. } => false,
268 : }
269 CBC 33 : }
270 :
271 33 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
272 33 : match self {
273 33 : ClientMode::Tcp => s.sni_hostname(),
274 UBC 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
275 : }
276 CBC 33 : }
277 :
278 37 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
279 37 : match self {
280 37 : ClientMode::Tcp => tls,
281 : // TLS is None here if using websockets, because the connection is already encrypted.
282 UBC 0 : ClientMode::Websockets { .. } => None,
283 : }
284 CBC 37 : }
285 : }
286 :
287 37 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
288 37 : config: &'static ProxyConfig,
289 37 : cancel_map: &CancelMap,
290 37 : session_id: uuid::Uuid,
291 37 : stream: S,
292 37 : mode: ClientMode,
293 37 : ) -> anyhow::Result<()> {
294 37 : info!(
295 37 : protocol = mode.protocol_label(),
296 37 : "handling interactive connection from client"
297 37 : );
298 :
299 37 : let proto = mode.protocol_label();
300 37 : NUM_CLIENT_CONNECTION_OPENED_COUNTER
301 37 : .with_label_values(&[proto])
302 37 : .inc();
303 37 : NUM_CONNECTIONS_ACCEPTED_COUNTER
304 37 : .with_label_values(&[proto])
305 37 : .inc();
306 37 : scopeguard::defer! {
307 37 : NUM_CLIENT_CONNECTION_CLOSED_COUNTER.with_label_values(&[proto]).inc();
308 37 : NUM_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
309 37 : }
310 :
311 37 : let tls = config.tls_config.as_ref();
312 37 :
313 37 : let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
314 88 : let (mut stream, params) = match do_handshake.await? {
315 33 : Some(x) => x,
316 UBC 0 : None => return Ok(()), // it's a cancellation request
317 : };
318 :
319 : // Extract credentials which we're going to use for auth.
320 CBC 33 : let creds = {
321 33 : let hostname = mode.hostname(stream.get_ref());
322 33 : let common_names = tls.and_then(|tls| tls.common_names.clone());
323 33 : let result = config
324 33 : .auth_backend
325 33 : .as_ref()
326 33 : .map(|_| auth::ClientCredentials::parse(¶ms, hostname, common_names))
327 33 : .transpose();
328 33 :
329 33 : match result {
330 33 : Ok(creds) => creds,
331 UBC 0 : Err(e) => stream.throw_error(e).await?,
332 : }
333 : };
334 :
335 CBC 33 : let client = Client::new(
336 33 : stream,
337 33 : creds,
338 33 : ¶ms,
339 33 : session_id,
340 33 : mode.allow_self_signed_compute(config),
341 33 : );
342 33 : cancel_map
343 33 : .with_session(|session| client.connect_to_db(session, mode))
344 397 : .await
345 37 : }
346 :
347 : /// Establish a (most probably, secure) connection with the client.
348 : /// For better testing experience, `stream` can be any object satisfying the traits.
349 : /// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
350 : /// we also take an extra care of propagating only the select handshake errors to client.
351 139 : #[tracing::instrument(skip_all)]
352 : async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
353 : stream: S,
354 : mut tls: Option<&TlsConfig>,
355 : cancel_map: &CancelMap,
356 : ) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
357 : // Client may try upgrading to each protocol only once
358 : let (mut tried_ssl, mut tried_gss) = (false, false);
359 :
360 : let mut stream = PqStream::new(Stream::from_raw(stream));
361 : loop {
362 : let msg = stream.read_startup_packet().await?;
363 70 : info!("received {msg:?}");
364 :
365 : use FeStartupPacket::*;
366 : match msg {
367 : SslRequest => match stream.get_ref() {
368 : Stream::Raw { .. } if !tried_ssl => {
369 : tried_ssl = true;
370 :
371 : // We can't perform TLS handshake without a config
372 : let enc = tls.is_some();
373 : stream.write_message(&Be::EncryptionResponse(enc)).await?;
374 : if let Some(tls) = tls.take() {
375 : // Upgrade raw stream into a secure TLS-backed stream.
376 : // NOTE: We've consumed `tls`; this fact will be used later.
377 :
378 : let (raw, read_buf) = stream.into_inner();
379 : // TODO: Normally, client doesn't send any data before
380 : // server says TLS handshake is ok and read_buf is empy.
381 : // However, you could imagine pipelining of postgres
382 : // SSLRequest + TLS ClientHello in one hunk similar to
383 : // pipelining in our node js driver. We should probably
384 : // support that by chaining read_buf with the stream.
385 : if !read_buf.is_empty() {
386 : bail!("data is sent before server replied with EncryptionResponse");
387 : }
388 : stream = PqStream::new(raw.upgrade(tls.to_server_config()).await?);
389 : }
390 : }
391 : _ => bail!(ERR_PROTO_VIOLATION),
392 : },
393 : GssEncRequest => match stream.get_ref() {
394 : Stream::Raw { .. } if !tried_gss => {
395 : tried_gss = true;
396 :
397 : // Currently, we don't support GSSAPI
398 : stream.write_message(&Be::EncryptionResponse(false)).await?;
399 : }
400 : _ => bail!(ERR_PROTO_VIOLATION),
401 : },
402 : StartupMessage { params, .. } => {
403 : // Check that the config has been consumed during upgrade
404 : // OR we didn't provide it at all (for dev purposes).
405 : if tls.is_some() {
406 : stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
407 : }
408 :
409 33 : info!(session_type = "normal", "successful handshake");
410 : break Ok(Some((stream, params)));
411 : }
412 : CancelRequest(cancel_key_data) => {
413 : cancel_map.cancel_session(cancel_key_data).await?;
414 :
415 UBC 0 : info!(session_type = "cancellation", "successful handshake");
416 : break Ok(None);
417 : }
418 : }
419 : }
420 : }
421 :
422 : /// If we couldn't connect, a cached connection info might be to blame
423 : /// (e.g. the compute node's address might've changed at the wrong time).
424 : /// Invalidate the cache entry (if any) to prevent subsequent errors.
425 CBC 14 : #[tracing::instrument(name = "invalidate_cache", skip_all)]
426 : pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
427 : let is_cached = node_info.cached();
428 : if is_cached {
429 UBC 0 : warn!("invalidating stalled compute node info cache entry");
430 : }
431 : let label = match is_cached {
432 : true => "compute_cached",
433 : false => "compute_uncached",
434 : };
435 : NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
436 :
437 : node_info.invalidate().config
438 : }
439 :
440 : /// Try to connect to the compute node once.
441 CBC 116 : #[tracing::instrument(name = "connect_once", skip_all)]
442 : async fn connect_to_compute_once(
443 : node_info: &console::CachedNodeInfo,
444 : timeout: time::Duration,
445 : ) -> Result<PostgresConnection, compute::ConnectionError> {
446 : let allow_self_signed_compute = node_info.allow_self_signed_compute;
447 :
448 : node_info
449 : .config
450 : .connect(allow_self_signed_compute, timeout)
451 : .await
452 : }
453 :
454 : #[async_trait]
455 : pub trait ConnectMechanism {
456 : type Connection;
457 : type ConnectError;
458 : type Error: From<Self::ConnectError>;
459 : async fn connect_once(
460 : &self,
461 : node_info: &console::CachedNodeInfo,
462 : timeout: time::Duration,
463 : ) -> Result<Self::Connection, Self::ConnectError>;
464 :
465 : fn update_connect_config(&self, conf: &mut compute::ConnCfg);
466 : }
467 :
468 : pub struct TcpMechanism<'a> {
469 : /// KV-dictionary with PostgreSQL connection params.
470 : pub params: &'a StartupMessageParams,
471 : }
472 :
473 : #[async_trait]
474 : impl ConnectMechanism for TcpMechanism<'_> {
475 : type Connection = PostgresConnection;
476 : type ConnectError = compute::ConnectionError;
477 : type Error = compute::ConnectionError;
478 :
479 29 : async fn connect_once(
480 29 : &self,
481 29 : node_info: &console::CachedNodeInfo,
482 29 : timeout: time::Duration,
483 29 : ) -> Result<PostgresConnection, Self::Error> {
484 87 : connect_to_compute_once(node_info, timeout).await
485 58 : }
486 :
487 29 : fn update_connect_config(&self, config: &mut compute::ConnCfg) {
488 29 : config.set_startup_params(self.params);
489 29 : }
490 : }
491 :
492 142 : const fn bool_to_str(x: bool) -> &'static str {
493 142 : if x {
494 76 : "true"
495 : } else {
496 66 : "false"
497 : }
498 142 : }
499 :
500 2 : fn report_error(e: &WakeComputeError, retry: bool) {
501 2 : use crate::console::errors::ApiError;
502 2 : let retry = bool_to_str(retry);
503 2 : let kind = match e {
504 UBC 0 : WakeComputeError::BadComputeAddress(_) => "bad_compute_address",
505 0 : WakeComputeError::ApiError(ApiError::Transport(_)) => "api_transport_error",
506 : WakeComputeError::ApiError(ApiError::Console {
507 : status: StatusCode::LOCKED,
508 0 : ref text,
509 0 : }) if text.contains("written data quota exceeded")
510 0 : || text.contains("the limit for current plan reached") =>
511 0 : {
512 0 : "quota_exceeded"
513 : }
514 : WakeComputeError::ApiError(ApiError::Console {
515 : status: StatusCode::LOCKED,
516 : ..
517 0 : }) => "api_console_locked",
518 : WakeComputeError::ApiError(ApiError::Console {
519 : status: StatusCode::BAD_REQUEST,
520 : ..
521 0 : }) => "api_console_bad_request",
522 CBC 2 : WakeComputeError::ApiError(ApiError::Console { status, .. })
523 2 : if status.is_server_error() =>
524 1 : {
525 1 : "api_console_other_server_error"
526 : }
527 1 : WakeComputeError::ApiError(ApiError::Console { .. }) => "api_console_other_error",
528 : };
529 2 : NUM_WAKEUP_FAILURES.with_label_values(&[retry, kind]).inc();
530 2 : }
531 :
532 : /// Try to connect to the compute node, retrying if necessary.
533 : /// This function might update `node_info`, so we take it by `&mut`.
534 196 : #[tracing::instrument(skip_all)]
535 : pub async fn connect_to_compute<M: ConnectMechanism>(
536 : mechanism: &M,
537 : mut node_info: console::CachedNodeInfo,
538 : extra: &console::ConsoleReqExtra<'_>,
539 : creds: &auth::BackendType<'_, auth::ClientCredentials<'_>>,
540 : mut latency_timer: LatencyTimer,
541 : ) -> Result<M::Connection, M::Error>
542 : where
543 : M::ConnectError: ShouldRetry + std::fmt::Debug,
544 : M::Error: From<WakeComputeError>,
545 : {
546 : mechanism.update_connect_config(&mut node_info.config);
547 :
548 : // try once
549 : let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
550 : Ok(res) => return Ok(res),
551 : Err(e) => {
552 2 : error!(error = ?e, "could not connect to compute node");
553 : (invalidate_cache(node_info), e)
554 : }
555 : };
556 :
557 : latency_timer.cache_miss();
558 :
559 : let mut num_retries = 1;
560 :
561 : // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
562 2 : info!("compute node's state has likely changed; requesting a wake-up");
563 : let node_info = loop {
564 : let wake_res = match creds {
565 : auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
566 : auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
567 : // nothing to do?
568 : auth::BackendType::Link(_) => return Err(err.into()),
569 : // test backend
570 : auth::BackendType::Test(x) => x.wake_compute(),
571 : };
572 :
573 : match handle_try_wake(wake_res, num_retries) {
574 : Err(e) => {
575 UBC 0 : error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
576 : report_error(&e, false);
577 : return Err(e.into());
578 : }
579 : // failed to wake up but we can continue to retry
580 : Ok(ControlFlow::Continue(e)) => {
581 : report_error(&e, true);
582 0 : warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
583 : }
584 : // successfully woke up a compute node and can break the wakeup loop
585 : Ok(ControlFlow::Break(mut node_info)) => {
586 : node_info.config.reuse_password(&config);
587 : mechanism.update_connect_config(&mut node_info.config);
588 : break node_info;
589 : }
590 : }
591 :
592 : let wait_duration = retry_after(num_retries);
593 : num_retries += 1;
594 :
595 : time::sleep(wait_duration).await;
596 : };
597 :
598 : // now that we have a new node, try connect to it repeatedly.
599 : // this can error for a few reasons, for instance:
600 : // * DNS connection settings haven't quite propagated yet
601 CBC 2 : info!("wake_compute success. attempting to connect");
602 : loop {
603 : match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
604 : Ok(res) => return Ok(res),
605 : Err(e) => {
606 : let retriable = e.should_retry(num_retries);
607 : if !retriable {
608 2 : error!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
609 : return Err(e.into());
610 : }
611 UBC 0 : warn!(error = ?e, num_retries, retriable, "couldn't connect to compute node");
612 : }
613 : }
614 :
615 : let wait_duration = retry_after(num_retries);
616 : num_retries += 1;
617 :
618 : time::sleep(wait_duration).await;
619 : }
620 : }
621 :
622 : /// Attempts to wake up the compute node.
623 : /// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
624 : /// * Returns Ok(Break(node)) if the wakeup succeeded
625 : /// * Returns Err(e) if there was an error
626 CBC 33 : pub fn handle_try_wake(
627 33 : result: Result<console::CachedNodeInfo, WakeComputeError>,
628 33 : num_retries: u32,
629 33 : ) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
630 33 : match result {
631 2 : Err(err) => match &err {
632 2 : WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
633 1 : Ok(ControlFlow::Continue(err))
634 : }
635 1 : _ => Err(err),
636 : },
637 : // Ready to try again.
638 31 : Ok(new) => Ok(ControlFlow::Break(new)),
639 : }
640 33 : }
641 :
642 : pub trait ShouldRetry {
643 : fn could_retry(&self) -> bool;
644 24 : fn should_retry(&self, num_retries: u32) -> bool {
645 24 : match self {
646 24 : _ if num_retries >= NUM_RETRIES_CONNECT => false,
647 23 : err => err.could_retry(),
648 : }
649 24 : }
650 : }
651 :
652 : impl ShouldRetry for io::Error {
653 UBC 0 : fn could_retry(&self) -> bool {
654 : use std::io::ErrorKind;
655 0 : matches!(
656 0 : self.kind(),
657 : ErrorKind::ConnectionRefused | ErrorKind::AddrNotAvailable | ErrorKind::TimedOut
658 : )
659 0 : }
660 : }
661 :
662 : impl ShouldRetry for tokio_postgres::error::DbError {
663 CBC 2 : fn could_retry(&self) -> bool {
664 : use tokio_postgres::error::SqlState;
665 2 : matches!(
666 2 : self.code(),
667 : &SqlState::CONNECTION_FAILURE
668 : | &SqlState::CONNECTION_EXCEPTION
669 : | &SqlState::CONNECTION_DOES_NOT_EXIST
670 : | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
671 : )
672 2 : }
673 : }
674 :
675 : impl ShouldRetry for tokio_postgres::Error {
676 : fn could_retry(&self) -> bool {
677 2 : if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
678 UBC 0 : io::Error::could_retry(io_err)
679 CBC 2 : } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
680 2 : tokio_postgres::error::DbError::could_retry(db_err)
681 : } else {
682 UBC 0 : false
683 : }
684 CBC 2 : }
685 : }
686 :
687 : impl ShouldRetry for compute::ConnectionError {
688 UBC 0 : fn could_retry(&self) -> bool {
689 0 : match self {
690 0 : compute::ConnectionError::Postgres(err) => err.could_retry(),
691 0 : compute::ConnectionError::CouldNotConnect(err) => err.could_retry(),
692 0 : _ => false,
693 : }
694 0 : }
695 : }
696 :
697 CBC 34 : pub fn retry_after(num_retries: u32) -> time::Duration {
698 34 : BASE_RETRY_WAIT_DURATION.mul_f64(RETRY_WAIT_EXPONENT_BASE.powi((num_retries as i32) - 1))
699 34 : }
700 :
701 : /// Finish client connection initialization: confirm auth success, send params, etc.
702 116 : #[tracing::instrument(skip_all)]
703 : async fn prepare_client_connection(
704 : node: &compute::PostgresConnection,
705 : reported_auth_ok: bool,
706 : session: cancellation::Session<'_>,
707 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
708 : ) -> anyhow::Result<()> {
709 : // Register compute's query cancellation token and produce a new, unique one.
710 : // The new token (cancel_key_data) will be sent to the client.
711 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
712 :
713 : // Report authentication success if we haven't done this already.
714 : // Note that we do this only (for the most part) after we've connected
715 : // to a compute (see above) which performs its own authentication.
716 : if !reported_auth_ok {
717 : stream.write_message_noflush(&Be::AuthenticationOk)?;
718 : }
719 :
720 : // Forward all postgres connection params to the client.
721 : // Right now the implementation is very hacky and inefficent (ideally,
722 : // we don't need an intermediate hashmap), but at least it should be correct.
723 : for (name, value) in &node.params {
724 : // TODO: Theoretically, this could result in a big pile of params...
725 : stream.write_message_noflush(&Be::ParameterStatus {
726 : name: name.as_bytes(),
727 : value: value.as_bytes(),
728 : })?;
729 : }
730 :
731 : stream
732 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
733 : .write_message(&Be::ReadyForQuery)
734 : .await?;
735 :
736 : Ok(())
737 : }
738 :
739 : /// Forward bytes in both directions (client <-> compute).
740 120 : #[tracing::instrument(skip_all)]
741 : pub async fn proxy_pass(
742 : client: impl AsyncRead + AsyncWrite + Unpin,
743 : compute: impl AsyncRead + AsyncWrite + Unpin,
744 : aux: &MetricsAuxInfo,
745 : ) -> anyhow::Result<()> {
746 : let usage = USAGE_METRICS.register(Ids {
747 : endpoint_id: aux.endpoint_id.to_string(),
748 : branch_id: aux.branch_id.to_string(),
749 : });
750 :
751 : let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
752 : let mut client = MeasuredStream::new(
753 : client,
754 90 : |_| {},
755 60 : |cnt| {
756 60 : // Number of bytes we sent to the client (outbound).
757 60 : m_sent.inc_by(cnt as u64);
758 60 : usage.record_egress(cnt as u64);
759 60 : },
760 : );
761 :
762 : let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
763 : let mut compute = MeasuredStream::new(
764 : compute,
765 60 : |_| {},
766 61 : |cnt| {
767 61 : // Number of bytes the client sent to the compute node (inbound).
768 61 : m_recv.inc_by(cnt as u64);
769 61 : },
770 : );
771 :
772 : // Starting from here we only proxy the client's traffic.
773 30 : info!("performing the proxy pass...");
774 : let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
775 :
776 : Ok(())
777 : }
778 :
779 : /// Thin connection context.
780 : struct Client<'a, S> {
781 : /// The underlying libpq protocol stream.
782 : stream: PqStream<S>,
783 : /// Client credentials that we care about.
784 : creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
785 : /// KV-dictionary with PostgreSQL connection params.
786 : params: &'a StartupMessageParams,
787 : /// Unique connection ID.
788 : session_id: uuid::Uuid,
789 : /// Allow self-signed certificates (for testing).
790 : allow_self_signed_compute: bool,
791 : }
792 :
793 : impl<'a, S> Client<'a, S> {
794 : /// Construct a new connection context.
795 33 : fn new(
796 33 : stream: PqStream<S>,
797 33 : creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
798 33 : params: &'a StartupMessageParams,
799 33 : session_id: uuid::Uuid,
800 33 : allow_self_signed_compute: bool,
801 33 : ) -> Self {
802 33 : Self {
803 33 : stream,
804 33 : creds,
805 33 : params,
806 33 : session_id,
807 33 : allow_self_signed_compute,
808 33 : }
809 33 : }
810 : }
811 :
812 : impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
813 : /// Let the client authenticate and connect to the designated compute node.
814 : // Instrumentation logs endpoint name everywhere. Doesn't work for link
815 : // auth; strictly speaking we don't know endpoint name in its case.
816 66 : #[tracing::instrument(name = "", fields(ep = self.creds.get_endpoint().unwrap_or("".to_owned())), skip_all)]
817 : async fn connect_to_db(
818 : self,
819 : session: cancellation::Session<'_>,
820 : mode: ClientMode,
821 : ) -> anyhow::Result<()> {
822 : let Self {
823 : mut stream,
824 : mut creds,
825 : params,
826 : session_id,
827 : allow_self_signed_compute,
828 : } = self;
829 :
830 : let extra = console::ConsoleReqExtra {
831 : session_id, // aka this connection's id
832 : application_name: params.get("application_name"),
833 : };
834 :
835 : let latency_timer = LatencyTimer::new(mode.protocol_label());
836 :
837 : let auth_result = match creds
838 : .authenticate(&extra, &mut stream, mode.allow_cleartext())
839 : .await
840 : {
841 : Ok(auth_result) => auth_result,
842 : Err(e) => {
843 : let user = creds.get_user();
844 : let db = params.get("database");
845 : let app = params.get("application_name");
846 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
847 :
848 : return stream.throw_error(e).instrument(params_span).await;
849 : }
850 : };
851 :
852 : let AuthSuccess {
853 : reported_auth_ok,
854 : value: mut node_info,
855 : } = auth_result;
856 :
857 : node_info.allow_self_signed_compute = allow_self_signed_compute;
858 :
859 : let aux = node_info.aux.clone();
860 : let mut node = connect_to_compute(
861 : &TcpMechanism { params },
862 : node_info,
863 : &extra,
864 : &creds,
865 : latency_timer,
866 : )
867 UBC 0 : .or_else(|e| stream.throw_error(e))
868 : .await?;
869 :
870 : let proto = mode.protocol_label();
871 : NUM_DB_CONNECTIONS_OPENED_COUNTER
872 : .with_label_values(&[proto])
873 : .inc();
874 CBC 29 : scopeguard::defer! {
875 29 : NUM_DB_CONNECTIONS_CLOSED_COUNTER.with_label_values(&[proto]).inc();
876 29 : }
877 :
878 : prepare_client_connection(&node, reported_auth_ok, session, &mut stream).await?;
879 : // Before proxy passing, forward to compute whatever data is left in the
880 : // PqStream input buffer. Normally there is none, but our serverless npm
881 : // driver in pipeline mode sends startup, password and first query
882 : // immediately after opening the connection.
883 : let (stream, read_buf) = stream.into_inner();
884 : node.stream.write_all(&read_buf).await?;
885 : proxy_pass(stream, node.stream, &aux).await
886 : }
887 : }
|