LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: f5f94ec0366b63fd2cbbe02edc2087dbd893d04d.info Lines: 80.9 % 320 259
Test Date: 2024-11-20 05:34:23 Functions: 56.9 % 72 41

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
       3              : 
       4              : use core::fmt;
       5              : use std::io;
       6              : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
       7              : use std::pin::Pin;
       8              : use std::task::{Context, Poll};
       9              : 
      10              : use bytes::{Buf, Bytes, BytesMut};
      11              : use pin_project_lite::pin_project;
      12              : use strum_macros::FromRepr;
      13              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
      14              : use zerocopy::{FromBytes, FromZeroes};
      15              : 
      16              : pin_project! {
      17              :     /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
      18              :     pub(crate) struct ChainRW<T> {
      19              :         #[pin]
      20              :         pub(crate) inner: T,
      21              :         buf: BytesMut,
      22              :     }
      23              : }
      24              : 
      25              : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
      26              :     #[inline]
      27           15 :     fn poll_write(
      28           15 :         self: Pin<&mut Self>,
      29           15 :         cx: &mut Context<'_>,
      30           15 :         buf: &[u8],
      31           15 :     ) -> Poll<Result<usize, io::Error>> {
      32           15 :         self.project().inner.poll_write(cx, buf)
      33           15 :     }
      34              : 
      35              :     #[inline]
      36           74 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      37           74 :         self.project().inner.poll_flush(cx)
      38           74 :     }
      39              : 
      40              :     #[inline]
      41            0 :     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
      42            0 :         self.project().inner.poll_shutdown(cx)
      43            0 :     }
      44              : 
      45              :     #[inline]
      46           59 :     fn poll_write_vectored(
      47           59 :         self: Pin<&mut Self>,
      48           59 :         cx: &mut Context<'_>,
      49           59 :         bufs: &[io::IoSlice<'_>],
      50           59 :     ) -> Poll<Result<usize, io::Error>> {
      51           59 :         self.project().inner.poll_write_vectored(cx, bufs)
      52           59 :     }
      53              : 
      54              :     #[inline]
      55            0 :     fn is_write_vectored(&self) -> bool {
      56            0 :         self.inner.is_write_vectored()
      57            0 :     }
      58              : }
      59              : 
      60              : /// Proxy Protocol Version 2 Header
      61              : const SIGNATURE: [u8; 12] = [
      62              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
      63              : ];
      64              : 
      65              : const LOCAL_V2: u8 = 0x20;
      66              : const PROXY_V2: u8 = 0x21;
      67              : 
      68              : const TCP_OVER_IPV4: u8 = 0x11;
      69              : const UDP_OVER_IPV4: u8 = 0x12;
      70              : const TCP_OVER_IPV6: u8 = 0x21;
      71              : const UDP_OVER_IPV6: u8 = 0x22;
      72              : 
      73              : #[derive(PartialEq, Eq, Clone, Debug)]
      74              : pub struct ConnectionInfo {
      75              :     pub addr: SocketAddr,
      76              :     pub extra: Option<ConnectionInfoExtra>,
      77              : }
      78              : 
      79              : #[derive(PartialEq, Eq, Clone, Debug)]
      80              : pub enum ConnectHeader {
      81              :     Missing,
      82              :     Local,
      83              :     Proxy(ConnectionInfo),
      84              : }
      85              : 
      86              : impl fmt::Display for ConnectionInfo {
      87            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      88            0 :         match &self.extra {
      89            0 :             None => self.addr.ip().fmt(f),
      90            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => {
      91            0 :                 write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
      92              :             }
      93            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => {
      94            0 :                 write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
      95              :             }
      96              :         }
      97            0 :     }
      98              : }
      99              : 
     100              : #[derive(PartialEq, Eq, Clone, Debug)]
     101              : pub enum ConnectionInfoExtra {
     102              :     Aws { vpce_id: Bytes },
     103              :     Azure { link_id: u32 },
     104              : }
     105              : 
     106           21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
     107           21 :     mut read: T,
     108           21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
     109           21 :     let mut buf = BytesMut::with_capacity(128);
     110            4 :     let header = loop {
     111           29 :         let bytes_read = read.read_buf(&mut buf).await?;
     112              : 
     113              :         // exit for bad header signature
     114           29 :         let len = usize::min(buf.len(), SIGNATURE.len());
     115           29 :         if buf[..len] != SIGNATURE[..len] {
     116           17 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     117           12 :         }
     118           12 : 
     119           12 :         // if no more bytes available then exit
     120           12 :         if bytes_read == 0 {
     121            0 :             return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
     122           12 :         };
     123              : 
     124              :         // check if we have enough bytes to continue
     125           12 :         if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
     126            4 :             break header;
     127            8 :         }
     128              :     };
     129              : 
     130            4 :     let remaining_length = usize::from(header.len.get());
     131              : 
     132           30 :     while buf.len() < remaining_length {
     133           26 :         if read.read_buf(&mut buf).await? == 0 {
     134            0 :             return Err(io::Error::new(
     135            0 :                 io::ErrorKind::UnexpectedEof,
     136            0 :                 "stream closed while waiting for proxy protocol addresses",
     137            0 :             ));
     138           26 :         }
     139              :     }
     140            4 :     let payload = buf.split_to(remaining_length);
     141              : 
     142            4 :     let res = process_proxy_payload(header, payload)?;
     143            4 :     Ok((ChainRW { inner: read, buf }, res))
     144           21 : }
     145              : 
     146            4 : fn process_proxy_payload(
     147            4 :     header: ProxyProtocolV2Header,
     148            4 :     mut payload: BytesMut,
     149            4 : ) -> std::io::Result<ConnectHeader> {
     150            4 :     match header.version_and_command {
     151              :         // the connection was established on purpose by the proxy
     152              :         // without being relayed. The connection endpoints are the sender and the
     153              :         // receiver. Such connections exist when the proxy sends health-checks to the
     154              :         // server. The receiver must accept this connection as valid and must use the
     155              :         // real connection endpoints and discard the protocol block including the
     156              :         // family which is ignored.
     157            1 :         LOCAL_V2 => return Ok(ConnectHeader::Local),
     158              :         // the connection was established on behalf of another node,
     159              :         // and reflects the original connection endpoints. The receiver must then use
     160              :         // the information provided in the protocol block to get original the address.
     161            3 :         PROXY_V2 => {}
     162              :         // other values are unassigned and must not be emitted by senders. Receivers
     163              :         // must drop connections presenting unexpected values here.
     164              :         #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
     165            0 :         _ => return Err(io::Error::new(
     166            0 :             io::ErrorKind::Other,
     167            0 :             format!(
     168            0 :                 "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
     169            0 :                 header.version_and_command
     170            0 :             ),
     171            0 :         )),
     172              :     };
     173              : 
     174            3 :     let size_err =
     175            3 :         "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
     176            3 :     let addr = match header.protocol_and_family {
     177              :         TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
     178            2 :             let addr = payload
     179            2 :                 .try_get::<ProxyProtocolV2HeaderV4>()
     180            2 :                 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
     181              : 
     182            2 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     183              :         }
     184              :         TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
     185            1 :             let addr = payload
     186            1 :                 .try_get::<ProxyProtocolV2HeaderV6>()
     187            1 :                 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
     188              : 
     189            1 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     190              :         }
     191              :         // unspecified or unix stream. ignore the addresses
     192              :         _ => {
     193            0 :             return Err(io::Error::new(
     194            0 :                 io::ErrorKind::Other,
     195            0 :                 "invalid proxy protocol address family/transport protocol.",
     196            0 :             ))
     197              :         }
     198              :     };
     199              : 
     200            3 :     let mut extra = None;
     201              : 
     202            4 :     while let Some(mut tlv) = read_tlv(&mut payload) {
     203            1 :         match Pp2Kind::from_repr(tlv.kind) {
     204              :             Some(Pp2Kind::Aws) => {
     205            0 :                 if tlv.value.is_empty() {
     206            0 :                     tracing::warn!("invalid aws tlv: no subtype");
     207            0 :                 }
     208            0 :                 let subtype = tlv.value.get_u8();
     209            0 :                 match Pp2AwsType::from_repr(subtype) {
     210            0 :                     Some(Pp2AwsType::VpceId) => {
     211            0 :                         extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
     212            0 :                     }
     213              :                     None => {
     214            0 :                         tracing::warn!("unknown aws tlv: subtype={subtype}");
     215              :                     }
     216              :                 }
     217              :             }
     218              :             Some(Pp2Kind::Azure) => {
     219            0 :                 if tlv.value.is_empty() {
     220            0 :                     tracing::warn!("invalid azure tlv: no subtype");
     221            0 :                 }
     222            0 :                 let subtype = tlv.value.get_u8();
     223            0 :                 match Pp2AzureType::from_repr(subtype) {
     224              :                     Some(Pp2AzureType::PrivateEndpointLinkId) => {
     225            0 :                         if tlv.value.len() != 4 {
     226            0 :                             tracing::warn!("invalid azure link_id: {:?}", tlv.value);
     227            0 :                         }
     228            0 :                         extra = Some(ConnectionInfoExtra::Azure {
     229            0 :                             link_id: tlv.value.get_u32_le(),
     230            0 :                         });
     231              :                     }
     232              :                     None => {
     233            0 :                         tracing::warn!("unknown azure tlv: subtype={subtype}");
     234              :                     }
     235              :                 }
     236              :             }
     237            0 :             Some(kind) => {
     238            0 :                 tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
     239              :             }
     240              :             None => {
     241            1 :                 tracing::debug!("unknown tlv: {tlv:?}");
     242              :             }
     243              :         }
     244              :     }
     245              : 
     246            3 :     Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
     247            4 : }
     248              : 
     249            1 : #[derive(FromRepr, Debug, Copy, Clone)]
     250              : #[repr(u8)]
     251              : enum Pp2Kind {
     252              :     // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
     253              :     // we don't use these but it would be interesting to know what's available
     254              :     Alpn = 0x01,
     255              :     Authority = 0x02,
     256              :     Crc32C = 0x03,
     257              :     Noop = 0x04,
     258              :     UniqueId = 0x05,
     259              :     Ssl = 0x20,
     260              :     NetNs = 0x30,
     261              : 
     262              :     /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
     263              :     Aws = 0xEA,
     264              : 
     265              :     /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
     266              :     Azure = 0xEE,
     267              : }
     268              : 
     269            0 : #[derive(FromRepr, Debug, Copy, Clone)]
     270              : #[repr(u8)]
     271              : enum Pp2AwsType {
     272              :     VpceId = 0x01,
     273              : }
     274              : 
     275            0 : #[derive(FromRepr, Debug, Copy, Clone)]
     276              : #[repr(u8)]
     277              : enum Pp2AzureType {
     278              :     PrivateEndpointLinkId = 0x01,
     279              : }
     280              : 
     281              : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
     282              :     #[inline]
     283          172 :     fn poll_read(
     284          172 :         self: Pin<&mut Self>,
     285          172 :         cx: &mut Context<'_>,
     286          172 :         buf: &mut ReadBuf<'_>,
     287          172 :     ) -> Poll<io::Result<()>> {
     288          172 :         if self.buf.is_empty() {
     289          153 :             self.project().inner.poll_read(cx, buf)
     290              :         } else {
     291           19 :             self.read_from_buf(buf)
     292              :         }
     293          172 :     }
     294              : }
     295              : 
     296              : impl<T: AsyncRead> ChainRW<T> {
     297              :     #[cold]
     298           19 :     fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
     299           19 :         debug_assert!(!self.buf.is_empty());
     300           19 :         let this = self.project();
     301           19 : 
     302           19 :         let write = usize::min(this.buf.len(), buf.remaining());
     303           19 :         let slice = this.buf.split_to(write).freeze();
     304           19 :         buf.put_slice(&slice);
     305           19 : 
     306           19 :         // reset the allocation so it can be freed
     307           19 :         if this.buf.is_empty() {
     308           17 :             *this.buf = BytesMut::new();
     309           17 :         }
     310              : 
     311           19 :         Poll::Ready(Ok(()))
     312           19 :     }
     313              : }
     314              : 
     315              : #[derive(Debug)]
     316              : struct Tlv {
     317              :     kind: u8,
     318              :     value: Bytes,
     319              : }
     320              : 
     321            4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
     322            4 :     let tlv_header = b.try_get::<TlvHeader>()?;
     323            3 :     let len = usize::from(tlv_header.len.get());
     324            3 :     if b.len() < len {
     325            2 :         return None;
     326            1 :     }
     327            1 :     Some(Tlv {
     328            1 :         kind: tlv_header.kind,
     329            1 :         value: b.split_to(len).freeze(),
     330            1 :     })
     331            4 : }
     332              : 
     333              : trait BufExt: Sized {
     334              :     fn try_get<T: FromBytes>(&mut self) -> Option<T>;
     335              : }
     336              : impl BufExt for BytesMut {
     337           19 :     fn try_get<T: FromBytes>(&mut self) -> Option<T> {
     338           19 :         let res = T::read_from_prefix(self)?;
     339           10 :         self.advance(size_of::<T>());
     340           10 :         Some(res)
     341           19 :     }
     342              : }
     343              : 
     344            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     345              : #[repr(C)]
     346              : struct ProxyProtocolV2Header {
     347              :     signature: [u8; 12],
     348              :     version_and_command: u8,
     349              :     protocol_and_family: u8,
     350              :     len: zerocopy::byteorder::network_endian::U16,
     351              : }
     352              : 
     353            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     354              : #[repr(C)]
     355              : struct ProxyProtocolV2HeaderV4 {
     356              :     src_addr: NetworkEndianIpv4,
     357              :     dst_addr: NetworkEndianIpv4,
     358              :     src_port: zerocopy::byteorder::network_endian::U16,
     359              :     dst_port: zerocopy::byteorder::network_endian::U16,
     360              : }
     361              : 
     362            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     363              : #[repr(C)]
     364              : struct ProxyProtocolV2HeaderV6 {
     365              :     src_addr: NetworkEndianIpv6,
     366              :     dst_addr: NetworkEndianIpv6,
     367              :     src_port: zerocopy::byteorder::network_endian::U16,
     368              :     dst_port: zerocopy::byteorder::network_endian::U16,
     369              : }
     370              : 
     371            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     372              : #[repr(C)]
     373              : struct TlvHeader {
     374              :     kind: u8,
     375              :     len: zerocopy::byteorder::network_endian::U16,
     376              : }
     377              : 
     378            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     379              : #[repr(transparent)]
     380              : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
     381              : impl NetworkEndianIpv4 {
     382              :     #[inline]
     383            2 :     fn get(self) -> Ipv4Addr {
     384            2 :         Ipv4Addr::from_bits(self.0.get())
     385            2 :     }
     386              : }
     387              : 
     388            0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
     389              : #[repr(transparent)]
     390              : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
     391              : impl NetworkEndianIpv6 {
     392              :     #[inline]
     393            1 :     fn get(self) -> Ipv6Addr {
     394            1 :         Ipv6Addr::from_bits(self.0.get())
     395            1 :     }
     396              : }
     397              : 
     398              : #[cfg(test)]
     399              : mod tests {
     400              :     use tokio::io::AsyncReadExt;
     401              : 
     402              :     use crate::protocol2::{
     403              :         read_proxy_protocol, ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6,
     404              :     };
     405              : 
     406              :     #[tokio::test]
     407            1 :     async fn test_ipv4() {
     408            1 :         let header = super::SIGNATURE
     409            1 :             // Proxy command, IPV4 | TCP
     410            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     411            1 :             // 12 + 3 bytes
     412            1 :             .chain([0, 15].as_slice())
     413            1 :             // src ip
     414            1 :             .chain([127, 0, 0, 1].as_slice())
     415            1 :             // dst ip
     416            1 :             .chain([192, 168, 0, 1].as_slice())
     417            1 :             // src port
     418            1 :             .chain([255, 255].as_slice())
     419            1 :             // dst port
     420            1 :             .chain([1, 1].as_slice())
     421            1 :             // TLV
     422            1 :             .chain([1, 2, 3].as_slice());
     423            1 : 
     424            1 :         let extra_data = [0x55; 256];
     425            1 : 
     426            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     427            1 :             .await
     428            1 :             .unwrap();
     429            1 : 
     430            1 :         let mut bytes = vec![];
     431            1 :         read.read_to_end(&mut bytes).await.unwrap();
     432            1 : 
     433            1 :         assert_eq!(bytes, extra_data);
     434            1 : 
     435            1 :         let ConnectHeader::Proxy(info) = info else {
     436            1 :             panic!()
     437            1 :         };
     438            1 :         assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
     439            1 :     }
     440              : 
     441              :     #[tokio::test]
     442            1 :     async fn test_ipv6() {
     443            1 :         let header = super::SIGNATURE
     444            1 :             // Proxy command, IPV6 | UDP
     445            1 :             .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
     446            1 :             // 36 + 3 bytes
     447            1 :             .chain([0, 39].as_slice())
     448            1 :             // src ip
     449            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     450            1 :             // dst ip
     451            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     452            1 :             // src port
     453            1 :             .chain([1, 1].as_slice())
     454            1 :             // dst port
     455            1 :             .chain([255, 255].as_slice())
     456            1 :             // TLV
     457            1 :             .chain([1, 2, 3].as_slice());
     458            1 : 
     459            1 :         let extra_data = [0x55; 256];
     460            1 : 
     461            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     462            1 :             .await
     463            1 :             .unwrap();
     464            1 : 
     465            1 :         let mut bytes = vec![];
     466            1 :         read.read_to_end(&mut bytes).await.unwrap();
     467            1 : 
     468            1 :         assert_eq!(bytes, extra_data);
     469            1 : 
     470            1 :         let ConnectHeader::Proxy(info) = info else {
     471            1 :             panic!()
     472            1 :         };
     473            1 :         assert_eq!(
     474            1 :             info.addr,
     475            1 :             ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     476            1 :         );
     477            1 :     }
     478              : 
     479              :     #[tokio::test]
     480            1 :     async fn test_invalid() {
     481            1 :         let data = [0x55; 256];
     482            1 : 
     483            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     484            1 : 
     485            1 :         let mut bytes = vec![];
     486            1 :         read.read_to_end(&mut bytes).await.unwrap();
     487            1 :         assert_eq!(bytes, data);
     488            1 :         assert_eq!(info, ConnectHeader::Missing);
     489            1 :     }
     490              : 
     491              :     #[tokio::test]
     492            1 :     async fn test_short() {
     493            1 :         let data = [0x55; 10];
     494            1 : 
     495            1 :         let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
     496            1 : 
     497            1 :         let mut bytes = vec![];
     498            1 :         read.read_to_end(&mut bytes).await.unwrap();
     499            1 :         assert_eq!(bytes, data);
     500            1 :         assert_eq!(info, ConnectHeader::Missing);
     501            1 :     }
     502              : 
     503              :     #[tokio::test]
     504            1 :     async fn test_large_tlv() {
     505            1 :         let tlv = vec![0x55; 32768];
     506            1 :         let tlv_len = (tlv.len() as u16).to_be_bytes();
     507            1 :         let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
     508            1 : 
     509            1 :         let header = super::SIGNATURE
     510            1 :             // Proxy command, Inet << 4 | Stream
     511            1 :             .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
     512            1 :             // 12 + 3 bytes
     513            1 :             .chain(len.as_slice())
     514            1 :             // src ip
     515            1 :             .chain([55, 56, 57, 58].as_slice())
     516            1 :             // dst ip
     517            1 :             .chain([192, 168, 0, 1].as_slice())
     518            1 :             // src port
     519            1 :             .chain([255, 255].as_slice())
     520            1 :             // dst port
     521            1 :             .chain([1, 1].as_slice())
     522            1 :             // TLV
     523            1 :             .chain([255].as_slice())
     524            1 :             .chain(tlv_len.as_slice())
     525            1 :             .chain(tlv.as_slice());
     526            1 : 
     527            1 :         let extra_data = [0xaa; 256];
     528            1 : 
     529            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     530            1 :             .await
     531            1 :             .unwrap();
     532            1 : 
     533            1 :         let mut bytes = vec![];
     534            1 :         read.read_to_end(&mut bytes).await.unwrap();
     535            1 : 
     536            1 :         assert_eq!(bytes, extra_data);
     537            1 : 
     538            1 :         let ConnectHeader::Proxy(info) = info else {
     539            1 :             panic!()
     540            1 :         };
     541            1 :         assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
     542            1 :     }
     543              : 
     544              :     #[tokio::test]
     545            1 :     async fn test_local() {
     546            1 :         let len = 0u16.to_be_bytes();
     547            1 :         let header = super::SIGNATURE
     548            1 :             .chain([LOCAL_V2, 0x00].as_slice())
     549            1 :             .chain(len.as_slice());
     550            1 : 
     551            1 :         let extra_data = [0xaa; 256];
     552            1 : 
     553            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     554            1 :             .await
     555            1 :             .unwrap();
     556            1 : 
     557            1 :         let mut bytes = vec![];
     558            1 :         read.read_to_end(&mut bytes).await.unwrap();
     559            1 : 
     560            1 :         assert_eq!(bytes, extra_data);
     561            1 : 
     562            1 :         let ConnectHeader::Local = info else { panic!() };
     563            1 :     }
     564              : }
        

Generated by: LCOV version 2.1-beta