LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 90.3 % 300 271
Test Date: 2023-09-06 10:18:01 Functions: 64.4 % 73 47

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : 
       3              : use std::{
       4              :     future::poll_fn,
       5              :     future::Future,
       6              :     io,
       7              :     net::SocketAddr,
       8              :     pin::{pin, Pin},
       9              :     task::{ready, Context, Poll},
      10              : };
      11              : 
      12              : use bytes::{Buf, BytesMut};
      13              : use hyper::server::conn::{AddrIncoming, AddrStream};
      14              : use pin_project_lite::pin_project;
      15              : use tls_listener::AsyncAccept;
      16              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      17              : 
      18              : pub struct ProxyProtocolAccept {
      19              :     pub incoming: AddrIncoming,
      20              : }
      21              : 
      22              : pin_project! {
      23              :     pub struct WithClientIp<T> {
      24              :         #[pin]
      25              :         pub inner: T,
      26              :         buf: BytesMut,
      27              :         tlv_bytes: u16,
      28              :         state: ProxyParse,
      29              :     }
      30              : }
      31              : 
      32            5 : #[derive(Clone, PartialEq, Debug)]
      33              : enum ProxyParse {
      34              :     NotStarted,
      35              : 
      36              :     Finished(SocketAddr),
      37              :     None,
      38              : }
      39              : 
      40              : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
      41              :     #[inline]
      42           42 :     fn poll_write(
      43           42 :         self: Pin<&mut Self>,
      44           42 :         cx: &mut Context<'_>,
      45           42 :         buf: &[u8],
      46           42 :     ) -> Poll<Result<usize, io::Error>> {
      47           42 :         self.project().inner.poll_write(cx, buf)
      48           42 :     }
      49              : 
      50              :     #[inline]
      51          550 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      52          550 :         self.project().inner.poll_flush(cx)
      53          550 :     }
      54              : 
      55              :     #[inline]
      56           49 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      57           49 :         self.project().inner.poll_shutdown(cx)
      58           49 :     }
      59              : 
      60              :     #[inline]
      61          377 :     fn poll_write_vectored(
      62          377 :         self: Pin<&mut Self>,
      63          377 :         cx: &mut Context<'_>,
      64          377 :         bufs: &[io::IoSlice<'_>],
      65          377 :     ) -> Poll<Result<usize, io::Error>> {
      66          377 :         self.project().inner.poll_write_vectored(cx, bufs)
      67          377 :     }
      68              : 
      69              :     #[inline]
      70            0 :     fn is_write_vectored(&self) -> bool {
      71            0 :         self.inner.is_write_vectored()
      72            0 :     }
      73              : }
      74              : 
      75              : impl<T> WithClientIp<T> {
      76           69 :     pub fn new(inner: T) -> Self {
      77           69 :         WithClientIp {
      78           69 :             inner,
      79           69 :             buf: BytesMut::with_capacity(128),
      80           69 :             tlv_bytes: 0,
      81           69 :             state: ProxyParse::NotStarted,
      82           69 :         }
      83           69 :     }
      84              : 
      85           22 :     pub fn client_addr(&self) -> Option<SocketAddr> {
      86           22 :         match self.state {
      87            0 :             ProxyParse::Finished(socket) => Some(socket),
      88           22 :             _ => None,
      89              :         }
      90           22 :     }
      91              : }
      92              : 
      93              : impl<T: AsyncRead + Unpin> WithClientIp<T> {
      94           35 :     pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
      95           35 :         match self.state {
      96              :             ProxyParse::NotStarted => {
      97           35 :                 let mut pin = Pin::new(&mut *self);
      98           70 :                 let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
      99           35 :                 match addr {
     100            0 :                     Some(addr) => self.state = ProxyParse::Finished(addr),
     101           35 :                     None => self.state = ProxyParse::None,
     102              :                 }
     103           35 :                 Ok(addr)
     104              :             }
     105            0 :             ProxyParse::Finished(addr) => Ok(Some(addr)),
     106            0 :             ProxyParse::None => Ok(None),
     107              :         }
     108           35 :     }
     109              : }
     110              : 
     111              : /// Proxy Protocol Version 2 Header
     112              : const HEADER: [u8; 12] = [
     113              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
     114              : ];
     115              : 
     116              : impl<T: AsyncRead> WithClientIp<T> {
     117              :     /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
     118              :     /// Version 2 (Binary Format)
     119          126 :     fn poll_client_ip(
     120          126 :         mut self: Pin<&mut Self>,
     121          126 :         cx: &mut Context<'_>,
     122          126 :     ) -> Poll<io::Result<Option<SocketAddr>>> {
     123              :         // The binary header format starts with a constant 12 bytes block containing the protocol signature :
     124              :         //    \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
     125          135 :         while self.buf.len() < 16 {
     126          132 :             let mut this = self.as_mut().project();
     127          132 :             let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
     128              : 
     129              :             // exit for bad header
     130          132 :             let len = usize::min(self.buf.len(), HEADER.len());
     131          132 :             if self.buf[..len] != HEADER[..len] {
     132           66 :                 return Poll::Ready(Ok(None));
     133           66 :             }
     134              : 
     135              :             // if no more bytes available then exit
     136           66 :             if ready!(bytes_read) == 0 {
     137            0 :                 return Poll::Ready(Ok(None));
     138            9 :             };
     139              :         }
     140              : 
     141              :         // The next byte (the 13th one) is the protocol version and command.
     142              :         // The highest four bits contains the version. As of this specification, it must
     143              :         // always be sent as \x2 and the receiver must only accept this value.
     144            3 :         let vc = self.buf[12];
     145            3 :         let version = vc >> 4;
     146            3 :         let command = vc & 0b1111;
     147            3 :         if version != 2 {
     148            0 :             return Poll::Ready(Err(io::Error::new(
     149            0 :                 io::ErrorKind::Other,
     150            0 :                 "invalid proxy protocol version. expected version 2",
     151            0 :             )));
     152            3 :         }
     153            3 :         match command {
     154              :             // the connection was established on purpose by the proxy
     155              :             // without being relayed. The connection endpoints are the sender and the
     156              :             // receiver. Such connections exist when the proxy sends health-checks to the
     157              :             // server. The receiver must accept this connection as valid and must use the
     158              :             // real connection endpoints and discard the protocol block including the
     159              :             // family which is ignored.
     160            0 :             0 => {}
     161              :             // the connection was established on behalf of another node,
     162              :             // and reflects the original connection endpoints. The receiver must then use
     163              :             // the information provided in the protocol block to get original the address.
     164            3 :             1 => {}
     165              :             // other values are unassigned and must not be emitted by senders. Receivers
     166              :             // must drop connections presenting unexpected values here.
     167              :             _ => {
     168            0 :                 return Poll::Ready(Err(io::Error::new(
     169            0 :                     io::ErrorKind::Other,
     170            0 :                     "invalid proxy protocol command. expected local (0) or proxy (1)",
     171            0 :                 )))
     172              :             }
     173              :         };
     174              : 
     175              :         // The 14th byte contains the transport protocol and address family. The highest 4
     176              :         // bits contain the address family, the lowest 4 bits contain the protocol.
     177            3 :         let ft = self.buf[13];
     178            3 :         let address_length = match ft {
     179              :             // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     180              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     181              :             // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     182              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     183            2 :             0x11 | 0x12 => 12,
     184              :             // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     185              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     186              :             // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     187              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     188            1 :             0x21 | 0x22 => 36,
     189              :             // unspecified or unix stream. ignore the addresses
     190            0 :             _ => 0,
     191              :         };
     192              : 
     193              :         // The 15th and 16th bytes is the address length in bytes in network endian order.
     194              :         // It is used so that the receiver knows how many address bytes to skip even when
     195              :         // it does not implement the presented protocol. Thus the length of the protocol
     196              :         // header in bytes is always exactly 16 + this value. When a sender presents a
     197              :         // LOCAL connection, it should not present any address so it sets this field to
     198              :         // zero. Receivers MUST always consider this field to skip the appropriate number
     199              :         // of bytes and must not assume zero is presented for LOCAL connections. When a
     200              :         // receiver accepts an incoming connection showing an UNSPEC address family or
     201              :         // protocol, it may or may not decide to log the address information if present.
     202            3 :         let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
     203            3 :         if remaining_length < address_length {
     204            0 :             return Poll::Ready(Err(io::Error::new(
     205            0 :                 io::ErrorKind::Other,
     206            0 :                 "invalid proxy protocol length. not enough to fit requested IP addresses",
     207            0 :             )));
     208            3 :         }
     209              : 
     210           15 :         while self.buf.len() < 16 + address_length as usize {
     211           12 :             let mut this = self.as_mut().project();
     212           12 :             if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
     213            0 :                 return Poll::Ready(Err(io::Error::new(
     214            0 :                     io::ErrorKind::UnexpectedEof,
     215            0 :                     "stream closed while waiting for proxy protocol addresses",
     216            0 :                 )));
     217           12 :             }
     218              :         }
     219              : 
     220            3 :         let this = self.as_mut().project();
     221            3 : 
     222            3 :         // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
     223            3 :         // discard the header we have parsed
     224            3 :         this.buf.advance(16);
     225            3 : 
     226            3 :         // Starting from the 17th byte, addresses are presented in network byte order.
     227            3 :         // The address order is always the same :
     228            3 :         //   - source layer 3 address in network byte order
     229            3 :         //   - destination layer 3 address in network byte order
     230            3 :         //   - source layer 4 address if any, in network byte order (port)
     231            3 :         //   - destination layer 4 address if any, in network byte order (port)
     232            3 :         let addresses = this.buf.split_to(address_length as usize);
     233            3 :         let socket = match address_length {
     234              :             12 => {
     235            2 :                 let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     236            2 :                 let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     237            2 :                 Some(SocketAddr::from((src_addr, src_port)))
     238              :             }
     239              :             36 => {
     240            1 :                 let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     241            1 :                 let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     242            1 :                 Some(SocketAddr::from((src_addr, src_port)))
     243              :             }
     244            0 :             _ => None,
     245              :         };
     246              : 
     247            3 :         *this.tlv_bytes = remaining_length - address_length;
     248            3 :         self.as_mut().skip_tlv_inner();
     249            3 : 
     250            3 :         Poll::Ready(Ok(socket))
     251          126 :     }
     252              : 
     253              :     #[cold]
     254           56 :     fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     255           56 :         let ip = ready!(self.as_mut().poll_client_ip(cx)?);
     256           34 :         match ip {
     257            3 :             Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
     258           31 :             None => *self.as_mut().project().state = ProxyParse::None,
     259              :         }
     260           34 :         Poll::Ready(Ok(()))
     261           56 :     }
     262              : 
     263              :     #[cold]
     264           34 :     fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     265           34 :         let mut this = self.as_mut().project();
     266              :         // we know that this.buf is empty
     267           34 :         debug_assert_eq!(this.buf.len(), 0);
     268              : 
     269           34 :         this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
     270           34 :         ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
     271           34 :         self.skip_tlv_inner();
     272           34 : 
     273           34 :         Poll::Ready(Ok(()))
     274           34 :     }
     275              : 
     276           37 :     fn skip_tlv_inner(self: Pin<&mut Self>) {
     277           37 :         let tlv_bytes_read = match u16::try_from(self.buf.len()) {
     278              :             // we read more than u16::MAX therefore we must have read the full tlv_bytes
     279            0 :             Err(_) => self.tlv_bytes,
     280              :             // we might not have read the full tlv bytes yet
     281           37 :             Ok(n) => u16::min(n, self.tlv_bytes),
     282              :         };
     283           37 :         let this = self.project();
     284           37 :         *this.tlv_bytes -= tlv_bytes_read;
     285           37 :         this.buf.advance(tlv_bytes_read as usize);
     286           37 :     }
     287              : }
     288              : 
     289              : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
     290              :     #[inline]
     291          962 :     fn poll_read(
     292          962 :         mut self: Pin<&mut Self>,
     293          962 :         cx: &mut Context<'_>,
     294          962 :         buf: &mut ReadBuf<'_>,
     295          962 :     ) -> Poll<io::Result<()>> {
     296          962 :         // I'm assuming these 3 comparisons will be easy to branch predict.
     297          962 :         // especially with the cold attributes
     298          962 :         // which should make this read wrapper almost invisible
     299          962 : 
     300          962 :         if let ProxyParse::NotStarted = self.state {
     301           56 :             ready!(self.as_mut().read_ip(cx)?);
     302          906 :         }
     303              : 
     304          974 :         while self.tlv_bytes > 0 {
     305           34 :             ready!(self.as_mut().skip_tlv(cx)?)
     306              :         }
     307              : 
     308          940 :         let this = self.project();
     309          940 :         if this.buf.is_empty() {
     310          872 :             this.inner.poll_read(cx, buf)
     311              :         } else {
     312              :             // we know that tlv_bytes is 0
     313           68 :             debug_assert_eq!(*this.tlv_bytes, 0);
     314              : 
     315           68 :             let write = usize::min(this.buf.len(), buf.remaining());
     316           68 :             let slice = this.buf.split_to(write).freeze();
     317           68 :             buf.put_slice(&slice);
     318           68 : 
     319           68 :             // reset the allocation so it can be freed
     320           68 :             if this.buf.is_empty() {
     321           66 :                 *this.buf = BytesMut::new();
     322           66 :             }
     323              : 
     324           68 :             Poll::Ready(Ok(()))
     325              :         }
     326          962 :     }
     327              : }
     328              : 
     329              : impl AsyncAccept for ProxyProtocolAccept {
     330              :     type Connection = WithClientIp<AddrStream>;
     331              : 
     332              :     type Error = io::Error;
     333              : 
     334          168 :     fn poll_accept(
     335          168 :         mut self: Pin<&mut Self>,
     336          168 :         cx: &mut Context<'_>,
     337          168 :     ) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
     338          168 :         let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
     339           22 :         let Some(conn) = conn else {
     340            0 :             return Poll::Ready(None);
     341              :         };
     342              : 
     343           22 :         Poll::Ready(Some(Ok(WithClientIp::new(conn))))
     344          168 :     }
     345              : }
     346              : 
     347              : #[cfg(test)]
     348              : mod tests {
     349              :     use std::pin::pin;
     350              : 
     351              :     use tokio::io::AsyncReadExt;
     352              : 
     353              :     use crate::protocol2::{ProxyParse, WithClientIp};
     354              : 
     355            1 :     #[tokio::test]
     356            1 :     async fn test_ipv4() {
     357            1 :         let header = super::HEADER
     358            1 :             // Proxy command, IPV4 | TCP
     359            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     360            1 :             // 12 + 3 bytes
     361            1 :             .chain([0, 15].as_slice())
     362            1 :             // src ip
     363            1 :             .chain([127, 0, 0, 1].as_slice())
     364            1 :             // dst ip
     365            1 :             .chain([192, 168, 0, 1].as_slice())
     366            1 :             // src port
     367            1 :             .chain([255, 255].as_slice())
     368            1 :             // dst port
     369            1 :             .chain([1, 1].as_slice())
     370            1 :             // TLV
     371            1 :             .chain([1, 2, 3].as_slice());
     372            1 : 
     373            1 :         let extra_data = [0x55; 256];
     374            1 : 
     375            1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     376            1 : 
     377            1 :         let mut bytes = vec![];
     378            1 :         read.read_to_end(&mut bytes).await.unwrap();
     379            1 : 
     380            1 :         assert_eq!(bytes, extra_data);
     381            1 :         assert_eq!(
     382            1 :             read.state,
     383            1 :             ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
     384            1 :         );
     385              :     }
     386              : 
     387            1 :     #[tokio::test]
     388            1 :     async fn test_ipv6() {
     389            1 :         let header = super::HEADER
     390            1 :             // Proxy command, IPV6 | UDP
     391            1 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     392            1 :             // 36 + 3 bytes
     393            1 :             .chain([0, 39].as_slice())
     394            1 :             // src ip
     395            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     396            1 :             // dst ip
     397            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     398            1 :             // src port
     399            1 :             .chain([1, 1].as_slice())
     400            1 :             // dst port
     401            1 :             .chain([255, 255].as_slice())
     402            1 :             // TLV
     403            1 :             .chain([1, 2, 3].as_slice());
     404            1 : 
     405            1 :         let extra_data = [0x55; 256];
     406            1 : 
     407            1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     408            1 : 
     409            1 :         let mut bytes = vec![];
     410            1 :         read.read_to_end(&mut bytes).await.unwrap();
     411            1 : 
     412            1 :         assert_eq!(bytes, extra_data);
     413            1 :         assert_eq!(
     414            1 :             read.state,
     415            1 :             ProxyParse::Finished(
     416            1 :                 ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     417            1 :             )
     418            1 :         );
     419              :     }
     420              : 
     421            1 :     #[tokio::test]
     422            1 :     async fn test_invalid() {
     423            1 :         let data = [0x55; 256];
     424            1 : 
     425            1 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     426            1 : 
     427            1 :         let mut bytes = vec![];
     428            1 :         read.read_to_end(&mut bytes).await.unwrap();
     429            1 :         assert_eq!(bytes, data);
     430            1 :         assert_eq!(read.state, ProxyParse::None);
     431              :     }
     432              : 
     433            1 :     #[tokio::test]
     434            1 :     async fn test_short() {
     435            1 :         let data = [0x55; 10];
     436            1 : 
     437            1 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     438            1 : 
     439            1 :         let mut bytes = vec![];
     440            1 :         read.read_to_end(&mut bytes).await.unwrap();
     441            1 :         assert_eq!(bytes, data);
     442            1 :         assert_eq!(read.state, ProxyParse::None);
     443              :     }
     444              : 
     445            1 :     #[tokio::test]
     446            1 :     async fn test_large_tlv() {
     447            1 :         let tlv = vec![0x55; 32768];
     448            1 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     449            1 : 
     450            1 :         let header = super::HEADER
     451            1 :             // Proxy command, Inet << 4 | Stream
     452            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     453            1 :             // 12 + 3 bytes
     454            1 :             .chain(len.as_slice())
     455            1 :             // src ip
     456            1 :             .chain([55, 56, 57, 58].as_slice())
     457            1 :             // dst ip
     458            1 :             .chain([192, 168, 0, 1].as_slice())
     459            1 :             // src port
     460            1 :             .chain([255, 255].as_slice())
     461            1 :             // dst port
     462            1 :             .chain([1, 1].as_slice())
     463            1 :             // TLV
     464            1 :             .chain(tlv.as_slice());
     465            1 : 
     466            1 :         let extra_data = [0xaa; 256];
     467            1 : 
     468            1 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     469            1 : 
     470            1 :         let mut bytes = vec![];
     471            1 :         read.read_to_end(&mut bytes).await.unwrap();
     472            1 : 
     473            1 :         assert_eq!(bytes, extra_data);
     474            1 :         assert_eq!(
     475            1 :             read.state,
     476            1 :             ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
     477            1 :         );
     478              :     }
     479              : }
        

Generated by: LCOV version 2.1-beta