Line data Source code
1 : #![deny(clippy::undocumented_unsafe_blocks)]
2 :
3 : use std::convert::Infallible;
4 :
5 : use anyhow::{bail, Context};
6 : use tokio::task::JoinError;
7 : use tokio_util::sync::CancellationToken;
8 : use tracing::warn;
9 :
10 : pub mod auth;
11 : pub mod cache;
12 : pub mod cancellation;
13 : pub mod compute;
14 : pub mod config;
15 : pub mod console;
16 : pub mod context;
17 : pub mod error;
18 : pub mod http;
19 : pub mod intern;
20 : pub mod jemalloc;
21 : pub mod logging;
22 : pub mod metrics;
23 : pub mod parse;
24 : pub mod protocol2;
25 : pub mod proxy;
26 : pub mod rate_limiter;
27 : pub mod redis;
28 : pub mod sasl;
29 : pub mod scram;
30 : pub mod serverless;
31 : pub mod stream;
32 : pub mod url;
33 : pub mod usage_metrics;
34 : pub mod waiters;
35 :
36 : /// Handle unix signals appropriately.
37 26 : pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
38 : use tokio::signal::unix::{signal, SignalKind};
39 :
40 26 : let mut hangup = signal(SignalKind::hangup())?;
41 26 : let mut interrupt = signal(SignalKind::interrupt())?;
42 26 : let mut terminate = signal(SignalKind::terminate())?;
43 :
44 52 : loop {
45 78 : tokio::select! {
46 78 : // Hangup is commonly used for config reload.
47 78 : _ = hangup.recv() => {
48 78 : warn!("received SIGHUP; config reload is not supported");
49 78 : }
50 78 : // Shut down the whole application.
51 78 : _ = interrupt.recv() => {
52 78 : warn!("received SIGINT, exiting immediately");
53 78 : bail!("interrupted");
54 78 : }
55 78 : _ = terminate.recv() => {
56 78 : warn!("received SIGTERM, shutting down once all existing connections have closed");
57 78 : token.cancel();
58 78 : }
59 78 : }
60 52 : }
61 0 : }
62 :
63 : /// Flattens `Result<Result<T>>` into `Result<T>`.
64 76 : pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
65 76 : r.context("join error").and_then(|x| x)
66 76 : }
67 :
68 : macro_rules! smol_str_wrapper {
69 : ($name:ident) => {
70 8919706 : #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
71 : pub struct $name(smol_str::SmolStr);
72 :
73 : impl $name {
74 292 : pub fn as_str(&self) -> &str {
75 292 : self.0.as_str()
76 292 : }
77 : }
78 :
79 : impl std::fmt::Display for $name {
80 1254 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 1254 : self.0.fmt(f)
82 1254 : }
83 : }
84 :
85 : impl<T> std::cmp::PartialEq<T> for $name
86 : where
87 : smol_str::SmolStr: std::cmp::PartialEq<T>,
88 : {
89 36 : fn eq(&self, other: &T) -> bool {
90 36 : self.0.eq(other)
91 36 : }
92 : }
93 :
94 : impl<T> From<T> for $name
95 : where
96 : smol_str::SmolStr: From<T>,
97 : {
98 2000759 : fn from(x: T) -> Self {
99 2000759 : Self(x.into())
100 2000759 : }
101 : }
102 :
103 : impl AsRef<str> for $name {
104 60 : fn as_ref(&self) -> &str {
105 60 : self.0.as_ref()
106 60 : }
107 : }
108 :
109 : impl std::ops::Deref for $name {
110 : type Target = str;
111 723 : fn deref(&self) -> &str {
112 723 : &*self.0
113 723 : }
114 : }
115 :
116 : impl<'de> serde::de::Deserialize<'de> for $name {
117 49 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
118 49 : <smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
119 49 : }
120 : }
121 :
122 : impl serde::Serialize for $name {
123 12 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
124 12 : self.0.serialize(s)
125 12 : }
126 : }
127 : };
128 : }
129 :
130 : // 90% of role name strings are 20 characters or less.
131 : smol_str_wrapper!(RoleName);
132 : // 50% of endpoint strings are 23 characters or less.
133 : smol_str_wrapper!(EndpointId);
134 : // 50% of branch strings are 23 characters or less.
135 : smol_str_wrapper!(BranchId);
136 : // 90% of project strings are 23 characters or less.
137 : smol_str_wrapper!(ProjectId);
138 :
139 : // will usually equal endpoint ID
140 : smol_str_wrapper!(EndpointCacheKey);
141 :
142 : smol_str_wrapper!(DbName);
|