Line data Source code
1 : //! Manages the pool of connections between local_proxy and postgres.
2 : //!
3 : //! The pool is keyed by database and role_name, and can contain multiple connections
4 : //! shared between users.
5 : //!
6 : //! The pool manages the pg_session_jwt extension used for authorizing
7 : //! requests in the db.
8 : //!
9 : //! The first time a db/role pair is seen, local_proxy attempts to install the extension
10 : //! and grant usage to the role on the given schema.
11 :
12 : use std::collections::HashMap;
13 : use std::pin::pin;
14 : use std::sync::Arc;
15 : use std::sync::atomic::AtomicUsize;
16 : use std::task::{Poll, ready};
17 : use std::time::Duration;
18 :
19 : use ed25519_dalek::{Signature, Signer, SigningKey};
20 : use futures::Future;
21 : use futures::future::poll_fn;
22 : use indexmap::IndexMap;
23 : use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
24 : use parking_lot::RwLock;
25 : use postgres_client::AsyncMessage;
26 : use postgres_client::tls::NoTlsStream;
27 : use serde_json::value::RawValue;
28 : use tokio::net::TcpStream;
29 : use tokio::time::Instant;
30 : use tokio_util::sync::CancellationToken;
31 : use tracing::{Instrument, debug, error, info, info_span, warn};
32 :
33 : use super::backend::HttpConnError;
34 : use super::conn_pool_lib::{
35 : Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, DbUserConn,
36 : EndpointConnPool,
37 : };
38 : use super::sql_over_http::SqlOverHttpError;
39 : use crate::context::RequestContext;
40 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
41 : use crate::metrics::Metrics;
42 :
43 : pub(crate) const EXT_NAME: &str = "pg_session_jwt";
44 : pub(crate) const EXT_VERSION: &str = "0.2.0";
45 : pub(crate) const EXT_SCHEMA: &str = "auth";
46 :
47 : #[derive(Clone)]
48 : pub(crate) struct ClientDataLocal {
49 : session: tokio::sync::watch::Sender<uuid::Uuid>,
50 : cancel: CancellationToken,
51 : key: SigningKey,
52 : jti: u64,
53 : }
54 :
55 : impl ClientDataLocal {
56 0 : pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
57 0 : &mut self.session
58 0 : }
59 :
60 0 : pub fn cancel(&mut self) {
61 0 : self.cancel.cancel();
62 0 : }
63 : }
64 :
65 : pub(crate) struct LocalConnPool<C: ClientInnerExt> {
66 : global_pool: Arc<RwLock<EndpointConnPool<C>>>,
67 :
68 : config: &'static crate::config::HttpConfig,
69 : }
70 :
71 : impl<C: ClientInnerExt> LocalConnPool<C> {
72 0 : pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
73 0 : Arc::new(Self {
74 0 : global_pool: Arc::new(RwLock::new(EndpointConnPool::new(
75 0 : HashMap::new(),
76 0 : 0,
77 0 : config.pool_options.max_conns_per_endpoint,
78 0 : Arc::new(AtomicUsize::new(0)),
79 0 : config.pool_options.max_total_conns,
80 0 : String::from("local_pool"),
81 0 : ))),
82 0 : config,
83 0 : })
84 0 : }
85 :
86 0 : pub(crate) fn get_idle_timeout(&self) -> Duration {
87 0 : self.config.pool_options.idle_timeout
88 0 : }
89 :
90 0 : pub(crate) fn get(
91 0 : self: &Arc<Self>,
92 0 : ctx: &RequestContext,
93 0 : conn_info: &ConnInfo,
94 0 : ) -> Result<Option<Client<C>>, HttpConnError> {
95 0 : let client = self
96 0 : .global_pool
97 0 : .write()
98 0 : .get_conn_entry(conn_info.db_and_user())
99 0 : .map(|entry| entry.conn);
100 :
101 : // ok return cached connection if found and establish a new one otherwise
102 0 : if let Some(mut client) = client {
103 0 : if client.inner.is_closed() {
104 0 : info!("local_pool: cached connection '{conn_info}' is closed, opening a new one");
105 0 : return Ok(None);
106 0 : }
107 0 :
108 0 : tracing::Span::current()
109 0 : .record("conn_id", tracing::field::display(client.get_conn_id()));
110 0 : tracing::Span::current().record(
111 0 : "pid",
112 0 : tracing::field::display(client.inner.get_process_id()),
113 0 : );
114 0 : debug!(
115 0 : cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
116 0 : "local_pool: reusing connection '{conn_info}'"
117 : );
118 :
119 0 : match client.get_data() {
120 0 : ClientDataEnum::Local(data) => {
121 0 : data.session().send(ctx.session_id())?;
122 : }
123 :
124 0 : ClientDataEnum::Remote(data) => {
125 0 : data.session().send(ctx.session_id())?;
126 : }
127 0 : ClientDataEnum::Http(_) => (),
128 : }
129 :
130 0 : ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
131 0 : ctx.success();
132 0 :
133 0 : return Ok(Some(Client::new(
134 0 : client,
135 0 : conn_info.clone(),
136 0 : Arc::downgrade(&self.global_pool),
137 0 : )));
138 0 : }
139 0 : Ok(None)
140 0 : }
141 :
142 0 : pub(crate) fn initialized(self: &Arc<Self>, conn_info: &ConnInfo) -> bool {
143 0 : if let Some(pool) = self.global_pool.read().get_pool(conn_info.db_and_user()) {
144 0 : return pool.is_initialized();
145 0 : }
146 0 : false
147 0 : }
148 :
149 0 : pub(crate) fn set_initialized(self: &Arc<Self>, conn_info: &ConnInfo) {
150 0 : if let Some(pool) = self
151 0 : .global_pool
152 0 : .write()
153 0 : .get_pool_mut(conn_info.db_and_user())
154 0 : {
155 0 : pool.set_initialized();
156 0 : }
157 0 : }
158 : }
159 :
160 : #[allow(clippy::too_many_arguments)]
161 0 : pub(crate) fn poll_client<C: ClientInnerExt>(
162 0 : global_pool: Arc<LocalConnPool<C>>,
163 0 : ctx: &RequestContext,
164 0 : conn_info: ConnInfo,
165 0 : client: C,
166 0 : mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
167 0 : key: SigningKey,
168 0 : conn_id: uuid::Uuid,
169 0 : aux: MetricsAuxInfo,
170 0 : ) -> Client<C> {
171 0 : let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
172 0 : let mut session_id = ctx.session_id();
173 0 : let (tx, mut rx) = tokio::sync::watch::channel(session_id);
174 :
175 0 : let span = info_span!(parent: None, "connection", %conn_id);
176 0 : let cold_start_info = ctx.cold_start_info();
177 0 : span.in_scope(|| {
178 0 : info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
179 0 : });
180 0 : let pool = Arc::downgrade(&global_pool);
181 0 :
182 0 : let db_user = conn_info.db_and_user();
183 0 : let idle = global_pool.get_idle_timeout();
184 0 : let cancel = CancellationToken::new();
185 0 : let cancelled = cancel.clone().cancelled_owned();
186 0 :
187 0 : tokio::spawn(
188 0 : async move {
189 0 : let _conn_gauge = conn_gauge;
190 0 : let mut idle_timeout = pin!(tokio::time::sleep(idle));
191 0 : let mut cancelled = pin!(cancelled);
192 0 :
193 0 : poll_fn(move |cx| {
194 0 : if cancelled.as_mut().poll(cx).is_ready() {
195 0 : info!("connection dropped");
196 0 : return Poll::Ready(())
197 0 : }
198 0 :
199 0 : match rx.has_changed() {
200 : Ok(true) => {
201 0 : session_id = *rx.borrow_and_update();
202 0 : info!(%session_id, "changed session");
203 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
204 : }
205 : Err(_) => {
206 0 : info!("connection dropped");
207 0 : return Poll::Ready(())
208 : }
209 0 : _ => {}
210 : }
211 :
212 : // 5 minute idle connection timeout
213 0 : if idle_timeout.as_mut().poll(cx).is_ready() {
214 0 : idle_timeout.as_mut().reset(Instant::now() + idle);
215 0 : info!("connection idle");
216 0 : if let Some(pool) = pool.clone().upgrade() {
217 : // remove client from pool - should close the connection if it's idle.
218 : // does nothing if the client is currently checked-out and in-use
219 0 : if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
220 0 : info!("idle connection removed");
221 0 : }
222 0 : }
223 0 : }
224 :
225 : loop {
226 0 : let message = ready!(connection.poll_message(cx));
227 :
228 0 : match message {
229 0 : Some(Ok(AsyncMessage::Notice(notice))) => {
230 0 : info!(%session_id, "notice: {}", notice);
231 : }
232 0 : Some(Ok(AsyncMessage::Notification(notif))) => {
233 0 : warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
234 : }
235 : Some(Ok(_)) => {
236 0 : warn!(%session_id, "unknown message");
237 : }
238 0 : Some(Err(e)) => {
239 0 : error!(%session_id, "connection error: {}", e);
240 0 : break
241 : }
242 : None => {
243 0 : info!("connection closed");
244 0 : break
245 : }
246 : }
247 : }
248 :
249 : // remove from connection pool
250 0 : if let Some(pool) = pool.clone().upgrade() {
251 0 : if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
252 0 : info!("closed connection removed");
253 0 : }
254 0 : }
255 :
256 0 : Poll::Ready(())
257 0 : }).await;
258 :
259 0 : }
260 0 : .instrument(span));
261 0 :
262 0 : let inner = ClientInnerCommon {
263 0 : inner: client,
264 0 : aux,
265 0 : conn_id,
266 0 : data: ClientDataEnum::Local(ClientDataLocal {
267 0 : session: tx,
268 0 : cancel,
269 0 : key,
270 0 : jti: 0,
271 0 : }),
272 0 : };
273 0 :
274 0 : Client::new(inner, conn_info, Arc::downgrade(&global_pool.global_pool))
275 0 : }
276 :
277 : impl ClientInnerCommon<postgres_client::Client> {
278 0 : pub(crate) async fn set_jwt_session(&mut self, payload: &[u8]) -> Result<(), SqlOverHttpError> {
279 0 : if let ClientDataEnum::Local(local_data) = &mut self.data {
280 0 : local_data.jti += 1;
281 0 : let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
282 :
283 0 : self.inner
284 0 : .discard_all()
285 0 : .await
286 0 : .map_err(SqlOverHttpError::InternalPostgres)?;
287 :
288 : // initiates the auth session
289 : // this is safe from query injections as the jwt format free of any escape characters.
290 0 : let query = format!("select auth.jwt_session_init('{token}')");
291 0 : self.inner
292 0 : .batch_execute(&query)
293 0 : .await
294 0 : .map_err(SqlOverHttpError::InternalPostgres)?;
295 :
296 0 : let pid = self.inner.get_process_id();
297 0 : info!(pid, jti = local_data.jti, "user session state init");
298 0 : Ok(())
299 : } else {
300 0 : panic!("unexpected client data type");
301 : }
302 0 : }
303 : }
304 :
305 : /// implements relatively efficient in-place json object key upserting
306 : ///
307 : /// only supports top-level keys
308 1 : fn upsert_json_object(
309 1 : payload: &[u8],
310 1 : key: &str,
311 1 : value: &RawValue,
312 1 : ) -> Result<String, serde_json::Error> {
313 1 : let mut payload = serde_json::from_slice::<IndexMap<&str, &RawValue>>(payload)?;
314 1 : payload.insert(key, value);
315 1 : serde_json::to_string(&payload)
316 1 : }
317 :
318 1 : fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result<String, HttpConnError> {
319 1 : let mut buffer = itoa::Buffer::new();
320 1 :
321 1 : // encode the jti integer to a json rawvalue
322 1 : let jti = serde_json::from_str::<&RawValue>(buffer.format(jti))
323 1 : .expect("itoa formatted integer should be guaranteed valid json");
324 :
325 : // update the jti in-place
326 1 : let payload =
327 1 : upsert_json_object(payload, "jti", jti).map_err(HttpConnError::JwtPayloadError)?;
328 :
329 : // sign the jwt
330 1 : let token = sign_jwt(sk, payload.as_bytes());
331 1 :
332 1 : Ok(token)
333 1 : }
334 :
335 1 : fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
336 1 : let header_len = 20;
337 1 : let payload_len = Base64UrlUnpadded::encoded_len(payload);
338 1 : let signature_len = Base64UrlUnpadded::encoded_len(&[0; 64]);
339 1 : let total_len = header_len + payload_len + signature_len + 2;
340 1 :
341 1 : let mut jwt = String::with_capacity(total_len);
342 1 : let cap = jwt.capacity();
343 1 :
344 1 : // we only need an empty header with the alg specified.
345 1 : // base64url(r#"{"alg":"EdDSA"}"#) == "eyJhbGciOiJFZERTQSJ9"
346 1 : jwt.push_str("eyJhbGciOiJFZERTQSJ9.");
347 1 :
348 1 : // encode the jwt payload in-place
349 1 : base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt);
350 1 :
351 1 : // create the signature from the encoded header || payload
352 1 : let sig: Signature = sk.sign(jwt.as_bytes());
353 1 :
354 1 : jwt.push('.');
355 1 :
356 1 : // encode the jwt signature in-place
357 1 : base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt);
358 1 :
359 1 : debug_assert_eq!(
360 1 : jwt.len(),
361 : total_len,
362 0 : "the jwt len should match our expected len"
363 : );
364 1 : debug_assert_eq!(jwt.capacity(), cap, "the jwt capacity should not change");
365 :
366 1 : jwt
367 1 : }
368 :
369 : #[cfg(test)]
370 : #[expect(clippy::unwrap_used)]
371 : mod tests {
372 : use ed25519_dalek::SigningKey;
373 : use typed_json::json;
374 :
375 : use super::resign_jwt;
376 :
377 : #[test]
378 1 : fn jwt_token_snapshot() {
379 1 : let key = SigningKey::from_bytes(&[1; 32]);
380 1 : let data =
381 1 : json!({"foo":"bar","jti":"foo\nbar","nested":{"jti":"tricky nesting"}}).to_string();
382 1 :
383 1 : let jwt = resign_jwt(&key, data.as_bytes(), 2).unwrap();
384 1 :
385 1 : // To validate the JWT, copy the JWT string and paste it into https://jwt.io/.
386 1 : // In the public-key box, paste the following jwk public key
387 1 : // `{"kty":"OKP","crv":"Ed25519","x":"iojj3XQJ8ZX9UtstPLpdcspnCb8dlBIb83SIAbQPb1w"}`
388 1 : // Note - jwt.io doesn't support EdDSA :(
389 1 : // https://github.com/jsonwebtoken/jsonwebtoken.github.io/issues/509
390 1 :
391 1 : // let jwk = jose_jwk::Key::Okp(jose_jwk::Okp {
392 1 : // crv: jose_jwk::OkpCurves::Ed25519,
393 1 : // x: jose_jwk::jose_b64::serde::Bytes::from(key.verifying_key().to_bytes().to_vec()),
394 1 : // d: None,
395 1 : // });
396 1 : // println!("{}", serde_json::to_string(&jwk).unwrap());
397 1 :
398 1 : assert_eq!(
399 1 : jwt,
400 1 : "eyJhbGciOiJFZERTQSJ9.eyJmb28iOiJiYXIiLCJqdGkiOjIsIm5lc3RlZCI6eyJqdGkiOiJ0cmlja3kgbmVzdGluZyJ9fQ.Cvyc2By33KI0f0obystwdy8PN111L3Sc9_Mr2CU3XshtSqSdxuRxNEZGbb_RvyJf2IzheC_s7aBZ-jLeQ9N0Bg"
401 1 : );
402 1 : }
403 : }
|