Line data Source code
1 : #![deny(clippy::undocumented_unsafe_blocks)]
2 :
3 : use std::convert::Infallible;
4 :
5 : use anyhow::{bail, Context};
6 : use intern::{EndpointIdInt, EndpointIdTag, InternId};
7 : use tokio::task::JoinError;
8 : use tokio_util::sync::CancellationToken;
9 : use tracing::warn;
10 :
11 : pub mod auth;
12 : pub mod cache;
13 : pub mod cancellation;
14 : pub mod compute;
15 : pub mod config;
16 : pub mod console;
17 : pub mod context;
18 : pub mod error;
19 : pub mod http;
20 : pub mod intern;
21 : pub mod jemalloc;
22 : pub mod logging;
23 : pub mod metrics;
24 : pub mod parse;
25 : pub mod protocol2;
26 : pub mod proxy;
27 : pub mod rate_limiter;
28 : pub mod redis;
29 : pub mod sasl;
30 : pub mod scram;
31 : pub mod serverless;
32 : pub mod stream;
33 : pub mod url;
34 : pub mod usage_metrics;
35 : pub mod waiters;
36 :
37 : /// Handle unix signals appropriately.
38 0 : pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
39 : use tokio::signal::unix::{signal, SignalKind};
40 :
41 0 : let mut hangup = signal(SignalKind::hangup())?;
42 0 : let mut interrupt = signal(SignalKind::interrupt())?;
43 0 : let mut terminate = signal(SignalKind::terminate())?;
44 :
45 0 : loop {
46 0 : tokio::select! {
47 : // Hangup is commonly used for config reload.
48 : _ = hangup.recv() => {
49 : warn!("received SIGHUP; config reload is not supported");
50 : }
51 : // Shut down the whole application.
52 : _ = interrupt.recv() => {
53 : warn!("received SIGINT, exiting immediately");
54 : bail!("interrupted");
55 : }
56 : _ = terminate.recv() => {
57 : warn!("received SIGTERM, shutting down once all existing connections have closed");
58 : token.cancel();
59 : }
60 0 : }
61 0 : }
62 0 : }
63 :
64 : /// Flattens `Result<Result<T>>` into `Result<T>`.
65 0 : pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
66 0 : r.context("join error").and_then(|x| x)
67 0 : }
68 :
69 : macro_rules! smol_str_wrapper {
70 : ($name:ident) => {
71 : #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
72 : pub struct $name(smol_str::SmolStr);
73 :
74 : impl $name {
75 0 : pub fn as_str(&self) -> &str {
76 0 : self.0.as_str()
77 0 : }
78 : }
79 :
80 : impl std::fmt::Display for $name {
81 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 6 : self.0.fmt(f)
83 6 : }
84 : }
85 :
86 : impl<T> std::cmp::PartialEq<T> for $name
87 : where
88 : smol_str::SmolStr: std::cmp::PartialEq<T>,
89 : {
90 38 : fn eq(&self, other: &T) -> bool {
91 38 : self.0.eq(other)
92 38 : }
93 : }
94 :
95 : impl<T> From<T> for $name
96 : where
97 : smol_str::SmolStr: From<T>,
98 : {
99 284 : fn from(x: T) -> Self {
100 284 : Self(x.into())
101 284 : }
102 : }
103 :
104 : impl AsRef<str> for $name {
105 20 : fn as_ref(&self) -> &str {
106 20 : self.0.as_ref()
107 20 : }
108 : }
109 :
110 : impl std::ops::Deref for $name {
111 : type Target = str;
112 282 : fn deref(&self) -> &str {
113 282 : &*self.0
114 282 : }
115 : }
116 :
117 : impl<'de> serde::de::Deserialize<'de> for $name {
118 0 : fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
119 0 : <smol_str::SmolStr as serde::de::Deserialize<'de>>::deserialize(d).map(Self)
120 0 : }
121 : }
122 :
123 : impl serde::Serialize for $name {
124 0 : fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
125 0 : self.0.serialize(s)
126 0 : }
127 : }
128 : };
129 : }
130 :
131 : const POOLER_SUFFIX: &str = "-pooler";
132 :
133 : impl EndpointId {
134 6 : fn normalize(&self) -> Self {
135 6 : if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
136 0 : stripped.into()
137 : } else {
138 6 : self.clone()
139 : }
140 6 : }
141 :
142 0 : fn normalize_intern(&self) -> EndpointIdInt {
143 0 : if let Some(stripped) = self.as_ref().strip_suffix(POOLER_SUFFIX) {
144 0 : EndpointIdTag::get_interner().get_or_intern(stripped)
145 : } else {
146 0 : self.into()
147 : }
148 0 : }
149 : }
150 :
151 : // 90% of role name strings are 20 characters or less.
152 : smol_str_wrapper!(RoleName);
153 : // 50% of endpoint strings are 23 characters or less.
154 : smol_str_wrapper!(EndpointId);
155 : // 50% of branch strings are 23 characters or less.
156 : smol_str_wrapper!(BranchId);
157 : // 90% of project strings are 23 characters or less.
158 : smol_str_wrapper!(ProjectId);
159 :
160 : // will usually equal endpoint ID
161 : smol_str_wrapper!(EndpointCacheKey);
162 :
163 : smol_str_wrapper!(DbName);
164 :
165 : // postgres hostname, will likely be a port:ip addr
166 : smol_str_wrapper!(Host);
167 :
168 : // Endpoints are a bit tricky. Rare they might be branches or projects.
169 : impl EndpointId {
170 0 : pub fn is_endpoint(&self) -> bool {
171 0 : self.0.starts_with("ep-")
172 0 : }
173 0 : pub fn is_branch(&self) -> bool {
174 0 : self.0.starts_with("br-")
175 0 : }
176 0 : pub fn is_project(&self) -> bool {
177 0 : !self.is_endpoint() && !self.is_branch()
178 0 : }
179 0 : pub fn as_branch(&self) -> BranchId {
180 0 : BranchId(self.0.clone())
181 0 : }
182 0 : pub fn as_project(&self) -> ProjectId {
183 0 : ProjectId(self.0.clone())
184 0 : }
185 : }
|