Line data Source code
1 : use std::convert::Infallible;
2 :
3 : use smol_str::SmolStr;
4 : use tokio::io::{AsyncRead, AsyncWrite};
5 : use tracing::debug;
6 : use utils::measured_stream::MeasuredStream;
7 :
8 : use super::copy_bidirectional::ErrorSource;
9 : use crate::compute::MaybeRustlsStream;
10 : use crate::control_plane::messages::MetricsAuxInfo;
11 : use crate::metrics::{
12 : Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
13 : NumDbConnectionsGuard,
14 : };
15 : use crate::stream::Stream;
16 : use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
17 :
18 : /// Forward bytes in both directions (client <-> compute).
19 : #[tracing::instrument(skip_all)]
20 : pub(crate) async fn proxy_pass(
21 : client: impl AsyncRead + AsyncWrite + Unpin,
22 : compute: impl AsyncRead + AsyncWrite + Unpin,
23 : aux: MetricsAuxInfo,
24 : private_link_id: Option<SmolStr>,
25 : ) -> Result<(), ErrorSource> {
26 : // we will report ingress at a later date
27 : let usage_tx = USAGE_METRICS.register(Ids {
28 : endpoint_id: aux.endpoint_id,
29 : branch_id: aux.branch_id,
30 : private_link_id,
31 : });
32 :
33 : let metrics = &Metrics::get().proxy.io_bytes;
34 : let m_sent = metrics.with_labels(Direction::Tx);
35 : let mut client = MeasuredStream::new(
36 : client,
37 0 : |_| {},
38 0 : |cnt| {
39 : // Number of bytes we sent to the client (outbound).
40 0 : metrics.get_metric(m_sent).inc_by(cnt as u64);
41 0 : usage_tx.record_egress(cnt as u64);
42 0 : },
43 : );
44 :
45 : let m_recv = metrics.with_labels(Direction::Rx);
46 : let mut compute = MeasuredStream::new(
47 : compute,
48 0 : |_| {},
49 0 : |cnt| {
50 : // Number of bytes the client sent to the compute node (inbound).
51 0 : metrics.get_metric(m_recv).inc_by(cnt as u64);
52 0 : usage_tx.record_ingress(cnt as u64);
53 0 : },
54 : );
55 :
56 : // Starting from here we only proxy the client's traffic.
57 : debug!("performing the proxy pass...");
58 : let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute(
59 : &mut client,
60 : &mut compute,
61 : )
62 : .await?;
63 :
64 : Ok(())
65 : }
66 :
67 : pub(crate) struct ProxyPassthrough<S> {
68 : pub(crate) client: Stream<S>,
69 : pub(crate) compute: MaybeRustlsStream,
70 :
71 : pub(crate) aux: MetricsAuxInfo,
72 : pub(crate) private_link_id: Option<SmolStr>,
73 :
74 : pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
75 :
76 : pub(crate) _req: NumConnectionRequestsGuard<'static>,
77 : pub(crate) _conn: NumClientConnectionsGuard<'static>,
78 : pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
79 : }
80 :
81 : impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
82 0 : pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
83 0 : proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
84 0 : }
85 : }
|