LCOV - code coverage report
Current view: top level - proxy/src - lib.rs (source / functions) Coverage Total Hit
Test: 86c536b7fe84b2afe03c3bb264199e9c319ae0f8.info Lines: 31.1 % 61 19
Test Date: 2024-06-24 16:38:41 Functions: 22.2 % 81 18

            Line data    Source code
       1              : #![deny(clippy::undocumented_unsafe_blocks)]
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use anyhow::{bail, Context};
       6              : use intern::{EndpointIdInt, EndpointIdTag, InternId};
       7              : use tokio::task::JoinError;
       8              : use tokio_util::sync::CancellationToken;
       9              : use tracing::warn;
      10              : 
      11              : pub mod auth;
      12              : pub mod cache;
      13              : pub mod cancellation;
      14              : pub mod compute;
      15              : pub mod config;
      16              : pub mod console;
      17              : pub mod context;
      18              : pub mod error;
      19              : pub mod http;
      20              : pub mod intern;
      21              : pub mod jemalloc;
      22              : pub mod logging;
      23              : pub mod metrics;
      24              : pub mod parse;
      25              : pub mod protocol2;
      26              : pub mod proxy;
      27              : pub mod rate_limiter;
      28              : pub mod redis;
      29              : pub mod sasl;
      30              : pub mod scram;
      31              : pub mod serverless;
      32              : pub mod stream;
      33              : pub mod url;
      34              : pub mod usage_metrics;
      35              : pub mod waiters;
      36              : 
      37              : /// Handle unix signals appropriately.
      38            0 : pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
      39              :     use tokio::signal::unix::{signal, SignalKind};
      40              : 
      41            0 :     let mut hangup = signal(SignalKind::hangup())?;
      42            0 :     let mut interrupt = signal(SignalKind::interrupt())?;
      43            0 :     let mut terminate = signal(SignalKind::terminate())?;
      44              : 
      45            0 :     loop {
      46            0 :         tokio::select! {
      47              :             // Hangup is commonly used for config reload.
      48              :             _ = hangup.recv() => {
      49              :                 warn!("received SIGHUP; config reload is not supported");
      50              :             }
      51              :             // Shut down the whole application.
      52              :             _ = interrupt.recv() => {
      53              :                 warn!("received SIGINT, exiting immediately");
      54              :                 bail!("interrupted");
      55              :             }
      56              :             _ = terminate.recv() => {
      57              :                 warn!("received SIGTERM, shutting down once all existing connections have closed");
      58              :                 token.cancel();
      59              :             }
      60            0 :         }
      61            0 :     }
      62            0 : }
      63              : 
      64              : /// Flattens `Result<Result<T>>` into `Result<T>`.
      65            0 : pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
      66            0 :     r.context("join error").and_then(|x| x)
      67            0 : }
      68              : 
      69              : macro_rules! smol_str_wrapper {
      70              :     ($name:ident) => {
      71              :         #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
      72              :         pub struct $name(smol_str::SmolStr);
      73              : 
      74              :         impl $name {
      75            0 :             pub fn as_str(&self) -> &str {
      76            0 :                 self.0.as_str()
      77            0 :             }
      78              :         }
      79              : 
      80              :         impl std::fmt::Display for $name {
      81            6 :             fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      82            6 :                 self.0.fmt(f)
      83            6 :             }
      84              :         }
      85              : 
      86              :         impl<T> std::cmp::PartialEq<T> for $name
      87              :         where
      88              :             smol_str::SmolStr: std::cmp::PartialEq<T>,
      89              :         {
      90           38 :             fn eq(&self, other: &T) -> bool {
      91           38 :                 self.0.eq(other)
      92           38 :             }
      93              :         }
      94              : 
      95              :         impl<T> From<T> for $name
      96              :         where
      97              :             smol_str::SmolStr: From<T>,
      98              :         {
      99          284 :             fn from(x: T) -> Self {
     100          284 :                 Self(x.into())
     101          284 :             }
     102              :         }
     103              : 
     104              :         impl AsRef<str> for $name {
     105           20 :             fn as_ref(&self) -> &str {
     106           20 :                 self.0.as_ref()
     107           20 :             }
     108              :         }
     109              : 
     110              :         impl std::ops::Deref for $name {
     111              :             type Target = str;
     112          282 :             fn deref(&self) -> &str {
     113          282 :                 &*self.0
     114          282 :             }
     115              :         }
     116              : 
     117              :         impl<'de> serde::de::Deserialize<'de> for $name {
     118            0 :             fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
     119            0 :                 <smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
     120            0 :             }
     121              :         }
     122              : 
     123              :         impl serde::Serialize for $name {
     124            0 :             fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
     125            0 :                 self.0.serialize(s)
     126            0 :             }
     127              :         }
     128              :     };
     129              : }
     130              : 
     131              : const POOLER_SUFFIX: &str = "-pooler";
     132              : 
     133              : impl EndpointId {
     134            6 :     fn normalize(&self) -> Self {
     135            6 :         if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
     136            0 :             stripped.into()
     137              :         } else {
     138            6 :             self.clone()
     139              :         }
     140            6 :     }
     141              : 
     142            0 :     fn normalize_intern(&self) -> EndpointIdInt {
     143            0 :         if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
     144            0 :             EndpointIdTag::get_interner().get_or_intern(stripped)
     145              :         } else {
     146            0 :             self.into()
     147              :         }
     148            0 :     }
     149              : }
     150              : 
     151              : // 90% of role name strings are 20 characters or less.
     152              : smol_str_wrapper!(RoleName);
     153              : // 50% of endpoint strings are 23 characters or less.
     154              : smol_str_wrapper!(EndpointId);
     155              : // 50% of branch strings are 23 characters or less.
     156              : smol_str_wrapper!(BranchId);
     157              : // 90% of project strings are 23 characters or less.
     158              : smol_str_wrapper!(ProjectId);
     159              : 
     160              : // will usually equal endpoint ID
     161              : smol_str_wrapper!(EndpointCacheKey);
     162              : 
     163              : smol_str_wrapper!(DbName);
     164              : 
     165              : // postgres hostname, will likely be a port:ip addr
     166              : smol_str_wrapper!(Host);
     167              : 
     168              : // Endpoints are a bit tricky. Rare they might be branches or projects.
     169              : impl EndpointId {
     170            0 :     pub fn is_endpoint(&self) -> bool {
     171            0 :         self.0.starts_with("ep-")
     172            0 :     }
     173            0 :     pub fn is_branch(&self) -> bool {
     174            0 :         self.0.starts_with("br-")
     175            0 :     }
     176            0 :     pub fn is_project(&self) -> bool {
     177            0 :         !self.is_endpoint() && !self.is_branch()
     178            0 :     }
     179            0 :     pub fn as_branch(&self) -> BranchId {
     180            0 :         BranchId(self.0.clone())
     181            0 :     }
     182            0 :     pub fn as_project(&self) -> ProjectId {
     183            0 :         ProjectId(self.0.clone())
     184            0 :     }
     185              : }
        

Generated by: LCOV version 2.1-beta