LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: 36bb8dd7c7efcb53483d1a7d9f7cb33e8406dcf0.info Lines: 72.1 % 344 248
Test Date: 2024-04-08 10:22:05 Functions: 31.0 % 87 27

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : 
       3              : use std::{
       4              :     future::{poll_fn, Future},
       5              :     io,
       6              :     net::SocketAddr,
       7              :     pin::{pin, Pin},
       8              :     sync::Mutex,
       9              :     task::{ready, Context, Poll},
      10              : };
      11              : 
      12              : use bytes::{Buf, BytesMut};
      13              : use hyper::server::accept::Accept;
      14              : use hyper::server::conn::{AddrIncoming, AddrStream};
      15              : use metrics::IntCounterPairGuard;
      16              : use pin_project_lite::pin_project;
      17              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      18              : use uuid::Uuid;
      19              : 
      20              : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
      21              : 
      22              : pub struct ProxyProtocolAccept {
      23              :     pub incoming: AddrIncoming,
      24              :     pub protocol: &'static str,
      25              : }
      26              : 
      27              : pin_project! {
      28              :     pub struct WithClientIp<T> {
      29              :         #[pin]
      30              :         pub inner: T,
      31              :         buf: BytesMut,
      32              :         tlv_bytes: u16,
      33              :         state: ProxyParse,
      34              :     }
      35              : }
      36              : 
      37              : #[derive(Clone, PartialEq, Debug)]
      38              : enum ProxyParse {
      39              :     NotStarted,
      40              : 
      41              :     Finished(SocketAddr),
      42              :     None,
      43              : }
      44              : 
      45              : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
      46              :     #[inline]
      47           30 :     fn poll_write(
      48           30 :         self: Pin<&mut Self>,
      49           30 :         cx: &mut Context<'_>,
      50           30 :         buf: &[u8],
      51           30 :     ) -> Poll<Result<usize, io::Error>> {
      52           30 :         self.project().inner.poll_write(cx, buf)
      53           30 :     }
      54              : 
      55              :     #[inline]
      56          148 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      57          148 :         self.project().inner.poll_flush(cx)
      58          148 :     }
      59              : 
      60              :     #[inline]
      61            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      62            0 :         self.project().inner.poll_shutdown(cx)
      63            0 :     }
      64              : 
      65              :     #[inline]
      66          118 :     fn poll_write_vectored(
      67          118 :         self: Pin<&mut Self>,
      68          118 :         cx: &mut Context<'_>,
      69          118 :         bufs: &[io::IoSlice<'_>],
      70          118 :     ) -> Poll<Result<usize, io::Error>> {
      71          118 :         self.project().inner.poll_write_vectored(cx, bufs)
      72          118 :     }
      73              : 
      74              :     #[inline]
      75            0 :     fn is_write_vectored(&self) -> bool {
      76            0 :         self.inner.is_write_vectored()
      77            0 :     }
      78              : }
      79              : 
      80              : impl<T> WithClientIp<T> {
      81           40 :     pub fn new(inner: T) -> Self {
      82           40 :         WithClientIp {
      83           40 :             inner,
      84           40 :             buf: BytesMut::with_capacity(128),
      85           40 :             tlv_bytes: 0,
      86           40 :             state: ProxyParse::NotStarted,
      87           40 :         }
      88           40 :     }
      89              : 
      90            0 :     pub fn client_addr(&self) -> Option<SocketAddr> {
      91            0 :         match self.state {
      92            0 :             ProxyParse::Finished(socket) => Some(socket),
      93            0 :             _ => None,
      94              :         }
      95            0 :     }
      96              : }
      97              : 
      98              : impl<T: AsyncRead + Unpin> WithClientIp<T> {
      99            0 :     pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
     100            0 :         match self.state {
     101              :             ProxyParse::NotStarted => {
     102            0 :                 let mut pin = Pin::new(&mut *self);
     103            0 :                 let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
     104            0 :                 match addr {
     105            0 :                     Some(addr) => self.state = ProxyParse::Finished(addr),
     106            0 :                     None => self.state = ProxyParse::None,
     107              :                 }
     108            0 :                 Ok(addr)
     109              :             }
     110            0 :             ProxyParse::Finished(addr) => Ok(Some(addr)),
     111            0 :             ProxyParse::None => Ok(None),
     112              :         }
     113            0 :     }
     114              : }
     115              : 
     116              : /// Proxy Protocol Version 2 Header
     117              : const HEADER: [u8; 12] = [
     118              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
     119              : ];
     120              : 
     121              : impl<T: AsyncRead> WithClientIp<T> {
     122              :     /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
     123              :     /// Version 2 (Binary Format)
     124           40 :     fn poll_client_ip(
     125           40 :         mut self: Pin<&mut Self>,
     126           40 :         cx: &mut Context<'_>,
     127           40 :     ) -> Poll<io::Result<Option<SocketAddr>>> {
     128              :         // The binary header format starts with a constant 12 bytes block containing the protocol signature :
     129              :         //    \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
     130           58 :         while self.buf.len() < 16 {
     131           52 :             let mut this = self.as_mut().project();
     132           52 :             let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
     133              : 
     134              :             // exit for bad header
     135           52 :             let len = usize::min(self.buf.len(), HEADER.len());
     136           52 :             if self.buf[..len] != HEADER[..len] {
     137           34 :                 return Poll::Ready(Ok(None));
     138           18 :             }
     139              : 
     140              :             // if no more bytes available then exit
     141           18 :             if ready!(bytes_read) == 0 {
     142            0 :                 return Poll::Ready(Ok(None));
     143           18 :             };
     144              :         }
     145              : 
     146              :         // The next byte (the 13th one) is the protocol version and command.
     147              :         // The highest four bits contains the version. As of this specification, it must
     148              :         // always be sent as \x2 and the receiver must only accept this value.
     149            6 :         let vc = self.buf[12];
     150            6 :         let version = vc >> 4;
     151            6 :         let command = vc & 0b1111;
     152            6 :         if version != 2 {
     153            0 :             return Poll::Ready(Err(io::Error::new(
     154            0 :                 io::ErrorKind::Other,
     155            0 :                 "invalid proxy protocol version. expected version 2",
     156            0 :             )));
     157            6 :         }
     158            6 :         match command {
     159              :             // the connection was established on purpose by the proxy
     160              :             // without being relayed. The connection endpoints are the sender and the
     161              :             // receiver. Such connections exist when the proxy sends health-checks to the
     162              :             // server. The receiver must accept this connection as valid and must use the
     163              :             // real connection endpoints and discard the protocol block including the
     164              :             // family which is ignored.
     165            0 :             0 => {}
     166              :             // the connection was established on behalf of another node,
     167              :             // and reflects the original connection endpoints. The receiver must then use
     168              :             // the information provided in the protocol block to get original the address.
     169            6 :             1 => {}
     170              :             // other values are unassigned and must not be emitted by senders. Receivers
     171              :             // must drop connections presenting unexpected values here.
     172              :             _ => {
     173            0 :                 return Poll::Ready(Err(io::Error::new(
     174            0 :                     io::ErrorKind::Other,
     175            0 :                     "invalid proxy protocol command. expected local (0) or proxy (1)",
     176            0 :                 )))
     177              :             }
     178              :         };
     179              : 
     180              :         // The 14th byte contains the transport protocol and address family. The highest 4
     181              :         // bits contain the address family, the lowest 4 bits contain the protocol.
     182            6 :         let ft = self.buf[13];
     183            6 :         let address_length = match ft {
     184              :             // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
     185              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     186              :             // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
     187              :             //   protocol family. Address length is 2*4 + 2*2 = 12 bytes.
     188            4 :             0x11 | 0x12 => 12,
     189              :             // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
     190              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     191              :             // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
     192              :             //   protocol family. Address length is 2*16 + 2*2 = 36 bytes.
     193            2 :             0x21 | 0x22 => 36,
     194              :             // unspecified or unix stream. ignore the addresses
     195            0 :             _ => 0,
     196              :         };
     197              : 
     198              :         // The 15th and 16th bytes is the address length in bytes in network endian order.
     199              :         // It is used so that the receiver knows how many address bytes to skip even when
     200              :         // it does not implement the presented protocol. Thus the length of the protocol
     201              :         // header in bytes is always exactly 16 + this value. When a sender presents a
     202              :         // LOCAL connection, it should not present any address so it sets this field to
     203              :         // zero. Receivers MUST always consider this field to skip the appropriate number
     204              :         // of bytes and must not assume zero is presented for LOCAL connections. When a
     205              :         // receiver accepts an incoming connection showing an UNSPEC address family or
     206              :         // protocol, it may or may not decide to log the address information if present.
     207            6 :         let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
     208            6 :         if remaining_length < address_length {
     209            0 :             return Poll::Ready(Err(io::Error::new(
     210            0 :                 io::ErrorKind::Other,
     211            0 :                 "invalid proxy protocol length. not enough to fit requested IP addresses",
     212            0 :             )));
     213            6 :         }
     214              : 
     215           30 :         while self.buf.len() < 16 + address_length as usize {
     216           24 :             let mut this = self.as_mut().project();
     217           24 :             if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
     218            0 :                 return Poll::Ready(Err(io::Error::new(
     219            0 :                     io::ErrorKind::UnexpectedEof,
     220            0 :                     "stream closed while waiting for proxy protocol addresses",
     221            0 :                 )));
     222           24 :             }
     223              :         }
     224              : 
     225            6 :         let this = self.as_mut().project();
     226            6 : 
     227            6 :         // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
     228            6 :         // discard the header we have parsed
     229            6 :         this.buf.advance(16);
     230            6 : 
     231            6 :         // Starting from the 17th byte, addresses are presented in network byte order.
     232            6 :         // The address order is always the same :
     233            6 :         //   - source layer 3 address in network byte order
     234            6 :         //   - destination layer 3 address in network byte order
     235            6 :         //   - source layer 4 address if any, in network byte order (port)
     236            6 :         //   - destination layer 4 address if any, in network byte order (port)
     237            6 :         let addresses = this.buf.split_to(address_length as usize);
     238            6 :         let socket = match address_length {
     239              :             12 => {
     240            4 :                 let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
     241            4 :                 let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
     242            4 :                 Some(SocketAddr::from((src_addr, src_port)))
     243              :             }
     244              :             36 => {
     245            2 :                 let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
     246            2 :                 let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
     247            2 :                 Some(SocketAddr::from((src_addr, src_port)))
     248              :             }
     249            0 :             _ => None,
     250              :         };
     251              : 
     252            6 :         *this.tlv_bytes = remaining_length - address_length;
     253            6 :         self.as_mut().skip_tlv_inner();
     254            6 : 
     255            6 :         Poll::Ready(Ok(socket))
     256           40 :     }
     257              : 
     258              :     #[cold]
     259           40 :     fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     260           40 :         let ip = ready!(self.as_mut().poll_client_ip(cx)?);
     261           40 :         match ip {
     262            6 :             Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
     263           34 :             None => *self.as_mut().project().state = ProxyParse::None,
     264              :         }
     265           40 :         Poll::Ready(Ok(()))
     266           40 :     }
     267              : 
     268              :     #[cold]
     269           68 :     fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
     270           68 :         let mut this = self.as_mut().project();
     271           68 :         // we know that this.buf is empty
     272           68 :         debug_assert_eq!(this.buf.len(), 0);
     273              : 
     274           68 :         this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
     275           68 :         ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
     276           68 :         self.skip_tlv_inner();
     277           68 : 
     278           68 :         Poll::Ready(Ok(()))
     279           68 :     }
     280              : 
     281           74 :     fn skip_tlv_inner(self: Pin<&mut Self>) {
     282           74 :         let tlv_bytes_read = match u16::try_from(self.buf.len()) {
     283              :             // we read more than u16::MAX therefore we must have read the full tlv_bytes
     284            0 :             Err(_) => self.tlv_bytes,
     285              :             // we might not have read the full tlv bytes yet
     286           74 :             Ok(n) => u16::min(n, self.tlv_bytes),
     287              :         };
     288           74 :         let this = self.project();
     289           74 :         *this.tlv_bytes -= tlv_bytes_read;
     290           74 :         this.buf.advance(tlv_bytes_read as usize);
     291           74 :     }
     292              : }
     293              : 
     294              : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
     295              :     #[inline]
     296          334 :     fn poll_read(
     297          334 :         mut self: Pin<&mut Self>,
     298          334 :         cx: &mut Context<'_>,
     299          334 :         buf: &mut ReadBuf<'_>,
     300          334 :     ) -> Poll<io::Result<()>> {
     301          334 :         // I'm assuming these 3 comparisons will be easy to branch predict.
     302          334 :         // especially with the cold attributes
     303          334 :         // which should make this read wrapper almost invisible
     304          334 : 
     305          334 :         if let ProxyParse::NotStarted = self.state {
     306           40 :             ready!(self.as_mut().read_ip(cx)?);
     307          294 :         }
     308              : 
     309          402 :         while self.tlv_bytes > 0 {
     310           68 :             ready!(self.as_mut().skip_tlv(cx)?)
     311              :         }
     312              : 
     313          334 :         let this = self.project();
     314          334 :         if this.buf.is_empty() {
     315          296 :             this.inner.poll_read(cx, buf)
     316              :         } else {
     317              :             // we know that tlv_bytes is 0
     318           38 :             debug_assert_eq!(*this.tlv_bytes, 0);
     319              : 
     320           38 :             let write = usize::min(this.buf.len(), buf.remaining());
     321           38 :             let slice = this.buf.split_to(write).freeze();
     322           38 :             buf.put_slice(&slice);
     323           38 : 
     324           38 :             // reset the allocation so it can be freed
     325           38 :             if this.buf.is_empty() {
     326           34 :                 *this.buf = BytesMut::new();
     327           34 :             }
     328              : 
     329           38 :             Poll::Ready(Ok(()))
     330              :         }
     331          334 :     }
     332              : }
     333              : 
     334              : impl Accept for ProxyProtocolAccept {
     335              :     type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
     336              : 
     337              :     type Error = io::Error;
     338              : 
     339            0 :     fn poll_accept(
     340            0 :         mut self: Pin<&mut Self>,
     341            0 :         cx: &mut Context<'_>,
     342            0 :     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
     343            0 :         let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
     344              : 
     345            0 :         let conn_id = uuid::Uuid::new_v4();
     346            0 :         let span = tracing::info_span!("http_conn", ?conn_id);
     347              :         {
     348            0 :             let _enter = span.enter();
     349            0 :             tracing::info!("accepted new TCP connection");
     350              :         }
     351              : 
     352            0 :         let Some(conn) = conn else {
     353            0 :             return Poll::Ready(None);
     354              :         };
     355              : 
     356            0 :         Poll::Ready(Some(Ok(WithConnectionGuard {
     357            0 :             inner: WithClientIp::new(conn),
     358            0 :             connection_id: Uuid::new_v4(),
     359            0 :             gauge: Mutex::new(Some(
     360            0 :                 NUM_CLIENT_CONNECTION_GAUGE
     361            0 :                     .with_label_values(&[self.protocol])
     362            0 :                     .guard(),
     363            0 :             )),
     364            0 :             span,
     365            0 :         })))
     366            0 :     }
     367              : }
     368              : 
     369              : pin_project! {
     370              :     pub struct WithConnectionGuard<T> {
     371              :         #[pin]
     372              :         pub inner: T,
     373              :         pub connection_id: Uuid,
     374              :         pub gauge: Mutex<Option<IntCounterPairGuard>>,
     375              :         pub span: tracing::Span,
     376              :     }
     377              : 
     378              :     impl<S> PinnedDrop for WithConnectionGuard<S> {
     379              :         fn drop(this: Pin<&mut Self>) {
     380              :             let _enter = this.span.enter();
     381            0 :             tracing::info!("HTTP connection closed")
     382              :         }
     383              :     }
     384              : }
     385              : 
     386              : impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
     387              :     #[inline]
     388            0 :     fn poll_write(
     389            0 :         self: Pin<&mut Self>,
     390            0 :         cx: &mut Context<'_>,
     391            0 :         buf: &[u8],
     392            0 :     ) -> Poll<Result<usize, io::Error>> {
     393            0 :         self.project().inner.poll_write(cx, buf)
     394            0 :     }
     395              : 
     396              :     #[inline]
     397            0 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
     398            0 :         self.project().inner.poll_flush(cx)
     399            0 :     }
     400              : 
     401              :     #[inline]
     402            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
     403            0 :         self.project().inner.poll_shutdown(cx)
     404            0 :     }
     405              : 
     406              :     #[inline]
     407            0 :     fn poll_write_vectored(
     408            0 :         self: Pin<&mut Self>,
     409            0 :         cx: &mut Context<'_>,
     410            0 :         bufs: &[io::IoSlice<'_>],
     411            0 :     ) -> Poll<Result<usize, io::Error>> {
     412            0 :         self.project().inner.poll_write_vectored(cx, bufs)
     413            0 :     }
     414              : 
     415              :     #[inline]
     416            0 :     fn is_write_vectored(&self) -> bool {
     417            0 :         self.inner.is_write_vectored()
     418            0 :     }
     419              : }
     420              : 
     421              : impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
     422            0 :     fn poll_read(
     423            0 :         self: Pin<&mut Self>,
     424            0 :         cx: &mut Context<'_>,
     425            0 :         buf: &mut ReadBuf<'_>,
     426            0 :     ) -> Poll<io::Result<()>> {
     427            0 :         self.project().inner.poll_read(cx, buf)
     428            0 :     }
     429              : }
     430              : 
     431              : #[cfg(test)]
     432              : mod tests {
     433              :     use std::pin::pin;
     434              : 
     435              :     use tokio::io::AsyncReadExt;
     436              : 
     437              :     use crate::protocol2::{ProxyParse, WithClientIp};
     438              : 
     439              :     #[tokio::test]
     440            2 :     async fn test_ipv4() {
     441            2 :         let header = super::HEADER
     442            2 :             // Proxy command, IPV4 | TCP
     443            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     444            2 :             // 12 + 3 bytes
     445            2 :             .chain([0, 15].as_slice())
     446            2 :             // src ip
     447            2 :             .chain([127, 0, 0, 1].as_slice())
     448            2 :             // dst ip
     449            2 :             .chain([192, 168, 0, 1].as_slice())
     450            2 :             // src port
     451            2 :             .chain([255, 255].as_slice())
     452            2 :             // dst port
     453            2 :             .chain([1, 1].as_slice())
     454            2 :             // TLV
     455            2 :             .chain([1, 2, 3].as_slice());
     456            2 : 
     457            2 :         let extra_data = [0x55; 256];
     458            2 : 
     459            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     460            2 : 
     461            2 :         let mut bytes = vec![];
     462            2 :         read.read_to_end(&mut bytes).await.unwrap();
     463            2 : 
     464            2 :         assert_eq!(bytes, extra_data);
     465            2 :         assert_eq!(
     466            2 :             read.state,
     467            2 :             ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
     468            2 :         );
     469            2 :     }
     470              : 
     471              :     #[tokio::test]
     472            2 :     async fn test_ipv6() {
     473            2 :         let header = super::HEADER
     474            2 :             // Proxy command, IPV6 | UDP
     475            2 :             .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
     476            2 :             // 36 + 3 bytes
     477            2 :             .chain([0, 39].as_slice())
     478            2 :             // src ip
     479            2 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     480            2 :             // dst ip
     481            2 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     482            2 :             // src port
     483            2 :             .chain([1, 1].as_slice())
     484            2 :             // dst port
     485            2 :             .chain([255, 255].as_slice())
     486            2 :             // TLV
     487            2 :             .chain([1, 2, 3].as_slice());
     488            2 : 
     489            2 :         let extra_data = [0x55; 256];
     490            2 : 
     491            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     492            2 : 
     493            2 :         let mut bytes = vec![];
     494            2 :         read.read_to_end(&mut bytes).await.unwrap();
     495            2 : 
     496            2 :         assert_eq!(bytes, extra_data);
     497            2 :         assert_eq!(
     498            2 :             read.state,
     499            2 :             ProxyParse::Finished(
     500            2 :                 ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     501            2 :             )
     502            2 :         );
     503            2 :     }
     504              : 
     505              :     #[tokio::test]
     506            2 :     async fn test_invalid() {
     507            2 :         let data = [0x55; 256];
     508            2 : 
     509            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     510            2 : 
     511            2 :         let mut bytes = vec![];
     512            2 :         read.read_to_end(&mut bytes).await.unwrap();
     513            2 :         assert_eq!(bytes, data);
     514            2 :         assert_eq!(read.state, ProxyParse::None);
     515            2 :     }
     516              : 
     517              :     #[tokio::test]
     518            2 :     async fn test_short() {
     519            2 :         let data = [0x55; 10];
     520            2 : 
     521            2 :         let mut read = pin!(WithClientIp::new(data.as_slice()));
     522            2 : 
     523            2 :         let mut bytes = vec![];
     524            2 :         read.read_to_end(&mut bytes).await.unwrap();
     525            2 :         assert_eq!(bytes, data);
     526            2 :         assert_eq!(read.state, ProxyParse::None);
     527            2 :     }
     528              : 
     529              :     #[tokio::test]
     530            2 :     async fn test_large_tlv() {
     531            2 :         let tlv = vec![0x55; 32768];
     532            2 :         let len = (12 + tlv.len() as u16).to_be_bytes();
     533            2 : 
     534            2 :         let header = super::HEADER
     535            2 :             // Proxy command, Inet << 4 | Stream
     536            2 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     537            2 :             // 12 + 3 bytes
     538            2 :             .chain(len.as_slice())
     539            2 :             // src ip
     540            2 :             .chain([55, 56, 57, 58].as_slice())
     541            2 :             // dst ip
     542            2 :             .chain([192, 168, 0, 1].as_slice())
     543            2 :             // src port
     544            2 :             .chain([255, 255].as_slice())
     545            2 :             // dst port
     546            2 :             .chain([1, 1].as_slice())
     547            2 :             // TLV
     548            2 :             .chain(tlv.as_slice());
     549            2 : 
     550            2 :         let extra_data = [0xaa; 256];
     551            2 : 
     552            2 :         let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
     553            2 : 
     554            2 :         let mut bytes = vec![];
     555            2 :         read.read_to_end(&mut bytes).await.unwrap();
     556            2 : 
     557            2 :         assert_eq!(bytes, extra_data);
     558            2 :         assert_eq!(
     559            2 :             read.state,
     560            2 :             ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
     561            2 :         );
     562            2 :     }
     563              : }
        

Generated by: LCOV version 2.1-beta