LCOV - code coverage report
Current view: top level - proxy/src - lib.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 98.0 % 50 49
Test Date: 2024-02-07 07:37:29 Functions: 48.0 % 125 60

            Line data    Source code
       1              : #![deny(clippy::undocumented_unsafe_blocks)]
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use anyhow::{bail, Context};
       6              : use tokio::task::JoinError;
       7              : use tokio_util::sync::CancellationToken;
       8              : use tracing::warn;
       9              : 
      10              : pub mod auth;
      11              : pub mod cache;
      12              : pub mod cancellation;
      13              : pub mod compute;
      14              : pub mod config;
      15              : pub mod console;
      16              : pub mod context;
      17              : pub mod error;
      18              : pub mod http;
      19              : pub mod intern;
      20              : pub mod jemalloc;
      21              : pub mod logging;
      22              : pub mod metrics;
      23              : pub mod parse;
      24              : pub mod protocol2;
      25              : pub mod proxy;
      26              : pub mod rate_limiter;
      27              : pub mod redis;
      28              : pub mod sasl;
      29              : pub mod scram;
      30              : pub mod serverless;
      31              : pub mod stream;
      32              : pub mod url;
      33              : pub mod usage_metrics;
      34              : pub mod waiters;
      35              : 
      36              : /// Handle unix signals appropriately.
      37           24 : pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
      38              :     use tokio::signal::unix::{signal, SignalKind};
      39              : 
      40           24 :     let mut hangup = signal(SignalKind::hangup())?;
      41           24 :     let mut interrupt = signal(SignalKind::interrupt())?;
      42           24 :     let mut terminate = signal(SignalKind::terminate())?;
      43              : 
      44           48 :     loop {
      45           72 :         tokio::select! {
      46           72 :             // Hangup is commonly used for config reload.
      47           72 :             _ = hangup.recv() => {
      48           72 :                 warn!("received SIGHUP; config reload is not supported");
      49           72 :             }
      50           72 :             // Shut down the whole application.
      51           72 :             _ = interrupt.recv() => {
      52           72 :                 warn!("received SIGINT, exiting immediately");
      53           72 :                 bail!("interrupted");
      54           72 :             }
      55           72 :             _ = terminate.recv() => {
      56           72 :                 warn!("received SIGTERM, shutting down once all existing connections have closed");
      57           72 :                 token.cancel();
      58           72 :             }
      59           72 :         }
      60           48 :     }
      61            0 : }
      62              : 
      63              : /// Flattens `Result<Result<T>>` into `Result<T>`.
      64           70 : pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
      65           70 :     r.context("join error").and_then(|x| x)
      66           70 : }
      67              : 
      68              : macro_rules! smol_str_wrapper {
      69              :     ($name:ident) => {
      70      8919593 :         #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
      71              :         pub struct $name(smol_str::SmolStr);
      72              : 
      73              :         impl $name {
      74          143 :             pub fn as_str(&self) -> &str {
      75          143 :                 self.0.as_str()
      76          143 :             }
      77              :         }
      78              : 
      79              :         impl std::fmt::Display for $name {
      80         1078 :             fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      81         1078 :                 self.0.fmt(f)
      82         1078 :             }
      83              :         }
      84              : 
      85              :         impl<T> std::cmp::PartialEq<T> for $name
      86              :         where
      87              :             smol_str::SmolStr: std::cmp::PartialEq<T>,
      88              :         {
      89           36 :             fn eq(&self, other: &T) -> bool {
      90           36 :                 self.0.eq(other)
      91           36 :             }
      92              :         }
      93              : 
      94              :         impl<T> From<T> for $name
      95              :         where
      96              :             smol_str::SmolStr: From<T>,
      97              :         {
      98      2000621 :             fn from(x: T) -> Self {
      99      2000621 :                 Self(x.into())
     100      2000621 :             }
     101              :         }
     102              : 
     103              :         impl AsRef<str> for $name {
     104           58 :             fn as_ref(&self) -> &str {
     105           58 :                 self.0.as_ref()
     106           58 :             }
     107              :         }
     108              : 
     109              :         impl std::ops::Deref for $name {
     110              :             type Target = str;
     111          633 :             fn deref(&self) -> &str {
     112          633 :                 &*self.0
     113          633 :             }
     114              :         }
     115              : 
     116              :         impl<'de> serde::de::Deserialize<'de> for $name {
     117           49 :             fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
     118           49 :                 <smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
     119           49 :             }
     120              :         }
     121              : 
     122              :         impl serde::Serialize for $name {
     123           12 :             fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
     124           12 :                 self.0.serialize(s)
     125           12 :             }
     126              :         }
     127              :     };
     128              : }
     129              : 
     130              : // 90% of role name strings are 20 characters or less.
     131              : smol_str_wrapper!(RoleName);
     132              : // 50% of endpoint strings are 23 characters or less.
     133              : smol_str_wrapper!(EndpointId);
     134              : // 50% of branch strings are 23 characters or less.
     135              : smol_str_wrapper!(BranchId);
     136              : // 90% of project strings are 23 characters or less.
     137              : smol_str_wrapper!(ProjectId);
     138              : 
     139              : // will usually equal endpoint ID
     140              : smol_str_wrapper!(EndpointCacheKey);
     141              : 
     142              : smol_str_wrapper!(DbName);
        

Generated by: LCOV version 2.1-beta