Line data Source code
1 : use futures::FutureExt;
2 : use smol_str::SmolStr;
3 : use tokio::io::{AsyncRead, AsyncWrite};
4 : use tracing::debug;
5 : use utils::measured_stream::MeasuredStream;
6 :
7 : use super::copy_bidirectional::ErrorSource;
8 : use crate::cancellation;
9 : use crate::compute::PostgresConnection;
10 : use crate::config::ComputeConfig;
11 : use crate::control_plane::messages::MetricsAuxInfo;
12 : use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
13 : use crate::stream::Stream;
14 : use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
15 :
16 : /// Forward bytes in both directions (client <-> compute).
17 : #[tracing::instrument(skip_all)]
18 : pub(crate) async fn proxy_pass(
19 : client: impl AsyncRead + AsyncWrite + Unpin,
20 : compute: impl AsyncRead + AsyncWrite + Unpin,
21 : aux: MetricsAuxInfo,
22 : private_link_id: Option<SmolStr>,
23 : ) -> Result<(), ErrorSource> {
24 : // we will report ingress at a later date
25 : let usage_tx = USAGE_METRICS.register(Ids {
26 : endpoint_id: aux.endpoint_id,
27 : branch_id: aux.branch_id,
28 : private_link_id,
29 : });
30 :
31 : let metrics = &Metrics::get().proxy.io_bytes;
32 : let m_sent = metrics.with_labels(Direction::Tx);
33 : let mut client = MeasuredStream::new(
34 : client,
35 0 : |_| {},
36 0 : |cnt| {
37 0 : // Number of bytes we sent to the client (outbound).
38 0 : metrics.get_metric(m_sent).inc_by(cnt as u64);
39 0 : usage_tx.record_egress(cnt as u64);
40 0 : },
41 : );
42 :
43 : let m_recv = metrics.with_labels(Direction::Rx);
44 : let mut compute = MeasuredStream::new(
45 : compute,
46 0 : |_| {},
47 0 : |cnt| {
48 0 : // Number of bytes the client sent to the compute node (inbound).
49 0 : metrics.get_metric(m_recv).inc_by(cnt as u64);
50 0 : usage_tx.record_ingress(cnt as u64);
51 0 : },
52 : );
53 :
54 : // Starting from here we only proxy the client's traffic.
55 : debug!("performing the proxy pass...");
56 : let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute(
57 : &mut client,
58 : &mut compute,
59 : )
60 : .await?;
61 :
62 : Ok(())
63 : }
64 :
65 : pub(crate) struct ProxyPassthrough<S> {
66 : pub(crate) client: Stream<S>,
67 : pub(crate) compute: PostgresConnection,
68 : pub(crate) aux: MetricsAuxInfo,
69 : pub(crate) session_id: uuid::Uuid,
70 : pub(crate) private_link_id: Option<SmolStr>,
71 : pub(crate) cancel: cancellation::Session,
72 :
73 : pub(crate) _req: NumConnectionRequestsGuard<'static>,
74 : pub(crate) _conn: NumClientConnectionsGuard<'static>,
75 : }
76 :
77 : impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
78 0 : pub(crate) async fn proxy_pass(
79 0 : self,
80 0 : compute_config: &ComputeConfig,
81 0 : ) -> Result<(), ErrorSource> {
82 0 : let res = proxy_pass(
83 0 : self.client,
84 0 : self.compute.stream,
85 0 : self.aux,
86 0 : self.private_link_id,
87 0 : )
88 0 : .await;
89 0 : if let Err(err) = self
90 0 : .compute
91 0 : .cancel_closure
92 0 : .try_cancel_query(compute_config)
93 0 : .boxed()
94 0 : .await
95 : {
96 0 : tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
97 0 : }
98 :
99 0 : drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
100 0 :
101 0 : res
102 0 : }
103 : }
|