LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 80.6 % 320 258
Test Date: 2025-03-12 18:05:02 Functions: 58.0 % 69 40

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
       3              : 
       4              : use core::fmt;
       5              : use std::io;
       6              : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
       7              : use std::pin::Pin;
       8              : use std::task::{Context, Poll};
       9              : 
      10              : use bytes::{Buf, Bytes, BytesMut};
      11              : use pin_project_lite::pin_project;
      12              : use smol_str::SmolStr;
      13              : use strum_macros::FromRepr;
      14              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      15              : use zerocopy::{FromBytes, FromZeroes};
      16              : 
      17              : pin_project! {
      18              :     /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
      19              :     pub(crate) struct ChainRW<T> {
      20              :         #[pin]
      21              :         pub(crate) inner: T,
      22              :         buf: BytesMut,
      23              :     }
      24              : }
      25              : 
      26              : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
      27              :     #[inline]
      28           15 :     fn poll_write(
      29           15 :         self: Pin<&mut Self>,
      30           15 :         cx: &mut Context<'_>,
      31           15 :         buf: &[u8],
      32           15 :     ) -> Poll<Result<usize, io::Error>> {
      33           15 :         self.project().inner.poll_write(cx, buf)
      34           15 :     }
      35              : 
      36              :     #[inline]
      37           69 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      38           69 :         self.project().inner.poll_flush(cx)
      39           69 :     }
      40              : 
      41              :     #[inline]
      42            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      43            0 :         self.project().inner.poll_shutdown(cx)
      44            0 :     }
      45              : 
      46              :     #[inline]
      47           54 :     fn poll_write_vectored(
      48           54 :         self: Pin<&mut Self>,
      49           54 :         cx: &mut Context<'_>,
      50           54 :         bufs: &[io::IoSlice<'_>],
      51           54 :     ) -> Poll<Result<usize, io::Error>> {
      52           54 :         self.project().inner.poll_write_vectored(cx, bufs)
      53           54 :     }
      54              : 
      55              :     #[inline]
      56            0 :     fn is_write_vectored(&self) -> bool {
      57            0 :         self.inner.is_write_vectored()
      58            0 :     }
      59              : }
      60              : 
      61              : /// Proxy Protocol Version 2 Header
      62              : const SIGNATURE: [u8; 12] = [
      63              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
      64              : ];
      65              : 
      66              : const LOCAL_V2: u8 = 0x20;
      67              : const PROXY_V2: u8 = 0x21;
      68              : 
      69              : const TCP_OVER_IPV4: u8 = 0x11;
      70              : const UDP_OVER_IPV4: u8 = 0x12;
      71              : const TCP_OVER_IPV6: u8 = 0x21;
      72              : const UDP_OVER_IPV6: u8 = 0x22;
      73              : 
      74              : #[derive(PartialEq, Eq, Clone, Debug)]
      75              : pub struct ConnectionInfo {
      76              :     pub addr: SocketAddr,
      77              :     pub extra: Option<ConnectionInfoExtra>,
      78              : }
      79              : 
      80              : #[derive(PartialEq, Eq, Clone, Debug)]
      81              : pub enum ConnectHeader {
      82              :     Missing,
      83              :     Local,
      84              :     Proxy(ConnectionInfo),
      85              : }
      86              : 
      87              : impl fmt::Display for ConnectionInfo {
      88            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      89            0 :         match &self.extra {
      90            0 :             None => self.addr.ip().fmt(f),
      91            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => {
      92            0 :                 write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
      93              :             }
      94            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => {
      95            0 :                 write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
      96              :             }
      97              :         }
      98            0 :     }
      99              : }
     100              : 
     101              : #[derive(PartialEq, Eq, Clone, Debug)]
     102              : pub enum ConnectionInfoExtra {
     103              :     Aws { vpce_id: SmolStr },
     104              :     Azure { link_id: u32 },
     105              : }
     106              : 
     107           21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
     108           21 :     mut read: T,
     109           21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
     110           21 :     let mut buf = BytesMut::with_capacity(128);
     111            4 :     let header = loop {
     112           29 :         let bytes_read = read.read_buf(&mut buf).await?;
     113              : 
     114              :         // exit for bad header signature
     115           29 :         let len = usize::min(buf.len(), SIGNATURE.len());
     116           29 :         if buf[..len] != SIGNATURE[..len] {
     117           17 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     118           12 :         }
     119           12 : 
     120           12 :         // if no more bytes available then exit
     121           12 :         if bytes_read == 0 {
     122            0 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     123           12 :         }
     124              : 
     125              :         // check if we have enough bytes to continue
     126           12 :         if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
     127            4 :             break header;
     128            8 :         }
     129              :     };
     130              : 
     131            4 :     let remaining_length = usize::from(header.len.get());
     132              : 
     133           30 :     while buf.len() < remaining_length {
     134           26 :         if read.read_buf(&mut buf).await? == 0 {
     135            0 :             return Err(io::Error::new(
     136            0 :                 io::ErrorKind::UnexpectedEof,
     137            0 :                 "stream closed while waiting for proxy protocol addresses",
     138            0 :             ));
     139           26 :         }
     140              :     }
     141            4 :     let payload = buf.split_to(remaining_length);
     142              : 
     143            4 :     let res = process_proxy_payload(header, payload)?;
     144            4 :     Ok((ChainRW { inner: read, buf }, res))
     145           21 : }
     146              : 
     147            4 : fn process_proxy_payload(
     148            4 :     header: ProxyProtocolV2Header,
     149            4 :     mut payload: BytesMut,
     150            4 : ) -> std::io::Result<ConnectHeader> {
     151            4 :     match header.version_and_command {
     152              :         // the connection was established on purpose by the proxy
     153              :         // without being relayed. The connection endpoints are the sender and the
     154              :         // receiver. Such connections exist when the proxy sends health-checks to the
     155              :         // server. The receiver must accept this connection as valid and must use the
     156              :         // real connection endpoints and discard the protocol block including the
     157              :         // family which is ignored.
     158            1 :         LOCAL_V2 => return Ok(ConnectHeader::Local),
     159              :         // the connection was established on behalf of another node,
     160              :         // and reflects the original connection endpoints. The receiver must then use
     161              :         // the information provided in the protocol block to get original the address.
     162            3 :         PROXY_V2 => {}
     163              :         // other values are unassigned and must not be emitted by senders. Receivers
     164              :         // must drop connections presenting unexpected values here.
     165              :         #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
     166            0 :         _ => return Err(io::Error::new(
     167            0 :             io::ErrorKind::Other,
     168            0 :             format!(
     169            0 :                 "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
     170            0 :                 header.version_and_command
     171            0 :             ),
     172            0 :         )),
     173              :     }
     174              : 
     175            3 :     let size_err =
     176            3 :         "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
     177            3 :     let addr = match header.protocol_and_family {
     178              :         TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
     179            2 :             let addr = payload
     180            2 :                 .try_get::<ProxyProtocolV2HeaderV4>()
     181            2 :                 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
     182              : 
     183            2 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     184              :         }
     185              :         TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
     186            1 :             let addr = payload
     187            1 :                 .try_get::<ProxyProtocolV2HeaderV6>()
     188            1 :                 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
     189              : 
     190            1 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     191              :         }
     192              :         // unspecified or unix stream. ignore the addresses
     193              :         _ => {
     194            0 :             return Err(io::Error::new(
     195            0 :                 io::ErrorKind::Other,
     196            0 :                 "invalid proxy protocol address family/transport protocol.",
     197            0 :             ));
     198              :         }
     199              :     };
     200              : 
     201            3 :     let mut extra = None;
     202              : 
     203            4 :     while let Some(mut tlv) = read_tlv(&mut payload) {
     204            1 :         match Pp2Kind::from_repr(tlv.kind) {
     205              :             Some(Pp2Kind::Aws) => {
     206            0 :                 if tlv.value.is_empty() {
     207            0 :                     tracing::warn!("invalid aws tlv: no subtype");
     208            0 :                 }
     209            0 :                 let subtype = tlv.value.get_u8();
     210            0 :                 match Pp2AwsType::from_repr(subtype) {
     211            0 :                     Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
     212            0 :                         Ok(s) => {
     213            0 :                             extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
     214            0 :                         }
     215            0 :                         Err(e) => {
     216            0 :                             tracing::warn!("invalid aws vpce id: {e}");
     217              :                         }
     218              :                     },
     219              :                     None => {
     220            0 :                         tracing::warn!("unknown aws tlv: subtype={subtype}");
     221              :                     }
     222              :                 }
     223              :             }
     224              :             Some(Pp2Kind::Azure) => {
     225            0 :                 if tlv.value.is_empty() {
     226            0 :                     tracing::warn!("invalid azure tlv: no subtype");
     227            0 :                 }
     228            0 :                 let subtype = tlv.value.get_u8();
     229            0 :                 match Pp2AzureType::from_repr(subtype) {
     230              :                     Some(Pp2AzureType::PrivateEndpointLinkId) => {
     231            0 :                         if tlv.value.len() != 4 {
     232            0 :                             tracing::warn!("invalid azure link_id: {:?}", tlv.value);
     233            0 :                         }
     234            0 :                         extra = Some(ConnectionInfoExtra::Azure {
     235            0 :                             link_id: tlv.value.get_u32_le(),
     236            0 :                         });
     237              :                     }
     238              :                     None => {
     239            0 :                         tracing::warn!("unknown azure tlv: subtype={subtype}");
     240              :                     }
     241              :                 }
     242              :             }
     243            0 :             Some(kind) => {
     244            0 :                 tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
     245              :             }
     246              :             None => {
     247            1 :                 tracing::debug!("unknown tlv: {tlv:?}");
     248              :             }
     249              :         }
     250              :     }
     251              : 
     252            3 :     Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
     253            4 : }
     254              : 
     255              : #[derive(FromRepr, Debug, Copy, Clone)]
     256              : #[repr(u8)]
     257              : enum Pp2Kind {
     258              :     // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
     259              :     // we don't use these but it would be interesting to know what's available
     260              :     Alpn = 0x01,
     261              :     Authority = 0x02,
     262              :     Crc32C = 0x03,
     263              :     Noop = 0x04,
     264              :     UniqueId = 0x05,
     265              :     Ssl = 0x20,
     266              :     NetNs = 0x30,
     267              : 
     268              :     /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
     269              :     Aws = 0xEA,
     270              : 
     271              :     /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
     272              :     Azure = 0xEE,
     273              : }
     274              : 
     275              : #[derive(FromRepr, Debug, Copy, Clone)]
     276              : #[repr(u8)]
     277              : enum Pp2AwsType {
     278              :     VpceId = 0x01,
     279              : }
     280              : 
     281              : #[derive(FromRepr, Debug, Copy, Clone)]
     282              : #[repr(u8)]
     283              : enum Pp2AzureType {
     284              :     PrivateEndpointLinkId = 0x01,
     285              : }
     286              : 
     287              : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
     288              :     #[inline]
     289          172 :     fn poll_read(
     290          172 :         self: Pin<&mut Self>,
     291          172 :         cx: &mut Context<'_>,
     292          172 :         buf: &mut ReadBuf<'_>,
     293          172 :     ) -> Poll<io::Result<()>> {
     294          172 :         if self.buf.is_empty() {
     295          153 :             self.project().inner.poll_read(cx, buf)
     296              :         } else {
     297           19 :             self.read_from_buf(buf)
     298              :         }
     299          172 :     }
     300              : }
     301              : 
     302              : impl<T: AsyncRead> ChainRW<T> {
     303              :     #[cold]
     304           19 :     fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
     305           19 :         debug_assert!(!self.buf.is_empty());
     306           19 :         let this = self.project();
     307           19 : 
     308           19 :         let write = usize::min(this.buf.len(), buf.remaining());
     309           19 :         let slice = this.buf.split_to(write).freeze();
     310           19 :         buf.put_slice(&slice);
     311           19 : 
     312           19 :         // reset the allocation so it can be freed
     313           19 :         if this.buf.is_empty() {
     314           17 :             *this.buf = BytesMut::new();
     315           17 :         }
     316              : 
     317           19 :         Poll::Ready(Ok(()))
     318           19 :     }
     319              : }
     320              : 
     321              : #[derive(Debug)]
     322              : struct Tlv {
     323              :     kind: u8,
     324              :     value: Bytes,
     325              : }
     326              : 
     327            4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
     328            4 :     let tlv_header = b.try_get::<TlvHeader>()?;
     329            3 :     let len = usize::from(tlv_header.len.get());
     330            3 :     if b.len() < len {
     331            2 :         return None;
     332            1 :     }
     333            1 :     Some(Tlv {
     334            1 :         kind: tlv_header.kind,
     335            1 :         value: b.split_to(len).freeze(),
     336            1 :     })
     337            4 : }
     338              : 
     339              : trait BufExt: Sized {
     340              :     fn try_get<T: FromBytes>(&mut self) -> Option<T>;
     341              : }
     342              : impl BufExt for BytesMut {
     343           19 :     fn try_get<T: FromBytes>(&mut self) -> Option<T> {
     344           19 :         let res = T::read_from_prefix(self)?;
     345           10 :         self.advance(size_of::<T>());
     346           10 :         Some(res)
     347           19 :     }
     348              : }
     349              : 
     350            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     351              : #[repr(C)]
     352              : struct ProxyProtocolV2Header {
     353              :     signature: [u8; 12],
     354              :     version_and_command: u8,
     355              :     protocol_and_family: u8,
     356              :     len: zerocopy::byteorder::network_endian::U16,
     357              : }
     358              : 
     359            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     360              : #[repr(C)]
     361              : struct ProxyProtocolV2HeaderV4 {
     362              :     src_addr: NetworkEndianIpv4,
     363              :     dst_addr: NetworkEndianIpv4,
     364              :     src_port: zerocopy::byteorder::network_endian::U16,
     365              :     dst_port: zerocopy::byteorder::network_endian::U16,
     366              : }
     367              : 
     368            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     369              : #[repr(C)]
     370              : struct ProxyProtocolV2HeaderV6 {
     371              :     src_addr: NetworkEndianIpv6,
     372              :     dst_addr: NetworkEndianIpv6,
     373              :     src_port: zerocopy::byteorder::network_endian::U16,
     374              :     dst_port: zerocopy::byteorder::network_endian::U16,
     375              : }
     376              : 
     377            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     378              : #[repr(C)]
     379              : struct TlvHeader {
     380              :     kind: u8,
     381              :     len: zerocopy::byteorder::network_endian::U16,
     382              : }
     383              : 
     384            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     385              : #[repr(transparent)]
     386              : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
     387              : impl NetworkEndianIpv4 {
     388              :     #[inline]
     389            2 :     fn get(self) -> Ipv4Addr {
     390            2 :         Ipv4Addr::from_bits(self.0.get())
     391            2 :     }
     392              : }
     393              : 
     394            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     395              : #[repr(transparent)]
     396              : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
     397              : impl NetworkEndianIpv6 {
     398              :     #[inline]
     399            1 :     fn get(self) -> Ipv6Addr {
     400            1 :         Ipv6Addr::from_bits(self.0.get())
     401            1 :     }
     402              : }
     403              : 
     404              : #[cfg(test)]
     405              : #[expect(clippy::unwrap_used)]
     406              : mod tests {
     407              :     use tokio::io::AsyncReadExt;
     408              : 
     409              :     use crate::protocol2::{
     410              :         ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
     411              :     };
     412              : 
     413              :     #[tokio::test]
     414            1 :     async fn test_ipv4() {
     415            1 :         let header = super::SIGNATURE
     416            1 :             // Proxy command, IPV4 | TCP
     417            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     418            1 :             // 12 + 3 bytes
     419            1 :             .chain([0, 15].as_slice())
     420            1 :             // src ip
     421            1 :             .chain([127, 0, 0, 1].as_slice())
     422            1 :             // dst ip
     423            1 :             .chain([192, 168, 0, 1].as_slice())
     424            1 :             // src port
     425            1 :             .chain([255, 255].as_slice())
     426            1 :             // dst port
     427            1 :             .chain([1, 1].as_slice())
     428            1 :             // TLV
     429            1 :             .chain([1, 2, 3].as_slice());
     430            1 : 
     431            1 :         let extra_data = [0x55; 256];
     432            1 : 
     433            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     434            1 :             .await
     435            1 :             .unwrap();
     436            1 : 
     437            1 :         let mut bytes = vec![];
     438            1 :         read.read_to_end(&mut bytes).await.unwrap();
     439            1 : 
     440            1 :         assert_eq!(bytes, extra_data);
     441            1 : 
     442            1 :         let ConnectHeader::Proxy(info) = info else {
     443            1 :             panic!()
     444            1 :         };
     445            1 :         assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
     446            1 :     }
     447              : 
     448              :     #[tokio::test]
     449            1 :     async fn test_ipv6() {
     450            1 :         let header = super::SIGNATURE
     451            1 :             // Proxy command, IPV6 | UDP
     452            1 :             .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
     453            1 :             // 36 + 3 bytes
     454            1 :             .chain([0, 39].as_slice())
     455            1 :             // src ip
     456            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     457            1 :             // dst ip
     458            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     459            1 :             // src port
     460            1 :             .chain([1, 1].as_slice())
     461            1 :             // dst port
     462            1 :             .chain([255, 255].as_slice())
     463            1 :             // TLV
     464            1 :             .chain([1, 2, 3].as_slice());
     465            1 : 
     466            1 :         let extra_data = [0x55; 256];
     467            1 : 
     468            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     469            1 :             .await
     470            1 :             .unwrap();
     471            1 : 
     472            1 :         let mut bytes = vec![];
     473            1 :         read.read_to_end(&mut bytes).await.unwrap();
     474            1 : 
     475            1 :         assert_eq!(bytes, extra_data);
     476            1 : 
     477            1 :         let ConnectHeader::Proxy(info) = info else {
     478            1 :             panic!()
     479            1 :         };
     480            1 :         assert_eq!(
     481            1 :             info.addr,
     482            1 :             ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     483            1 :         );
     484            1 :     }
     485              : 
     486              :     #[tokio::test]
     487            1 :     async fn test_invalid() {
     488            1 :         let data = [0x55; 256];
     489            1 : 
     490            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     491            1 : 
     492            1 :         let mut bytes = vec![];
     493            1 :         read.read_to_end(&mut bytes).await.unwrap();
     494            1 :         assert_eq!(bytes, data);
     495            1 :         assert_eq!(info, ConnectHeader::Missing);
     496            1 :     }
     497              : 
     498              :     #[tokio::test]
     499            1 :     async fn test_short() {
     500            1 :         let data = [0x55; 10];
     501            1 : 
     502            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     503            1 : 
     504            1 :         let mut bytes = vec![];
     505            1 :         read.read_to_end(&mut bytes).await.unwrap();
     506            1 :         assert_eq!(bytes, data);
     507            1 :         assert_eq!(info, ConnectHeader::Missing);
     508            1 :     }
     509              : 
     510              :     #[tokio::test]
     511            1 :     async fn test_large_tlv() {
     512            1 :         let tlv = vec![0x55; 32768];
     513            1 :         let tlv_len = (tlv.len() as u16).to_be_bytes();
     514            1 :         let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
     515            1 : 
     516            1 :         let header = super::SIGNATURE
     517            1 :             // Proxy command, Inet << 4 | Stream
     518            1 :             .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
     519            1 :             // 12 + 3 bytes
     520            1 :             .chain(len.as_slice())
     521            1 :             // src ip
     522            1 :             .chain([55, 56, 57, 58].as_slice())
     523            1 :             // dst ip
     524            1 :             .chain([192, 168, 0, 1].as_slice())
     525            1 :             // src port
     526            1 :             .chain([255, 255].as_slice())
     527            1 :             // dst port
     528            1 :             .chain([1, 1].as_slice())
     529            1 :             // TLV
     530            1 :             .chain([255].as_slice())
     531            1 :             .chain(tlv_len.as_slice())
     532            1 :             .chain(tlv.as_slice());
     533            1 : 
     534            1 :         let extra_data = [0xaa; 256];
     535            1 : 
     536            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     537            1 :             .await
     538            1 :             .unwrap();
     539            1 : 
     540            1 :         let mut bytes = vec![];
     541            1 :         read.read_to_end(&mut bytes).await.unwrap();
     542            1 : 
     543            1 :         assert_eq!(bytes, extra_data);
     544            1 : 
     545            1 :         let ConnectHeader::Proxy(info) = info else {
     546            1 :             panic!()
     547            1 :         };
     548            1 :         assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
     549            1 :     }
     550              : 
     551              :     #[tokio::test]
     552            1 :     async fn test_local() {
     553            1 :         let len = 0u16.to_be_bytes();
     554            1 :         let header = super::SIGNATURE
     555            1 :             .chain([LOCAL_V2, 0x00].as_slice())
     556            1 :             .chain(len.as_slice());
     557            1 : 
     558            1 :         let extra_data = [0xaa; 256];
     559            1 : 
     560            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     561            1 :             .await
     562            1 :             .unwrap();
     563            1 : 
     564            1 :         let mut bytes = vec![];
     565            1 :         read.read_to_end(&mut bytes).await.unwrap();
     566            1 : 
     567            1 :         assert_eq!(bytes, extra_data);
     568            1 : 
     569            1 :         let ConnectHeader::Local = info else { panic!() };
     570            1 :     }
     571              : }
        

Generated by: LCOV version 2.1-beta