Line data Source code
1 : //! Proxy Protocol V2 implementation
2 : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
3 :
4 : use core::fmt;
5 : use std::io;
6 : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
7 :
8 : use bytes::Buf;
9 : use smol_str::SmolStr;
10 : use strum_macros::FromRepr;
11 : use tokio::io::{AsyncRead, AsyncReadExt};
12 : use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian};
13 :
14 : /// Proxy Protocol Version 2 Header
15 : const SIGNATURE: [u8; 12] = [
16 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
17 : ];
18 :
19 : const LOCAL_V2: u8 = 0x20;
20 : const PROXY_V2: u8 = 0x21;
21 :
22 : const TCP_OVER_IPV4: u8 = 0x11;
23 : const UDP_OVER_IPV4: u8 = 0x12;
24 : const TCP_OVER_IPV6: u8 = 0x21;
25 : const UDP_OVER_IPV6: u8 = 0x22;
26 :
27 : #[derive(PartialEq, Eq, Clone, Debug)]
28 : pub struct ConnectionInfo {
29 : pub addr: SocketAddr,
30 : pub extra: Option<ConnectionInfoExtra>,
31 : }
32 :
33 : #[derive(PartialEq, Eq, Clone, Debug)]
34 : pub enum ConnectHeader {
35 : Local,
36 : Proxy(ConnectionInfo),
37 : }
38 :
39 : impl fmt::Display for ConnectionInfo {
40 6 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 0 : match &self.extra {
42 6 : None => self.addr.ip().fmt(f),
43 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => {
44 0 : write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
45 : }
46 0 : Some(ConnectionInfoExtra::Azure { link_id }) => {
47 0 : write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
48 : }
49 : }
50 6 : }
51 : }
52 :
53 : #[derive(PartialEq, Eq, Clone, Debug)]
54 : pub enum ConnectionInfoExtra {
55 : Aws { vpce_id: SmolStr },
56 : Azure { link_id: u32 },
57 : }
58 :
59 6 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
60 6 : mut read: T,
61 6 : ) -> std::io::Result<(T, ConnectHeader)> {
62 6 : let mut header = [0; size_of::<ProxyProtocolV2Header>()];
63 6 : read.read_exact(&mut header).await?;
64 5 : let header: ProxyProtocolV2Header = zerocopy::transmute!(header);
65 5 : if header.signature != SIGNATURE {
66 1 : return Err(std::io::Error::other("invalid proxy protocol header"));
67 4 : }
68 :
69 4 : let mut payload = vec![0; usize::from(header.len.get())];
70 4 : read.read_exact(&mut payload).await?;
71 :
72 4 : let res = process_proxy_payload(header, &payload)?;
73 4 : Ok((read, res))
74 6 : }
75 :
76 4 : fn process_proxy_payload(
77 4 : header: ProxyProtocolV2Header,
78 4 : mut payload: &[u8],
79 4 : ) -> std::io::Result<ConnectHeader> {
80 4 : match header.version_and_command {
81 : // the connection was established on purpose by the proxy
82 : // without being relayed. The connection endpoints are the sender and the
83 : // receiver. Such connections exist when the proxy sends health-checks to the
84 : // server. The receiver must accept this connection as valid and must use the
85 : // real connection endpoints and discard the protocol block including the
86 : // family which is ignored.
87 1 : LOCAL_V2 => return Ok(ConnectHeader::Local),
88 : // the connection was established on behalf of another node,
89 : // and reflects the original connection endpoints. The receiver must then use
90 : // the information provided in the protocol block to get original the address.
91 3 : PROXY_V2 => {}
92 : // other values are unassigned and must not be emitted by senders. Receivers
93 : // must drop connections presenting unexpected values here.
94 : _ => {
95 0 : return Err(io::Error::other(format!(
96 0 : "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
97 0 : header.version_and_command
98 0 : )));
99 : }
100 : }
101 :
102 3 : let size_err =
103 3 : "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
104 3 : let addr = match header.protocol_and_family {
105 : TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
106 2 : let addr = payload
107 2 : .try_get::<ProxyProtocolV2HeaderV4>()
108 2 : .ok_or_else(|| io::Error::other(size_err))?;
109 :
110 2 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
111 : }
112 : TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
113 1 : let addr = payload
114 1 : .try_get::<ProxyProtocolV2HeaderV6>()
115 1 : .ok_or_else(|| io::Error::other(size_err))?;
116 :
117 1 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
118 : }
119 : // unspecified or unix stream. ignore the addresses
120 : _ => {
121 0 : return Err(io::Error::other(
122 0 : "invalid proxy protocol address family/transport protocol.",
123 0 : ));
124 : }
125 : };
126 :
127 3 : let mut extra = None;
128 :
129 4 : while let Some(mut tlv) = read_tlv(&mut payload) {
130 1 : match Pp2Kind::from_repr(tlv.kind) {
131 : Some(Pp2Kind::Aws) => {
132 0 : if tlv.value.is_empty() {
133 0 : tracing::warn!("invalid aws tlv: no subtype");
134 0 : }
135 0 : let subtype = tlv.value.get_u8();
136 0 : match Pp2AwsType::from_repr(subtype) {
137 0 : Some(Pp2AwsType::VpceId) => match std::str::from_utf8(tlv.value) {
138 0 : Ok(s) => {
139 0 : extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
140 0 : }
141 0 : Err(e) => {
142 0 : tracing::warn!("invalid aws vpce id: {e}");
143 : }
144 : },
145 : None => {
146 0 : tracing::warn!("unknown aws tlv: subtype={subtype}");
147 : }
148 : }
149 : }
150 : Some(Pp2Kind::Azure) => {
151 0 : if tlv.value.is_empty() {
152 0 : tracing::warn!("invalid azure tlv: no subtype");
153 0 : }
154 0 : let subtype = tlv.value.get_u8();
155 0 : match Pp2AzureType::from_repr(subtype) {
156 : Some(Pp2AzureType::PrivateEndpointLinkId) => {
157 0 : if tlv.value.len() != 4 {
158 0 : tracing::warn!("invalid azure link_id: {:?}", tlv.value);
159 0 : }
160 0 : extra = Some(ConnectionInfoExtra::Azure {
161 0 : link_id: tlv.value.get_u32_le(),
162 0 : });
163 : }
164 : None => {
165 0 : tracing::warn!("unknown azure tlv: subtype={subtype}");
166 : }
167 : }
168 : }
169 0 : Some(kind) => {
170 0 : tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
171 : }
172 : None => {
173 1 : tracing::debug!("unknown tlv: {tlv:?}");
174 : }
175 : }
176 : }
177 :
178 3 : Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
179 4 : }
180 :
181 : #[derive(FromRepr, Debug, Copy, Clone)]
182 : #[repr(u8)]
183 : enum Pp2Kind {
184 : // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
185 : // we don't use these but it would be interesting to know what's available
186 : Alpn = 0x01,
187 : Authority = 0x02,
188 : Crc32C = 0x03,
189 : Noop = 0x04,
190 : UniqueId = 0x05,
191 : Ssl = 0x20,
192 : NetNs = 0x30,
193 :
194 : /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
195 : Aws = 0xEA,
196 :
197 : /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
198 : Azure = 0xEE,
199 : }
200 :
201 : #[derive(FromRepr, Debug, Copy, Clone)]
202 : #[repr(u8)]
203 : enum Pp2AwsType {
204 : VpceId = 0x01,
205 : }
206 :
207 : #[derive(FromRepr, Debug, Copy, Clone)]
208 : #[repr(u8)]
209 : enum Pp2AzureType {
210 : PrivateEndpointLinkId = 0x01,
211 : }
212 :
213 : #[derive(Debug)]
214 : struct Tlv<'a> {
215 : kind: u8,
216 : value: &'a [u8],
217 : }
218 :
219 4 : fn read_tlv<'a>(b: &mut &'a [u8]) -> Option<Tlv<'a>> {
220 4 : let tlv_header = b.try_get::<TlvHeader>()?;
221 3 : let len = usize::from(tlv_header.len.get());
222 : Some(Tlv {
223 3 : kind: tlv_header.kind,
224 3 : value: b.split_off(..len)?,
225 : })
226 4 : }
227 :
228 : trait BufExt: Sized {
229 : fn try_get<T: FromBytes>(&mut self) -> Option<T>;
230 : }
231 : impl BufExt for &[u8] {
232 7 : fn try_get<T: FromBytes>(&mut self) -> Option<T> {
233 7 : let (res, rest) = T::read_from_prefix(self).ok()?;
234 6 : *self = rest;
235 6 : Some(res)
236 7 : }
237 : }
238 :
239 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
240 : #[repr(C, packed)]
241 : struct ProxyProtocolV2Header {
242 : signature: [u8; 12],
243 : version_and_command: u8,
244 : protocol_and_family: u8,
245 : len: network_endian::U16,
246 : }
247 :
248 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
249 : #[repr(C, packed)]
250 : struct ProxyProtocolV2HeaderV4 {
251 : src_addr: NetworkEndianIpv4,
252 : dst_addr: NetworkEndianIpv4,
253 : src_port: network_endian::U16,
254 : dst_port: network_endian::U16,
255 : }
256 :
257 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
258 : #[repr(C, packed)]
259 : struct ProxyProtocolV2HeaderV6 {
260 : src_addr: NetworkEndianIpv6,
261 : dst_addr: NetworkEndianIpv6,
262 : src_port: network_endian::U16,
263 : dst_port: network_endian::U16,
264 : }
265 :
266 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
267 : #[repr(C, packed)]
268 : struct TlvHeader {
269 : kind: u8,
270 : len: network_endian::U16,
271 : }
272 :
273 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
274 : #[repr(transparent)]
275 : struct NetworkEndianIpv4(network_endian::U32);
276 : impl NetworkEndianIpv4 {
277 : #[inline]
278 2 : fn get(self) -> Ipv4Addr {
279 2 : Ipv4Addr::from_bits(self.0.get())
280 2 : }
281 : }
282 :
283 : #[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
284 : #[repr(transparent)]
285 : struct NetworkEndianIpv6(network_endian::U128);
286 : impl NetworkEndianIpv6 {
287 : #[inline]
288 1 : fn get(self) -> Ipv6Addr {
289 1 : Ipv6Addr::from_bits(self.0.get())
290 1 : }
291 : }
292 :
293 : #[cfg(test)]
294 : mod tests {
295 : use tokio::io::AsyncReadExt;
296 :
297 : use crate::protocol2::{
298 : ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
299 : };
300 :
301 : #[tokio::test]
302 1 : async fn test_ipv4() {
303 1 : let header = super::SIGNATURE
304 : // Proxy command, IPV4 | TCP
305 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
306 : // 12 + 3 bytes
307 1 : .chain([0, 15].as_slice())
308 : // src ip
309 1 : .chain([127, 0, 0, 1].as_slice())
310 : // dst ip
311 1 : .chain([192, 168, 0, 1].as_slice())
312 : // src port
313 1 : .chain([255, 255].as_slice())
314 : // dst port
315 1 : .chain([1, 1].as_slice())
316 : // TLV
317 1 : .chain([1, 2, 3].as_slice());
318 :
319 1 : let extra_data = [0x55; 256];
320 :
321 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
322 1 : .await
323 1 : .unwrap();
324 :
325 1 : let mut bytes = vec![];
326 1 : read.read_to_end(&mut bytes).await.unwrap();
327 :
328 1 : assert_eq!(bytes, extra_data);
329 :
330 1 : let ConnectHeader::Proxy(info) = info else {
331 0 : panic!()
332 : };
333 1 : assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
334 1 : }
335 :
336 : #[tokio::test]
337 1 : async fn test_ipv6() {
338 1 : let header = super::SIGNATURE
339 : // Proxy command, IPV6 | UDP
340 1 : .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
341 : // 36 + 3 bytes
342 1 : .chain([0, 39].as_slice())
343 : // src ip
344 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
345 : // dst ip
346 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
347 : // src port
348 1 : .chain([1, 1].as_slice())
349 : // dst port
350 1 : .chain([255, 255].as_slice())
351 : // TLV
352 1 : .chain([1, 2, 3].as_slice());
353 :
354 1 : let extra_data = [0x55; 256];
355 :
356 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
357 1 : .await
358 1 : .unwrap();
359 :
360 1 : let mut bytes = vec![];
361 1 : read.read_to_end(&mut bytes).await.unwrap();
362 :
363 1 : assert_eq!(bytes, extra_data);
364 :
365 1 : let ConnectHeader::Proxy(info) = info else {
366 0 : panic!()
367 : };
368 1 : assert_eq!(
369 1 : info.addr,
370 1 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
371 1 : );
372 1 : }
373 :
374 : #[tokio::test]
375 : #[should_panic = "invalid proxy protocol header"]
376 1 : async fn test_invalid() {
377 1 : let data = [0x55; 256];
378 :
379 1 : read_proxy_protocol(data.as_slice()).await.unwrap();
380 1 : }
381 :
382 : #[tokio::test]
383 : #[should_panic = "early eof"]
384 1 : async fn test_short() {
385 1 : let data = [0x55; 10];
386 :
387 1 : read_proxy_protocol(data.as_slice()).await.unwrap();
388 1 : }
389 :
390 : #[tokio::test]
391 1 : async fn test_large_tlv() {
392 1 : let tlv = vec![0x55; 32768];
393 1 : let tlv_len = (tlv.len() as u16).to_be_bytes();
394 1 : let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
395 :
396 1 : let header = super::SIGNATURE
397 : // Proxy command, Inet << 4 | Stream
398 1 : .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
399 : // 12 + 3 bytes
400 1 : .chain(len.as_slice())
401 : // src ip
402 1 : .chain([55, 56, 57, 58].as_slice())
403 : // dst ip
404 1 : .chain([192, 168, 0, 1].as_slice())
405 : // src port
406 1 : .chain([255, 255].as_slice())
407 : // dst port
408 1 : .chain([1, 1].as_slice())
409 : // TLV
410 1 : .chain([255].as_slice())
411 1 : .chain(tlv_len.as_slice())
412 1 : .chain(tlv.as_slice());
413 :
414 1 : let extra_data = [0xaa; 256];
415 :
416 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
417 1 : .await
418 1 : .unwrap();
419 :
420 1 : let mut bytes = vec![];
421 1 : read.read_to_end(&mut bytes).await.unwrap();
422 :
423 1 : assert_eq!(bytes, extra_data);
424 :
425 1 : let ConnectHeader::Proxy(info) = info else {
426 0 : panic!()
427 : };
428 1 : assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
429 1 : }
430 :
431 : #[tokio::test]
432 1 : async fn test_local() {
433 1 : let len = 0u16.to_be_bytes();
434 1 : let header = super::SIGNATURE
435 1 : .chain([LOCAL_V2, 0x00].as_slice())
436 1 : .chain(len.as_slice());
437 :
438 1 : let extra_data = [0xaa; 256];
439 :
440 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
441 1 : .await
442 1 : .unwrap();
443 :
444 1 : let mut bytes = vec![];
445 1 : read.read_to_end(&mut bytes).await.unwrap();
446 :
447 1 : assert_eq!(bytes, extra_data);
448 :
449 1 : let ConnectHeader::Local = info else { panic!() };
450 1 : }
451 : }
|