LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: c639aa5f7ab62b43d647b10f40d15a15686ce8a9.info Lines: 90.5 % 305 276
Test Date: 2024-02-12 20:26:03 Functions: 64.4 % 73 47

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : 
       3              : use std::{
       4              :     future::poll_fn,
       5              :     future::Future,
       6              :     io,
       7              :     net::SocketAddr,
       8              :     pin::{pin, Pin},
       9              :     task::{ready, Context, Poll},
      10              : };
      11              : 
      12              : use bytes::{Buf, BytesMut};
      13              : use hyper::server::conn::{AddrIncoming, AddrStream};
      14              : use pin_project_lite::pin_project;
      15              : use tls_listener::AsyncAccept;
      16              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      17              : 
      18              : pub struct ProxyProtocolAccept {
      19              :     pub incoming: AddrIncoming,
      20              : }
      21              : 
      22              : pin_project! {
      23              :     pub struct WithClientIp<T> {
      24              :         #[pin]
      25              :         pub inner: T,
      26              :         buf: BytesMut,
      27              :         tlv_bytes: u16,
      28              :         state: ProxyParse,
      29              :     }
      30              : }
      31              : 
      32           10 : #[derive(Clone, PartialEq, Debug)]
      33              : enum ProxyParse {
      34              :     NotStarted,
      35              : 
      36              :     Finished(SocketAddr),
      37              :     None,
      38              : }
      39              : 
      40              : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
      41              :     #[inline]
      42           93 :     fn poll_write(
      43           93 :         self: Pin<&mut Self>,
      44           93 :         cx: &mut Context<'_>,
      45           93 :         buf: &[u8],
      46           93 :     ) -> Poll<Result<usize, io::Error>> {
      47           93 :         self.project().inner.poll_write(cx, buf)
      48           93 :     }
      49              : 
      50              :     #[inline]
      51         1688 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      52         1688 :         self.project().inner.poll_flush(cx)
      53         1688 :     }
      54              : 
      55              :     #[inline]
      56           88 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      57           88 :         self.project().inner.poll_shutdown(cx)
      58           88 :     }
      59              : 
      60              :     #[inline]
      61          663 :     fn poll_write_vectored(
      62          663 :         self: Pin<&mut Self>,
      63          663 :         cx: &mut Context<'_>,
      64          663 :         bufs: &[io::IoSlice<'_>],
      65          663 :     ) -> Poll<Result<usize, io::Error>> {
      66          663 :         self.project().inner.poll_write_vectored(cx, bufs)
      67          663 :     }
      68              : 
      69              :     #[inline]
      70            0 :     fn is_write_vectored(&self) -> bool {
      71            0 :         self.inner.is_write_vectored()
      72            0 :     }
      73              : }
      74              : 
      75              : impl<T> WithClientIp<T> {
      76          150 :     pub fn new(inner: T) -> Self {
      77          150 :         WithClientIp {
      78          150 :             inner,
      79          150 :             buf: BytesMut::with_capacity(128),
      80          150 :             tlv_bytes: 0,
      81          150 :             state: ProxyParse::NotStarted,
      82          150 :         }
      83          150 :     }
      84              : 
      85           47 :     pub fn client_addr(&self) -> Option<SocketAddr> {
      86           47 :         match self.state {
      87            0 :             ProxyParse::Finished(socket) => Some(socket),
      88           47 :             _ => None,
      89              :         }
      90           47 :     }
      91              : }
      92              : 
      93              : impl<T: AsyncRead + Unpin> WithClientIp<T> {
      94           63 :     pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
      95           63 :         match self.state {
      96              :             ProxyParse::NotStarted => {
      97           63 :                 let mut pin = Pin::new(&mut *self);
      98          126 :                 let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
      99           63 :                 match addr {
     100            0 :                     Some(addr) => self.state = ProxyParse::Finished(addr),
     101           63 :                     None => self.state = ProxyParse::None,
     102              :                 }
     103           63 :                 Ok(addr)
     104              :             }
     105            0 :             ProxyParse::Finished(addr) => Ok(Some(addr)),
     106            0 :             ProxyParse::None => Ok(None),
     107              :         }
     108           63 :     }
     109              : }
     110              : 
     111              : /// Proxy Protocol Version 2 Header
     112              : const HEADER: [u8; 12] = [
     113              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
     114              : ];
     115              : 
     116              : impl<T: AsyncRead> WithClientIp<T> {
     117              :     /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
     118              :     /// Version 2 (Binary Format)
     119          256 :     fn poll_client_ip(
     120          256 :         mut self: Pin<&mut Self>,
     121          256 :         cx: &mut Context<'_>,
     122          256 :     ) -> Poll<io::Result<Option<SocketAddr>>> {
     123              :         // The binary header format starts with a constant 12 bytes block containing the protocol signature :
     124              :         //    \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
     125          274 :         while self.buf.len() < 16 {
     126          268 :             let mut this = self.as_mut().project();
     127          268 :             let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
     128              : 
     129              :             // exit for bad header
     130          268 :             let len = usize::min(self.buf.len(), HEADER.len());
     131          268 :             if self.buf[..len] != HEADER[..len] {
     132          144 :                 return Poll::Ready(Ok(None));
     133          124 :             }
     134              : 
     135              :             // if no more bytes available then exit
     136          124 :             if ready!(bytes_read) == 0 {
     137            0 :                 return Poll::Ready(Ok(None));
     138           18 :             };
     139              :         }
     140              : 
     141              :         // The next byte (the 13th one) is the protocol version and command.
     142              :         // The highest four bits contains the version. As of this specification, it must
     143              :         // always be sent as \x2 and the receiver must only accept this value.
     144            6 :         let vc = self.buf[12];
     145            6 :         let version = vc >> 4;
     146            6 :         let command = vc & 0b1111;
     147            6 :         if version != 2 {
     148            0 :             return Poll::Ready(Err(io::Error::new(
     149            0 :                 io::ErrorKind::Other,
     150            0 :                 "invalid proxy protocol version. expected version 2",
     151            0 :             )));
     152            6 :         }
     153            6 :         match command {
     154              :             // the connection was established on purpose by the proxy
     155              :             // without being relayed. The connection endpoints are the sender and the
     156              :             // receiver. Such connections exist when the proxy sends health-checks to the
     157              :             // server. The receiver must accept this connection as valid and must use the
     158              :             // real connection endpoints and discard the protocol block including the
     159              :             // family which is ignored.
     160            0 :             0 => {}
     161              :             // the connection was established on behalf of another node,
     162              :             // and reflects the original connection endpoints. The receiver must then use
     163              :             // the information provided in the protocol block to get original the address.
     164            6 :             1 => {}
     165              :             // other values are unassigned and must not be emitted by senders. Receivers
     166              :             // must drop connections presenting unexpected values here.
     167              :             _ => {
     168            0 :                 return Poll::Ready(Err(io::Error::new(
     169            0 :                     io::ErrorKind::Other,
     170            0 :                     "invalid proxy protocol command. expected local (0) or proxy (1)",
     171            0 :                 )))
     172              :             }
     173              :         };
     174              : 
     175              :         // The 14th byte contains the transport protocol and address family. The highest 4
     176              :         // bits contain the address family, the lowest 4 bits contain the protocol.
     177            6 :         let ft = self.buf[13];
     178            6 :         let address_length = match ft {
     179              :             // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     180              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     181              :             // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     182              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     183            4 :             0x11 | 0x12 => 12,
     184              :             // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     185              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     186              :             // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     187              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     188            2 :             0x21 | 0x22 => 36,
     189              :             // unspecified or unix stream. ignore the addresses
     190            0 :             _ => 0,
     191              :         };
     192              : 
     193              :         // The 15th and 16th bytes is the address length in bytes in network endian order.
     194              :         // It is used so that the receiver knows how many address bytes to skip even when
     195              :         // it does not implement the presented protocol. Thus the length of the protocol
     196              :         // header in bytes is always exactly 16 + this value. When a sender presents a
     197              :         // LOCAL connection, it should not present any address so it sets this field to
     198              :         // zero. Receivers MUST always consider this field to skip the appropriate number
     199              :         // of bytes and must not assume zero is presented for LOCAL connections. When a
     200              :         // receiver accepts an incoming connection showing an UNSPEC address family or
     201              :         // protocol, it may or may not decide to log the address information if present.
     202            6 :         let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
     203            6 :         if remaining_length < address_length {
     204            0 :             return Poll::Ready(Err(io::Error::new(
     205            0 :                 io::ErrorKind::Other,
     206            0 :                 "invalid proxy protocol length. not enough to fit requested IP addresses",
     207            0 :             )));
     208            6 :         }
     209              : 
     210           30 :         while self.buf.len() < 16 + address_length as usize {
     211           24 :             let mut this = self.as_mut().project();
     212           24 :             if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
     213            0 :                 return Poll::Ready(Err(io::Error::new(
     214            0 :                     io::ErrorKind::UnexpectedEof,
     215            0 :                     "stream closed while waiting for proxy protocol addresses",
     216            0 :                 )));
     217           24 :             }
     218              :         }
     219              : 
     220            6 :         let this = self.as_mut().project();
     221            6 : 
     222            6 :         // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
     223            6 :         // discard the header we have parsed
     224            6 :         this.buf.advance(16);
     225            6 : 
     226            6 :         // Starting from the 17th byte, addresses are presented in network byte order.
     227            6 :         // The address order is always the same :
     228            6 :         //   - source layer 3 address in network byte order
     229            6 :         //   - destination layer 3 address in network byte order
     230            6 :         //   - source layer 4 address if any, in network byte order (port)
     231            6 :         //   - destination layer 4 address if any, in network byte order (port)
     232            6 :         let addresses = this.buf.split_to(address_length as usize);
     233            6 :         let socket = match address_length {
     234              :             12 => {
     235            4 :                 let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     236            4 :                 let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     237            4 :                 Some(SocketAddr::from((src_addr, src_port)))
     238              :             }
     239              :             36 => {
     240            2 :                 let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     241            2 :                 let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     242            2 :                 Some(SocketAddr::from((src_addr, src_port)))
     243              :             }
     244            0 :             _ => None,
     245              :         };
     246              : 
     247            6 :         *this.tlv_bytes = remaining_length - address_length;
     248            6 :         self.as_mut().skip_tlv_inner();
     249            6 : 
     250            6 :         Poll::Ready(Ok(socket))
     251          256 :     }
     252              : 
     253              :     #[cold]
     254          130 :     fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     255          130 :         let ip = ready!(self.as_mut().poll_client_ip(cx)?);
     256           87 :         match ip {
     257            6 :             Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
     258           81 :             None => *self.as_mut().project().state = ProxyParse::None,
     259              :         }
     260           87 :         Poll::Ready(Ok(()))
     261          130 :     }
     262              : 
     263              :     #[cold]
     264           68 :     fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     265           68 :         let mut this = self.as_mut().project();
     266              :         // we know that this.buf is empty
     267           68 :         debug_assert_eq!(this.buf.len(), 0);
     268              : 
     269           68 :         this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
     270           68 :         ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
     271           68 :         self.skip_tlv_inner();
     272           68 : 
     273           68 :         Poll::Ready(Ok(()))
     274           68 :     }
     275              : 
     276           74 :     fn skip_tlv_inner(self: Pin<&mut Self>) {
     277           74 :         let tlv_bytes_read = match u16::try_from(self.buf.len()) {
     278              :             // we read more than u16::MAX therefore we must have read the full tlv_bytes
     279            0 :             Err(_) => self.tlv_bytes,
     280              :             // we might not have read the full tlv bytes yet
     281           74 :             Ok(n) => u16::min(n, self.tlv_bytes),
     282              :         };
     283           74 :         let this = self.project();
     284           74 :         *this.tlv_bytes -= tlv_bytes_read;
     285           74 :         this.buf.advance(tlv_bytes_read as usize);
     286           74 :     }
     287              : }
     288              : 
     289              : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
     290              :     #[inline]
     291         2450 :     fn poll_read(
     292         2450 :         mut self: Pin<&mut Self>,
     293         2450 :         cx: &mut Context<'_>,
     294         2450 :         buf: &mut ReadBuf<'_>,
     295         2450 :     ) -> Poll<io::Result<()>> {
     296         2450 :         // I'm assuming these 3 comparisons will be easy to branch predict.
     297         2450 :         // especially with the cold attributes
     298         2450 :         // which should make this read wrapper almost invisible
     299         2450 : 
     300         2450 :         if let ProxyParse::NotStarted = self.state {
     301          130 :             ready!(self.as_mut().read_ip(cx)?);
     302         2320 :         }
     303              : 
     304         2475 :         while self.tlv_bytes > 0 {
     305           68 :             ready!(self.as_mut().skip_tlv(cx)?)
     306              :         }
     307              : 
     308         2407 :         let this = self.project();
     309         2407 :         if this.buf.is_empty() {
     310         2259 :             this.inner.poll_read(cx, buf)
     311              :         } else {
     312              :             // we know that tlv_bytes is 0
     313          148 :             debug_assert_eq!(*this.tlv_bytes, 0);
     314              : 
     315          148 :             let write = usize::min(this.buf.len(), buf.remaining());
     316          148 :             let slice = this.buf.split_to(write).freeze();
     317          148 :             buf.put_slice(&slice);
     318          148 : 
     319          148 :             // reset the allocation so it can be freed
     320          148 :             if this.buf.is_empty() {
     321          144 :                 *this.buf = BytesMut::new();
     322          144 :             }
     323              : 
     324          148 :             Poll::Ready(Ok(()))
     325              :         }
     326         2450 :     }
     327              : }
     328              : 
     329              : impl AsyncAccept for ProxyProtocolAccept {
     330              :     type Connection = WithClientIp<AddrStream>;
     331              : 
     332              :     type Error = io::Error;
     333              : 
     334          312 :     fn poll_accept(
     335          312 :         mut self: Pin<&mut Self>,
     336          312 :         cx: &mut Context<'_>,
     337          312 :     ) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
     338          312 :         let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
     339           47 :         let Some(conn) = conn else {
     340            0 :             return Poll::Ready(None);
     341              :         };
     342              : 
     343           47 :         Poll::Ready(Some(Ok(WithClientIp::new(conn))))
     344          312 :     }
     345              : }
     346              : 
     347              : #[cfg(test)]
     348              : mod tests {
     349              :     use std::pin::pin;
     350              : 
     351              :     use tokio::io::AsyncReadExt;
     352              : 
     353              :     use crate::protocol2::{ProxyParse, WithClientIp};
     354              : 
     355            2 :     #[tokio::test]
     356            2 :     async fn test_ipv4() {
     357            2 :         let header = super::HEADER
     358            2 :             // Proxy command, IPV4 | TCP
     359            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     360            2 :             // 12 + 3 bytes
     361            2 :             .chain([0, 15].as_slice())
     362            2 :             // src ip
     363            2 :             .chain([127, 0, 0, 1].as_slice())
     364            2 :             // dst ip
     365            2 :             .chain([192, 168, 0, 1].as_slice())
     366            2 :             // src port
     367            2 :             .chain([255, 255].as_slice())
     368            2 :             // dst port
     369            2 :             .chain([1, 1].as_slice())
     370            2 :             // TLV
     371            2 :             .chain([1, 2, 3].as_slice());
     372            2 : 
     373            2 :         let extra_data = [0x55; 256];
     374            2 : 
     375            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     376            2 : 
     377            2 :         let mut bytes = vec![];
     378            2 :         read.read_to_end(&mut bytes).await.unwrap();
     379            2 : 
     380            2 :         assert_eq!(bytes, extra_data);
     381            2 :         assert_eq!(
     382            2 :             read.state,
     383            2 :             ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
     384            2 :         );
     385            2 :     }
     386              : 
     387            2 :     #[tokio::test]
     388            2 :     async fn test_ipv6() {
     389            2 :         let header = super::HEADER
     390            2 :             // Proxy command, IPV6 | UDP
     391            2 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     392            2 :             // 36 + 3 bytes
     393            2 :             .chain([0, 39].as_slice())
     394            2 :             // src ip
     395            2 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     396            2 :             // dst ip
     397            2 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     398            2 :             // src port
     399            2 :             .chain([1, 1].as_slice())
     400            2 :             // dst port
     401            2 :             .chain([255, 255].as_slice())
     402            2 :             // TLV
     403            2 :             .chain([1, 2, 3].as_slice());
     404            2 : 
     405            2 :         let extra_data = [0x55; 256];
     406            2 : 
     407            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     408            2 : 
     409            2 :         let mut bytes = vec![];
     410            2 :         read.read_to_end(&mut bytes).await.unwrap();
     411            2 : 
     412            2 :         assert_eq!(bytes, extra_data);
     413            2 :         assert_eq!(
     414            2 :             read.state,
     415            2 :             ProxyParse::Finished(
     416            2 :                 ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     417            2 :             )
     418            2 :         );
     419            2 :     }
     420              : 
     421            2 :     #[tokio::test]
     422            2 :     async fn test_invalid() {
     423            2 :         let data = [0x55; 256];
     424            2 : 
     425            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     426            2 : 
     427            2 :         let mut bytes = vec![];
     428            2 :         read.read_to_end(&mut bytes).await.unwrap();
     429            2 :         assert_eq!(bytes, data);
     430            2 :         assert_eq!(read.state, ProxyParse::None);
     431            2 :     }
     432              : 
     433            2 :     #[tokio::test]
     434            2 :     async fn test_short() {
     435            2 :         let data = [0x55; 10];
     436            2 : 
     437            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     438            2 : 
     439            2 :         let mut bytes = vec![];
     440            2 :         read.read_to_end(&mut bytes).await.unwrap();
     441            2 :         assert_eq!(bytes, data);
     442            2 :         assert_eq!(read.state, ProxyParse::None);
     443            2 :     }
     444              : 
     445            2 :     #[tokio::test]
     446            2 :     async fn test_large_tlv() {
     447            2 :         let tlv = vec![0x55; 32768];
     448            2 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     449            2 : 
     450            2 :         let header = super::HEADER
     451            2 :             // Proxy command, Inet << 4 | Stream
     452            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     453            2 :             // 12 + 3 bytes
     454            2 :             .chain(len.as_slice())
     455            2 :             // src ip
     456            2 :             .chain([55, 56, 57, 58].as_slice())
     457            2 :             // dst ip
     458            2 :             .chain([192, 168, 0, 1].as_slice())
     459            2 :             // src port
     460            2 :             .chain([255, 255].as_slice())
     461            2 :             // dst port
     462            2 :             .chain([1, 1].as_slice())
     463            2 :             // TLV
     464            2 :             .chain(tlv.as_slice());
     465            2 : 
     466            2 :         let extra_data = [0xaa; 256];
     467            2 : 
     468            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     469            2 : 
     470            2 :         let mut bytes = vec![];
     471            2 :         read.read_to_end(&mut bytes).await.unwrap();
     472            2 : 
     473            2 :         assert_eq!(bytes, extra_data);
     474            2 :         assert_eq!(
     475            2 :             read.state,
     476            2 :             ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
     477            2 :         );
     478            2 :     }
     479              : }
        

Generated by: LCOV version 2.1-beta