|             Line data    Source code 
       1              : use std::convert::Infallible;
       2              : 
       3              : use smol_str::SmolStr;
       4              : use tokio::io::{AsyncRead, AsyncWrite};
       5              : use tracing::debug;
       6              : use utils::measured_stream::MeasuredStream;
       7              : 
       8              : use super::copy_bidirectional::ErrorSource;
       9              : use crate::compute::MaybeRustlsStream;
      10              : use crate::control_plane::messages::MetricsAuxInfo;
      11              : use crate::metrics::{
      12              :     Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
      13              :     NumDbConnectionsGuard,
      14              : };
      15              : use crate::stream::Stream;
      16              : use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
      17              : 
      18              : /// Forward bytes in both directions (client <-> compute).
      19              : #[tracing::instrument(skip_all)]
      20              : pub(crate) async fn proxy_pass(
      21              :     client: impl AsyncRead + AsyncWrite + Unpin,
      22              :     compute: impl AsyncRead + AsyncWrite + Unpin,
      23              :     aux: MetricsAuxInfo,
      24              :     private_link_id: Option<SmolStr>,
      25              : ) -> Result<(), ErrorSource> {
      26              :     // we will report ingress at a later date
      27              :     let usage_tx = USAGE_METRICS.register(Ids {
      28              :         endpoint_id: aux.endpoint_id,
      29              :         branch_id: aux.branch_id,
      30              :         private_link_id,
      31              :     });
      32              : 
      33              :     let metrics = &Metrics::get().proxy.io_bytes;
      34              :     let m_sent = metrics.with_labels(Direction::Tx);
      35              :     let mut client = MeasuredStream::new(
      36              :         client,
      37            0 :         |_| {},
      38            0 :         |cnt| {
      39              :             // Number of bytes we sent to the client (outbound).
      40            0 :             metrics.get_metric(m_sent).inc_by(cnt as u64);
      41            0 :             usage_tx.record_egress(cnt as u64);
      42            0 :         },
      43              :     );
      44              : 
      45              :     let m_recv = metrics.with_labels(Direction::Rx);
      46              :     let mut compute = MeasuredStream::new(
      47              :         compute,
      48            0 :         |_| {},
      49            0 :         |cnt| {
      50              :             // Number of bytes the client sent to the compute node (inbound).
      51            0 :             metrics.get_metric(m_recv).inc_by(cnt as u64);
      52            0 :             usage_tx.record_ingress(cnt as u64);
      53            0 :         },
      54              :     );
      55              : 
      56              :     // Starting from here we only proxy the client's traffic.
      57              :     debug!("performing the proxy pass...");
      58              :     let _ = crate::pglb::copy_bidirectional::copy_bidirectional_client_compute(
      59              :         &mut client,
      60              :         &mut compute,
      61              :     )
      62              :     .await?;
      63              : 
      64              :     Ok(())
      65              : }
      66              : 
      67              : pub(crate) struct ProxyPassthrough<S> {
      68              :     pub(crate) client: Stream<S>,
      69              :     pub(crate) compute: MaybeRustlsStream,
      70              : 
      71              :     pub(crate) aux: MetricsAuxInfo,
      72              :     pub(crate) private_link_id: Option<SmolStr>,
      73              : 
      74              :     pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender<Infallible>,
      75              : 
      76              :     pub(crate) _req: NumConnectionRequestsGuard<'static>,
      77              :     pub(crate) _conn: NumClientConnectionsGuard<'static>,
      78              :     pub(crate) _db_conn: NumDbConnectionsGuard<'static>,
      79              : }
      80              : 
      81              : impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
      82            0 :     pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
      83            0 :         proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await
      84            0 :     }
      85              : }
         |