LCOV - code coverage report
Current view: top level - proxy/src - protocol2.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 78.0 % 186 145
Test Date: 2025-07-16 12:29:03 Functions: 87.5 % 32 28

            Line data    Source code
       1              : //! Proxy Protocol V2 implementation
       2              : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
       3              : 
       4              : use core::fmt;
       5              : use std::io;
       6              : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
       7              : 
       8              : use bytes::Buf;
       9              : use smol_str::SmolStr;
      10              : use strum_macros::FromRepr;
      11              : use tokio::io::{AsyncRead, AsyncReadExt};
      12              : use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian};
      13              : 
      14              : /// Proxy Protocol Version 2 Header
      15              : const SIGNATURE: [u8; 12] = [
      16              :     0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
      17              : ];
      18              : 
      19              : const LOCAL_V2: u8 = 0x20;
      20              : const PROXY_V2: u8 = 0x21;
      21              : 
      22              : const TCP_OVER_IPV4: u8 = 0x11;
      23              : const UDP_OVER_IPV4: u8 = 0x12;
      24              : const TCP_OVER_IPV6: u8 = 0x21;
      25              : const UDP_OVER_IPV6: u8 = 0x22;
      26              : 
      27              : #[derive(PartialEq, Eq, Clone, Debug)]
      28              : pub struct ConnectionInfo {
      29              :     pub addr: SocketAddr,
      30              :     pub extra: Option<ConnectionInfoExtra>,
      31              : }
      32              : 
      33              : #[derive(PartialEq, Eq, Clone, Debug)]
      34              : pub enum ConnectHeader {
      35              :     Local,
      36              :     Proxy(ConnectionInfo),
      37              : }
      38              : 
      39              : impl fmt::Display for ConnectionInfo {
      40            6 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      41            0 :         match &self.extra {
      42            6 :             None => self.addr.ip().fmt(f),
      43            0 :             Some(ConnectionInfoExtra::Aws { vpce_id }) => {
      44            0 :                 write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
      45              :             }
      46            0 :             Some(ConnectionInfoExtra::Azure { link_id }) => {
      47            0 :                 write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
      48              :             }
      49              :         }
      50            6 :     }
      51              : }
      52              : 
      53              : #[derive(PartialEq, Eq, Clone, Debug)]
      54              : pub enum ConnectionInfoExtra {
      55              :     Aws { vpce_id: SmolStr },
      56              :     Azure { link_id: u32 },
      57              : }
      58              : 
      59            6 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
      60            6 :     mut read: T,
      61            6 : ) -> std::io::Result<(T, ConnectHeader)> {
      62            6 :     let mut header = [0; size_of::<ProxyProtocolV2Header>()];
      63            6 :     read.read_exact(&mut header).await?;
      64            5 :     let header: ProxyProtocolV2Header = zerocopy::transmute!(header);
      65            5 :     if header.signature != SIGNATURE {
      66            1 :         return Err(std::io::Error::other("invalid proxy protocol header"));
      67            4 :     }
      68              : 
      69            4 :     let mut payload = vec![0; usize::from(header.len.get())];
      70            4 :     read.read_exact(&mut payload).await?;
      71              : 
      72            4 :     let res = process_proxy_payload(header, &payload)?;
      73            4 :     Ok((read, res))
      74            6 : }
      75              : 
      76            4 : fn process_proxy_payload(
      77            4 :     header: ProxyProtocolV2Header,
      78            4 :     mut payload: &[u8],
      79            4 : ) -> std::io::Result<ConnectHeader> {
      80            4 :     match header.version_and_command {
      81              :         // the connection was established on purpose by the proxy
      82              :         // without being relayed. The connection endpoints are the sender and the
      83              :         // receiver. Such connections exist when the proxy sends health-checks to the
      84              :         // server. The receiver must accept this connection as valid and must use the
      85              :         // real connection endpoints and discard the protocol block including the
      86              :         // family which is ignored.
      87            1 :         LOCAL_V2 => return Ok(ConnectHeader::Local),
      88              :         // the connection was established on behalf of another node,
      89              :         // and reflects the original connection endpoints. The receiver must then use
      90              :         // the information provided in the protocol block to get original the address.
      91            3 :         PROXY_V2 => {}
      92              :         // other values are unassigned and must not be emitted by senders. Receivers
      93              :         // must drop connections presenting unexpected values here.
      94              :         _ => {
      95            0 :             return Err(io::Error::other(format!(
      96            0 :                 "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
      97            0 :                 header.version_and_command
      98            0 :             )));
      99              :         }
     100              :     }
     101              : 
     102            3 :     let size_err =
     103            3 :         "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
     104            3 :     let addr = match header.protocol_and_family {
     105              :         TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
     106            2 :             let addr = payload
     107            2 :                 .try_get::<ProxyProtocolV2HeaderV4>()
     108            2 :                 .ok_or_else(|| io::Error::other(size_err))?;
     109              : 
     110            2 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     111              :         }
     112              :         TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
     113            1 :             let addr = payload
     114            1 :                 .try_get::<ProxyProtocolV2HeaderV6>()
     115            1 :                 .ok_or_else(|| io::Error::other(size_err))?;
     116              : 
     117            1 :             SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
     118              :         }
     119              :         // unspecified or unix stream. ignore the addresses
     120              :         _ => {
     121            0 :             return Err(io::Error::other(
     122            0 :                 "invalid proxy protocol address family/transport protocol.",
     123            0 :             ));
     124              :         }
     125              :     };
     126              : 
     127            3 :     let mut extra = None;
     128              : 
     129            4 :     while let Some(mut tlv) = read_tlv(&mut payload) {
     130            1 :         match Pp2Kind::from_repr(tlv.kind) {
     131              :             Some(Pp2Kind::Aws) => {
     132            0 :                 if tlv.value.is_empty() {
     133            0 :                     tracing::warn!("invalid aws tlv: no subtype");
     134            0 :                 }
     135            0 :                 let subtype = tlv.value.get_u8();
     136            0 :                 match Pp2AwsType::from_repr(subtype) {
     137            0 :                     Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) {
     138            0 :                         Ok(s) => {
     139            0 :                             extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
     140            0 :                         }
     141            0 :                         Err(e) => {
     142            0 :                             tracing::warn!("invalid aws vpce id: {e}");
     143              :                         }
     144              :                     },
     145              :                     None => {
     146            0 :                         tracing::warn!("unknown aws tlv: subtype={subtype}");
     147              :                     }
     148              :                 }
     149              :             }
     150              :             Some(Pp2Kind::Azure) => {
     151            0 :                 if tlv.value.is_empty() {
     152            0 :                     tracing::warn!("invalid azure tlv: no subtype");
     153            0 :                 }
     154            0 :                 let subtype = tlv.value.get_u8();
     155            0 :                 match Pp2AzureType::from_repr(subtype) {
     156              :                     Some(Pp2AzureType::PrivateEndpointLinkId) => {
     157            0 :                         if tlv.value.len() != 4 {
     158            0 :                             tracing::warn!("invalid azure link_id: {:?}", tlv.value);
     159            0 :                         }
     160            0 :                         extra = Some(ConnectionInfoExtra::Azure {
     161            0 :                             link_id: tlv.value.get_u32_le(),
     162            0 :                         });
     163              :                     }
     164              :                     None => {
     165            0 :                         tracing::warn!("unknown azure tlv: subtype={subtype}");
     166              :                     }
     167              :                 }
     168              :             }
     169            0 :             Some(kind) => {
     170            0 :                 tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
     171              :             }
     172              :             None => {
     173            1 :                 tracing::debug!("unknown tlv: {tlv:?}");
     174              :             }
     175              :         }
     176              :     }
     177              : 
     178            3 :     Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
     179            4 : }
     180              : 
     181              : #[derive(FromRepr, Debug, Copy, Clone)]
     182              : #[repr(u8)]
     183              : enum Pp2Kind {
     184              :     // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
     185              :     // we don't use these but it would be interesting to know what's available
     186              :     Alpn = 0x01,
     187              :     Authority = 0x02,
     188              :     Crc32C = 0x03,
     189              :     Noop = 0x04,
     190              :     UniqueId = 0x05,
     191              :     Ssl = 0x20,
     192              :     NetNs = 0x30,
     193              : 
     194              :     /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
     195              :     Aws = 0xEA,
     196              : 
     197              :     /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
     198              :     Azure = 0xEE,
     199              : }
     200              : 
     201              : #[derive(FromRepr, Debug, Copy, Clone)]
     202              : #[repr(u8)]
     203              : enum Pp2AwsType {
     204              :     VpceId = 0x01,
     205              : }
     206              : 
     207              : #[derive(FromRepr, Debug, Copy, Clone)]
     208              : #[repr(u8)]
     209              : enum Pp2AzureType {
     210              :     PrivateEndpointLinkId = 0x01,
     211              : }
     212              : 
     213              : #[derive(Debug)]
     214              : struct Tlv<'a> {
     215              :     kind: u8,
     216              :     value: &'a [u8],
     217              : }
     218              : 
     219            4 : fn read_tlv<'a>(b: &mut &'a [u8]) -> Option<Tlv<'a>> {
     220            4 :     let tlv_header = b.try_get::<TlvHeader>()?;
     221            3 :     let len = usize::from(tlv_header.len.get());
     222              :     Some(Tlv {
     223            3 :         kind: tlv_header.kind,
     224            3 :         value: b.split_off(..len)?,
     225              :     })
     226            4 : }
     227              : 
     228              : trait BufExt: Sized {
     229              :     fn try_get<T: FromBytes>(&mut self) -> Option<T>;
     230              : }
     231              : impl BufExt for &[u8] {
     232            7 :     fn try_get<T: FromBytes>(&mut self) -> Option<T> {
     233            7 :         let (res, rest) = T::read_from_prefix(self).ok()?;
     234            6 :         *self = rest;
     235            6 :         Some(res)
     236            7 :     }
     237              : }
     238              : 
     239              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     240              : #[repr(C, packed)]
     241              : struct ProxyProtocolV2Header {
     242              :     signature: [u8; 12],
     243              :     version_and_command: u8,
     244              :     protocol_and_family: u8,
     245              :     len: network_endian::U16,
     246              : }
     247              : 
     248              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     249              : #[repr(C, packed)]
     250              : struct ProxyProtocolV2HeaderV4 {
     251              :     src_addr: NetworkEndianIpv4,
     252              :     dst_addr: NetworkEndianIpv4,
     253              :     src_port: network_endian::U16,
     254              :     dst_port: network_endian::U16,
     255              : }
     256              : 
     257              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     258              : #[repr(C, packed)]
     259              : struct ProxyProtocolV2HeaderV6 {
     260              :     src_addr: NetworkEndianIpv6,
     261              :     dst_addr: NetworkEndianIpv6,
     262              :     src_port: network_endian::U16,
     263              :     dst_port: network_endian::U16,
     264              : }
     265              : 
     266              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     267              : #[repr(C, packed)]
     268              : struct TlvHeader {
     269              :     kind: u8,
     270              :     len: network_endian::U16,
     271              : }
     272              : 
     273              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     274              : #[repr(transparent)]
     275              : struct NetworkEndianIpv4(network_endian::U32);
     276              : impl NetworkEndianIpv4 {
     277              :     #[inline]
     278            2 :     fn get(self) -> Ipv4Addr {
     279            2 :         Ipv4Addr::from_bits(self.0.get())
     280            2 :     }
     281              : }
     282              : 
     283              : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
     284              : #[repr(transparent)]
     285              : struct NetworkEndianIpv6(network_endian::U128);
     286              : impl NetworkEndianIpv6 {
     287              :     #[inline]
     288            1 :     fn get(self) -> Ipv6Addr {
     289            1 :         Ipv6Addr::from_bits(self.0.get())
     290            1 :     }
     291              : }
     292              : 
     293              : #[cfg(test)]
     294              : mod tests {
     295              :     use tokio::io::AsyncReadExt;
     296              : 
     297              :     use crate::protocol2::{
     298              :         ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
     299              :     };
     300              : 
     301              :     #[tokio::test]
     302            1 :     async fn test_ipv4() {
     303            1 :         let header = super::SIGNATURE
     304              :             // Proxy command, IPV4 | TCP
     305            1 :             .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
     306              :             // 12 + 3 bytes
     307            1 :             .chain([0, 15].as_slice())
     308              :             // src ip
     309            1 :             .chain([127, 0, 0, 1].as_slice())
     310              :             // dst ip
     311            1 :             .chain([192, 168, 0, 1].as_slice())
     312              :             // src port
     313            1 :             .chain([255, 255].as_slice())
     314              :             // dst port
     315            1 :             .chain([1, 1].as_slice())
     316              :             // TLV
     317            1 :             .chain([1, 2, 3].as_slice());
     318              : 
     319            1 :         let extra_data = [0x55; 256];
     320              : 
     321            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     322            1 :             .await
     323            1 :             .unwrap();
     324              : 
     325            1 :         let mut bytes = vec![];
     326            1 :         read.read_to_end(&mut bytes).await.unwrap();
     327              : 
     328            1 :         assert_eq!(bytes, extra_data);
     329              : 
     330            1 :         let ConnectHeader::Proxy(info) = info else {
     331            0 :             panic!()
     332              :         };
     333            1 :         assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
     334            1 :     }
     335              : 
     336              :     #[tokio::test]
     337            1 :     async fn test_ipv6() {
     338            1 :         let header = super::SIGNATURE
     339              :             // Proxy command, IPV6 | UDP
     340            1 :             .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
     341              :             // 36 + 3 bytes
     342            1 :             .chain([0, 39].as_slice())
     343              :             // src ip
     344            1 :             .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
     345              :             // dst ip
     346            1 :             .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
     347              :             // src port
     348            1 :             .chain([1, 1].as_slice())
     349              :             // dst port
     350            1 :             .chain([255, 255].as_slice())
     351              :             // TLV
     352            1 :             .chain([1, 2, 3].as_slice());
     353              : 
     354            1 :         let extra_data = [0x55; 256];
     355              : 
     356            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     357            1 :             .await
     358            1 :             .unwrap();
     359              : 
     360            1 :         let mut bytes = vec![];
     361            1 :         read.read_to_end(&mut bytes).await.unwrap();
     362              : 
     363            1 :         assert_eq!(bytes, extra_data);
     364              : 
     365            1 :         let ConnectHeader::Proxy(info) = info else {
     366            0 :             panic!()
     367              :         };
     368            1 :         assert_eq!(
     369            1 :             info.addr,
     370            1 :             ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
     371            1 :         );
     372            1 :     }
     373              : 
     374              :     #[tokio::test]
     375              :     #[should_panic = "invalid proxy protocol header"]
     376            1 :     async fn test_invalid() {
     377            1 :         let data = [0x55; 256];
     378              : 
     379            1 :         read_proxy_protocol(data.as_slice()).await.unwrap();
     380            1 :     }
     381              : 
     382              :     #[tokio::test]
     383              :     #[should_panic = "early eof"]
     384            1 :     async fn test_short() {
     385            1 :         let data = [0x55; 10];
     386              : 
     387            1 :         read_proxy_protocol(data.as_slice()).await.unwrap();
     388            1 :     }
     389              : 
     390              :     #[tokio::test]
     391            1 :     async fn test_large_tlv() {
     392            1 :         let tlv = vec![0x55; 32768];
     393            1 :         let tlv_len = (tlv.len() as u16).to_be_bytes();
     394            1 :         let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
     395              : 
     396            1 :         let header = super::SIGNATURE
     397              :             // Proxy command, Inet << 4 | Stream
     398            1 :             .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
     399              :             // 12 + 3 bytes
     400            1 :             .chain(len.as_slice())
     401              :             // src ip
     402            1 :             .chain([55, 56, 57, 58].as_slice())
     403              :             // dst ip
     404            1 :             .chain([192, 168, 0, 1].as_slice())
     405              :             // src port
     406            1 :             .chain([255, 255].as_slice())
     407              :             // dst port
     408            1 :             .chain([1, 1].as_slice())
     409              :             // TLV
     410            1 :             .chain([255].as_slice())
     411            1 :             .chain(tlv_len.as_slice())
     412            1 :             .chain(tlv.as_slice());
     413              : 
     414            1 :         let extra_data = [0xaa; 256];
     415              : 
     416            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     417            1 :             .await
     418            1 :             .unwrap();
     419              : 
     420            1 :         let mut bytes = vec![];
     421            1 :         read.read_to_end(&mut bytes).await.unwrap();
     422              : 
     423            1 :         assert_eq!(bytes, extra_data);
     424              : 
     425            1 :         let ConnectHeader::Proxy(info) = info else {
     426            0 :             panic!()
     427              :         };
     428            1 :         assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
     429            1 :     }
     430              : 
     431              :     #[tokio::test]
     432            1 :     async fn test_local() {
     433            1 :         let len = 0u16.to_be_bytes();
     434            1 :         let header = super::SIGNATURE
     435            1 :             .chain([LOCAL_V2, 0x00].as_slice())
     436            1 :             .chain(len.as_slice());
     437              : 
     438            1 :         let extra_data = [0xaa; 256];
     439              : 
     440            1 :         let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
     441            1 :             .await
     442            1 :             .unwrap();
     443              : 
     444            1 :         let mut bytes = vec![];
     445            1 :         read.read_to_end(&mut bytes).await.unwrap();
     446              : 
     447            1 :         assert_eq!(bytes, extra_data);
     448              : 
     449            1 :         let ConnectHeader::Local = info else { panic!() };
     450            1 :     }
     451              : }
        

Generated by: LCOV version 2.1-beta