Line data Source code
1 : //! Connection request monitoring contexts
2 :
3 : use std::net::IpAddr;
4 :
5 : use chrono::Utc;
6 : use once_cell::sync::OnceCell;
7 : use pq_proto::StartupMessageParams;
8 : use smol_str::SmolStr;
9 : use tokio::sync::mpsc;
10 : use tracing::field::display;
11 : use tracing::{Span, debug, error, info_span};
12 : use try_lock::TryLock;
13 : use uuid::Uuid;
14 :
15 : use self::parquet::RequestData;
16 : use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
17 : use crate::error::ErrorKind;
18 : use crate::intern::{BranchIdInt, ProjectIdInt};
19 : use crate::metrics::{
20 : ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
21 : Waiting,
22 : };
23 : use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
24 : use crate::types::{DbName, EndpointId, RoleName};
25 :
26 : pub mod parquet;
27 :
28 : pub(crate) static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
29 : pub(crate) static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> =
30 : OnceCell::new();
31 :
32 : /// Context data for a single request to connect to a database.
33 : ///
34 : /// This data should **not** be used for connection logic, only for observability and limiting purposes.
35 : /// All connection logic should instead use strongly typed state machines, not a bunch of Options.
36 : pub struct RequestContext(
37 : /// To allow easier use of the ctx object, we have interior mutability.
38 : /// I would typically use a RefCell but that would break the `Send` requirements
39 : /// so we need something with thread-safety. `TryLock` is a cheap alternative
40 : /// that offers similar semantics to a `RefCell` but with synchronisation.
41 : TryLock<RequestContextInner>,
42 : );
43 :
44 : struct RequestContextInner {
45 : pub(crate) conn_info: ConnectionInfo,
46 : pub(crate) session_id: Uuid,
47 : pub(crate) protocol: Protocol,
48 : first_packet: chrono::DateTime<Utc>,
49 : region: &'static str,
50 : pub(crate) span: Span,
51 :
52 : // filled in as they are discovered
53 : project: Option<ProjectIdInt>,
54 : branch: Option<BranchIdInt>,
55 : endpoint_id: Option<EndpointId>,
56 : dbname: Option<DbName>,
57 : user: Option<RoleName>,
58 : application: Option<SmolStr>,
59 : user_agent: Option<SmolStr>,
60 : error_kind: Option<ErrorKind>,
61 : pub(crate) auth_method: Option<AuthMethod>,
62 : jwt_issuer: Option<String>,
63 : success: bool,
64 : pub(crate) cold_start_info: ColdStartInfo,
65 : pg_options: Option<StartupMessageParams>,
66 : testodrome_query_id: Option<String>,
67 :
68 : // extra
69 : // This sender is here to keep the request monitoring channel open while requests are taking place.
70 : sender: Option<mpsc::UnboundedSender<RequestData>>,
71 : // This sender is only used to log the length of session in case of success.
72 : disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
73 : pub(crate) latency_timer: LatencyTimer,
74 : // Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
75 : rejected: Option<bool>,
76 : disconnect_timestamp: Option<chrono::DateTime<Utc>>,
77 : }
78 :
79 : #[derive(Clone, Debug)]
80 : pub(crate) enum AuthMethod {
81 : // aka passwordless, fka link
82 : ConsoleRedirect,
83 : ScramSha256,
84 : ScramSha256Plus,
85 : Cleartext,
86 : Jwt,
87 : }
88 :
89 : impl Clone for RequestContext {
90 0 : fn clone(&self) -> Self {
91 0 : let inner = self.0.try_lock().expect("should not deadlock");
92 0 : let new = RequestContextInner {
93 0 : conn_info: inner.conn_info.clone(),
94 0 : session_id: inner.session_id,
95 0 : protocol: inner.protocol,
96 0 : first_packet: inner.first_packet,
97 0 : region: inner.region,
98 0 : span: info_span!("background_task"),
99 :
100 0 : project: inner.project,
101 0 : branch: inner.branch,
102 0 : endpoint_id: inner.endpoint_id.clone(),
103 0 : dbname: inner.dbname.clone(),
104 0 : user: inner.user.clone(),
105 0 : application: inner.application.clone(),
106 0 : user_agent: inner.user_agent.clone(),
107 0 : error_kind: inner.error_kind,
108 0 : auth_method: inner.auth_method.clone(),
109 0 : jwt_issuer: inner.jwt_issuer.clone(),
110 0 : success: inner.success,
111 0 : rejected: inner.rejected,
112 0 : cold_start_info: inner.cold_start_info,
113 0 : pg_options: inner.pg_options.clone(),
114 0 : testodrome_query_id: inner.testodrome_query_id.clone(),
115 0 :
116 0 : sender: None,
117 0 : disconnect_sender: None,
118 0 : latency_timer: LatencyTimer::noop(inner.protocol),
119 0 : disconnect_timestamp: inner.disconnect_timestamp,
120 0 : };
121 0 :
122 0 : Self(TryLock::new(new))
123 0 : }
124 : }
125 :
126 : impl RequestContext {
127 70 : pub fn new(
128 70 : session_id: Uuid,
129 70 : conn_info: ConnectionInfo,
130 70 : protocol: Protocol,
131 70 : region: &'static str,
132 70 : ) -> Self {
133 : // TODO: be careful with long lived spans
134 70 : let span = info_span!(
135 70 : "connect_request",
136 70 : %protocol,
137 70 : ?session_id,
138 70 : %conn_info,
139 70 : ep = tracing::field::Empty,
140 70 : role = tracing::field::Empty,
141 70 : );
142 :
143 70 : let inner = RequestContextInner {
144 70 : conn_info,
145 70 : session_id,
146 70 : protocol,
147 70 : first_packet: Utc::now(),
148 70 : region,
149 70 : span,
150 70 :
151 70 : project: None,
152 70 : branch: None,
153 70 : endpoint_id: None,
154 70 : dbname: None,
155 70 : user: None,
156 70 : application: None,
157 70 : user_agent: None,
158 70 : error_kind: None,
159 70 : auth_method: None,
160 70 : jwt_issuer: None,
161 70 : success: false,
162 70 : rejected: None,
163 70 : cold_start_info: ColdStartInfo::Unknown,
164 70 : pg_options: None,
165 70 : testodrome_query_id: None,
166 70 :
167 70 : sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
168 70 : disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
169 70 : latency_timer: LatencyTimer::new(protocol),
170 70 : disconnect_timestamp: None,
171 70 : };
172 70 :
173 70 : Self(TryLock::new(inner))
174 70 : }
175 :
176 : #[cfg(test)]
177 70 : pub(crate) fn test() -> Self {
178 : use std::net::SocketAddr;
179 70 : let ip = IpAddr::from([127, 0, 0, 1]);
180 70 : let addr = SocketAddr::new(ip, 5432);
181 70 : let conn_info = ConnectionInfo { addr, extra: None };
182 70 : RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
183 70 : }
184 :
185 0 : pub(crate) fn console_application_name(&self) -> String {
186 0 : let this = self.0.try_lock().expect("should not deadlock");
187 0 : format!(
188 0 : "{}/{}",
189 0 : this.application.as_deref().unwrap_or_default(),
190 0 : this.protocol
191 0 : )
192 0 : }
193 :
194 0 : pub(crate) fn set_rejected(&self, rejected: bool) {
195 0 : let mut this = self.0.try_lock().expect("should not deadlock");
196 0 : this.rejected = Some(rejected);
197 0 : }
198 :
199 0 : pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
200 0 : self.0
201 0 : .try_lock()
202 0 : .expect("should not deadlock")
203 0 : .set_cold_start_info(info);
204 0 : }
205 :
206 0 : pub(crate) fn set_db_options(&self, options: StartupMessageParams) {
207 0 : let mut this = self.0.try_lock().expect("should not deadlock");
208 0 : this.set_application(options.get("application_name").map(SmolStr::from));
209 0 : if let Some(user) = options.get("user") {
210 0 : this.set_user(user.into());
211 0 : }
212 0 : if let Some(dbname) = options.get("database") {
213 0 : this.set_dbname(dbname.into());
214 0 : }
215 :
216 : // Try to get testodrome_query_id directly from parameters
217 0 : if let Some(options_str) = options.get("options") {
218 : // If not found directly, try to extract it from the options string
219 0 : for option in options_str.split_whitespace() {
220 0 : if option.starts_with("neon_query_id:") {
221 0 : if let Some(value) = option.strip_prefix("neon_query_id:") {
222 0 : this.set_testodrome_id(value.to_string());
223 0 : break;
224 0 : }
225 0 : }
226 : }
227 0 : }
228 :
229 0 : this.pg_options = Some(options);
230 0 : }
231 :
232 0 : pub(crate) fn set_project(&self, x: MetricsAuxInfo) {
233 0 : let mut this = self.0.try_lock().expect("should not deadlock");
234 0 : if this.endpoint_id.is_none() {
235 0 : this.set_endpoint_id(x.endpoint_id.as_str().into());
236 0 : }
237 0 : this.branch = Some(x.branch_id);
238 0 : this.project = Some(x.project_id);
239 0 : this.set_cold_start_info(x.cold_start_info);
240 0 : }
241 :
242 0 : pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) {
243 0 : let mut this = self.0.try_lock().expect("should not deadlock");
244 0 : this.project = Some(project_id);
245 0 : }
246 :
247 28 : pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) {
248 28 : self.0
249 28 : .try_lock()
250 28 : .expect("should not deadlock")
251 28 : .set_endpoint_id(endpoint_id);
252 28 : }
253 :
254 0 : pub(crate) fn set_dbname(&self, dbname: DbName) {
255 0 : self.0
256 0 : .try_lock()
257 0 : .expect("should not deadlock")
258 0 : .set_dbname(dbname);
259 0 : }
260 :
261 0 : pub(crate) fn set_user(&self, user: RoleName) {
262 0 : self.0
263 0 : .try_lock()
264 0 : .expect("should not deadlock")
265 0 : .set_user(user);
266 0 : }
267 :
268 0 : pub(crate) fn set_user_agent(&self, user_agent: Option<SmolStr>) {
269 0 : self.0
270 0 : .try_lock()
271 0 : .expect("should not deadlock")
272 0 : .set_user_agent(user_agent);
273 0 : }
274 :
275 0 : pub(crate) fn set_testodrome_id(&self, query_id: String) {
276 0 : self.0
277 0 : .try_lock()
278 0 : .expect("should not deadlock")
279 0 : .set_testodrome_id(query_id);
280 0 : }
281 :
282 15 : pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) {
283 15 : let mut this = self.0.try_lock().expect("should not deadlock");
284 15 : this.auth_method = Some(auth_method);
285 15 : }
286 :
287 12 : pub(crate) fn set_jwt_issuer(&self, jwt_issuer: String) {
288 12 : let mut this = self.0.try_lock().expect("should not deadlock");
289 12 : this.jwt_issuer = Some(jwt_issuer);
290 12 : }
291 :
292 0 : pub fn has_private_peer_addr(&self) -> bool {
293 0 : self.0
294 0 : .try_lock()
295 0 : .expect("should not deadlock")
296 0 : .has_private_peer_addr()
297 0 : }
298 :
299 0 : pub(crate) fn set_error_kind(&self, kind: ErrorKind) {
300 0 : let mut this = self.0.try_lock().expect("should not deadlock");
301 0 : // Do not record errors from the private address to metrics.
302 0 : if !this.has_private_peer_addr() {
303 0 : Metrics::get().proxy.errors_total.inc(kind);
304 0 : }
305 0 : if let Some(ep) = &this.endpoint_id {
306 0 : let metric = &Metrics::get().proxy.endpoints_affected_by_errors;
307 0 : let label = metric.with_labels(kind);
308 0 : metric.get_metric(label).measure(ep);
309 0 : }
310 0 : this.error_kind = Some(kind);
311 0 : }
312 :
313 0 : pub fn set_success(&self) {
314 0 : let mut this = self.0.try_lock().expect("should not deadlock");
315 0 : this.success = true;
316 0 : }
317 :
318 0 : pub fn log_connect(self) -> DisconnectLogger {
319 0 : let mut this = self.0.into_inner();
320 0 : this.log_connect();
321 0 :
322 0 : // close current span.
323 0 : this.span = Span::none();
324 0 :
325 0 : DisconnectLogger(this)
326 0 : }
327 :
328 0 : pub(crate) fn protocol(&self) -> Protocol {
329 0 : self.0.try_lock().expect("should not deadlock").protocol
330 0 : }
331 :
332 0 : pub(crate) fn span(&self) -> Span {
333 0 : self.0.try_lock().expect("should not deadlock").span.clone()
334 0 : }
335 :
336 0 : pub(crate) fn session_id(&self) -> Uuid {
337 0 : self.0.try_lock().expect("should not deadlock").session_id
338 0 : }
339 :
340 6 : pub(crate) fn peer_addr(&self) -> IpAddr {
341 6 : self.0
342 6 : .try_lock()
343 6 : .expect("should not deadlock")
344 6 : .conn_info
345 6 : .addr
346 6 : .ip()
347 6 : }
348 :
349 0 : pub(crate) fn extra(&self) -> Option<ConnectionInfoExtra> {
350 0 : self.0
351 0 : .try_lock()
352 0 : .expect("should not deadlock")
353 0 : .conn_info
354 0 : .extra
355 0 : .clone()
356 0 : }
357 :
358 0 : pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
359 0 : self.0
360 0 : .try_lock()
361 0 : .expect("should not deadlock")
362 0 : .cold_start_info
363 0 : }
364 :
365 28 : pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
366 28 : LatencyTimerPause {
367 28 : ctx: self,
368 28 : start: tokio::time::Instant::now(),
369 28 : waiting_for,
370 28 : }
371 28 : }
372 :
373 0 : pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated {
374 0 : self.0
375 0 : .try_lock()
376 0 : .expect("should not deadlock")
377 0 : .latency_timer
378 0 : .accumulated()
379 0 : }
380 :
381 0 : pub(crate) fn get_testodrome_id(&self) -> Option<String> {
382 0 : self.0
383 0 : .try_lock()
384 0 : .expect("should not deadlock")
385 0 : .testodrome_query_id
386 0 : .clone()
387 0 : }
388 :
389 4 : pub(crate) fn success(&self) {
390 4 : self.0
391 4 : .try_lock()
392 4 : .expect("should not deadlock")
393 4 : .latency_timer
394 4 : .success();
395 4 : }
396 : }
397 :
398 : pub(crate) struct LatencyTimerPause<'a> {
399 : ctx: &'a RequestContext,
400 : start: tokio::time::Instant,
401 : waiting_for: Waiting,
402 : }
403 :
404 : impl Drop for LatencyTimerPause<'_> {
405 28 : fn drop(&mut self) {
406 28 : self.ctx
407 28 : .0
408 28 : .try_lock()
409 28 : .expect("should not deadlock")
410 28 : .latency_timer
411 28 : .unpause(self.start, self.waiting_for);
412 28 : }
413 : }
414 :
415 : impl RequestContextInner {
416 0 : fn set_cold_start_info(&mut self, info: ColdStartInfo) {
417 0 : self.cold_start_info = info;
418 0 : self.latency_timer.cold_start_info(info);
419 0 : }
420 :
421 28 : fn set_endpoint_id(&mut self, endpoint_id: EndpointId) {
422 28 : if self.endpoint_id.is_none() {
423 28 : self.span.record("ep", display(&endpoint_id));
424 28 : let metric = &Metrics::get().proxy.connecting_endpoints;
425 28 : let label = metric.with_labels(self.protocol);
426 28 : metric.get_metric(label).measure(&endpoint_id);
427 28 : self.endpoint_id = Some(endpoint_id);
428 28 : }
429 28 : }
430 :
431 0 : fn set_application(&mut self, app: Option<SmolStr>) {
432 0 : if let Some(app) = app {
433 0 : self.application = Some(app);
434 0 : }
435 0 : }
436 :
437 0 : fn set_user_agent(&mut self, user_agent: Option<SmolStr>) {
438 0 : self.user_agent = user_agent;
439 0 : }
440 :
441 0 : fn set_dbname(&mut self, dbname: DbName) {
442 0 : self.dbname = Some(dbname);
443 0 : }
444 :
445 0 : fn set_user(&mut self, user: RoleName) {
446 0 : self.span.record("role", display(&user));
447 0 : self.user = Some(user);
448 0 : }
449 :
450 0 : fn set_testodrome_id(&mut self, query_id: String) {
451 0 : self.testodrome_query_id = Some(query_id);
452 0 : }
453 :
454 0 : fn has_private_peer_addr(&self) -> bool {
455 0 : match self.conn_info.addr.ip() {
456 0 : IpAddr::V4(ip) => ip.is_private(),
457 0 : IpAddr::V6(_) => false,
458 : }
459 0 : }
460 :
461 0 : fn log_connect(&mut self) {
462 0 : let outcome = if self.success {
463 0 : ConnectOutcome::Success
464 : } else {
465 0 : ConnectOutcome::Failed
466 : };
467 :
468 : // TODO: get rid of entirely/refactor
469 : // check for false positives
470 : // AND false negatives
471 0 : if let Some(rejected) = self.rejected {
472 0 : let ep = self
473 0 : .endpoint_id
474 0 : .as_ref()
475 0 : .map(|x| x.as_str())
476 0 : .unwrap_or_default();
477 0 : // This makes sense only if cache is disabled
478 0 : debug!(
479 : ?outcome,
480 : ?rejected,
481 : ?ep,
482 0 : "check endpoint is valid with outcome"
483 : );
484 0 : Metrics::get()
485 0 : .proxy
486 0 : .invalid_endpoints_total
487 0 : .inc(InvalidEndpointsGroup {
488 0 : protocol: self.protocol,
489 0 : rejected: rejected.into(),
490 0 : outcome,
491 0 : });
492 0 : }
493 :
494 0 : if let Some(tx) = self.sender.take() {
495 : // If type changes, this error handling needs to be updated.
496 0 : let tx: mpsc::UnboundedSender<RequestData> = tx;
497 0 : if let Err(e) = tx.send(RequestData::from(&*self)) {
498 0 : error!("log_connect channel send failed: {e}");
499 0 : }
500 0 : }
501 0 : }
502 :
503 0 : fn log_disconnect(&mut self) {
504 0 : // If we are here, it's guaranteed that the user successfully connected to the endpoint.
505 0 : // Here we log the length of the session.
506 0 : self.disconnect_timestamp = Some(Utc::now());
507 0 : if let Some(tx) = self.disconnect_sender.take() {
508 : // If type changes, this error handling needs to be updated.
509 0 : let tx: mpsc::UnboundedSender<RequestData> = tx;
510 0 : if let Err(e) = tx.send(RequestData::from(&*self)) {
511 0 : error!("log_disconnect channel send failed: {e}");
512 0 : }
513 0 : }
514 0 : }
515 : }
516 :
517 : impl Drop for RequestContextInner {
518 70 : fn drop(&mut self) {
519 70 : if self.sender.is_some() {
520 0 : self.log_connect();
521 70 : }
522 70 : }
523 : }
524 :
525 : pub struct DisconnectLogger(RequestContextInner);
526 :
527 : impl Drop for DisconnectLogger {
528 0 : fn drop(&mut self) {
529 0 : self.0.log_disconnect();
530 0 : }
531 : }
|