LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: 62212f4d57a7ad0f69dc82a04629a0bbd5f7c824.info Lines: 81.1 % 318 258
Test Date: 2025-03-17 10:41:39 Functions: 58.0 % 69 40

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
       3              : 
       4              : use core::fmt;
       5              : use std::io;
       6              : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
       7              : use std::pin::Pin;
       8              : use std::task::{Context, Poll};
       9              : 
      10              : use bytes::{Buf, Bytes, BytesMut};
      11              : use pin_project_lite::pin_project;
      12              : use smol_str::SmolStr;
      13              : use strum_macros::FromRepr;
      14              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      15              : use zerocopy::{FromBytes, FromZeroes};
      16              : 
      17              : pin_project! {
      18              :     /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
      19              :     pub(crate) struct ChainRW<T> {
      20              :         #[pin]
      21              :         pub(crate) inner: T,
      22              :         buf: BytesMut,
      23              :     }
      24              : }
      25              : 
      26              : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
      27              :     #[inline]
      28           15 :     fn poll_write(
      29           15 :         self: Pin<&mut Self>,
      30           15 :         cx: &mut Context<'_>,
      31           15 :         buf: &[u8],
      32           15 :     ) -> Poll<Result<usize, io::Error>> {
      33           15 :         self.project().inner.poll_write(cx, buf)
      34           15 :     }
      35              : 
      36              :     #[inline]
      37           69 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      38           69 :         self.project().inner.poll_flush(cx)
      39           69 :     }
      40              : 
      41              :     #[inline]
      42            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      43            0 :         self.project().inner.poll_shutdown(cx)
      44            0 :     }
      45              : 
      46              :     #[inline]
      47           54 :     fn poll_write_vectored(
      48           54 :         self: Pin<&mut Self>,
      49           54 :         cx: &mut Context<'_>,
      50           54 :         bufs: &[io::IoSlice<'_>],
      51           54 :     ) -> Poll<Result<usize, io::Error>> {
      52           54 :         self.project().inner.poll_write_vectored(cx, bufs)
      53           54 :     }
      54              : 
      55              :     #[inline]
      56            0 :     fn is_write_vectored(&self) -> bool {
      57            0 :         self.inner.is_write_vectored()
      58            0 :     }
      59              : }
      60              : 
      61              : /// Proxy Protocol Version 2 Header
      62              : const SIGNATURE: [u8; 12] = [
      63              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
      64              : ];
      65              : 
      66              : const LOCAL_V2: u8 = 0x20;
      67              : const PROXY_V2: u8 = 0x21;
      68              : 
      69              : const TCP_OVER_IPV4: u8 = 0x11;
      70              : const UDP_OVER_IPV4: u8 = 0x12;
      71              : const TCP_OVER_IPV6: u8 = 0x21;
      72              : const UDP_OVER_IPV6: u8 = 0x22;
      73              : 
      74              : #[derive(PartialEq, Eq, Clone, Debug)]
      75              : pub struct ConnectionInfo {
      76              :     pub addr: SocketAddr,
      77              :     pub extra: Option<ConnectionInfoExtra>,
      78              : }
      79              : 
      80              : #[derive(PartialEq, Eq, Clone, Debug)]
      81              : pub enum ConnectHeader {
      82              :     Missing,
      83              :     Local,
      84              :     Proxy(ConnectionInfo),
      85              : }
      86              : 
      87              : impl fmt::Display for ConnectionInfo {
      88            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      89            0 :         match &self.extra {
      90            0 :             None => self.addr.ip().fmt(f),
      91            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => {
      92            0 :                 write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
      93              :             }
      94            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => {
      95            0 :                 write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
      96              :             }
      97              :         }
      98            0 :     }
      99              : }
     100              : 
     101              : #[derive(PartialEq, Eq, Clone, Debug)]
     102              : pub enum ConnectionInfoExtra {
     103              :     Aws { vpce_id: SmolStr },
     104              :     Azure { link_id: u32 },
     105              : }
     106              : 
     107           21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
     108           21 :     mut read: T,
     109           21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
     110           21 :     let mut buf = BytesMut::with_capacity(128);
     111            4 :     let header = loop {
     112           29 :         let bytes_read = read.read_buf(&mut buf).await?;
     113              : 
     114              :         // exit for bad header signature
     115           29 :         let len = usize::min(buf.len(), SIGNATURE.len());
     116           29 :         if buf[..len] != SIGNATURE[..len] {
     117           17 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     118           12 :         }
     119           12 : 
     120           12 :         // if no more bytes available then exit
     121           12 :         if bytes_read == 0 {
     122            0 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     123           12 :         }
     124              : 
     125              :         // check if we have enough bytes to continue
     126           12 :         if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
     127            4 :             break header;
     128            8 :         }
     129              :     };
     130              : 
     131            4 :     let remaining_length = usize::from(header.len.get());
     132              : 
     133           30 :     while buf.len() < remaining_length {
     134           26 :         if read.read_buf(&mut buf).await? == 0 {
     135            0 :             return Err(io::Error::new(
     136            0 :                 io::ErrorKind::UnexpectedEof,
     137            0 :                 "stream closed while waiting for proxy protocol addresses",
     138            0 :             ));
     139           26 :         }
     140              :     }
     141            4 :     let payload = buf.split_to(remaining_length);
     142              : 
     143            4 :     let res = process_proxy_payload(header, payload)?;
     144            4 :     Ok((ChainRW { inner: read, buf }, res))
     145           21 : }
     146              : 
     147            4 : fn process_proxy_payload(
     148            4 :     header: ProxyProtocolV2Header,
     149            4 :     mut payload: BytesMut,
     150            4 : ) -> std::io::Result<ConnectHeader> {
     151            4 :     match header.version_and_command {
     152              :         // the connection was established on purpose by the proxy
     153              :         // without being relayed. The connection endpoints are the sender and the
     154              :         // receiver. Such connections exist when the proxy sends health-checks to the
     155              :         // server. The receiver must accept this connection as valid and must use the
     156              :         // real connection endpoints and discard the protocol block including the
     157              :         // family which is ignored.
     158            1 :         LOCAL_V2 => return Ok(ConnectHeader::Local),
     159              :         // the connection was established on behalf of another node,
     160              :         // and reflects the original connection endpoints. The receiver must then use
     161              :         // the information provided in the protocol block to get original the address.
     162            3 :         PROXY_V2 => {}
     163              :         // other values are unassigned and must not be emitted by senders. Receivers
     164              :         // must drop connections presenting unexpected values here.
     165              :         #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
     166            0 :         _ => return Err(io::Error::other(
     167            0 :             format!(
     168            0 :                 "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
     169            0 :                 header.version_and_command
     170            0 :             ),
     171            0 :         )),
     172              :     }
     173              : 
     174            3 :     let size_err =
     175            3 :         "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
     176            3 :     let addr = match header.protocol_and_family {
     177              :         TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
     178            2 :             let addr = payload
     179            2 :                 .try_get::<ProxyProtocolV2HeaderV4>()
     180            2 :                 .ok_or_else(|| io::Error::other(size_err))?;
     181              : 
     182            2 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     183              :         }
     184              :         TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
     185            1 :             let addr = payload
     186            1 :                 .try_get::<ProxyProtocolV2HeaderV6>()
     187            1 :                 .ok_or_else(|| io::Error::other(size_err))?;
     188              : 
     189            1 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     190              :         }
     191              :         // unspecified or unix stream. ignore the addresses
     192              :         _ => {
     193            0 :             return Err(io::Error::other(
     194            0 :                 "invalid proxy protocol address family/transport protocol.",
     195            0 :             ));
     196              :         }
     197              :     };
     198              : 
     199            3 :     let mut extra = None;
     200              : 
     201            4 :     while let Some(mut tlv) = read_tlv(&mut payload) {
     202            1 :         match Pp2Kind::from_repr(tlv.kind) {
     203              :             Some(Pp2Kind::Aws) => {
     204            0 :                 if tlv.value.is_empty() {
     205            0 :                     tracing::warn!("invalid aws tlv: no subtype");
     206            0 :                 }
     207            0 :                 let subtype = tlv.value.get_u8();
     208            0 :                 match Pp2AwsType::from_repr(subtype) {
     209            0 :                     Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
     210            0 :                         Ok(s) => {
     211            0 :                             extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
     212            0 :                         }
     213            0 :                         Err(e) => {
     214            0 :                             tracing::warn!("invalid aws vpce id: {e}");
     215              :                         }
     216              :                     },
     217              :                     None => {
     218            0 :                         tracing::warn!("unknown aws tlv: subtype={subtype}");
     219              :                     }
     220              :                 }
     221              :             }
     222              :             Some(Pp2Kind::Azure) => {
     223            0 :                 if tlv.value.is_empty() {
     224            0 :                     tracing::warn!("invalid azure tlv: no subtype");
     225            0 :                 }
     226            0 :                 let subtype = tlv.value.get_u8();
     227            0 :                 match Pp2AzureType::from_repr(subtype) {
     228              :                     Some(Pp2AzureType::PrivateEndpointLinkId) => {
     229            0 :                         if tlv.value.len() != 4 {
     230            0 :                             tracing::warn!("invalid azure link_id: {:?}", tlv.value);
     231            0 :                         }
     232            0 :                         extra = Some(ConnectionInfoExtra::Azure {
     233            0 :                             link_id: tlv.value.get_u32_le(),
     234            0 :                         });
     235              :                     }
     236              :                     None => {
     237            0 :                         tracing::warn!("unknown azure tlv: subtype={subtype}");
     238              :                     }
     239              :                 }
     240              :             }
     241            0 :             Some(kind) => {
     242            0 :                 tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
     243              :             }
     244              :             None => {
     245            1 :                 tracing::debug!("unknown tlv: {tlv:?}");
     246              :             }
     247              :         }
     248              :     }
     249              : 
     250            3 :     Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
     251            4 : }
     252              : 
     253              : #[derive(FromRepr, Debug, Copy, Clone)]
     254              : #[repr(u8)]
     255              : enum Pp2Kind {
     256              :     // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
     257              :     // we don't use these but it would be interesting to know what's available
     258              :     Alpn = 0x01,
     259              :     Authority = 0x02,
     260              :     Crc32C = 0x03,
     261              :     Noop = 0x04,
     262              :     UniqueId = 0x05,
     263              :     Ssl = 0x20,
     264              :     NetNs = 0x30,
     265              : 
     266              :     /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
     267              :     Aws = 0xEA,
     268              : 
     269              :     /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
     270              :     Azure = 0xEE,
     271              : }
     272              : 
     273              : #[derive(FromRepr, Debug, Copy, Clone)]
     274              : #[repr(u8)]
     275              : enum Pp2AwsType {
     276              :     VpceId = 0x01,
     277              : }
     278              : 
     279              : #[derive(FromRepr, Debug, Copy, Clone)]
     280              : #[repr(u8)]
     281              : enum Pp2AzureType {
     282              :     PrivateEndpointLinkId = 0x01,
     283              : }
     284              : 
     285              : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
     286              :     #[inline]
     287          172 :     fn poll_read(
     288          172 :         self: Pin<&mut Self>,
     289          172 :         cx: &mut Context<'_>,
     290          172 :         buf: &mut ReadBuf<'_>,
     291          172 :     ) -> Poll<io::Result<()>> {
     292          172 :         if self.buf.is_empty() {
     293          153 :             self.project().inner.poll_read(cx, buf)
     294              :         } else {
     295           19 :             self.read_from_buf(buf)
     296              :         }
     297          172 :     }
     298              : }
     299              : 
     300              : impl<T: AsyncRead> ChainRW<T> {
     301              :     #[cold]
     302           19 :     fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
     303           19 :         debug_assert!(!self.buf.is_empty());
     304           19 :         let this = self.project();
     305           19 : 
     306           19 :         let write = usize::min(this.buf.len(), buf.remaining());
     307           19 :         let slice = this.buf.split_to(write).freeze();
     308           19 :         buf.put_slice(&slice);
     309           19 : 
     310           19 :         // reset the allocation so it can be freed
     311           19 :         if this.buf.is_empty() {
     312           17 :             *this.buf = BytesMut::new();
     313           17 :         }
     314              : 
     315           19 :         Poll::Ready(Ok(()))
     316           19 :     }
     317              : }
     318              : 
     319              : #[derive(Debug)]
     320              : struct Tlv {
     321              :     kind: u8,
     322              :     value: Bytes,
     323              : }
     324              : 
     325            4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
     326            4 :     let tlv_header = b.try_get::<TlvHeader>()?;
     327            3 :     let len = usize::from(tlv_header.len.get());
     328            3 :     if b.len() < len {
     329            2 :         return None;
     330            1 :     }
     331            1 :     Some(Tlv {
     332            1 :         kind: tlv_header.kind,
     333            1 :         value: b.split_to(len).freeze(),
     334            1 :     })
     335            4 : }
     336              : 
     337              : trait BufExt: Sized {
     338              :     fn try_get<T: FromBytes>(&mut self) -> Option<T>;
     339              : }
     340              : impl BufExt for BytesMut {
     341           19 :     fn try_get<T: FromBytes>(&mut self) -> Option<T> {
     342           19 :         let res = T::read_from_prefix(self)?;
     343           10 :         self.advance(size_of::<T>());
     344           10 :         Some(res)
     345           19 :     }
     346              : }
     347              : 
     348            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     349              : #[repr(C)]
     350              : struct ProxyProtocolV2Header {
     351              :     signature: [u8; 12],
     352              :     version_and_command: u8,
     353              :     protocol_and_family: u8,
     354              :     len: zerocopy::byteorder::network_endian::U16,
     355              : }
     356              : 
     357            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     358              : #[repr(C)]
     359              : struct ProxyProtocolV2HeaderV4 {
     360              :     src_addr: NetworkEndianIpv4,
     361              :     dst_addr: NetworkEndianIpv4,
     362              :     src_port: zerocopy::byteorder::network_endian::U16,
     363              :     dst_port: zerocopy::byteorder::network_endian::U16,
     364              : }
     365              : 
     366            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     367              : #[repr(C)]
     368              : struct ProxyProtocolV2HeaderV6 {
     369              :     src_addr: NetworkEndianIpv6,
     370              :     dst_addr: NetworkEndianIpv6,
     371              :     src_port: zerocopy::byteorder::network_endian::U16,
     372              :     dst_port: zerocopy::byteorder::network_endian::U16,
     373              : }
     374              : 
     375            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     376              : #[repr(C)]
     377              : struct TlvHeader {
     378              :     kind: u8,
     379              :     len: zerocopy::byteorder::network_endian::U16,
     380              : }
     381              : 
     382            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     383              : #[repr(transparent)]
     384              : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
     385              : impl NetworkEndianIpv4 {
     386              :     #[inline]
     387            2 :     fn get(self) -> Ipv4Addr {
     388            2 :         Ipv4Addr::from_bits(self.0.get())
     389            2 :     }
     390              : }
     391              : 
     392            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     393              : #[repr(transparent)]
     394              : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
     395              : impl NetworkEndianIpv6 {
     396              :     #[inline]
     397            1 :     fn get(self) -> Ipv6Addr {
     398            1 :         Ipv6Addr::from_bits(self.0.get())
     399            1 :     }
     400              : }
     401              : 
     402              : #[cfg(test)]
     403              : #[expect(clippy::unwrap_used)]
     404              : mod tests {
     405              :     use tokio::io::AsyncReadExt;
     406              : 
     407              :     use crate::protocol2::{
     408              :         ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
     409              :     };
     410              : 
     411              :     #[tokio::test]
     412            1 :     async fn test_ipv4() {
     413            1 :         let header = super::SIGNATURE
     414            1 :             // Proxy command, IPV4 | TCP
     415            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     416            1 :             // 12 + 3 bytes
     417            1 :             .chain([0, 15].as_slice())
     418            1 :             // src ip
     419            1 :             .chain([127, 0, 0, 1].as_slice())
     420            1 :             // dst ip
     421            1 :             .chain([192, 168, 0, 1].as_slice())
     422            1 :             // src port
     423            1 :             .chain([255, 255].as_slice())
     424            1 :             // dst port
     425            1 :             .chain([1, 1].as_slice())
     426            1 :             // TLV
     427            1 :             .chain([1, 2, 3].as_slice());
     428            1 : 
     429            1 :         let extra_data = [0x55; 256];
     430            1 : 
     431            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     432            1 :             .await
     433            1 :             .unwrap();
     434            1 : 
     435            1 :         let mut bytes = vec![];
     436            1 :         read.read_to_end(&mut bytes).await.unwrap();
     437            1 : 
     438            1 :         assert_eq!(bytes, extra_data);
     439            1 : 
     440            1 :         let ConnectHeader::Proxy(info) = info else {
     441            1 :             panic!()
     442            1 :         };
     443            1 :         assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
     444            1 :     }
     445              : 
     446              :     #[tokio::test]
     447            1 :     async fn test_ipv6() {
     448            1 :         let header = super::SIGNATURE
     449            1 :             // Proxy command, IPV6 | UDP
     450            1 :             .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
     451            1 :             // 36 + 3 bytes
     452            1 :             .chain([0, 39].as_slice())
     453            1 :             // src ip
     454            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     455            1 :             // dst ip
     456            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     457            1 :             // src port
     458            1 :             .chain([1, 1].as_slice())
     459            1 :             // dst port
     460            1 :             .chain([255, 255].as_slice())
     461            1 :             // TLV
     462            1 :             .chain([1, 2, 3].as_slice());
     463            1 : 
     464            1 :         let extra_data = [0x55; 256];
     465            1 : 
     466            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     467            1 :             .await
     468            1 :             .unwrap();
     469            1 : 
     470            1 :         let mut bytes = vec![];
     471            1 :         read.read_to_end(&mut bytes).await.unwrap();
     472            1 : 
     473            1 :         assert_eq!(bytes, extra_data);
     474            1 : 
     475            1 :         let ConnectHeader::Proxy(info) = info else {
     476            1 :             panic!()
     477            1 :         };
     478            1 :         assert_eq!(
     479            1 :             info.addr,
     480            1 :             ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     481            1 :         );
     482            1 :     }
     483              : 
     484              :     #[tokio::test]
     485            1 :     async fn test_invalid() {
     486            1 :         let data = [0x55; 256];
     487            1 : 
     488            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     489            1 : 
     490            1 :         let mut bytes = vec![];
     491            1 :         read.read_to_end(&mut bytes).await.unwrap();
     492            1 :         assert_eq!(bytes, data);
     493            1 :         assert_eq!(info, ConnectHeader::Missing);
     494            1 :     }
     495              : 
     496              :     #[tokio::test]
     497            1 :     async fn test_short() {
     498            1 :         let data = [0x55; 10];
     499            1 : 
     500            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     501            1 : 
     502            1 :         let mut bytes = vec![];
     503            1 :         read.read_to_end(&mut bytes).await.unwrap();
     504            1 :         assert_eq!(bytes, data);
     505            1 :         assert_eq!(info, ConnectHeader::Missing);
     506            1 :     }
     507              : 
     508              :     #[tokio::test]
     509            1 :     async fn test_large_tlv() {
     510            1 :         let tlv = vec![0x55; 32768];
     511            1 :         let tlv_len = (tlv.len() as u16).to_be_bytes();
     512            1 :         let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
     513            1 : 
     514            1 :         let header = super::SIGNATURE
     515            1 :             // Proxy command, Inet << 4 | Stream
     516            1 :             .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
     517            1 :             // 12 + 3 bytes
     518            1 :             .chain(len.as_slice())
     519            1 :             // src ip
     520            1 :             .chain([55, 56, 57, 58].as_slice())
     521            1 :             // dst ip
     522            1 :             .chain([192, 168, 0, 1].as_slice())
     523            1 :             // src port
     524            1 :             .chain([255, 255].as_slice())
     525            1 :             // dst port
     526            1 :             .chain([1, 1].as_slice())
     527            1 :             // TLV
     528            1 :             .chain([255].as_slice())
     529            1 :             .chain(tlv_len.as_slice())
     530            1 :             .chain(tlv.as_slice());
     531            1 : 
     532            1 :         let extra_data = [0xaa; 256];
     533            1 : 
     534            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     535            1 :             .await
     536            1 :             .unwrap();
     537            1 : 
     538            1 :         let mut bytes = vec![];
     539            1 :         read.read_to_end(&mut bytes).await.unwrap();
     540            1 : 
     541            1 :         assert_eq!(bytes, extra_data);
     542            1 : 
     543            1 :         let ConnectHeader::Proxy(info) = info else {
     544            1 :             panic!()
     545            1 :         };
     546            1 :         assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
     547            1 :     }
     548              : 
     549              :     #[tokio::test]
     550            1 :     async fn test_local() {
     551            1 :         let len = 0u16.to_be_bytes();
     552            1 :         let header = super::SIGNATURE
     553            1 :             .chain([LOCAL_V2, 0x00].as_slice())
     554            1 :             .chain(len.as_slice());
     555            1 : 
     556            1 :         let extra_data = [0xaa; 256];
     557            1 : 
     558            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     559            1 :             .await
     560            1 :             .unwrap();
     561            1 : 
     562            1 :         let mut bytes = vec![];
     563            1 :         read.read_to_end(&mut bytes).await.unwrap();
     564            1 : 
     565            1 :         assert_eq!(bytes, extra_data);
     566            1 : 
     567            1 :         let ConnectHeader::Local = info else { panic!() };
     568            1 :     }
     569              : }
        

Generated by: LCOV version 2.1-beta