LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: b837401fb09d2d9818b70e630fdb67e9799b7b0d.info Lines: 85.2 % 291 248
Test Date: 2024-04-18 15:32:49 Functions: 45.8 % 59 27

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : 
       3              : use std::{
       4              :     future::{poll_fn, Future},
       5              :     io,
       6              :     net::SocketAddr,
       7              :     pin::{pin, Pin},
       8              :     task::{ready, Context, Poll},
       9              : };
      10              : 
      11              : use bytes::{Buf, BytesMut};
      12              : use hyper::server::conn::AddrIncoming;
      13              : use pin_project_lite::pin_project;
      14              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      15              : 
      16              : pub struct ProxyProtocolAccept {
      17              :     pub incoming: AddrIncoming,
      18              :     pub protocol: &'static str,
      19              : }
      20              : 
      21              : pin_project! {
      22              :     pub struct WithClientIp<T> {
      23              :         #[pin]
      24              :         pub inner: T,
      25              :         buf: BytesMut,
      26              :         tlv_bytes: u16,
      27              :         state: ProxyParse,
      28              :     }
      29              : }
      30              : 
      31              : #[derive(Clone, PartialEq, Debug)]
      32              : enum ProxyParse {
      33              :     NotStarted,
      34              : 
      35              :     Finished(SocketAddr),
      36              :     None,
      37              : }
      38              : 
      39              : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
      40              :     #[inline]
      41           30 :     fn poll_write(
      42           30 :         self: Pin<&mut Self>,
      43           30 :         cx: &mut Context<'_>,
      44           30 :         buf: &[u8],
      45           30 :     ) -> Poll<Result<usize, io::Error>> {
      46           30 :         self.project().inner.poll_write(cx, buf)
      47           30 :     }
      48              : 
      49              :     #[inline]
      50          148 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      51          148 :         self.project().inner.poll_flush(cx)
      52          148 :     }
      53              : 
      54              :     #[inline]
      55            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      56            0 :         self.project().inner.poll_shutdown(cx)
      57            0 :     }
      58              : 
      59              :     #[inline]
      60          118 :     fn poll_write_vectored(
      61          118 :         self: Pin<&mut Self>,
      62          118 :         cx: &mut Context<'_>,
      63          118 :         bufs: &[io::IoSlice<'_>],
      64          118 :     ) -> Poll<Result<usize, io::Error>> {
      65          118 :         self.project().inner.poll_write_vectored(cx, bufs)
      66          118 :     }
      67              : 
      68              :     #[inline]
      69            0 :     fn is_write_vectored(&self) -> bool {
      70            0 :         self.inner.is_write_vectored()
      71            0 :     }
      72              : }
      73              : 
      74              : impl<T> WithClientIp<T> {
      75           40 :     pub fn new(inner: T) -> Self {
      76           40 :         WithClientIp {
      77           40 :             inner,
      78           40 :             buf: BytesMut::with_capacity(128),
      79           40 :             tlv_bytes: 0,
      80           40 :             state: ProxyParse::NotStarted,
      81           40 :         }
      82           40 :     }
      83              : 
      84            0 :     pub fn client_addr(&self) -> Option<SocketAddr> {
      85            0 :         match self.state {
      86            0 :             ProxyParse::Finished(socket) => Some(socket),
      87            0 :             _ => None,
      88              :         }
      89            0 :     }
      90              : }
      91              : 
      92              : impl<T: AsyncRead + Unpin> WithClientIp<T> {
      93            0 :     pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
      94            0 :         match self.state {
      95              :             ProxyParse::NotStarted => {
      96            0 :                 let mut pin = Pin::new(&mut *self);
      97            0 :                 let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
      98            0 :                 match addr {
      99            0 :                     Some(addr) => self.state = ProxyParse::Finished(addr),
     100            0 :                     None => self.state = ProxyParse::None,
     101              :                 }
     102            0 :                 Ok(addr)
     103              :             }
     104            0 :             ProxyParse::Finished(addr) => Ok(Some(addr)),
     105            0 :             ProxyParse::None => Ok(None),
     106              :         }
     107            0 :     }
     108              : }
     109              : 
     110              : /// Proxy Protocol Version 2 Header
     111              : const HEADER: [u8; 12] = [
     112              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
     113              : ];
     114              : 
     115              : impl<T: AsyncRead> WithClientIp<T> {
     116              :     /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
     117              :     /// Version 2 (Binary Format)
     118           40 :     fn poll_client_ip(
     119           40 :         mut self: Pin<&mut Self>,
     120           40 :         cx: &mut Context<'_>,
     121           40 :     ) -> Poll<io::Result<Option<SocketAddr>>> {
     122              :         // The binary header format starts with a constant 12 bytes block containing the protocol signature :
     123              :         //    \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
     124           58 :         while self.buf.len() < 16 {
     125           52 :             let mut this = self.as_mut().project();
     126           52 :             let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
     127              : 
     128              :             // exit for bad header
     129           52 :             let len = usize::min(self.buf.len(), HEADER.len());
     130           52 :             if self.buf[..len] != HEADER[..len] {
     131           34 :                 return Poll::Ready(Ok(None));
     132           18 :             }
     133              : 
     134              :             // if no more bytes available then exit
     135           18 :             if ready!(bytes_read) == 0 {
     136            0 :                 return Poll::Ready(Ok(None));
     137           18 :             };
     138              :         }
     139              : 
     140              :         // The next byte (the 13th one) is the protocol version and command.
     141              :         // The highest four bits contains the version. As of this specification, it must
     142              :         // always be sent as \x2 and the receiver must only accept this value.
     143            6 :         let vc = self.buf[12];
     144            6 :         let version = vc >> 4;
     145            6 :         let command = vc & 0b1111;
     146            6 :         if version != 2 {
     147            0 :             return Poll::Ready(Err(io::Error::new(
     148            0 :                 io::ErrorKind::Other,
     149            0 :                 "invalid proxy protocol version. expected version 2",
     150            0 :             )));
     151            6 :         }
     152            6 :         match command {
     153              :             // the connection was established on purpose by the proxy
     154              :             // without being relayed. The connection endpoints are the sender and the
     155              :             // receiver. Such connections exist when the proxy sends health-checks to the
     156              :             // server. The receiver must accept this connection as valid and must use the
     157              :             // real connection endpoints and discard the protocol block including the
     158              :             // family which is ignored.
     159            0 :             0 => {}
     160              :             // the connection was established on behalf of another node,
     161              :             // and reflects the original connection endpoints. The receiver must then use
     162              :             // the information provided in the protocol block to get original the address.
     163            6 :             1 => {}
     164              :             // other values are unassigned and must not be emitted by senders. Receivers
     165              :             // must drop connections presenting unexpected values here.
     166              :             _ => {
     167            0 :                 return Poll::Ready(Err(io::Error::new(
     168            0 :                     io::ErrorKind::Other,
     169            0 :                     "invalid proxy protocol command. expected local (0) or proxy (1)",
     170            0 :                 )))
     171              :             }
     172              :         };
     173              : 
     174              :         // The 14th byte contains the transport protocol and address family. The highest 4
     175              :         // bits contain the address family, the lowest 4 bits contain the protocol.
     176            6 :         let ft = self.buf[13];
     177            6 :         let address_length = match ft {
     178              :             // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     179              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     180              :             // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     181              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     182            4 :             0x11 | 0x12 => 12,
     183              :             // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     184              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     185              :             // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     186              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     187            2 :             0x21 | 0x22 => 36,
     188              :             // unspecified or unix stream. ignore the addresses
     189            0 :             _ => 0,
     190              :         };
     191              : 
     192              :         // The 15th and 16th bytes is the address length in bytes in network endian order.
     193              :         // It is used so that the receiver knows how many address bytes to skip even when
     194              :         // it does not implement the presented protocol. Thus the length of the protocol
     195              :         // header in bytes is always exactly 16 + this value. When a sender presents a
     196              :         // LOCAL connection, it should not present any address so it sets this field to
     197              :         // zero. Receivers MUST always consider this field to skip the appropriate number
     198              :         // of bytes and must not assume zero is presented for LOCAL connections. When a
     199              :         // receiver accepts an incoming connection showing an UNSPEC address family or
     200              :         // protocol, it may or may not decide to log the address information if present.
     201            6 :         let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
     202            6 :         if remaining_length < address_length {
     203            0 :             return Poll::Ready(Err(io::Error::new(
     204            0 :                 io::ErrorKind::Other,
     205            0 :                 "invalid proxy protocol length. not enough to fit requested IP addresses",
     206            0 :             )));
     207            6 :         }
     208              : 
     209           30 :         while self.buf.len() < 16 + address_length as usize {
     210           24 :             let mut this = self.as_mut().project();
     211           24 :             if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
     212            0 :                 return Poll::Ready(Err(io::Error::new(
     213            0 :                     io::ErrorKind::UnexpectedEof,
     214            0 :                     "stream closed while waiting for proxy protocol addresses",
     215            0 :                 )));
     216           24 :             }
     217              :         }
     218              : 
     219            6 :         let this = self.as_mut().project();
     220            6 : 
     221            6 :         // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
     222            6 :         // discard the header we have parsed
     223            6 :         this.buf.advance(16);
     224            6 : 
     225            6 :         // Starting from the 17th byte, addresses are presented in network byte order.
     226            6 :         // The address order is always the same :
     227            6 :         //   - source layer 3 address in network byte order
     228            6 :         //   - destination layer 3 address in network byte order
     229            6 :         //   - source layer 4 address if any, in network byte order (port)
     230            6 :         //   - destination layer 4 address if any, in network byte order (port)
     231            6 :         let addresses = this.buf.split_to(address_length as usize);
     232            6 :         let socket = match address_length {
     233              :             12 => {
     234            4 :                 let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     235            4 :                 let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     236            4 :                 Some(SocketAddr::from((src_addr, src_port)))
     237              :             }
     238              :             36 => {
     239            2 :                 let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     240            2 :                 let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     241            2 :                 Some(SocketAddr::from((src_addr, src_port)))
     242              :             }
     243            0 :             _ => None,
     244              :         };
     245              : 
     246            6 :         *this.tlv_bytes = remaining_length - address_length;
     247            6 :         self.as_mut().skip_tlv_inner();
     248            6 : 
     249            6 :         Poll::Ready(Ok(socket))
     250           40 :     }
     251              : 
     252              :     #[cold]
     253           40 :     fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     254           40 :         let ip = ready!(self.as_mut().poll_client_ip(cx)?);
     255           40 :         match ip {
     256            6 :             Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
     257           34 :             None => *self.as_mut().project().state = ProxyParse::None,
     258              :         }
     259           40 :         Poll::Ready(Ok(()))
     260           40 :     }
     261              : 
     262              :     #[cold]
     263           68 :     fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     264           68 :         let mut this = self.as_mut().project();
     265           68 :         // we know that this.buf is empty
     266           68 :         debug_assert_eq!(this.buf.len(), 0);
     267              : 
     268           68 :         this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
     269           68 :         ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
     270           68 :         self.skip_tlv_inner();
     271           68 : 
     272           68 :         Poll::Ready(Ok(()))
     273           68 :     }
     274              : 
     275           74 :     fn skip_tlv_inner(self: Pin<&mut Self>) {
     276           74 :         let tlv_bytes_read = match u16::try_from(self.buf.len()) {
     277              :             // we read more than u16::MAX therefore we must have read the full tlv_bytes
     278            0 :             Err(_) => self.tlv_bytes,
     279              :             // we might not have read the full tlv bytes yet
     280           74 :             Ok(n) => u16::min(n, self.tlv_bytes),
     281              :         };
     282           74 :         let this = self.project();
     283           74 :         *this.tlv_bytes -= tlv_bytes_read;
     284           74 :         this.buf.advance(tlv_bytes_read as usize);
     285           74 :     }
     286              : }
     287              : 
     288              : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
     289              :     #[inline]
     290          334 :     fn poll_read(
     291          334 :         mut self: Pin<&mut Self>,
     292          334 :         cx: &mut Context<'_>,
     293          334 :         buf: &mut ReadBuf<'_>,
     294          334 :     ) -> Poll<io::Result<()>> {
     295          334 :         // I'm assuming these 3 comparisons will be easy to branch predict.
     296          334 :         // especially with the cold attributes
     297          334 :         // which should make this read wrapper almost invisible
     298          334 : 
     299          334 :         if let ProxyParse::NotStarted = self.state {
     300           40 :             ready!(self.as_mut().read_ip(cx)?);
     301          294 :         }
     302              : 
     303          402 :         while self.tlv_bytes > 0 {
     304           68 :             ready!(self.as_mut().skip_tlv(cx)?)
     305              :         }
     306              : 
     307          334 :         let this = self.project();
     308          334 :         if this.buf.is_empty() {
     309          296 :             this.inner.poll_read(cx, buf)
     310              :         } else {
     311              :             // we know that tlv_bytes is 0
     312           38 :             debug_assert_eq!(*this.tlv_bytes, 0);
     313              : 
     314           38 :             let write = usize::min(this.buf.len(), buf.remaining());
     315           38 :             let slice = this.buf.split_to(write).freeze();
     316           38 :             buf.put_slice(&slice);
     317           38 : 
     318           38 :             // reset the allocation so it can be freed
     319           38 :             if this.buf.is_empty() {
     320           34 :                 *this.buf = BytesMut::new();
     321           34 :             }
     322              : 
     323           38 :             Poll::Ready(Ok(()))
     324              :         }
     325          334 :     }
     326              : }
     327              : 
     328              : #[cfg(test)]
     329              : mod tests {
     330              :     use std::pin::pin;
     331              : 
     332              :     use tokio::io::AsyncReadExt;
     333              : 
     334              :     use crate::protocol2::{ProxyParse, WithClientIp};
     335              : 
     336              :     #[tokio::test]
     337            2 :     async fn test_ipv4() {
     338            2 :         let header = super::HEADER
     339            2 :             // Proxy command, IPV4 | TCP
     340            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     341            2 :             // 12 + 3 bytes
     342            2 :             .chain([0, 15].as_slice())
     343            2 :             // src ip
     344            2 :             .chain([127, 0, 0, 1].as_slice())
     345            2 :             // dst ip
     346            2 :             .chain([192, 168, 0, 1].as_slice())
     347            2 :             // src port
     348            2 :             .chain([255, 255].as_slice())
     349            2 :             // dst port
     350            2 :             .chain([1, 1].as_slice())
     351            2 :             // TLV
     352            2 :             .chain([1, 2, 3].as_slice());
     353            2 : 
     354            2 :         let extra_data = [0x55; 256];
     355            2 : 
     356            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     357            2 : 
     358            2 :         let mut bytes = vec![];
     359            2 :         read.read_to_end(&mut bytes).await.unwrap();
     360            2 : 
     361            2 :         assert_eq!(bytes, extra_data);
     362            2 :         assert_eq!(
     363            2 :             read.state,
     364            2 :             ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
     365            2 :         );
     366            2 :     }
     367              : 
     368              :     #[tokio::test]
     369            2 :     async fn test_ipv6() {
     370            2 :         let header = super::HEADER
     371            2 :             // Proxy command, IPV6 | UDP
     372            2 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     373            2 :             // 36 + 3 bytes
     374            2 :             .chain([0, 39].as_slice())
     375            2 :             // src ip
     376            2 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     377            2 :             // dst ip
     378            2 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     379            2 :             // src port
     380            2 :             .chain([1, 1].as_slice())
     381            2 :             // dst port
     382            2 :             .chain([255, 255].as_slice())
     383            2 :             // TLV
     384            2 :             .chain([1, 2, 3].as_slice());
     385            2 : 
     386            2 :         let extra_data = [0x55; 256];
     387            2 : 
     388            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     389            2 : 
     390            2 :         let mut bytes = vec![];
     391            2 :         read.read_to_end(&mut bytes).await.unwrap();
     392            2 : 
     393            2 :         assert_eq!(bytes, extra_data);
     394            2 :         assert_eq!(
     395            2 :             read.state,
     396            2 :             ProxyParse::Finished(
     397            2 :                 ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     398            2 :             )
     399            2 :         );
     400            2 :     }
     401              : 
     402              :     #[tokio::test]
     403            2 :     async fn test_invalid() {
     404            2 :         let data = [0x55; 256];
     405            2 : 
     406            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     407            2 : 
     408            2 :         let mut bytes = vec![];
     409            2 :         read.read_to_end(&mut bytes).await.unwrap();
     410            2 :         assert_eq!(bytes, data);
     411            2 :         assert_eq!(read.state, ProxyParse::None);
     412            2 :     }
     413              : 
     414              :     #[tokio::test]
     415            2 :     async fn test_short() {
     416            2 :         let data = [0x55; 10];
     417            2 : 
     418            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     419            2 : 
     420            2 :         let mut bytes = vec![];
     421            2 :         read.read_to_end(&mut bytes).await.unwrap();
     422            2 :         assert_eq!(bytes, data);
     423            2 :         assert_eq!(read.state, ProxyParse::None);
     424            2 :     }
     425              : 
     426              :     #[tokio::test]
     427            2 :     async fn test_large_tlv() {
     428            2 :         let tlv = vec![0x55; 32768];
     429            2 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     430            2 : 
     431            2 :         let header = super::HEADER
     432            2 :             // Proxy command, Inet << 4 | Stream
     433            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     434            2 :             // 12 + 3 bytes
     435            2 :             .chain(len.as_slice())
     436            2 :             // src ip
     437            2 :             .chain([55, 56, 57, 58].as_slice())
     438            2 :             // dst ip
     439            2 :             .chain([192, 168, 0, 1].as_slice())
     440            2 :             // src port
     441            2 :             .chain([255, 255].as_slice())
     442            2 :             // dst port
     443            2 :             .chain([1, 1].as_slice())
     444            2 :             // TLV
     445            2 :             .chain(tlv.as_slice());
     446            2 : 
     447            2 :         let extra_data = [0xaa; 256];
     448            2 : 
     449            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     450            2 : 
     451            2 :         let mut bytes = vec![];
     452            2 :         read.read_to_end(&mut bytes).await.unwrap();
     453            2 : 
     454            2 :         assert_eq!(bytes, extra_data);
     455            2 :         assert_eq!(
     456            2 :             read.state,
     457            2 :             ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
     458            2 :         );
     459            2 :     }
     460              : }
        

Generated by: LCOV version 2.1-beta