LCOV - differential code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 90.3 % 300 271 29 271
Current Date: 2023-10-19 02:04:12 Functions: 64.4 % 73 47 26 47
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Proxy Protocol V2 implementation
       2                 : 
       3                 : use std::{
       4                 :     future::poll_fn,
       5                 :     future::Future,
       6                 :     io,
       7                 :     net::SocketAddr,
       8                 :     pin::{pin, Pin},
       9                 :     task::{ready, Context, Poll},
      10                 : };
      11                 : 
      12                 : use bytes::{Buf, BytesMut};
      13                 : use hyper::server::conn::{AddrIncoming, AddrStream};
      14                 : use pin_project_lite::pin_project;
      15                 : use tls_listener::AsyncAccept;
      16                 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      17                 : 
      18                 : pub struct ProxyProtocolAccept {
      19                 :     pub incoming: AddrIncoming,
      20                 : }
      21                 : 
      22                 : pin_project! {
      23                 :     pub struct WithClientIp<T> {
      24                 :         #[pin]
      25                 :         pub inner: T,
      26                 :         buf: BytesMut,
      27                 :         tlv_bytes: u16,
      28                 :         state: ProxyParse,
      29                 :     }
      30                 : }
      31                 : 
      32 CBC           5 : #[derive(Clone, PartialEq, Debug)]
      33                 : enum ProxyParse {
      34                 :     NotStarted,
      35                 : 
      36                 :     Finished(SocketAddr),
      37                 :     None,
      38                 : }
      39                 : 
      40                 : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
      41                 :     #[inline]
      42              44 :     fn poll_write(
      43              44 :         self: Pin<&mut Self>,
      44              44 :         cx: &mut Context<'_>,
      45              44 :         buf: &[u8],
      46              44 :     ) -> Poll<Result<usize, io::Error>> {
      47              44 :         self.project().inner.poll_write(cx, buf)
      48              44 :     }
      49                 : 
      50                 :     #[inline]
      51             692 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      52             692 :         self.project().inner.poll_flush(cx)
      53             692 :     }
      54                 : 
      55                 :     #[inline]
      56              59 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      57              59 :         self.project().inner.poll_shutdown(cx)
      58              59 :     }
      59                 : 
      60                 :     #[inline]
      61             425 :     fn poll_write_vectored(
      62             425 :         self: Pin<&mut Self>,
      63             425 :         cx: &mut Context<'_>,
      64             425 :         bufs: &[io::IoSlice<'_>],
      65             425 :     ) -> Poll<Result<usize, io::Error>> {
      66             425 :         self.project().inner.poll_write_vectored(cx, bufs)
      67             425 :     }
      68                 : 
      69                 :     #[inline]
      70 UBC           0 :     fn is_write_vectored(&self) -> bool {
      71               0 :         self.inner.is_write_vectored()
      72               0 :     }
      73                 : }
      74                 : 
      75                 : impl<T> WithClientIp<T> {
      76 CBC          79 :     pub fn new(inner: T) -> Self {
      77              79 :         WithClientIp {
      78              79 :             inner,
      79              79 :             buf: BytesMut::with_capacity(128),
      80              79 :             tlv_bytes: 0,
      81              79 :             state: ProxyParse::NotStarted,
      82              79 :         }
      83              79 :     }
      84                 : 
      85              30 :     pub fn client_addr(&self) -> Option<SocketAddr> {
      86              30 :         match self.state {
      87 UBC           0 :             ProxyParse::Finished(socket) => Some(socket),
      88 CBC          30 :             _ => None,
      89                 :         }
      90              30 :     }
      91                 : }
      92                 : 
      93                 : impl<T: AsyncRead + Unpin> WithClientIp<T> {
      94              37 :     pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
      95              37 :         match self.state {
      96                 :             ProxyParse::NotStarted => {
      97              37 :                 let mut pin = Pin::new(&mut *self);
      98              66 :                 let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
      99              37 :                 match addr {
     100 UBC           0 :                     Some(addr) => self.state = ProxyParse::Finished(addr),
     101 CBC          37 :                     None => self.state = ProxyParse::None,
     102                 :                 }
     103              37 :                 Ok(addr)
     104                 :             }
     105 UBC           0 :             ProxyParse::Finished(addr) => Ok(Some(addr)),
     106               0 :             ProxyParse::None => Ok(None),
     107                 :         }
     108 CBC          37 :     }
     109                 : }
     110                 : 
     111                 : /// Proxy Protocol Version 2 Header
     112                 : const HEADER: [u8; 12] = [
     113                 :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
     114                 : ];
     115                 : 
     116                 : impl<T: AsyncRead> WithClientIp<T> {
     117                 :     /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
     118                 :     /// Version 2 (Binary Format)
     119             138 :     fn poll_client_ip(
     120             138 :         mut self: Pin<&mut Self>,
     121             138 :         cx: &mut Context<'_>,
     122             138 :     ) -> Poll<io::Result<Option<SocketAddr>>> {
     123                 :         // The binary header format starts with a constant 12 bytes block containing the protocol signature :
     124                 :         //    \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
     125             147 :         while self.buf.len() < 16 {
     126             144 :             let mut this = self.as_mut().project();
     127             144 :             let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
     128                 : 
     129                 :             // exit for bad header
     130             144 :             let len = usize::min(self.buf.len(), HEADER.len());
     131             144 :             if self.buf[..len] != HEADER[..len] {
     132              76 :                 return Poll::Ready(Ok(None));
     133              68 :             }
     134                 : 
     135                 :             // if no more bytes available then exit
     136              68 :             if ready!(bytes_read) == 0 {
     137 UBC           0 :                 return Poll::Ready(Ok(None));
     138 CBC           9 :             };
     139                 :         }
     140                 : 
     141                 :         // The next byte (the 13th one) is the protocol version and command.
     142                 :         // The highest four bits contains the version. As of this specification, it must
     143                 :         // always be sent as \x2 and the receiver must only accept this value.
     144               3 :         let vc = self.buf[12];
     145               3 :         let version = vc >> 4;
     146               3 :         let command = vc & 0b1111;
     147               3 :         if version != 2 {
     148 UBC           0 :             return Poll::Ready(Err(io::Error::new(
     149               0 :                 io::ErrorKind::Other,
     150               0 :                 "invalid proxy protocol version. expected version 2",
     151               0 :             )));
     152 CBC           3 :         }
     153               3 :         match command {
     154                 :             // the connection was established on purpose by the proxy
     155                 :             // without being relayed. The connection endpoints are the sender and the
     156                 :             // receiver. Such connections exist when the proxy sends health-checks to the
     157                 :             // server. The receiver must accept this connection as valid and must use the
     158                 :             // real connection endpoints and discard the protocol block including the
     159                 :             // family which is ignored.
     160 UBC           0 :             0 => {}
     161                 :             // the connection was established on behalf of another node,
     162                 :             // and reflects the original connection endpoints. The receiver must then use
     163                 :             // the information provided in the protocol block to get original the address.
     164 CBC           3 :             1 => {}
     165                 :             // other values are unassigned and must not be emitted by senders. Receivers
     166                 :             // must drop connections presenting unexpected values here.
     167                 :             _ => {
     168 UBC           0 :                 return Poll::Ready(Err(io::Error::new(
     169               0 :                     io::ErrorKind::Other,
     170               0 :                     "invalid proxy protocol command. expected local (0) or proxy (1)",
     171               0 :                 )))
     172                 :             }
     173                 :         };
     174                 : 
     175                 :         // The 14th byte contains the transport protocol and address family. The highest 4
     176                 :         // bits contain the address family, the lowest 4 bits contain the protocol.
     177 CBC           3 :         let ft = self.buf[13];
     178               3 :         let address_length = match ft {
     179                 :             // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     180                 :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     181                 :             // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     182                 :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     183               2 :             0x11 | 0x12 => 12,
     184                 :             // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     185                 :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     186                 :             // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     187                 :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     188               1 :             0x21 | 0x22 => 36,
     189                 :             // unspecified or unix stream. ignore the addresses
     190 UBC           0 :             _ => 0,
     191                 :         };
     192                 : 
     193                 :         // The 15th and 16th bytes is the address length in bytes in network endian order.
     194                 :         // It is used so that the receiver knows how many address bytes to skip even when
     195                 :         // it does not implement the presented protocol. Thus the length of the protocol
     196                 :         // header in bytes is always exactly 16 + this value. When a sender presents a
     197                 :         // LOCAL connection, it should not present any address so it sets this field to
     198                 :         // zero. Receivers MUST always consider this field to skip the appropriate number
     199                 :         // of bytes and must not assume zero is presented for LOCAL connections. When a
     200                 :         // receiver accepts an incoming connection showing an UNSPEC address family or
     201                 :         // protocol, it may or may not decide to log the address information if present.
     202 CBC           3 :         let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
     203               3 :         if remaining_length < address_length {
     204 UBC           0 :             return Poll::Ready(Err(io::Error::new(
     205               0 :                 io::ErrorKind::Other,
     206               0 :                 "invalid proxy protocol length. not enough to fit requested IP addresses",
     207               0 :             )));
     208 CBC           3 :         }
     209                 : 
     210              15 :         while self.buf.len() < 16 + address_length as usize {
     211              12 :             let mut this = self.as_mut().project();
     212              12 :             if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
     213 UBC           0 :                 return Poll::Ready(Err(io::Error::new(
     214               0 :                     io::ErrorKind::UnexpectedEof,
     215               0 :                     "stream closed while waiting for proxy protocol addresses",
     216               0 :                 )));
     217 CBC          12 :             }
     218                 :         }
     219                 : 
     220               3 :         let this = self.as_mut().project();
     221               3 : 
     222               3 :         // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
     223               3 :         // discard the header we have parsed
     224               3 :         this.buf.advance(16);
     225               3 : 
     226               3 :         // Starting from the 17th byte, addresses are presented in network byte order.
     227               3 :         // The address order is always the same :
     228               3 :         //   - source layer 3 address in network byte order
     229               3 :         //   - destination layer 3 address in network byte order
     230               3 :         //   - source layer 4 address if any, in network byte order (port)
     231               3 :         //   - destination layer 4 address if any, in network byte order (port)
     232               3 :         let addresses = this.buf.split_to(address_length as usize);
     233               3 :         let socket = match address_length {
     234                 :             12 => {
     235               2 :                 let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     236               2 :                 let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     237               2 :                 Some(SocketAddr::from((src_addr, src_port)))
     238                 :             }
     239                 :             36 => {
     240               1 :                 let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     241               1 :                 let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     242               1 :                 Some(SocketAddr::from((src_addr, src_port)))
     243                 :             }
     244 UBC           0 :             _ => None,
     245                 :         };
     246                 : 
     247 CBC           3 :         *this.tlv_bytes = remaining_length - address_length;
     248               3 :         self.as_mut().skip_tlv_inner();
     249               3 : 
     250               3 :         Poll::Ready(Ok(socket))
     251             138 :     }
     252                 : 
     253                 :     #[cold]
     254              72 :     fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     255              72 :         let ip = ready!(self.as_mut().poll_client_ip(cx)?);
     256              42 :         match ip {
     257               3 :             Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
     258              39 :             None => *self.as_mut().project().state = ProxyParse::None,
     259                 :         }
     260              42 :         Poll::Ready(Ok(()))
     261              72 :     }
     262                 : 
     263                 :     #[cold]
     264              34 :     fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     265              34 :         let mut this = self.as_mut().project();
     266                 :         // we know that this.buf is empty
     267              34 :         debug_assert_eq!(this.buf.len(), 0);
     268                 : 
     269              34 :         this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
     270              34 :         ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
     271              34 :         self.skip_tlv_inner();
     272              34 : 
     273              34 :         Poll::Ready(Ok(()))
     274              34 :     }
     275                 : 
     276              37 :     fn skip_tlv_inner(self: Pin<&mut Self>) {
     277              37 :         let tlv_bytes_read = match u16::try_from(self.buf.len()) {
     278                 :             // we read more than u16::MAX therefore we must have read the full tlv_bytes
     279 UBC           0 :             Err(_) => self.tlv_bytes,
     280                 :             // we might not have read the full tlv bytes yet
     281 CBC          37 :             Ok(n) => u16::min(n, self.tlv_bytes),
     282                 :         };
     283              37 :         let this = self.project();
     284              37 :         *this.tlv_bytes -= tlv_bytes_read;
     285              37 :         this.buf.advance(tlv_bytes_read as usize);
     286              37 :     }
     287                 : }
     288                 : 
     289                 : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
     290                 :     #[inline]
     291            1176 :     fn poll_read(
     292            1176 :         mut self: Pin<&mut Self>,
     293            1176 :         cx: &mut Context<'_>,
     294            1176 :         buf: &mut ReadBuf<'_>,
     295            1176 :     ) -> Poll<io::Result<()>> {
     296            1176 :         // I'm assuming these 3 comparisons will be easy to branch predict.
     297            1176 :         // especially with the cold attributes
     298            1176 :         // which should make this read wrapper almost invisible
     299            1176 : 
     300            1176 :         if let ProxyParse::NotStarted = self.state {
     301              72 :             ready!(self.as_mut().read_ip(cx)?);
     302            1104 :         }
     303                 : 
     304            1180 :         while self.tlv_bytes > 0 {
     305              34 :             ready!(self.as_mut().skip_tlv(cx)?)
     306                 :         }
     307                 : 
     308            1146 :         let this = self.project();
     309            1146 :         if this.buf.is_empty() {
     310            1068 :             this.inner.poll_read(cx, buf)
     311                 :         } else {
     312                 :             // we know that tlv_bytes is 0
     313              78 :             debug_assert_eq!(*this.tlv_bytes, 0);
     314                 : 
     315              78 :             let write = usize::min(this.buf.len(), buf.remaining());
     316              78 :             let slice = this.buf.split_to(write).freeze();
     317              78 :             buf.put_slice(&slice);
     318              78 : 
     319              78 :             // reset the allocation so it can be freed
     320              78 :             if this.buf.is_empty() {
     321              76 :                 *this.buf = BytesMut::new();
     322              76 :             }
     323                 : 
     324              78 :             Poll::Ready(Ok(()))
     325                 :         }
     326            1176 :     }
     327                 : }
     328                 : 
     329                 : impl AsyncAccept for ProxyProtocolAccept {
     330                 :     type Connection = WithClientIp<AddrStream>;
     331                 : 
     332                 :     type Error = io::Error;
     333                 : 
     334             226 :     fn poll_accept(
     335             226 :         mut self: Pin<&mut Self>,
     336             226 :         cx: &mut Context<'_>,
     337             226 :     ) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
     338             226 :         let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
     339              30 :         let Some(conn) = conn else {
     340 UBC           0 :             return Poll::Ready(None);
     341                 :         };
     342                 : 
     343 CBC          30 :         Poll::Ready(Some(Ok(WithClientIp::new(conn))))
     344             226 :     }
     345                 : }
     346                 : 
     347                 : #[cfg(test)]
     348                 : mod tests {
     349                 :     use std::pin::pin;
     350                 : 
     351                 :     use tokio::io::AsyncReadExt;
     352                 : 
     353                 :     use crate::protocol2::{ProxyParse, WithClientIp};
     354                 : 
     355               1 :     #[tokio::test]
     356               1 :     async fn test_ipv4() {
     357               1 :         let header = super::HEADER
     358               1 :             // Proxy command, IPV4 | TCP
     359               1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     360               1 :             // 12 + 3 bytes
     361               1 :             .chain([0, 15].as_slice())
     362               1 :             // src ip
     363               1 :             .chain([127, 0, 0, 1].as_slice())
     364               1 :             // dst ip
     365               1 :             .chain([192, 168, 0, 1].as_slice())
     366               1 :             // src port
     367               1 :             .chain([255, 255].as_slice())
     368               1 :             // dst port
     369               1 :             .chain([1, 1].as_slice())
     370               1 :             // TLV
     371               1 :             .chain([1, 2, 3].as_slice());
     372               1 : 
     373               1 :         let extra_data = [0x55; 256];
     374               1 : 
     375               1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     376               1 : 
     377               1 :         let mut bytes = vec![];
     378               1 :         read.read_to_end(&mut bytes).await.unwrap();
     379               1 : 
     380               1 :         assert_eq!(bytes, extra_data);
     381               1 :         assert_eq!(
     382               1 :             read.state,
     383               1 :             ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
     384               1 :         );
     385                 :     }
     386                 : 
     387               1 :     #[tokio::test]
     388               1 :     async fn test_ipv6() {
     389               1 :         let header = super::HEADER
     390               1 :             // Proxy command, IPV6 | UDP
     391               1 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     392               1 :             // 36 + 3 bytes
     393               1 :             .chain([0, 39].as_slice())
     394               1 :             // src ip
     395               1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     396               1 :             // dst ip
     397               1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     398               1 :             // src port
     399               1 :             .chain([1, 1].as_slice())
     400               1 :             // dst port
     401               1 :             .chain([255, 255].as_slice())
     402               1 :             // TLV
     403               1 :             .chain([1, 2, 3].as_slice());
     404               1 : 
     405               1 :         let extra_data = [0x55; 256];
     406               1 : 
     407               1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     408               1 : 
     409               1 :         let mut bytes = vec![];
     410               1 :         read.read_to_end(&mut bytes).await.unwrap();
     411               1 : 
     412               1 :         assert_eq!(bytes, extra_data);
     413               1 :         assert_eq!(
     414               1 :             read.state,
     415               1 :             ProxyParse::Finished(
     416               1 :                 ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     417               1 :             )
     418               1 :         );
     419                 :     }
     420                 : 
     421               1 :     #[tokio::test]
     422               1 :     async fn test_invalid() {
     423               1 :         let data = [0x55; 256];
     424               1 : 
     425               1 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     426               1 : 
     427               1 :         let mut bytes = vec![];
     428               1 :         read.read_to_end(&mut bytes).await.unwrap();
     429               1 :         assert_eq!(bytes, data);
     430               1 :         assert_eq!(read.state, ProxyParse::None);
     431                 :     }
     432                 : 
     433               1 :     #[tokio::test]
     434               1 :     async fn test_short() {
     435               1 :         let data = [0x55; 10];
     436               1 : 
     437               1 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     438               1 : 
     439               1 :         let mut bytes = vec![];
     440               1 :         read.read_to_end(&mut bytes).await.unwrap();
     441               1 :         assert_eq!(bytes, data);
     442               1 :         assert_eq!(read.state, ProxyParse::None);
     443                 :     }
     444                 : 
     445               1 :     #[tokio::test]
     446               1 :     async fn test_large_tlv() {
     447               1 :         let tlv = vec![0x55; 32768];
     448               1 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     449               1 : 
     450               1 :         let header = super::HEADER
     451               1 :             // Proxy command, Inet << 4 | Stream
     452               1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     453               1 :             // 12 + 3 bytes
     454               1 :             .chain(len.as_slice())
     455               1 :             // src ip
     456               1 :             .chain([55, 56, 57, 58].as_slice())
     457               1 :             // dst ip
     458               1 :             .chain([192, 168, 0, 1].as_slice())
     459               1 :             // src port
     460               1 :             .chain([255, 255].as_slice())
     461               1 :             // dst port
     462               1 :             .chain([1, 1].as_slice())
     463               1 :             // TLV
     464               1 :             .chain(tlv.as_slice());
     465               1 : 
     466               1 :         let extra_data = [0xaa; 256];
     467               1 : 
     468               1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     469               1 : 
     470               1 :         let mut bytes = vec![];
     471               1 :         read.read_to_end(&mut bytes).await.unwrap();
     472               1 : 
     473               1 :         assert_eq!(bytes, extra_data);
     474               1 :         assert_eq!(
     475               1 :             read.state,
     476               1 :             ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
     477               1 :         );
     478                 :     }
     479                 : }
        

Generated by: LCOV version 2.1-beta