Line data Source code
1 : use std::{
2 : convert::Infallible,
3 : pin::Pin,
4 : task::{Context, Poll},
5 : time::Duration,
6 : };
7 :
8 : use hyper::server::{accept::Accept, conn::AddrStream};
9 : use pin_project_lite::pin_project;
10 : use tokio::{
11 : io::{AsyncRead, AsyncWrite},
12 : task::JoinSet,
13 : time::timeout,
14 : };
15 : use tokio_rustls::{server::TlsStream, TlsAcceptor};
16 : use tracing::{info, warn, Instrument};
17 :
18 : use crate::{
19 : metrics::TLS_HANDSHAKE_FAILURES,
20 : protocol2::{WithClientIp, WithConnectionGuard},
21 : };
22 :
23 : pin_project! {
24 : /// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
25 : /// encrypted using TLS.
26 : pub(crate) struct TlsListener<A: Accept> {
27 : #[pin]
28 : listener: A,
29 : tls: TlsAcceptor,
30 : waiting: JoinSet<Option<TlsStream<A::Conn>>>,
31 : timeout: Duration,
32 : }
33 : }
34 :
35 : impl<A: Accept> TlsListener<A> {
36 : /// Create a `TlsListener` with default options.
37 0 : pub(crate) fn new(tls: TlsAcceptor, listener: A, timeout: Duration) -> Self {
38 0 : TlsListener {
39 0 : listener,
40 0 : tls,
41 0 : waiting: JoinSet::new(),
42 0 : timeout,
43 0 : }
44 0 : }
45 : }
46 :
47 : impl<A> Accept for TlsListener<A>
48 : where
49 : A: Accept<Conn = WithConnectionGuard<WithClientIp<AddrStream>>>,
50 : A::Error: std::error::Error,
51 : A::Conn: AsyncRead + AsyncWrite + Unpin + Send + 'static,
52 : {
53 : type Conn = TlsStream<A::Conn>;
54 :
55 : type Error = Infallible;
56 :
57 0 : fn poll_accept(
58 0 : self: Pin<&mut Self>,
59 0 : cx: &mut Context<'_>,
60 0 : ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
61 0 : let mut this = self.project();
62 :
63 : loop {
64 0 : match this.listener.as_mut().poll_accept(cx) {
65 0 : Poll::Pending => break,
66 0 : Poll::Ready(Some(Ok(mut conn))) => {
67 0 : let t = *this.timeout;
68 0 : let tls = this.tls.clone();
69 0 : let span = conn.span.clone();
70 0 : this.waiting.spawn(async move {
71 0 : let peer_addr = match conn.inner.wait_for_addr().await {
72 0 : Ok(Some(addr)) => addr,
73 0 : Err(e) => {
74 0 : tracing::error!("failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}");
75 0 : return None;
76 : }
77 0 : Ok(None) => conn.inner.inner.remote_addr()
78 : };
79 :
80 0 : let accept = tls.accept(conn);
81 0 : match timeout(t, accept).await {
82 0 : Ok(Ok(conn)) => {
83 0 : info!(%peer_addr, "accepted new TLS connection");
84 0 : Some(conn)
85 : },
86 : // The handshake failed, try getting another connection from the queue
87 0 : Ok(Err(e)) => {
88 0 : TLS_HANDSHAKE_FAILURES.inc();
89 0 : warn!(%peer_addr, "failed to accept TLS connection: {e:?}");
90 0 : None
91 : }
92 : // The handshake timed out, try getting another connection from the queue
93 : Err(_) => {
94 0 : TLS_HANDSHAKE_FAILURES.inc();
95 0 : warn!(%peer_addr, "failed to accept TLS connection: timeout");
96 0 : None
97 : }
98 : }
99 0 : }.instrument(span));
100 0 : }
101 0 : Poll::Ready(Some(Err(e))) => {
102 0 : tracing::error!("error accepting TCP connection: {e}");
103 0 : continue;
104 : }
105 0 : Poll::Ready(None) => return Poll::Ready(None),
106 : }
107 : }
108 :
109 : loop {
110 0 : return match this.waiting.poll_join_next(cx) {
111 0 : Poll::Ready(Some(Ok(Some(conn)))) => Poll::Ready(Some(Ok(conn))),
112 : // The handshake failed to complete, try getting another connection from the queue
113 0 : Poll::Ready(Some(Ok(None))) => continue,
114 : // The handshake panicked or was cancelled. ignore and get another connection
115 0 : Poll::Ready(Some(Err(e))) => {
116 0 : tracing::warn!("handshake aborted: {e}");
117 0 : continue;
118 : }
119 0 : _ => Poll::Pending,
120 : };
121 : }
122 0 : }
123 : }
|