Line data Source code
1 : use std::io;
2 : use std::net::{IpAddr, SocketAddr};
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use async_trait::async_trait;
7 : use ed25519_dalek::SigningKey;
8 : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
9 : use jose_jwk::jose_b64;
10 : use postgres_client::config::SslMode;
11 : use rand::rngs::OsRng;
12 : use rustls::pki_types::{DnsName, ServerName};
13 : use tokio::net::{TcpStream, lookup_host};
14 : use tokio_rustls::TlsConnector;
15 : use tracing::field::display;
16 : use tracing::{debug, info};
17 :
18 : use super::AsyncRW;
19 : use super::conn_pool::poll_client;
20 : use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
21 : use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
22 : use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
23 : use crate::auth::backend::local::StaticAuthRules;
24 : use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
25 : use crate::auth::{self, AuthError, check_peer_addr_is_in_list};
26 : use crate::compute;
27 : use crate::compute_ctl::{
28 : ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
29 : };
30 : use crate::config::{ComputeConfig, ProxyConfig};
31 : use crate::context::RequestContext;
32 : use crate::control_plane::CachedNodeInfo;
33 : use crate::control_plane::client::ApiLockError;
34 : use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
35 : use crate::control_plane::locks::ApiLocks;
36 : use crate::error::{ErrorKind, ReportableError, UserFacingError};
37 : use crate::intern::EndpointIdInt;
38 : use crate::protocol2::ConnectionInfoExtra;
39 : use crate::proxy::connect_compute::ConnectMechanism;
40 : use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
41 : use crate::rate_limiter::EndpointRateLimiter;
42 : use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
43 :
44 : pub(crate) struct PoolingBackend {
45 : pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
46 : pub(crate) local_pool: Arc<LocalConnPool<postgres_client::Client>>,
47 : pub(crate) pool:
48 : Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
49 :
50 : pub(crate) config: &'static ProxyConfig,
51 : pub(crate) auth_backend: &'static crate::auth::Backend<'static, ()>,
52 : pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
53 : }
54 :
55 : impl PoolingBackend {
56 0 : pub(crate) async fn authenticate_with_password(
57 0 : &self,
58 0 : ctx: &RequestContext,
59 0 : user_info: &ComputeUserInfo,
60 0 : password: &[u8],
61 0 : ) -> Result<ComputeCredentials, AuthError> {
62 0 : ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
63 0 :
64 0 : let user_info = user_info.clone();
65 0 : let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
66 0 : let allowed_ips = backend.get_allowed_ips(ctx).await?;
67 :
68 0 : if self.config.authentication_config.ip_allowlist_check_enabled
69 0 : && !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
70 : {
71 0 : return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
72 0 : }
73 :
74 0 : let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?;
75 0 : if self.config.authentication_config.is_vpc_acccess_proxy {
76 0 : if access_blocker_flags.vpc_access_blocked {
77 0 : return Err(AuthError::NetworkNotAllowed);
78 0 : }
79 0 :
80 0 : let extra = ctx.extra();
81 0 : let incoming_endpoint_id = match extra {
82 0 : None => String::new(),
83 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
84 0 : Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
85 : };
86 :
87 0 : if incoming_endpoint_id.is_empty() {
88 0 : return Err(AuthError::MissingVPCEndpointId);
89 0 : }
90 :
91 0 : let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?;
92 : // TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
93 0 : if !allowed_vpc_endpoint_ids.is_empty()
94 0 : && !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id)
95 : {
96 0 : return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id));
97 0 : }
98 0 : } else if access_blocker_flags.public_access_blocked {
99 0 : return Err(AuthError::NetworkNotAllowed);
100 0 : }
101 :
102 0 : if !self
103 0 : .endpoint_rate_limiter
104 0 : .check(user_info.endpoint.clone().into(), 1)
105 : {
106 0 : return Err(AuthError::too_many_connections());
107 0 : }
108 0 : let cached_secret = backend.get_role_secret(ctx).await?;
109 0 : let secret = match cached_secret.value.clone() {
110 0 : Some(secret) => self.config.authentication_config.check_rate_limit(
111 0 : ctx,
112 0 : secret,
113 0 : &user_info.endpoint,
114 0 : true,
115 0 : )?,
116 : None => {
117 : // If we don't have an authentication secret, for the http flow we can just return an error.
118 0 : info!("authentication info not found");
119 0 : return Err(AuthError::password_failed(&*user_info.user));
120 : }
121 : };
122 0 : let ep = EndpointIdInt::from(&user_info.endpoint);
123 0 : let auth_outcome = crate::auth::validate_password_and_exchange(
124 0 : &self.config.authentication_config.thread_pool,
125 0 : ep,
126 0 : password,
127 0 : secret,
128 0 : )
129 0 : .await?;
130 0 : let res = match auth_outcome {
131 0 : crate::sasl::Outcome::Success(key) => {
132 0 : info!("user successfully authenticated");
133 0 : Ok(key)
134 : }
135 0 : crate::sasl::Outcome::Failure(reason) => {
136 0 : info!("auth backend failed with an error: {reason}");
137 0 : Err(AuthError::password_failed(&*user_info.user))
138 : }
139 : };
140 0 : res.map(|key| ComputeCredentials {
141 0 : info: user_info,
142 0 : keys: key,
143 0 : })
144 0 : }
145 :
146 0 : pub(crate) async fn authenticate_with_jwt(
147 0 : &self,
148 0 : ctx: &RequestContext,
149 0 : user_info: &ComputeUserInfo,
150 0 : jwt: String,
151 0 : ) -> Result<ComputeCredentials, AuthError> {
152 0 : ctx.set_auth_method(crate::context::AuthMethod::Jwt);
153 0 :
154 0 : match &self.auth_backend {
155 0 : crate::auth::Backend::ControlPlane(console, ()) => {
156 0 : self.config
157 0 : .authentication_config
158 0 : .jwks_cache
159 0 : .check_jwt(
160 0 : ctx,
161 0 : user_info.endpoint.clone(),
162 0 : &user_info.user,
163 0 : &**console,
164 0 : &jwt,
165 0 : )
166 0 : .await?;
167 :
168 0 : Ok(ComputeCredentials {
169 0 : info: user_info.clone(),
170 0 : keys: crate::auth::backend::ComputeCredentialKeys::None,
171 0 : })
172 : }
173 : crate::auth::Backend::Local(_) => {
174 0 : let keys = self
175 0 : .config
176 0 : .authentication_config
177 0 : .jwks_cache
178 0 : .check_jwt(
179 0 : ctx,
180 0 : user_info.endpoint.clone(),
181 0 : &user_info.user,
182 0 : &StaticAuthRules,
183 0 : &jwt,
184 0 : )
185 0 : .await?;
186 :
187 0 : Ok(ComputeCredentials {
188 0 : info: user_info.clone(),
189 0 : keys,
190 0 : })
191 : }
192 : }
193 0 : }
194 :
195 : // Wake up the destination if needed. Code here is a bit involved because
196 : // we reuse the code from the usual proxy and we need to prepare few structures
197 : // that this code expects.
198 : #[tracing::instrument(skip_all, fields(
199 : pid = tracing::field::Empty,
200 : compute_id = tracing::field::Empty,
201 : conn_id = tracing::field::Empty,
202 : ))]
203 : pub(crate) async fn connect_to_compute(
204 : &self,
205 : ctx: &RequestContext,
206 : conn_info: ConnInfo,
207 : keys: ComputeCredentials,
208 : force_new: bool,
209 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
210 : let maybe_client = if force_new {
211 : debug!("pool: pool is disabled");
212 : None
213 : } else {
214 : debug!("pool: looking for an existing connection");
215 : self.pool.get(ctx, &conn_info)?
216 : };
217 :
218 : if let Some(client) = maybe_client {
219 : return Ok(client);
220 : }
221 : let conn_id = uuid::Uuid::new_v4();
222 : tracing::Span::current().record("conn_id", display(conn_id));
223 : info!(%conn_id, "pool: opening a new connection '{conn_info}'");
224 0 : let backend = self.auth_backend.as_ref().map(|()| keys);
225 : crate::proxy::connect_compute::connect_to_compute(
226 : ctx,
227 : &TokioMechanism {
228 : conn_id,
229 : conn_info,
230 : pool: self.pool.clone(),
231 : locks: &self.config.connect_compute_locks,
232 : },
233 : &backend,
234 : self.config.wake_compute_retry_config,
235 : &self.config.connect_to_compute,
236 : )
237 : .await
238 : }
239 :
240 : // Wake up the destination if needed
241 : #[tracing::instrument(skip_all, fields(
242 : compute_id = tracing::field::Empty,
243 : conn_id = tracing::field::Empty,
244 : ))]
245 : pub(crate) async fn connect_to_local_proxy(
246 : &self,
247 : ctx: &RequestContext,
248 : conn_info: ConnInfo,
249 : ) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
250 : debug!("pool: looking for an existing connection");
251 : if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) {
252 : return Ok(client);
253 : }
254 :
255 : let conn_id = uuid::Uuid::new_v4();
256 : tracing::Span::current().record("conn_id", display(conn_id));
257 : debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
258 0 : let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
259 0 : info: ComputeUserInfo {
260 0 : user: conn_info.user_info.user.clone(),
261 0 : endpoint: EndpointId::from(format!(
262 0 : "{}{LOCAL_PROXY_SUFFIX}",
263 0 : conn_info.user_info.endpoint.normalize()
264 0 : )),
265 0 : options: conn_info.user_info.options.clone(),
266 0 : },
267 0 : keys: crate::auth::backend::ComputeCredentialKeys::None,
268 0 : });
269 : crate::proxy::connect_compute::connect_to_compute(
270 : ctx,
271 : &HyperMechanism {
272 : conn_id,
273 : conn_info,
274 : pool: self.http_conn_pool.clone(),
275 : locks: &self.config.connect_compute_locks,
276 : },
277 : &backend,
278 : self.config.wake_compute_retry_config,
279 : &self.config.connect_to_compute,
280 : )
281 : .await
282 : }
283 :
284 : /// Connect to postgres over localhost.
285 : ///
286 : /// We expect postgres to be started here, so we won't do any retries.
287 : ///
288 : /// # Panics
289 : ///
290 : /// Panics if called with a non-local_proxy backend.
291 : #[tracing::instrument(skip_all, fields(
292 : pid = tracing::field::Empty,
293 : conn_id = tracing::field::Empty,
294 : ))]
295 : pub(crate) async fn connect_to_local_postgres(
296 : &self,
297 : ctx: &RequestContext,
298 : conn_info: ConnInfo,
299 : ) -> Result<Client<postgres_client::Client>, HttpConnError> {
300 : if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
301 : return Ok(client);
302 : }
303 :
304 : let local_backend = match &self.auth_backend {
305 : auth::Backend::ControlPlane(_, ()) => {
306 : unreachable!("only local_proxy can connect to local postgres")
307 : }
308 : auth::Backend::Local(local) => local,
309 : };
310 :
311 : if !self.local_pool.initialized(&conn_info) {
312 : // only install and grant usage one at a time.
313 : let _permit = local_backend
314 : .initialize
315 : .acquire()
316 : .await
317 : .expect("semaphore should never be closed");
318 :
319 : // check again for race
320 : if !self.local_pool.initialized(&conn_info) {
321 : local_backend
322 : .compute_ctl
323 : .install_extension(&ExtensionInstallRequest {
324 : extension: EXT_NAME,
325 : database: conn_info.dbname.clone(),
326 : version: EXT_VERSION,
327 : })
328 : .await?;
329 :
330 : local_backend
331 : .compute_ctl
332 : .grant_role(&SetRoleGrantsRequest {
333 : schema: EXT_SCHEMA,
334 : privileges: vec![Privilege::Usage],
335 : database: conn_info.dbname.clone(),
336 : role: conn_info.user_info.user.clone(),
337 : })
338 : .await?;
339 :
340 : self.local_pool.set_initialized(&conn_info);
341 : }
342 : }
343 :
344 : let conn_id = uuid::Uuid::new_v4();
345 : tracing::Span::current().record("conn_id", display(conn_id));
346 : info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
347 :
348 : let mut node_info = local_backend.node_info.clone();
349 :
350 : let (key, jwk) = create_random_jwk();
351 :
352 : let config = node_info
353 : .config
354 : .user(&conn_info.user_info.user)
355 : .dbname(&conn_info.dbname)
356 : .set_param(
357 : "options",
358 : &format!(
359 : "-c pg_session_jwt.jwk={}",
360 : serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
361 : ),
362 : );
363 :
364 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
365 : let (client, connection) = config.connect(postgres_client::NoTls).await?;
366 : drop(pause);
367 :
368 : let pid = client.get_process_id();
369 : tracing::Span::current().record("pid", pid);
370 :
371 : let mut handle = local_conn_pool::poll_client(
372 : self.local_pool.clone(),
373 : ctx,
374 : conn_info,
375 : client,
376 : connection,
377 : key,
378 : conn_id,
379 : node_info.aux.clone(),
380 : );
381 :
382 : {
383 : let (client, mut discard) = handle.inner();
384 : debug!("setting up backend session state");
385 :
386 : // initiates the auth session
387 : if let Err(e) = client.batch_execute("select auth.init();").await {
388 : discard.discard();
389 : return Err(e.into());
390 : }
391 :
392 : info!("backend session state initialized");
393 : }
394 :
395 : Ok(handle)
396 : }
397 : }
398 :
399 0 : fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
400 0 : let key = SigningKey::generate(&mut OsRng);
401 0 :
402 0 : let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
403 0 : crv: jose_jwk::OkpCurves::Ed25519,
404 0 : x: jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
405 0 : d: None,
406 0 : });
407 0 :
408 0 : (key, jwk)
409 0 : }
410 :
411 : #[derive(Debug, thiserror::Error)]
412 : pub(crate) enum HttpConnError {
413 : #[error("pooled connection closed at inconsistent state")]
414 : ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
415 : #[error("could not connect to postgres in compute")]
416 : PostgresConnectionError(#[from] postgres_client::Error),
417 : #[error("could not connect to local-proxy in compute")]
418 : LocalProxyConnectionError(#[from] LocalProxyConnError),
419 : #[error("could not parse JWT payload")]
420 : JwtPayloadError(serde_json::Error),
421 :
422 : #[error("could not install extension: {0}")]
423 : ComputeCtl(#[from] ComputeCtlError),
424 : #[error("could not get auth info")]
425 : GetAuthInfo(#[from] GetAuthInfoError),
426 : #[error("user not authenticated")]
427 : AuthError(#[from] AuthError),
428 : #[error("wake_compute returned error")]
429 : WakeCompute(#[from] WakeComputeError),
430 : #[error("error acquiring resource permit: {0}")]
431 : TooManyConnectionAttempts(#[from] ApiLockError),
432 : }
433 :
434 : #[derive(Debug, thiserror::Error)]
435 : pub(crate) enum LocalProxyConnError {
436 : #[error("error with connection to local-proxy")]
437 : Io(#[source] std::io::Error),
438 : #[error("could not establish h2 connection")]
439 : H2(#[from] hyper::Error),
440 : }
441 :
442 : impl ReportableError for HttpConnError {
443 0 : fn get_error_kind(&self) -> ErrorKind {
444 0 : match self {
445 0 : HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
446 0 : HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
447 0 : HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
448 0 : HttpConnError::ComputeCtl(_) => ErrorKind::Service,
449 0 : HttpConnError::JwtPayloadError(_) => ErrorKind::User,
450 0 : HttpConnError::GetAuthInfo(a) => a.get_error_kind(),
451 0 : HttpConnError::AuthError(a) => a.get_error_kind(),
452 0 : HttpConnError::WakeCompute(w) => w.get_error_kind(),
453 0 : HttpConnError::TooManyConnectionAttempts(w) => w.get_error_kind(),
454 : }
455 0 : }
456 : }
457 :
458 : impl UserFacingError for HttpConnError {
459 0 : fn to_string_client(&self) -> String {
460 0 : match self {
461 0 : HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
462 0 : HttpConnError::PostgresConnectionError(p) => p.to_string(),
463 0 : HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
464 0 : HttpConnError::ComputeCtl(_) => "could not set up the JWT authorization database extension".to_string(),
465 0 : HttpConnError::JwtPayloadError(p) => p.to_string(),
466 0 : HttpConnError::GetAuthInfo(c) => c.to_string_client(),
467 0 : HttpConnError::AuthError(c) => c.to_string_client(),
468 0 : HttpConnError::WakeCompute(c) => c.to_string_client(),
469 : HttpConnError::TooManyConnectionAttempts(_) => {
470 0 : "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
471 : }
472 : }
473 0 : }
474 : }
475 :
476 : impl CouldRetry for HttpConnError {
477 0 : fn could_retry(&self) -> bool {
478 0 : match self {
479 0 : HttpConnError::PostgresConnectionError(e) => e.could_retry(),
480 0 : HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
481 0 : HttpConnError::ComputeCtl(_) => false,
482 0 : HttpConnError::ConnectionClosedAbruptly(_) => false,
483 0 : HttpConnError::JwtPayloadError(_) => false,
484 0 : HttpConnError::GetAuthInfo(_) => false,
485 0 : HttpConnError::AuthError(_) => false,
486 0 : HttpConnError::WakeCompute(_) => false,
487 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
488 : }
489 0 : }
490 : }
491 : impl ShouldRetryWakeCompute for HttpConnError {
492 0 : fn should_retry_wake_compute(&self) -> bool {
493 0 : match self {
494 0 : HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
495 : // we never checked cache validity
496 0 : HttpConnError::TooManyConnectionAttempts(_) => false,
497 0 : _ => true,
498 : }
499 0 : }
500 : }
501 :
502 : impl ReportableError for LocalProxyConnError {
503 0 : fn get_error_kind(&self) -> ErrorKind {
504 0 : match self {
505 0 : LocalProxyConnError::Io(_) => ErrorKind::Compute,
506 0 : LocalProxyConnError::H2(_) => ErrorKind::Compute,
507 : }
508 0 : }
509 : }
510 :
511 : impl UserFacingError for LocalProxyConnError {
512 0 : fn to_string_client(&self) -> String {
513 0 : "Could not establish HTTP connection to the database".to_string()
514 0 : }
515 : }
516 :
517 : impl CouldRetry for LocalProxyConnError {
518 0 : fn could_retry(&self) -> bool {
519 0 : match self {
520 0 : LocalProxyConnError::Io(_) => false,
521 0 : LocalProxyConnError::H2(_) => false,
522 : }
523 0 : }
524 : }
525 : impl ShouldRetryWakeCompute for LocalProxyConnError {
526 0 : fn should_retry_wake_compute(&self) -> bool {
527 0 : match self {
528 0 : LocalProxyConnError::Io(_) => false,
529 0 : LocalProxyConnError::H2(_) => false,
530 : }
531 0 : }
532 : }
533 :
534 : struct TokioMechanism {
535 : pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
536 : conn_info: ConnInfo,
537 : conn_id: uuid::Uuid,
538 :
539 : /// connect_to_compute concurrency lock
540 : locks: &'static ApiLocks<Host>,
541 : }
542 :
543 : #[async_trait]
544 : impl ConnectMechanism for TokioMechanism {
545 : type Connection = Client<postgres_client::Client>;
546 : type ConnectError = HttpConnError;
547 : type Error = HttpConnError;
548 :
549 0 : async fn connect_once(
550 0 : &self,
551 0 : ctx: &RequestContext,
552 0 : node_info: &CachedNodeInfo,
553 0 : compute_config: &ComputeConfig,
554 0 : ) -> Result<Self::Connection, Self::ConnectError> {
555 0 : let host = node_info.config.get_host();
556 0 : let permit = self.locks.get_permit(&host).await?;
557 :
558 0 : let mut config = (*node_info.config).clone();
559 0 : let config = config
560 0 : .user(&self.conn_info.user_info.user)
561 0 : .dbname(&self.conn_info.dbname)
562 0 : .connect_timeout(compute_config.timeout);
563 0 :
564 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
565 0 : let res = config.connect(postgres_client::NoTls).await;
566 0 : drop(pause);
567 0 : let (client, connection) = permit.release_result(res)?;
568 :
569 0 : tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
570 0 : tracing::Span::current().record(
571 0 : "compute_id",
572 0 : tracing::field::display(&node_info.aux.compute_id),
573 0 : );
574 :
575 0 : if let Some(query_id) = ctx.get_testodrome_id() {
576 0 : info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
577 0 : }
578 :
579 0 : Ok(poll_client(
580 0 : self.pool.clone(),
581 0 : ctx,
582 0 : self.conn_info.clone(),
583 0 : client,
584 0 : connection,
585 0 : self.conn_id,
586 0 : node_info.aux.clone(),
587 0 : ))
588 0 : }
589 :
590 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
591 : }
592 :
593 : struct HyperMechanism {
594 : pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
595 : conn_info: ConnInfo,
596 : conn_id: uuid::Uuid,
597 :
598 : /// connect_to_compute concurrency lock
599 : locks: &'static ApiLocks<Host>,
600 : }
601 :
602 : #[async_trait]
603 : impl ConnectMechanism for HyperMechanism {
604 : type Connection = http_conn_pool::Client<Send>;
605 : type ConnectError = HttpConnError;
606 : type Error = HttpConnError;
607 :
608 0 : async fn connect_once(
609 0 : &self,
610 0 : ctx: &RequestContext,
611 0 : node_info: &CachedNodeInfo,
612 0 : config: &ComputeConfig,
613 0 : ) -> Result<Self::Connection, Self::ConnectError> {
614 0 : let host_addr = node_info.config.get_host_addr();
615 0 : let host = node_info.config.get_host();
616 0 : let permit = self.locks.get_permit(&host).await?;
617 :
618 0 : let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
619 :
620 0 : let tls = if node_info.config.get_ssl_mode() == SslMode::Disable {
621 0 : None
622 : } else {
623 0 : Some(&config.tls)
624 : };
625 :
626 0 : let port = node_info.config.get_port();
627 0 : let res = connect_http2(host_addr, &host, port, config.timeout, tls).await;
628 0 : drop(pause);
629 0 : let (client, connection) = permit.release_result(res)?;
630 :
631 0 : tracing::Span::current().record(
632 0 : "compute_id",
633 0 : tracing::field::display(&node_info.aux.compute_id),
634 0 : );
635 :
636 0 : if let Some(query_id) = ctx.get_testodrome_id() {
637 0 : info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
638 0 : }
639 :
640 0 : Ok(poll_http2_client(
641 0 : self.pool.clone(),
642 0 : ctx,
643 0 : &self.conn_info,
644 0 : client,
645 0 : connection,
646 0 : self.conn_id,
647 0 : node_info.aux.clone(),
648 0 : ))
649 0 : }
650 :
651 0 : fn update_connect_config(&self, _config: &mut compute::ConnCfg) {}
652 : }
653 :
654 0 : async fn connect_http2(
655 0 : host_addr: Option<IpAddr>,
656 0 : host: &str,
657 0 : port: u16,
658 0 : timeout: Duration,
659 0 : tls: Option<&Arc<rustls::ClientConfig>>,
660 0 : ) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
661 0 : let addrs = match host_addr {
662 0 : Some(addr) => vec![SocketAddr::new(addr, port)],
663 0 : None => lookup_host((host, port))
664 0 : .await
665 0 : .map_err(LocalProxyConnError::Io)?
666 0 : .collect(),
667 : };
668 0 : let mut last_err = None;
669 0 :
670 0 : let mut addrs = addrs.into_iter();
671 0 : let stream = loop {
672 0 : let Some(addr) = addrs.next() else {
673 0 : return Err(last_err.unwrap_or_else(|| {
674 0 : LocalProxyConnError::Io(io::Error::new(
675 0 : io::ErrorKind::InvalidInput,
676 0 : "could not resolve any addresses",
677 0 : ))
678 0 : }));
679 : };
680 :
681 0 : match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
682 0 : Ok(Ok(stream)) => {
683 0 : stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
684 0 : break stream;
685 : }
686 0 : Ok(Err(e)) => {
687 0 : last_err = Some(LocalProxyConnError::Io(e));
688 0 : }
689 0 : Err(e) => {
690 0 : last_err = Some(LocalProxyConnError::Io(io::Error::new(
691 0 : io::ErrorKind::TimedOut,
692 0 : e,
693 0 : )));
694 0 : }
695 : }
696 : };
697 :
698 0 : let stream = if let Some(tls) = tls {
699 0 : let host = DnsName::try_from(host)
700 0 : .map_err(io::Error::other)
701 0 : .map_err(LocalProxyConnError::Io)?
702 0 : .to_owned();
703 0 : let stream = TlsConnector::from(tls.clone())
704 0 : .connect(ServerName::DnsName(host), stream)
705 0 : .await
706 0 : .map_err(LocalProxyConnError::Io)?;
707 0 : Box::pin(stream) as AsyncRW
708 : } else {
709 0 : Box::pin(stream) as AsyncRW
710 : };
711 :
712 0 : let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
713 0 : .timer(TokioTimer::new())
714 0 : .keep_alive_interval(Duration::from_secs(20))
715 0 : .keep_alive_while_idle(true)
716 0 : .keep_alive_timeout(Duration::from_secs(5))
717 0 : .handshake(TokioIo::new(stream))
718 0 : .await?;
719 :
720 0 : Ok((client, connection))
721 0 : }
|