Line data Source code
1 : use tokio::io::{AsyncRead, AsyncWrite};
2 : use tracing::debug;
3 : use utils::measured_stream::MeasuredStream;
4 :
5 : use super::copy_bidirectional::ErrorSource;
6 : use crate::cancellation;
7 : use crate::compute::PostgresConnection;
8 : use crate::control_plane::messages::MetricsAuxInfo;
9 : use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
10 : use crate::stream::Stream;
11 : use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
12 :
13 : /// Forward bytes in both directions (client <-> compute).
14 0 : #[tracing::instrument(skip_all)]
15 : pub(crate) async fn proxy_pass(
16 : client: impl AsyncRead + AsyncWrite + Unpin,
17 : compute: impl AsyncRead + AsyncWrite + Unpin,
18 : aux: MetricsAuxInfo,
19 : ) -> Result<(), ErrorSource> {
20 : let usage = USAGE_METRICS.register(Ids {
21 : endpoint_id: aux.endpoint_id,
22 : branch_id: aux.branch_id,
23 : });
24 :
25 : let metrics = &Metrics::get().proxy.io_bytes;
26 : let m_sent = metrics.with_labels(Direction::Tx);
27 : let mut client = MeasuredStream::new(
28 : client,
29 0 : |_| {},
30 0 : |cnt| {
31 0 : // Number of bytes we sent to the client (outbound).
32 0 : metrics.get_metric(m_sent).inc_by(cnt as u64);
33 0 : usage.record_egress(cnt as u64);
34 0 : },
35 : );
36 :
37 : let m_recv = metrics.with_labels(Direction::Rx);
38 : let mut compute = MeasuredStream::new(
39 : compute,
40 0 : |_| {},
41 0 : |cnt| {
42 0 : // Number of bytes the client sent to the compute node (inbound).
43 0 : metrics.get_metric(m_recv).inc_by(cnt as u64);
44 0 : },
45 : );
46 :
47 : // Starting from here we only proxy the client's traffic.
48 : debug!("performing the proxy pass...");
49 : let _ = crate::proxy::copy_bidirectional::copy_bidirectional_client_compute(
50 : &mut client,
51 : &mut compute,
52 : )
53 : .await?;
54 :
55 : Ok(())
56 : }
57 :
58 : pub(crate) struct ProxyPassthrough<P, S> {
59 : pub(crate) client: Stream<S>,
60 : pub(crate) compute: PostgresConnection,
61 : pub(crate) aux: MetricsAuxInfo,
62 :
63 : pub(crate) _req: NumConnectionRequestsGuard<'static>,
64 : pub(crate) _conn: NumClientConnectionsGuard<'static>,
65 : pub(crate) _cancel: cancellation::Session<P>,
66 : }
67 :
68 : impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
69 0 : pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
70 0 : let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
71 0 : if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
72 0 : tracing::warn!(?err, "could not cancel the query in the database");
73 0 : }
74 0 : res
75 0 : }
76 : }
|