LCOV - code coverage report
Current view: top level - proxy/src/context - mod.rs (source / functions) Coverage Total Hit
Test: 91bf6c8f32e5e69adde6241313e732fdd6d6e277.info Lines: 32.7 % 330 108
Test Date: 2025-03-04 12:19:20 Functions: 25.6 % 43 11

            Line data    Source code
       1              : //! Connection request monitoring contexts
       2              : 
       3              : use std::net::IpAddr;
       4              : 
       5              : use chrono::Utc;
       6              : use once_cell::sync::OnceCell;
       7              : use pq_proto::StartupMessageParams;
       8              : use smol_str::SmolStr;
       9              : use tokio::sync::mpsc;
      10              : use tracing::field::display;
      11              : use tracing::{Span, debug, error, info_span};
      12              : use try_lock::TryLock;
      13              : use uuid::Uuid;
      14              : 
      15              : use self::parquet::RequestData;
      16              : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
      17              : use crate::error::ErrorKind;
      18              : use crate::intern::{BranchIdInt, ProjectIdInt};
      19              : use crate::metrics::{
      20              :     ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
      21              :     Waiting,
      22              : };
      23              : use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
      24              : use crate::types::{DbName, EndpointId, RoleName};
      25              : 
      26              : pub mod parquet;
      27              : 
      28              : pub(crate) static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
      29              : pub(crate) static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> =
      30              :     OnceCell::new();
      31              : 
      32              : /// Context data for a single request to connect to a database.
      33              : ///
      34              : /// This data should **not** be used for connection logic, only for observability and limiting purposes.
      35              : /// All connection logic should instead use strongly typed state machines, not a bunch of Options.
      36              : pub struct RequestContext(
      37              :     /// To allow easier use of the ctx object, we have interior mutability.
      38              :     /// I would typically use a RefCell but that would break the `Send` requirements
      39              :     /// so we need something with thread-safety. `TryLock` is a cheap alternative
      40              :     /// that offers similar semantics to a `RefCell` but with synchronisation.
      41              :     TryLock<RequestContextInner>,
      42              : );
      43              : 
      44              : struct RequestContextInner {
      45              :     pub(crate) conn_info: ConnectionInfo,
      46              :     pub(crate) session_id: Uuid,
      47              :     pub(crate) protocol: Protocol,
      48              :     first_packet: chrono::DateTime<Utc>,
      49              :     region: &'static str,
      50              :     pub(crate) span: Span,
      51              : 
      52              :     // filled in as they are discovered
      53              :     project: Option<ProjectIdInt>,
      54              :     branch: Option<BranchIdInt>,
      55              :     endpoint_id: Option<EndpointId>,
      56              :     dbname: Option<DbName>,
      57              :     user: Option<RoleName>,
      58              :     application: Option<SmolStr>,
      59              :     user_agent: Option<SmolStr>,
      60              :     error_kind: Option<ErrorKind>,
      61              :     pub(crate) auth_method: Option<AuthMethod>,
      62              :     jwt_issuer: Option<String>,
      63              :     success: bool,
      64              :     pub(crate) cold_start_info: ColdStartInfo,
      65              :     pg_options: Option<StartupMessageParams>,
      66              : 
      67              :     // extra
      68              :     // This sender is here to keep the request monitoring channel open while requests are taking place.
      69              :     sender: Option<mpsc::UnboundedSender<RequestData>>,
      70              :     // This sender is only used to log the length of session in case of success.
      71              :     disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
      72              :     pub(crate) latency_timer: LatencyTimer,
      73              :     // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
      74              :     rejected: Option<bool>,
      75              :     disconnect_timestamp: Option<chrono::DateTime<Utc>>,
      76              : }
      77              : 
      78              : #[derive(Clone, Debug)]
      79              : pub(crate) enum AuthMethod {
      80              :     // aka passwordless, fka link
      81              :     ConsoleRedirect,
      82              :     ScramSha256,
      83              :     ScramSha256Plus,
      84              :     Cleartext,
      85              :     Jwt,
      86              : }
      87              : 
      88              : impl Clone for RequestContext {
      89            0 :     fn clone(&self) -> Self {
      90            0 :         let inner = self.0.try_lock().expect("should not deadlock");
      91            0 :         let new = RequestContextInner {
      92            0 :             conn_info: inner.conn_info.clone(),
      93            0 :             session_id: inner.session_id,
      94            0 :             protocol: inner.protocol,
      95            0 :             first_packet: inner.first_packet,
      96            0 :             region: inner.region,
      97            0 :             span: info_span!("background_task"),
      98              : 
      99            0 :             project: inner.project,
     100            0 :             branch: inner.branch,
     101            0 :             endpoint_id: inner.endpoint_id.clone(),
     102            0 :             dbname: inner.dbname.clone(),
     103            0 :             user: inner.user.clone(),
     104            0 :             application: inner.application.clone(),
     105            0 :             user_agent: inner.user_agent.clone(),
     106            0 :             error_kind: inner.error_kind,
     107            0 :             auth_method: inner.auth_method.clone(),
     108            0 :             jwt_issuer: inner.jwt_issuer.clone(),
     109            0 :             success: inner.success,
     110            0 :             rejected: inner.rejected,
     111            0 :             cold_start_info: inner.cold_start_info,
     112            0 :             pg_options: inner.pg_options.clone(),
     113            0 : 
     114            0 :             sender: None,
     115            0 :             disconnect_sender: None,
     116            0 :             latency_timer: LatencyTimer::noop(inner.protocol),
     117            0 :             disconnect_timestamp: inner.disconnect_timestamp,
     118            0 :         };
     119            0 : 
     120            0 :         Self(TryLock::new(new))
     121            0 :     }
     122              : }
     123              : 
     124              : impl RequestContext {
     125           70 :     pub fn new(
     126           70 :         session_id: Uuid,
     127           70 :         conn_info: ConnectionInfo,
     128           70 :         protocol: Protocol,
     129           70 :         region: &'static str,
     130           70 :     ) -> Self {
     131              :         // TODO: be careful with long lived spans
     132           70 :         let span = info_span!(
     133           70 :             "connect_request",
     134           70 :             %protocol,
     135           70 :             ?session_id,
     136           70 :             %conn_info,
     137           70 :             ep = tracing::field::Empty,
     138           70 :             role = tracing::field::Empty,
     139           70 :         );
     140              : 
     141           70 :         let inner = RequestContextInner {
     142           70 :             conn_info,
     143           70 :             session_id,
     144           70 :             protocol,
     145           70 :             first_packet: Utc::now(),
     146           70 :             region,
     147           70 :             span,
     148           70 : 
     149           70 :             project: None,
     150           70 :             branch: None,
     151           70 :             endpoint_id: None,
     152           70 :             dbname: None,
     153           70 :             user: None,
     154           70 :             application: None,
     155           70 :             user_agent: None,
     156           70 :             error_kind: None,
     157           70 :             auth_method: None,
     158           70 :             jwt_issuer: None,
     159           70 :             success: false,
     160           70 :             rejected: None,
     161           70 :             cold_start_info: ColdStartInfo::Unknown,
     162           70 :             pg_options: None,
     163           70 : 
     164           70 :             sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
     165           70 :             disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
     166           70 :             latency_timer: LatencyTimer::new(protocol),
     167           70 :             disconnect_timestamp: None,
     168           70 :         };
     169           70 : 
     170           70 :         Self(TryLock::new(inner))
     171           70 :     }
     172              : 
     173              :     #[cfg(test)]
     174           70 :     pub(crate) fn test() -> Self {
     175              :         use std::net::SocketAddr;
     176           70 :         let ip = IpAddr::from([127, 0, 0, 1]);
     177           70 :         let addr = SocketAddr::new(ip, 5432);
     178           70 :         let conn_info = ConnectionInfo { addr, extra: None };
     179           70 :         RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
     180           70 :     }
     181              : 
     182            0 :     pub(crate) fn console_application_name(&self) -> String {
     183            0 :         let this = self.0.try_lock().expect("should not deadlock");
     184            0 :         format!(
     185            0 :             "{}/{}",
     186            0 :             this.application.as_deref().unwrap_or_default(),
     187            0 :             this.protocol
     188            0 :         )
     189            0 :     }
     190              : 
     191            0 :     pub(crate) fn set_rejected(&self, rejected: bool) {
     192            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     193            0 :         this.rejected = Some(rejected);
     194            0 :     }
     195              : 
     196            0 :     pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
     197            0 :         self.0
     198            0 :             .try_lock()
     199            0 :             .expect("should not deadlock")
     200            0 :             .set_cold_start_info(info);
     201            0 :     }
     202              : 
     203            0 :     pub(crate) fn set_db_options(&self, options: StartupMessageParams) {
     204            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     205            0 :         this.set_application(options.get("application_name").map(SmolStr::from));
     206            0 :         if let Some(user) = options.get("user") {
     207            0 :             this.set_user(user.into());
     208            0 :         }
     209            0 :         if let Some(dbname) = options.get("database") {
     210            0 :             this.set_dbname(dbname.into());
     211            0 :         }
     212              : 
     213            0 :         this.pg_options = Some(options);
     214            0 :     }
     215              : 
     216            0 :     pub(crate) fn set_project(&self, x: MetricsAuxInfo) {
     217            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     218            0 :         if this.endpoint_id.is_none() {
     219            0 :             this.set_endpoint_id(x.endpoint_id.as_str().into());
     220            0 :         }
     221            0 :         this.branch = Some(x.branch_id);
     222            0 :         this.project = Some(x.project_id);
     223            0 :         this.set_cold_start_info(x.cold_start_info);
     224            0 :     }
     225              : 
     226            0 :     pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) {
     227            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     228            0 :         this.project = Some(project_id);
     229            0 :     }
     230              : 
     231           28 :     pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) {
     232           28 :         self.0
     233           28 :             .try_lock()
     234           28 :             .expect("should not deadlock")
     235           28 :             .set_endpoint_id(endpoint_id);
     236           28 :     }
     237              : 
     238            0 :     pub(crate) fn set_dbname(&self, dbname: DbName) {
     239            0 :         self.0
     240            0 :             .try_lock()
     241            0 :             .expect("should not deadlock")
     242            0 :             .set_dbname(dbname);
     243            0 :     }
     244              : 
     245            0 :     pub(crate) fn set_user(&self, user: RoleName) {
     246            0 :         self.0
     247            0 :             .try_lock()
     248            0 :             .expect("should not deadlock")
     249            0 :             .set_user(user);
     250            0 :     }
     251              : 
     252            0 :     pub(crate) fn set_user_agent(&self, user_agent: Option<SmolStr>) {
     253            0 :         self.0
     254            0 :             .try_lock()
     255            0 :             .expect("should not deadlock")
     256            0 :             .set_user_agent(user_agent);
     257            0 :     }
     258              : 
     259           15 :     pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) {
     260           15 :         let mut this = self.0.try_lock().expect("should not deadlock");
     261           15 :         this.auth_method = Some(auth_method);
     262           15 :     }
     263              : 
     264           12 :     pub(crate) fn set_jwt_issuer(&self, jwt_issuer: String) {
     265           12 :         let mut this = self.0.try_lock().expect("should not deadlock");
     266           12 :         this.jwt_issuer = Some(jwt_issuer);
     267           12 :     }
     268              : 
     269            0 :     pub fn has_private_peer_addr(&self) -> bool {
     270            0 :         self.0
     271            0 :             .try_lock()
     272            0 :             .expect("should not deadlock")
     273            0 :             .has_private_peer_addr()
     274            0 :     }
     275              : 
     276            0 :     pub(crate) fn set_error_kind(&self, kind: ErrorKind) {
     277            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     278            0 :         // Do not record errors from the private address to metrics.
     279            0 :         if !this.has_private_peer_addr() {
     280            0 :             Metrics::get().proxy.errors_total.inc(kind);
     281            0 :         }
     282            0 :         if let Some(ep) = &this.endpoint_id {
     283            0 :             let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
     284            0 :             let label = metric.with_labels(kind);
     285            0 :             metric.get_metric(label).measure(ep);
     286            0 :         }
     287            0 :         this.error_kind = Some(kind);
     288            0 :     }
     289              : 
     290            0 :     pub fn set_success(&self) {
     291            0 :         let mut this = self.0.try_lock().expect("should not deadlock");
     292            0 :         this.success = true;
     293            0 :     }
     294              : 
     295            0 :     pub fn log_connect(self) -> DisconnectLogger {
     296            0 :         let mut this = self.0.into_inner();
     297            0 :         this.log_connect();
     298            0 : 
     299            0 :         // close current span.
     300            0 :         this.span = Span::none();
     301            0 : 
     302            0 :         DisconnectLogger(this)
     303            0 :     }
     304              : 
     305            0 :     pub(crate) fn protocol(&self) -> Protocol {
     306            0 :         self.0.try_lock().expect("should not deadlock").protocol
     307            0 :     }
     308              : 
     309            0 :     pub(crate) fn span(&self) -> Span {
     310            0 :         self.0.try_lock().expect("should not deadlock").span.clone()
     311            0 :     }
     312              : 
     313            0 :     pub(crate) fn session_id(&self) -> Uuid {
     314            0 :         self.0.try_lock().expect("should not deadlock").session_id
     315            0 :     }
     316              : 
     317            6 :     pub(crate) fn peer_addr(&self) -> IpAddr {
     318            6 :         self.0
     319            6 :             .try_lock()
     320            6 :             .expect("should not deadlock")
     321            6 :             .conn_info
     322            6 :             .addr
     323            6 :             .ip()
     324            6 :     }
     325              : 
     326            0 :     pub(crate) fn extra(&self) -> Option<ConnectionInfoExtra> {
     327            0 :         self.0
     328            0 :             .try_lock()
     329            0 :             .expect("should not deadlock")
     330            0 :             .conn_info
     331            0 :             .extra
     332            0 :             .clone()
     333            0 :     }
     334              : 
     335            0 :     pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
     336            0 :         self.0
     337            0 :             .try_lock()
     338            0 :             .expect("should not deadlock")
     339            0 :             .cold_start_info
     340            0 :     }
     341              : 
     342           28 :     pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
     343           28 :         LatencyTimerPause {
     344           28 :             ctx: self,
     345           28 :             start: tokio::time::Instant::now(),
     346           28 :             waiting_for,
     347           28 :         }
     348           28 :     }
     349              : 
     350            0 :     pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated {
     351            0 :         self.0
     352            0 :             .try_lock()
     353            0 :             .expect("should not deadlock")
     354            0 :             .latency_timer
     355            0 :             .accumulated()
     356            0 :     }
     357              : 
     358            4 :     pub(crate) fn success(&self) {
     359            4 :         self.0
     360            4 :             .try_lock()
     361            4 :             .expect("should not deadlock")
     362            4 :             .latency_timer
     363            4 :             .success();
     364            4 :     }
     365              : }
     366              : 
     367              : pub(crate) struct LatencyTimerPause<'a> {
     368              :     ctx: &'a RequestContext,
     369              :     start: tokio::time::Instant,
     370              :     waiting_for: Waiting,
     371              : }
     372              : 
     373              : impl Drop for LatencyTimerPause<'_> {
     374           28 :     fn drop(&mut self) {
     375           28 :         self.ctx
     376           28 :             .0
     377           28 :             .try_lock()
     378           28 :             .expect("should not deadlock")
     379           28 :             .latency_timer
     380           28 :             .unpause(self.start, self.waiting_for);
     381           28 :     }
     382              : }
     383              : 
     384              : impl RequestContextInner {
     385            0 :     fn set_cold_start_info(&mut self, info: ColdStartInfo) {
     386            0 :         self.cold_start_info = info;
     387            0 :         self.latency_timer.cold_start_info(info);
     388            0 :     }
     389              : 
     390           28 :     fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
     391           28 :         if self.endpoint_id.is_none() {
     392           28 :             self.span.record("ep", display(&endpoint_id));
     393           28 :             let metric = &Metrics::get().proxy.connecting_endpoints;
     394           28 :             let label = metric.with_labels(self.protocol);
     395           28 :             metric.get_metric(label).measure(&endpoint_id);
     396           28 :             self.endpoint_id = Some(endpoint_id);
     397           28 :         }
     398           28 :     }
     399              : 
     400            0 :     fn set_application(&mut self, app: Option<SmolStr>) {
     401            0 :         if let Some(app) = app {
     402            0 :             self.application = Some(app);
     403            0 :         }
     404            0 :     }
     405              : 
     406            0 :     fn set_user_agent(&mut self, user_agent: Option<SmolStr>) {
     407            0 :         self.user_agent = user_agent;
     408            0 :     }
     409              : 
     410            0 :     fn set_dbname(&mut self, dbname: DbName) {
     411            0 :         self.dbname = Some(dbname);
     412            0 :     }
     413              : 
     414            0 :     fn set_user(&mut self, user: RoleName) {
     415            0 :         self.span.record("role", display(&user));
     416            0 :         self.user = Some(user);
     417            0 :     }
     418              : 
     419            0 :     fn has_private_peer_addr(&self) -> bool {
     420            0 :         match self.conn_info.addr.ip() {
     421            0 :             IpAddr::V4(ip) => ip.is_private(),
     422            0 :             IpAddr::V6(_) => false,
     423              :         }
     424            0 :     }
     425              : 
     426            0 :     fn log_connect(&mut self) {
     427            0 :         let outcome = if self.success {
     428            0 :             ConnectOutcome::Success
     429              :         } else {
     430            0 :             ConnectOutcome::Failed
     431              :         };
     432              : 
     433              :         // TODO: get rid of entirely/refactor
     434              :         // check for false positives
     435              :         // AND false negatives
     436            0 :         if let Some(rejected) = self.rejected {
     437            0 :             let ep = self
     438            0 :                 .endpoint_id
     439            0 :                 .as_ref()
     440            0 :                 .map(|x| x.as_str())
     441            0 :                 .unwrap_or_default();
     442            0 :             // This makes sense only if cache is disabled
     443            0 :             debug!(
     444              :                 ?outcome,
     445              :                 ?rejected,
     446              :                 ?ep,
     447            0 :                 "check endpoint is valid with outcome"
     448              :             );
     449            0 :             Metrics::get()
     450            0 :                 .proxy
     451            0 :                 .invalid_endpoints_total
     452            0 :                 .inc(InvalidEndpointsGroup {
     453            0 :                     protocol: self.protocol,
     454            0 :                     rejected: rejected.into(),
     455            0 :                     outcome,
     456            0 :                 });
     457            0 :         }
     458              : 
     459            0 :         if let Some(tx) = self.sender.take() {
     460              :             // If type changes, this error handling needs to be updated.
     461            0 :             let tx: mpsc::UnboundedSender<RequestData> = tx;
     462            0 :             if let Err(e) = tx.send(RequestData::from(&*self)) {
     463            0 :                 error!("log_connect channel send failed: {e}");
     464            0 :             }
     465            0 :         }
     466            0 :     }
     467              : 
     468            0 :     fn log_disconnect(&mut self) {
     469            0 :         // If we are here, it's guaranteed that the user successfully connected to the endpoint.
     470            0 :         // Here we log the length of the session.
     471            0 :         self.disconnect_timestamp = Some(Utc::now());
     472            0 :         if let Some(tx) = self.disconnect_sender.take() {
     473              :             // If type changes, this error handling needs to be updated.
     474            0 :             let tx: mpsc::UnboundedSender<RequestData> = tx;
     475            0 :             if let Err(e) = tx.send(RequestData::from(&*self)) {
     476            0 :                 error!("log_disconnect channel send failed: {e}");
     477            0 :             }
     478            0 :         }
     479            0 :     }
     480              : }
     481              : 
     482              : impl Drop for RequestContextInner {
     483           70 :     fn drop(&mut self) {
     484           70 :         if self.sender.is_some() {
     485            0 :             self.log_connect();
     486           70 :         }
     487           70 :     }
     488              : }
     489              : 
     490              : pub struct DisconnectLogger(RequestContextInner);
     491              : 
     492              : impl Drop for DisconnectLogger {
     493            0 :     fn drop(&mut self) {
     494            0 :         self.0.log_disconnect();
     495            0 :     }
     496              : }
        

Generated by: LCOV version 2.1-beta