LCOV - code coverage report
Current view: top level - proxy/src/serverless - tls_listener.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 0.0 % 52 0
Test Date: 2024-04-08 10:22:05 Functions: 0.0 % 18 0

            Line data    Source code
       1              : use std::{
       2              :     convert::Infallible,
       3              :     pin::Pin,
       4              :     task::{Context, Poll},
       5              :     time::Duration,
       6              : };
       7              : 
       8              : use hyper::server::{accept::Accept, conn::AddrStream};
       9              : use pin_project_lite::pin_project;
      10              : use tokio::{
      11              :     io::{AsyncRead, AsyncWrite},
      12              :     task::JoinSet,
      13              :     time::timeout,
      14              : };
      15              : use tokio_rustls::{server::TlsStream, TlsAcceptor};
      16              : use tracing::{info, warn, Instrument};
      17              : 
      18              : use crate::{
      19              :     metrics::TLS_HANDSHAKE_FAILURES,
      20              :     protocol2::{WithClientIp, WithConnectionGuard},
      21              : };
      22              : 
      23              : pin_project! {
      24              :     /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
      25              :     /// encrypted using TLS.
      26              :     pub(crate) struct TlsListener<A: Accept> {
      27              :         #[pin]
      28              :         listener: A,
      29              :         tls: TlsAcceptor,
      30              :         waiting: JoinSet<Option<TlsStream<A::Conn>>>,
      31              :         timeout: Duration,
      32              :     }
      33              : }
      34              : 
      35              : impl<A: Accept> TlsListener<A> {
      36              :     /// Create a `TlsListener` with default options.
      37            0 :     pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
      38            0 :         TlsListener {
      39            0 :             listener,
      40            0 :             tls,
      41            0 :             waiting: JoinSet::new(),
      42            0 :             timeout,
      43            0 :         }
      44            0 :     }
      45              : }
      46              : 
      47              : impl<A> Accept for TlsListener<A>
      48              : where
      49              :     A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
      50              :     A::Error: std::error::Error,
      51              :     A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
      52              : {
      53              :     type Conn = TlsStream<A::Conn>;
      54              : 
      55              :     type Error = Infallible;
      56              : 
      57            0 :     fn poll_accept(
      58            0 :         self: Pin<&mut Self>,
      59            0 :         cx: &mut Context<'_>,
      60            0 :     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
      61            0 :         let mut this = self.project();
      62              : 
      63              :         loop {
      64            0 :             match this.listener.as_mut().poll_accept(cx) {
      65            0 :                 Poll::Pending => break,
      66            0 :                 Poll::Ready(Some(Ok(mut conn))) => {
      67            0 :                     let t = *this.timeout;
      68            0 :                     let tls = this.tls.clone();
      69            0 :                     let span = conn.span.clone();
      70            0 :                     this.waiting.spawn(async move {
      71            0 :                         let peer_addr = match conn.inner.wait_for_addr().await {
      72            0 :                             Ok(Some(addr)) => addr,
      73            0 :                             Err(e) => {
      74            0 :                                 tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
      75            0 :                                 return None;
      76              :                             }
      77            0 :                             Ok(None) => conn.inner.inner.remote_addr()
      78              :                         };
      79              : 
      80            0 :                         let accept = tls.accept(conn);
      81            0 :                         match timeout(t, accept).await {
      82            0 :                             Ok(Ok(conn)) => {
      83            0 :                                 info!(%peer_addr, "accepted new TLS connection");
      84            0 :                                 Some(conn)
      85              :                             },
      86              :                             // The handshake failed, try getting another connection from the queue
      87            0 :                             Ok(Err(e)) => {
      88            0 :                                 TLS_HANDSHAKE_FAILURES.inc();
      89            0 :                                 warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
      90            0 :                                 None
      91              :                             }
      92              :                             // The handshake timed out, try getting another connection from the queue
      93              :                             Err(_) => {
      94            0 :                                 TLS_HANDSHAKE_FAILURES.inc();
      95            0 :                                 warn!(%peer_addr, "failed to accept TLS connection: timeout");
      96            0 :                                 None
      97              :                             }
      98              :                         }
      99            0 :                     }.instrument(span));
     100            0 :                 }
     101            0 :                 Poll::Ready(Some(Err(e))) => {
     102            0 :                     tracing::error!("error accepting TCP connection: {e}");
     103            0 :                     continue;
     104              :                 }
     105            0 :                 Poll::Ready(None) => return Poll::Ready(None),
     106              :             }
     107              :         }
     108              : 
     109              :         loop {
     110            0 :             return match this.waiting.poll_join_next(cx) {
     111            0 :                 Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
     112              :                 // The handshake failed to complete, try getting another connection from the queue
     113            0 :                 Poll::Ready(Some(Ok(None))) => continue,
     114              :                 // The handshake panicked or was cancelled. ignore and get another connection
     115            0 :                 Poll::Ready(Some(Err(e))) => {
     116            0 :                     tracing::warn!("handshake aborted: {e}");
     117            0 :                     continue;
     118              :                 }
     119            0 :                 _ => Poll::Pending,
     120              :             };
     121              :         }
     122            0 :     }
     123              : }
        

Generated by: LCOV version 2.1-beta