LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: f8d8f5b90fa487a9e82c42da223f012f5d4fece7.info Lines: 88.5 % 226 200
Test Date: 2024-09-19 20:36:02 Functions: 66.7 % 36 24

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : 
       3              : use std::{
       4              :     io,
       5              :     net::SocketAddr,
       6              :     pin::Pin,
       7              :     task::{Context, Poll},
       8              : };
       9              : 
      10              : use bytes::BytesMut;
      11              : use pin_project_lite::pin_project;
      12              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      13              : 
      14              : pin_project! {
      15              :     /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
      16              :     pub(crate) struct ChainRW<T> {
      17              :         #[pin]
      18              :         pub(crate) inner: T,
      19              :         buf: BytesMut,
      20              :     }
      21              : }
      22              : 
      23              : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
      24              :     #[inline]
      25           15 :     fn poll_write(
      26           15 :         self: Pin<&mut Self>,
      27           15 :         cx: &mut Context<'_>,
      28           15 :         buf: &[u8],
      29           15 :     ) -> Poll<Result<usize, io::Error>> {
      30           15 :         self.project().inner.poll_write(cx, buf)
      31           15 :     }
      32              : 
      33              :     #[inline]
      34           74 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      35           74 :         self.project().inner.poll_flush(cx)
      36           74 :     }
      37              : 
      38              :     #[inline]
      39            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      40            0 :         self.project().inner.poll_shutdown(cx)
      41            0 :     }
      42              : 
      43              :     #[inline]
      44           59 :     fn poll_write_vectored(
      45           59 :         self: Pin<&mut Self>,
      46           59 :         cx: &mut Context<'_>,
      47           59 :         bufs: &[io::IoSlice<'_>],
      48           59 :     ) -> Poll<Result<usize, io::Error>> {
      49           59 :         self.project().inner.poll_write_vectored(cx, bufs)
      50           59 :     }
      51              : 
      52              :     #[inline]
      53            0 :     fn is_write_vectored(&self) -> bool {
      54            0 :         self.inner.is_write_vectored()
      55            0 :     }
      56              : }
      57              : 
      58              : /// Proxy Protocol Version 2 Header
      59              : const HEADER: [u8; 12] = [
      60              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
      61              : ];
      62              : 
      63           20 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
      64           20 :     mut read: T,
      65           20 : ) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
      66           20 :     let mut buf = BytesMut::with_capacity(128);
      67           29 :     while buf.len() < 16 {
      68           26 :         let bytes_read = read.read_buf(&mut buf).await?;
      69              : 
      70              :         // exit for bad header
      71           26 :         let len = usize::min(buf.len(), HEADER.len());
      72           26 :         if buf[..len] != HEADER[..len] {
      73           17 :             return Ok((ChainRW { inner: read, buf }, None));
      74            9 :         }
      75            9 : 
      76            9 :         // if no more bytes available then exit
      77            9 :         if bytes_read == 0 {
      78            0 :             return Ok((ChainRW { inner: read, buf }, None));
      79            9 :         };
      80              :     }
      81              : 
      82            3 :     let header = buf.split_to(16);
      83            3 : 
      84            3 :     // The next byte (the 13th one) is the protocol version and command.
      85            3 :     // The highest four bits contains the version. As of this specification, it must
      86            3 :     // always be sent as \x2 and the receiver must only accept this value.
      87            3 :     let vc = header[12];
      88            3 :     let version = vc >> 4;
      89            3 :     let command = vc & 0b1111;
      90            3 :     if version != 2 {
      91            0 :         return Err(io::Error::new(
      92            0 :             io::ErrorKind::Other,
      93            0 :             "invalid proxy protocol version. expected version 2",
      94            0 :         ));
      95            3 :     }
      96            3 :     match command {
      97              :         // the connection was established on purpose by the proxy
      98              :         // without being relayed. The connection endpoints are the sender and the
      99              :         // receiver. Such connections exist when the proxy sends health-checks to the
     100              :         // server. The receiver must accept this connection as valid and must use the
     101              :         // real connection endpoints and discard the protocol block including the
     102              :         // family which is ignored.
     103            0 :         0 => {}
     104              :         // the connection was established on behalf of another node,
     105              :         // and reflects the original connection endpoints. The receiver must then use
     106              :         // the information provided in the protocol block to get original the address.
     107            3 :         1 => {}
     108              :         // other values are unassigned and must not be emitted by senders. Receivers
     109              :         // must drop connections presenting unexpected values here.
     110              :         _ => {
     111            0 :             return Err(io::Error::new(
     112            0 :                 io::ErrorKind::Other,
     113            0 :                 "invalid proxy protocol command. expected local (0) or proxy (1)",
     114            0 :             ))
     115              :         }
     116              :     };
     117              : 
     118              :     // The 14th byte contains the transport protocol and address family. The highest 4
     119              :     // bits contain the address family, the lowest 4 bits contain the protocol.
     120            3 :     let ft = header[13];
     121            3 :     let address_length = match ft {
     122              :         // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     123              :         //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     124              :         // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     125              :         //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     126            2 :         0x11 | 0x12 => 12,
     127              :         // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     128              :         //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     129              :         // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     130              :         //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     131            1 :         0x21 | 0x22 => 36,
     132              :         // unspecified or unix stream. ignore the addresses
     133            0 :         _ => 0,
     134              :     };
     135              : 
     136              :     // The 15th and 16th bytes is the address length in bytes in network endian order.
     137              :     // It is used so that the receiver knows how many address bytes to skip even when
     138              :     // it does not implement the presented protocol. Thus the length of the protocol
     139              :     // header in bytes is always exactly 16 + this value. When a sender presents a
     140              :     // LOCAL connection, it should not present any address so it sets this field to
     141              :     // zero. Receivers MUST always consider this field to skip the appropriate number
     142              :     // of bytes and must not assume zero is presented for LOCAL connections. When a
     143              :     // receiver accepts an incoming connection showing an UNSPEC address family or
     144              :     // protocol, it may or may not decide to log the address information if present.
     145            3 :     let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
     146            3 :     if remaining_length < address_length {
     147            0 :         return Err(io::Error::new(
     148            0 :             io::ErrorKind::Other,
     149            0 :             "invalid proxy protocol length. not enough to fit requested IP addresses",
     150            0 :         ));
     151            3 :     }
     152            3 :     drop(header);
     153              : 
     154           27 :     while buf.len() < remaining_length as usize {
     155           24 :         if read.read_buf(&mut buf).await? == 0 {
     156            0 :             return Err(io::Error::new(
     157            0 :                 io::ErrorKind::UnexpectedEof,
     158            0 :                 "stream closed while waiting for proxy protocol addresses",
     159            0 :             ));
     160           24 :         }
     161              :     }
     162              : 
     163              :     // Starting from the 17th byte, addresses are presented in network byte order.
     164              :     // The address order is always the same :
     165              :     //   - source layer 3 address in network byte order
     166              :     //   - destination layer 3 address in network byte order
     167              :     //   - source layer 4 address if any, in network byte order (port)
     168              :     //   - destination layer 4 address if any, in network byte order (port)
     169            3 :     let addresses = buf.split_to(remaining_length as usize);
     170            3 :     let socket = match address_length {
     171              :         12 => {
     172            2 :             let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     173            2 :             let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     174            2 :             Some(SocketAddr::from((src_addr, src_port)))
     175              :         }
     176              :         36 => {
     177            1 :             let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     178            1 :             let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     179            1 :             Some(SocketAddr::from((src_addr, src_port)))
     180              :         }
     181            0 :         _ => None,
     182              :     };
     183              : 
     184            3 :     Ok((ChainRW { inner: read, buf }, socket))
     185           20 : }
     186              : 
     187              : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
     188              :     #[inline]
     189          167 :     fn poll_read(
     190          167 :         self: Pin<&mut Self>,
     191          167 :         cx: &mut Context<'_>,
     192          167 :         buf: &mut ReadBuf<'_>,
     193          167 :     ) -> Poll<io::Result<()>> {
     194          167 :         if self.buf.is_empty() {
     195          148 :             self.project().inner.poll_read(cx, buf)
     196              :         } else {
     197           19 :             self.read_from_buf(buf)
     198              :         }
     199          167 :     }
     200              : }
     201              : 
     202              : impl<T: AsyncRead> ChainRW<T> {
     203              :     #[cold]
     204           19 :     fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
     205           19 :         debug_assert!(!self.buf.is_empty());
     206           19 :         let this = self.project();
     207           19 : 
     208           19 :         let write = usize::min(this.buf.len(), buf.remaining());
     209           19 :         let slice = this.buf.split_to(write).freeze();
     210           19 :         buf.put_slice(&slice);
     211           19 : 
     212           19 :         // reset the allocation so it can be freed
     213           19 :         if this.buf.is_empty() {
     214           17 :             *this.buf = BytesMut::new();
     215           17 :         }
     216              : 
     217           19 :         Poll::Ready(Ok(()))
     218           19 :     }
     219              : }
     220              : 
     221              : #[cfg(test)]
     222              : mod tests {
     223              :     use tokio::io::AsyncReadExt;
     224              : 
     225              :     use crate::protocol2::read_proxy_protocol;
     226              : 
     227              :     #[tokio::test]
     228            1 :     async fn test_ipv4() {
     229            1 :         let header = super::HEADER
     230            1 :             // Proxy command, IPV4 | TCP
     231            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     232            1 :             // 12 + 3 bytes
     233            1 :             .chain([0, 15].as_slice())
     234            1 :             // src ip
     235            1 :             .chain([127, 0, 0, 1].as_slice())
     236            1 :             // dst ip
     237            1 :             .chain([192, 168, 0, 1].as_slice())
     238            1 :             // src port
     239            1 :             .chain([255, 255].as_slice())
     240            1 :             // dst port
     241            1 :             .chain([1, 1].as_slice())
     242            1 :             // TLV
     243            1 :             .chain([1, 2, 3].as_slice());
     244            1 : 
     245            1 :         let extra_data = [0x55; 256];
     246            1 : 
     247            1 :         let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     248            1 :             .await
     249            1 :             .unwrap();
     250            1 : 
     251            1 :         let mut bytes = vec![];
     252            1 :         read.read_to_end(&mut bytes).await.unwrap();
     253            1 : 
     254            1 :         assert_eq!(bytes, extra_data);
     255            1 :         assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
     256            1 :     }
     257              : 
     258              :     #[tokio::test]
     259            1 :     async fn test_ipv6() {
     260            1 :         let header = super::HEADER
     261            1 :             // Proxy command, IPV6 | UDP
     262            1 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     263            1 :             // 36 + 3 bytes
     264            1 :             .chain([0, 39].as_slice())
     265            1 :             // src ip
     266            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     267            1 :             // dst ip
     268            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     269            1 :             // src port
     270            1 :             .chain([1, 1].as_slice())
     271            1 :             // dst port
     272            1 :             .chain([255, 255].as_slice())
     273            1 :             // TLV
     274            1 :             .chain([1, 2, 3].as_slice());
     275            1 : 
     276            1 :         let extra_data = [0x55; 256];
     277            1 : 
     278            1 :         let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     279            1 :             .await
     280            1 :             .unwrap();
     281            1 : 
     282            1 :         let mut bytes = vec![];
     283            1 :         read.read_to_end(&mut bytes).await.unwrap();
     284            1 : 
     285            1 :         assert_eq!(bytes, extra_data);
     286            1 :         assert_eq!(
     287            1 :             addr,
     288            1 :             Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
     289            1 :         );
     290            1 :     }
     291              : 
     292              :     #[tokio::test]
     293            1 :     async fn test_invalid() {
     294            1 :         let data = [0x55; 256];
     295            1 : 
     296            1 :         let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
     297            1 : 
     298            1 :         let mut bytes = vec![];
     299            1 :         read.read_to_end(&mut bytes).await.unwrap();
     300            1 :         assert_eq!(bytes, data);
     301            1 :         assert_eq!(addr, None);
     302            1 :     }
     303              : 
     304              :     #[tokio::test]
     305            1 :     async fn test_short() {
     306            1 :         let data = [0x55; 10];
     307            1 : 
     308            1 :         let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
     309            1 : 
     310            1 :         let mut bytes = vec![];
     311            1 :         read.read_to_end(&mut bytes).await.unwrap();
     312            1 :         assert_eq!(bytes, data);
     313            1 :         assert_eq!(addr, None);
     314            1 :     }
     315              : 
     316              :     #[tokio::test]
     317            1 :     async fn test_large_tlv() {
     318            1 :         let tlv = vec![0x55; 32768];
     319            1 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     320            1 : 
     321            1 :         let header = super::HEADER
     322            1 :             // Proxy command, Inet << 4 | Stream
     323            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     324            1 :             // 12 + 3 bytes
     325            1 :             .chain(len.as_slice())
     326            1 :             // src ip
     327            1 :             .chain([55, 56, 57, 58].as_slice())
     328            1 :             // dst ip
     329            1 :             .chain([192, 168, 0, 1].as_slice())
     330            1 :             // src port
     331            1 :             .chain([255, 255].as_slice())
     332            1 :             // dst port
     333            1 :             .chain([1, 1].as_slice())
     334            1 :             // TLV
     335            1 :             .chain(tlv.as_slice());
     336            1 : 
     337            1 :         let extra_data = [0xaa; 256];
     338            1 : 
     339            1 :         let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     340            1 :             .await
     341            1 :             .unwrap();
     342            1 : 
     343            1 :         let mut bytes = vec![];
     344            1 :         read.read_to_end(&mut bytes).await.unwrap();
     345            1 : 
     346            1 :         assert_eq!(bytes, extra_data);
     347            1 :         assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
     348            1 :     }
     349              : }
        

Generated by: LCOV version 2.1-beta