Line data Source code
1 : #![deny(clippy::undocumented_unsafe_blocks)]
2 :
3 : use std::convert::Infallible;
4 :
5 : use anyhow::{bail, Context};
6 : use tokio::task::JoinError;
7 : use tokio_util::sync::CancellationToken;
8 : use tracing::warn;
9 :
10 : pub mod auth;
11 : pub mod cache;
12 : pub mod cancellation;
13 : pub mod compute;
14 : pub mod config;
15 : pub mod console;
16 : pub mod context;
17 : pub mod error;
18 : pub mod http;
19 : pub mod intern;
20 : pub mod jemalloc;
21 : pub mod logging;
22 : pub mod metrics;
23 : pub mod parse;
24 : pub mod protocol2;
25 : pub mod proxy;
26 : pub mod rate_limiter;
27 : pub mod redis;
28 : pub mod sasl;
29 : pub mod scram;
30 : pub mod serverless;
31 : pub mod stream;
32 : pub mod url;
33 : pub mod usage_metrics;
34 : pub mod waiters;
35 :
36 : /// Handle unix signals appropriately.
37 0 : pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
38 : use tokio::signal::unix::{signal, SignalKind};
39 :
40 0 : let mut hangup = signal(SignalKind::hangup())?;
41 0 : let mut interrupt = signal(SignalKind::interrupt())?;
42 0 : let mut terminate = signal(SignalKind::terminate())?;
43 :
44 0 : loop {
45 0 : tokio::select! {
46 : // Hangup is commonly used for config reload.
47 : _ = hangup.recv() => {
48 : warn!("received SIGHUP; config reload is not supported");
49 : }
50 : // Shut down the whole application.
51 : _ = interrupt.recv() => {
52 : warn!("received SIGINT, exiting immediately");
53 : bail!("interrupted");
54 : }
55 : _ = terminate.recv() => {
56 : warn!("received SIGTERM, shutting down once all existing connections have closed");
57 : token.cancel();
58 : }
59 0 : }
60 0 : }
61 0 : }
62 :
63 : /// Flattens `Result<Result<T>>` into `Result<T>`.
64 0 : pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
65 0 : r.context("join error").and_then(|x| x)
66 0 : }
67 :
68 : macro_rules! smol_str_wrapper {
69 : ($name:ident) => {
70 : #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
71 : pub struct $name(smol_str::SmolStr);
72 :
73 : impl $name {
74 0 : pub fn as_str(&self) -> &str {
75 0 : self.0.as_str()
76 0 : }
77 : }
78 :
79 : impl std::fmt::Display for $name {
80 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 6 : self.0.fmt(f)
82 6 : }
83 : }
84 :
85 : impl<T> std::cmp::PartialEq<T> for $name
86 : where
87 : smol_str::SmolStr: std::cmp::PartialEq<T>,
88 : {
89 38 : fn eq(&self, other: &T) -> bool {
90 38 : self.0.eq(other)
91 38 : }
92 : }
93 :
94 : impl<T> From<T> for $name
95 : where
96 : smol_str::SmolStr: From<T>,
97 : {
98 280 : fn from(x: T) -> Self {
99 280 : Self(x.into())
100 280 : }
101 : }
102 :
103 : impl AsRef<str> for $name {
104 20 : fn as_ref(&self) -> &str {
105 20 : self.0.as_ref()
106 20 : }
107 : }
108 :
109 : impl std::ops::Deref for $name {
110 : type Target = str;
111 272 : fn deref(&self) -> &str {
112 272 : &*self.0
113 272 : }
114 : }
115 :
116 : impl<'de> serde::de::Deserialize<'de> for $name {
117 0 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
118 0 : <smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
119 0 : }
120 : }
121 :
122 : impl serde::Serialize for $name {
123 0 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
124 0 : self.0.serialize(s)
125 0 : }
126 : }
127 : };
128 : }
129 :
130 : const POOLER_SUFFIX: &str = "-pooler";
131 :
132 : pub trait Normalize {
133 : fn normalize(&self) -> Self;
134 : }
135 :
136 : impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
137 6 : fn normalize(&self) -> Self {
138 6 : if self.as_ref().ends_with(POOLER_SUFFIX) {
139 0 : let mut s = self.as_ref().to_string();
140 0 : s.truncate(s.len() - POOLER_SUFFIX.len());
141 0 : s.into()
142 : } else {
143 6 : self.clone()
144 : }
145 6 : }
146 : }
147 :
148 : // 90% of role name strings are 20 characters or less.
149 : smol_str_wrapper!(RoleName);
150 : // 50% of endpoint strings are 23 characters or less.
151 : smol_str_wrapper!(EndpointId);
152 : // 50% of branch strings are 23 characters or less.
153 : smol_str_wrapper!(BranchId);
154 : // 90% of project strings are 23 characters or less.
155 : smol_str_wrapper!(ProjectId);
156 :
157 : // will usually equal endpoint ID
158 : smol_str_wrapper!(EndpointCacheKey);
159 :
160 : smol_str_wrapper!(DbName);
161 :
162 : // postgres hostname, will likely be a port:ip addr
163 : smol_str_wrapper!(Host);
164 :
165 : // Endpoints are a bit tricky. Rare they might be branches or projects.
166 : impl EndpointId {
167 0 : pub fn is_endpoint(&self) -> bool {
168 0 : self.0.starts_with("ep-")
169 0 : }
170 0 : pub fn is_branch(&self) -> bool {
171 0 : self.0.starts_with("br-")
172 0 : }
173 0 : pub fn is_project(&self) -> bool {
174 0 : !self.is_endpoint() && !self.is_branch()
175 0 : }
176 0 : pub fn as_branch(&self) -> BranchId {
177 0 : BranchId(self.0.clone())
178 0 : }
179 0 : pub fn as_project(&self) -> ProjectId {
180 0 : ProjectId(self.0.clone())
181 0 : }
182 : }
|