Line data Source code
1 : //! Connection request monitoring contexts
2 :
3 : use std::net::IpAddr;
4 :
5 : use chrono::Utc;
6 : use once_cell::sync::OnceCell;
7 : use pq_proto::StartupMessageParams;
8 : use smol_str::SmolStr;
9 : use tokio::sync::mpsc;
10 : use tracing::field::display;
11 : use tracing::{debug, error, info_span, Span};
12 : use try_lock::TryLock;
13 : use uuid::Uuid;
14 :
15 : use self::parquet::RequestData;
16 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
17 : use crate::error::ErrorKind;
18 : use crate::intern::{BranchIdInt, ProjectIdInt};
19 : use crate::metrics::{
20 : ConnectOutcome, InvalidEndpointsGroup, LatencyTimer, Metrics, Protocol, Waiting,
21 : };
22 : use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
23 : use crate::types::{DbName, EndpointId, RoleName};
24 :
25 : pub mod parquet;
26 :
27 : pub(crate) static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
28 : pub(crate) static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> =
29 : OnceCell::new();
30 :
31 : /// Context data for a single request to connect to a database.
32 : ///
33 : /// This data should **not** be used for connection logic, only for observability and limiting purposes.
34 : /// All connection logic should instead use strongly typed state machines, not a bunch of Options.
35 : pub struct RequestContext(
36 : /// To allow easier use of the ctx object, we have interior mutability.
37 : /// I would typically use a RefCell but that would break the `Send` requirements
38 : /// so we need something with thread-safety. `TryLock` is a cheap alternative
39 : /// that offers similar semantics to a `RefCell` but with synchronisation.
40 : TryLock<RequestContextInner>,
41 : );
42 :
43 : struct RequestContextInner {
44 : pub(crate) conn_info: ConnectionInfo,
45 : pub(crate) session_id: Uuid,
46 : pub(crate) protocol: Protocol,
47 : first_packet: chrono::DateTime<Utc>,
48 : region: &'static str,
49 : pub(crate) span: Span,
50 :
51 : // filled in as they are discovered
52 : project: Option<ProjectIdInt>,
53 : branch: Option<BranchIdInt>,
54 : endpoint_id: Option<EndpointId>,
55 : dbname: Option<DbName>,
56 : user: Option<RoleName>,
57 : application: Option<SmolStr>,
58 : error_kind: Option<ErrorKind>,
59 : pub(crate) auth_method: Option<AuthMethod>,
60 : jwt_issuer: Option<String>,
61 : success: bool,
62 : pub(crate) cold_start_info: ColdStartInfo,
63 : pg_options: Option<StartupMessageParams>,
64 :
65 : // extra
66 : // This sender is here to keep the request monitoring channel open while requests are taking place.
67 : sender: Option<mpsc::UnboundedSender<RequestData>>,
68 : // This sender is only used to log the length of session in case of success.
69 : disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
70 : pub(crate) latency_timer: LatencyTimer,
71 : // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
72 : rejected: Option<bool>,
73 : disconnect_timestamp: Option<chrono::DateTime<Utc>>,
74 : }
75 :
76 : #[derive(Clone, Debug)]
77 : pub(crate) enum AuthMethod {
78 : // aka passwordless, fka link
79 : ConsoleRedirect,
80 : ScramSha256,
81 : ScramSha256Plus,
82 : Cleartext,
83 : Jwt,
84 : }
85 :
86 : impl Clone for RequestContext {
87 0 : fn clone(&self) -> Self {
88 0 : let inner = self.0.try_lock().expect("should not deadlock");
89 0 : let new = RequestContextInner {
90 0 : conn_info: inner.conn_info.clone(),
91 0 : session_id: inner.session_id,
92 0 : protocol: inner.protocol,
93 0 : first_packet: inner.first_packet,
94 0 : region: inner.region,
95 0 : span: info_span!("background_task"),
96 :
97 0 : project: inner.project,
98 0 : branch: inner.branch,
99 0 : endpoint_id: inner.endpoint_id.clone(),
100 0 : dbname: inner.dbname.clone(),
101 0 : user: inner.user.clone(),
102 0 : application: inner.application.clone(),
103 0 : error_kind: inner.error_kind,
104 0 : auth_method: inner.auth_method.clone(),
105 0 : jwt_issuer: inner.jwt_issuer.clone(),
106 0 : success: inner.success,
107 0 : rejected: inner.rejected,
108 0 : cold_start_info: inner.cold_start_info,
109 0 : pg_options: inner.pg_options.clone(),
110 0 :
111 0 : sender: None,
112 0 : disconnect_sender: None,
113 0 : latency_timer: LatencyTimer::noop(inner.protocol),
114 0 : disconnect_timestamp: inner.disconnect_timestamp,
115 0 : };
116 0 :
117 0 : Self(TryLock::new(new))
118 0 : }
119 : }
120 :
121 : impl RequestContext {
122 70 : pub fn new(
123 70 : session_id: Uuid,
124 70 : conn_info: ConnectionInfo,
125 70 : protocol: Protocol,
126 70 : region: &'static str,
127 70 : ) -> Self {
128 : // TODO: be careful with long lived spans
129 70 : let span = info_span!(
130 70 : "connect_request",
131 70 : %protocol,
132 70 : ?session_id,
133 70 : %conn_info,
134 70 : ep = tracing::field::Empty,
135 70 : role = tracing::field::Empty,
136 70 : );
137 :
138 70 : let inner = RequestContextInner {
139 70 : conn_info,
140 70 : session_id,
141 70 : protocol,
142 70 : first_packet: Utc::now(),
143 70 : region,
144 70 : span,
145 70 :
146 70 : project: None,
147 70 : branch: None,
148 70 : endpoint_id: None,
149 70 : dbname: None,
150 70 : user: None,
151 70 : application: None,
152 70 : error_kind: None,
153 70 : auth_method: None,
154 70 : jwt_issuer: None,
155 70 : success: false,
156 70 : rejected: None,
157 70 : cold_start_info: ColdStartInfo::Unknown,
158 70 : pg_options: None,
159 70 :
160 70 : sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
161 70 : disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
162 70 : latency_timer: LatencyTimer::new(protocol),
163 70 : disconnect_timestamp: None,
164 70 : };
165 70 :
166 70 : Self(TryLock::new(inner))
167 70 : }
168 :
169 : #[cfg(test)]
170 70 : pub(crate) fn test() -> Self {
171 : use std::net::SocketAddr;
172 70 : let ip = IpAddr::from([127, 0, 0, 1]);
173 70 : let addr = SocketAddr::new(ip, 5432);
174 70 : let conn_info = ConnectionInfo { addr, extra: None };
175 70 : RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
176 70 : }
177 :
178 0 : pub(crate) fn console_application_name(&self) -> String {
179 0 : let this = self.0.try_lock().expect("should not deadlock");
180 0 : format!(
181 0 : "{}/{}",
182 0 : this.application.as_deref().unwrap_or_default(),
183 0 : this.protocol
184 0 : )
185 0 : }
186 :
187 0 : pub(crate) fn set_rejected(&self, rejected: bool) {
188 0 : let mut this = self.0.try_lock().expect("should not deadlock");
189 0 : this.rejected = Some(rejected);
190 0 : }
191 :
192 0 : pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
193 0 : self.0
194 0 : .try_lock()
195 0 : .expect("should not deadlock")
196 0 : .set_cold_start_info(info);
197 0 : }
198 :
199 0 : pub(crate) fn set_db_options(&self, options: StartupMessageParams) {
200 0 : let mut this = self.0.try_lock().expect("should not deadlock");
201 0 : this.set_application(options.get("application_name").map(SmolStr::from));
202 0 : if let Some(user) = options.get("user") {
203 0 : this.set_user(user.into());
204 0 : }
205 0 : if let Some(dbname) = options.get("database") {
206 0 : this.set_dbname(dbname.into());
207 0 : }
208 :
209 0 : this.pg_options = Some(options);
210 0 : }
211 :
212 0 : pub(crate) fn set_project(&self, x: MetricsAuxInfo) {
213 0 : let mut this = self.0.try_lock().expect("should not deadlock");
214 0 : if this.endpoint_id.is_none() {
215 0 : this.set_endpoint_id(x.endpoint_id.as_str().into());
216 0 : }
217 0 : this.branch = Some(x.branch_id);
218 0 : this.project = Some(x.project_id);
219 0 : this.set_cold_start_info(x.cold_start_info);
220 0 : }
221 :
222 0 : pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) {
223 0 : let mut this = self.0.try_lock().expect("should not deadlock");
224 0 : this.project = Some(project_id);
225 0 : }
226 :
227 28 : pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) {
228 28 : self.0
229 28 : .try_lock()
230 28 : .expect("should not deadlock")
231 28 : .set_endpoint_id(endpoint_id);
232 28 : }
233 :
234 0 : pub(crate) fn set_dbname(&self, dbname: DbName) {
235 0 : self.0
236 0 : .try_lock()
237 0 : .expect("should not deadlock")
238 0 : .set_dbname(dbname);
239 0 : }
240 :
241 0 : pub(crate) fn set_user(&self, user: RoleName) {
242 0 : self.0
243 0 : .try_lock()
244 0 : .expect("should not deadlock")
245 0 : .set_user(user);
246 0 : }
247 :
248 15 : pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) {
249 15 : let mut this = self.0.try_lock().expect("should not deadlock");
250 15 : this.auth_method = Some(auth_method);
251 15 : }
252 :
253 12 : pub(crate) fn set_jwt_issuer(&self, jwt_issuer: String) {
254 12 : let mut this = self.0.try_lock().expect("should not deadlock");
255 12 : this.jwt_issuer = Some(jwt_issuer);
256 12 : }
257 :
258 0 : pub fn has_private_peer_addr(&self) -> bool {
259 0 : self.0
260 0 : .try_lock()
261 0 : .expect("should not deadlock")
262 0 : .has_private_peer_addr()
263 0 : }
264 :
265 0 : pub(crate) fn set_error_kind(&self, kind: ErrorKind) {
266 0 : let mut this = self.0.try_lock().expect("should not deadlock");
267 0 : // Do not record errors from the private address to metrics.
268 0 : if !this.has_private_peer_addr() {
269 0 : Metrics::get().proxy.errors_total.inc(kind);
270 0 : }
271 0 : if let Some(ep) = &this.endpoint_id {
272 0 : let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
273 0 : let label = metric.with_labels(kind);
274 0 : metric.get_metric(label).measure(ep);
275 0 : }
276 0 : this.error_kind = Some(kind);
277 0 : }
278 :
279 0 : pub fn set_success(&self) {
280 0 : let mut this = self.0.try_lock().expect("should not deadlock");
281 0 : this.success = true;
282 0 : }
283 :
284 0 : pub fn log_connect(self) -> DisconnectLogger {
285 0 : let mut this = self.0.into_inner();
286 0 : this.log_connect();
287 0 :
288 0 : // close current span.
289 0 : this.span = Span::none();
290 0 :
291 0 : DisconnectLogger(this)
292 0 : }
293 :
294 0 : pub(crate) fn protocol(&self) -> Protocol {
295 0 : self.0.try_lock().expect("should not deadlock").protocol
296 0 : }
297 :
298 0 : pub(crate) fn span(&self) -> Span {
299 0 : self.0.try_lock().expect("should not deadlock").span.clone()
300 0 : }
301 :
302 0 : pub(crate) fn session_id(&self) -> Uuid {
303 0 : self.0.try_lock().expect("should not deadlock").session_id
304 0 : }
305 :
306 6 : pub(crate) fn peer_addr(&self) -> IpAddr {
307 6 : self.0
308 6 : .try_lock()
309 6 : .expect("should not deadlock")
310 6 : .conn_info
311 6 : .addr
312 6 : .ip()
313 6 : }
314 :
315 0 : pub(crate) fn extra(&self) -> Option<ConnectionInfoExtra> {
316 0 : self.0
317 0 : .try_lock()
318 0 : .expect("should not deadlock")
319 0 : .conn_info
320 0 : .extra
321 0 : .clone()
322 0 : }
323 :
324 0 : pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
325 0 : self.0
326 0 : .try_lock()
327 0 : .expect("should not deadlock")
328 0 : .cold_start_info
329 0 : }
330 :
331 28 : pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
332 28 : LatencyTimerPause {
333 28 : ctx: self,
334 28 : start: tokio::time::Instant::now(),
335 28 : waiting_for,
336 28 : }
337 28 : }
338 :
339 4 : pub(crate) fn success(&self) {
340 4 : self.0
341 4 : .try_lock()
342 4 : .expect("should not deadlock")
343 4 : .latency_timer
344 4 : .success();
345 4 : }
346 : }
347 :
348 : pub(crate) struct LatencyTimerPause<'a> {
349 : ctx: &'a RequestContext,
350 : start: tokio::time::Instant,
351 : waiting_for: Waiting,
352 : }
353 :
354 : impl Drop for LatencyTimerPause<'_> {
355 28 : fn drop(&mut self) {
356 28 : self.ctx
357 28 : .0
358 28 : .try_lock()
359 28 : .expect("should not deadlock")
360 28 : .latency_timer
361 28 : .unpause(self.start, self.waiting_for);
362 28 : }
363 : }
364 :
365 : impl RequestContextInner {
366 0 : fn set_cold_start_info(&mut self, info: ColdStartInfo) {
367 0 : self.cold_start_info = info;
368 0 : self.latency_timer.cold_start_info(info);
369 0 : }
370 :
371 28 : fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
372 28 : if self.endpoint_id.is_none() {
373 28 : self.span.record("ep", display(&endpoint_id));
374 28 : let metric = &Metrics::get().proxy.connecting_endpoints;
375 28 : let label = metric.with_labels(self.protocol);
376 28 : metric.get_metric(label).measure(&endpoint_id);
377 28 : self.endpoint_id = Some(endpoint_id);
378 28 : }
379 28 : }
380 :
381 0 : fn set_application(&mut self, app: Option<SmolStr>) {
382 0 : if let Some(app) = app {
383 0 : self.application = Some(app);
384 0 : }
385 0 : }
386 :
387 0 : fn set_dbname(&mut self, dbname: DbName) {
388 0 : self.dbname = Some(dbname);
389 0 : }
390 :
391 0 : fn set_user(&mut self, user: RoleName) {
392 0 : self.span.record("role", display(&user));
393 0 : self.user = Some(user);
394 0 : }
395 :
396 0 : fn has_private_peer_addr(&self) -> bool {
397 0 : match self.conn_info.addr.ip() {
398 0 : IpAddr::V4(ip) => ip.is_private(),
399 0 : IpAddr::V6(_) => false,
400 : }
401 0 : }
402 :
403 0 : fn log_connect(&mut self) {
404 0 : let outcome = if self.success {
405 0 : ConnectOutcome::Success
406 : } else {
407 0 : ConnectOutcome::Failed
408 : };
409 :
410 : // TODO: get rid of entirely/refactor
411 : // check for false positives
412 : // AND false negatives
413 0 : if let Some(rejected) = self.rejected {
414 0 : let ep = self
415 0 : .endpoint_id
416 0 : .as_ref()
417 0 : .map(|x| x.as_str())
418 0 : .unwrap_or_default();
419 0 : // This makes sense only if cache is disabled
420 0 : debug!(
421 : ?outcome,
422 : ?rejected,
423 : ?ep,
424 0 : "check endpoint is valid with outcome"
425 : );
426 0 : Metrics::get()
427 0 : .proxy
428 0 : .invalid_endpoints_total
429 0 : .inc(InvalidEndpointsGroup {
430 0 : protocol: self.protocol,
431 0 : rejected: rejected.into(),
432 0 : outcome,
433 0 : });
434 0 : }
435 :
436 0 : if let Some(tx) = self.sender.take() {
437 : // If type changes, this error handling needs to be updated.
438 0 : let tx: mpsc::UnboundedSender<RequestData> = tx;
439 0 : if let Err(e) = tx.send(RequestData::from(&*self)) {
440 0 : error!("log_connect channel send failed: {e}");
441 0 : }
442 0 : }
443 0 : }
444 :
445 0 : fn log_disconnect(&mut self) {
446 0 : // If we are here, it's guaranteed that the user successfully connected to the endpoint.
447 0 : // Here we log the length of the session.
448 0 : self.disconnect_timestamp = Some(Utc::now());
449 0 : if let Some(tx) = self.disconnect_sender.take() {
450 : // If type changes, this error handling needs to be updated.
451 0 : let tx: mpsc::UnboundedSender<RequestData> = tx;
452 0 : if let Err(e) = tx.send(RequestData::from(&*self)) {
453 0 : error!("log_disconnect channel send failed: {e}");
454 0 : }
455 0 : }
456 0 : }
457 : }
458 :
459 : impl Drop for RequestContextInner {
460 70 : fn drop(&mut self) {
461 70 : if self.sender.is_some() {
462 0 : self.log_connect();
463 70 : }
464 70 : }
465 : }
466 :
467 : pub struct DisconnectLogger(RequestContextInner);
468 :
469 : impl Drop for DisconnectLogger {
470 0 : fn drop(&mut self) {
471 0 : self.0.log_disconnect();
472 0 : }
473 : }
|