Line data Source code
1 : //! Proxy Protocol V2 implementation
2 : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
3 :
4 : use core::fmt;
5 : use std::io;
6 : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
7 : use std::pin::Pin;
8 : use std::task::{Context, Poll};
9 :
10 : use bytes::{Buf, Bytes, BytesMut};
11 : use pin_project_lite::pin_project;
12 : use smol_str::SmolStr;
13 : use strum_macros::FromRepr;
14 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15 : use zerocopy::{FromBytes, FromZeroes};
16 :
17 : pin_project! {
18 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
19 : pub(crate) struct ChainRW<T> {
20 : #[pin]
21 : pub(crate) inner: T,
22 : buf: BytesMut,
23 : }
24 : }
25 :
26 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
27 : #[inline]
28 15 : fn poll_write(
29 15 : self: Pin<&mut Self>,
30 15 : cx: &mut Context<'_>,
31 15 : buf: &[u8],
32 15 : ) -> Poll<Result<usize, io::Error>> {
33 15 : self.project().inner.poll_write(cx, buf)
34 15 : }
35 :
36 : #[inline]
37 69 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
38 69 : self.project().inner.poll_flush(cx)
39 69 : }
40 :
41 : #[inline]
42 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
43 0 : self.project().inner.poll_shutdown(cx)
44 0 : }
45 :
46 : #[inline]
47 54 : fn poll_write_vectored(
48 54 : self: Pin<&mut Self>,
49 54 : cx: &mut Context<'_>,
50 54 : bufs: &[io::IoSlice<'_>],
51 54 : ) -> Poll<Result<usize, io::Error>> {
52 54 : self.project().inner.poll_write_vectored(cx, bufs)
53 54 : }
54 :
55 : #[inline]
56 0 : fn is_write_vectored(&self) -> bool {
57 0 : self.inner.is_write_vectored()
58 0 : }
59 : }
60 :
61 : /// Proxy Protocol Version 2 Header
62 : const SIGNATURE: [u8; 12] = [
63 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
64 : ];
65 :
66 : const LOCAL_V2: u8 = 0x20;
67 : const PROXY_V2: u8 = 0x21;
68 :
69 : const TCP_OVER_IPV4: u8 = 0x11;
70 : const UDP_OVER_IPV4: u8 = 0x12;
71 : const TCP_OVER_IPV6: u8 = 0x21;
72 : const UDP_OVER_IPV6: u8 = 0x22;
73 :
74 : #[derive(PartialEq, Eq, Clone, Debug)]
75 : pub struct ConnectionInfo {
76 : pub addr: SocketAddr,
77 : pub extra: Option<ConnectionInfoExtra>,
78 : }
79 :
80 : #[derive(PartialEq, Eq, Clone, Debug)]
81 : pub enum ConnectHeader {
82 : Missing,
83 : Local,
84 : Proxy(ConnectionInfo),
85 : }
86 :
87 : impl fmt::Display for ConnectionInfo {
88 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 0 : match &self.extra {
90 0 : None => self.addr.ip().fmt(f),
91 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => {
92 0 : write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
93 : }
94 0 : Some(ConnectionInfoExtra::Azure { link_id }) => {
95 0 : write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
96 : }
97 : }
98 0 : }
99 : }
100 :
101 : #[derive(PartialEq, Eq, Clone, Debug)]
102 : pub enum ConnectionInfoExtra {
103 : Aws { vpce_id: SmolStr },
104 : Azure { link_id: u32 },
105 : }
106 :
107 21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
108 21 : mut read: T,
109 21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
110 21 : let mut buf = BytesMut::with_capacity(128);
111 4 : let header = loop {
112 29 : let bytes_read = read.read_buf(&mut buf).await?;
113 :
114 : // exit for bad header signature
115 29 : let len = usize::min(buf.len(), SIGNATURE.len());
116 29 : if buf[..len] != SIGNATURE[..len] {
117 17 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
118 12 : }
119 12 :
120 12 : // if no more bytes available then exit
121 12 : if bytes_read == 0 {
122 0 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
123 12 : }
124 :
125 : // check if we have enough bytes to continue
126 12 : if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
127 4 : break header;
128 8 : }
129 : };
130 :
131 4 : let remaining_length = usize::from(header.len.get());
132 :
133 30 : while buf.len() < remaining_length {
134 26 : if read.read_buf(&mut buf).await? == 0 {
135 0 : return Err(io::Error::new(
136 0 : io::ErrorKind::UnexpectedEof,
137 0 : "stream closed while waiting for proxy protocol addresses",
138 0 : ));
139 26 : }
140 : }
141 4 : let payload = buf.split_to(remaining_length);
142 :
143 4 : let res = process_proxy_payload(header, payload)?;
144 4 : Ok((ChainRW { inner: read, buf }, res))
145 21 : }
146 :
147 4 : fn process_proxy_payload(
148 4 : header: ProxyProtocolV2Header,
149 4 : mut payload: BytesMut,
150 4 : ) -> std::io::Result<ConnectHeader> {
151 4 : match header.version_and_command {
152 : // the connection was established on purpose by the proxy
153 : // without being relayed. The connection endpoints are the sender and the
154 : // receiver. Such connections exist when the proxy sends health-checks to the
155 : // server. The receiver must accept this connection as valid and must use the
156 : // real connection endpoints and discard the protocol block including the
157 : // family which is ignored.
158 1 : LOCAL_V2 => return Ok(ConnectHeader::Local),
159 : // the connection was established on behalf of another node,
160 : // and reflects the original connection endpoints. The receiver must then use
161 : // the information provided in the protocol block to get original the address.
162 3 : PROXY_V2 => {}
163 : // other values are unassigned and must not be emitted by senders. Receivers
164 : // must drop connections presenting unexpected values here.
165 : #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
166 0 : _ => return Err(io::Error::other(
167 0 : format!(
168 0 : "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
169 0 : header.version_and_command
170 0 : ),
171 0 : )),
172 : }
173 :
174 3 : let size_err =
175 3 : "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
176 3 : let addr = match header.protocol_and_family {
177 : TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
178 2 : let addr = payload
179 2 : .try_get::<ProxyProtocolV2HeaderV4>()
180 2 : .ok_or_else(|| io::Error::other(size_err))?;
181 :
182 2 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
183 : }
184 : TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
185 1 : let addr = payload
186 1 : .try_get::<ProxyProtocolV2HeaderV6>()
187 1 : .ok_or_else(|| io::Error::other(size_err))?;
188 :
189 1 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
190 : }
191 : // unspecified or unix stream. ignore the addresses
192 : _ => {
193 0 : return Err(io::Error::other(
194 0 : "invalid proxy protocol address family/transport protocol.",
195 0 : ));
196 : }
197 : };
198 :
199 3 : let mut extra = None;
200 :
201 4 : while let Some(mut tlv) = read_tlv(&mut payload) {
202 1 : match Pp2Kind::from_repr(tlv.kind) {
203 : Some(Pp2Kind::Aws) => {
204 0 : if tlv.value.is_empty() {
205 0 : tracing::warn!("invalid aws tlv: no subtype");
206 0 : }
207 0 : let subtype = tlv.value.get_u8();
208 0 : match Pp2AwsType::from_repr(subtype) {
209 0 : Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
210 0 : Ok(s) => {
211 0 : extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
212 0 : }
213 0 : Err(e) => {
214 0 : tracing::warn!("invalid aws vpce id: {e}");
215 : }
216 : },
217 : None => {
218 0 : tracing::warn!("unknown aws tlv: subtype={subtype}");
219 : }
220 : }
221 : }
222 : Some(Pp2Kind::Azure) => {
223 0 : if tlv.value.is_empty() {
224 0 : tracing::warn!("invalid azure tlv: no subtype");
225 0 : }
226 0 : let subtype = tlv.value.get_u8();
227 0 : match Pp2AzureType::from_repr(subtype) {
228 : Some(Pp2AzureType::PrivateEndpointLinkId) => {
229 0 : if tlv.value.len() != 4 {
230 0 : tracing::warn!("invalid azure link_id: {:?}", tlv.value);
231 0 : }
232 0 : extra = Some(ConnectionInfoExtra::Azure {
233 0 : link_id: tlv.value.get_u32_le(),
234 0 : });
235 : }
236 : None => {
237 0 : tracing::warn!("unknown azure tlv: subtype={subtype}");
238 : }
239 : }
240 : }
241 0 : Some(kind) => {
242 0 : tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
243 : }
244 : None => {
245 1 : tracing::debug!("unknown tlv: {tlv:?}");
246 : }
247 : }
248 : }
249 :
250 3 : Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
251 4 : }
252 :
253 : #[derive(FromRepr, Debug, Copy, Clone)]
254 : #[repr(u8)]
255 : enum Pp2Kind {
256 : // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
257 : // we don't use these but it would be interesting to know what's available
258 : Alpn = 0x01,
259 : Authority = 0x02,
260 : Crc32C = 0x03,
261 : Noop = 0x04,
262 : UniqueId = 0x05,
263 : Ssl = 0x20,
264 : NetNs = 0x30,
265 :
266 : /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
267 : Aws = 0xEA,
268 :
269 : /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
270 : Azure = 0xEE,
271 : }
272 :
273 : #[derive(FromRepr, Debug, Copy, Clone)]
274 : #[repr(u8)]
275 : enum Pp2AwsType {
276 : VpceId = 0x01,
277 : }
278 :
279 : #[derive(FromRepr, Debug, Copy, Clone)]
280 : #[repr(u8)]
281 : enum Pp2AzureType {
282 : PrivateEndpointLinkId = 0x01,
283 : }
284 :
285 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
286 : #[inline]
287 172 : fn poll_read(
288 172 : self: Pin<&mut Self>,
289 172 : cx: &mut Context<'_>,
290 172 : buf: &mut ReadBuf<'_>,
291 172 : ) -> Poll<io::Result<()>> {
292 172 : if self.buf.is_empty() {
293 153 : self.project().inner.poll_read(cx, buf)
294 : } else {
295 19 : self.read_from_buf(buf)
296 : }
297 172 : }
298 : }
299 :
300 : impl<T: AsyncRead> ChainRW<T> {
301 : #[cold]
302 19 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
303 19 : debug_assert!(!self.buf.is_empty());
304 19 : let this = self.project();
305 19 :
306 19 : let write = usize::min(this.buf.len(), buf.remaining());
307 19 : let slice = this.buf.split_to(write).freeze();
308 19 : buf.put_slice(&slice);
309 19 :
310 19 : // reset the allocation so it can be freed
311 19 : if this.buf.is_empty() {
312 17 : *this.buf = BytesMut::new();
313 17 : }
314 :
315 19 : Poll::Ready(Ok(()))
316 19 : }
317 : }
318 :
319 : #[derive(Debug)]
320 : struct Tlv {
321 : kind: u8,
322 : value: Bytes,
323 : }
324 :
325 4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
326 4 : let tlv_header = b.try_get::<TlvHeader>()?;
327 3 : let len = usize::from(tlv_header.len.get());
328 3 : if b.len() < len {
329 2 : return None;
330 1 : }
331 1 : Some(Tlv {
332 1 : kind: tlv_header.kind,
333 1 : value: b.split_to(len).freeze(),
334 1 : })
335 4 : }
336 :
337 : trait BufExt: Sized {
338 : fn try_get<T: FromBytes>(&mut self) -> Option<T>;
339 : }
340 : impl BufExt for BytesMut {
341 19 : fn try_get<T: FromBytes>(&mut self) -> Option<T> {
342 19 : let res = T::read_from_prefix(self)?;
343 10 : self.advance(size_of::<T>());
344 10 : Some(res)
345 19 : }
346 : }
347 :
348 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
349 : #[repr(C)]
350 : struct ProxyProtocolV2Header {
351 : signature: [u8; 12],
352 : version_and_command: u8,
353 : protocol_and_family: u8,
354 : len: zerocopy::byteorder::network_endian::U16,
355 : }
356 :
357 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
358 : #[repr(C)]
359 : struct ProxyProtocolV2HeaderV4 {
360 : src_addr: NetworkEndianIpv4,
361 : dst_addr: NetworkEndianIpv4,
362 : src_port: zerocopy::byteorder::network_endian::U16,
363 : dst_port: zerocopy::byteorder::network_endian::U16,
364 : }
365 :
366 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
367 : #[repr(C)]
368 : struct ProxyProtocolV2HeaderV6 {
369 : src_addr: NetworkEndianIpv6,
370 : dst_addr: NetworkEndianIpv6,
371 : src_port: zerocopy::byteorder::network_endian::U16,
372 : dst_port: zerocopy::byteorder::network_endian::U16,
373 : }
374 :
375 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
376 : #[repr(C)]
377 : struct TlvHeader {
378 : kind: u8,
379 : len: zerocopy::byteorder::network_endian::U16,
380 : }
381 :
382 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
383 : #[repr(transparent)]
384 : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
385 : impl NetworkEndianIpv4 {
386 : #[inline]
387 2 : fn get(self) -> Ipv4Addr {
388 2 : Ipv4Addr::from_bits(self.0.get())
389 2 : }
390 : }
391 :
392 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
393 : #[repr(transparent)]
394 : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
395 : impl NetworkEndianIpv6 {
396 : #[inline]
397 1 : fn get(self) -> Ipv6Addr {
398 1 : Ipv6Addr::from_bits(self.0.get())
399 1 : }
400 : }
401 :
402 : #[cfg(test)]
403 : #[expect(clippy::unwrap_used)]
404 : mod tests {
405 : use tokio::io::AsyncReadExt;
406 :
407 : use crate::protocol2::{
408 : ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
409 : };
410 :
411 : #[tokio::test]
412 1 : async fn test_ipv4() {
413 1 : let header = super::SIGNATURE
414 1 : // Proxy command, IPV4 | TCP
415 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
416 1 : // 12 + 3 bytes
417 1 : .chain([0, 15].as_slice())
418 1 : // src ip
419 1 : .chain([127, 0, 0, 1].as_slice())
420 1 : // dst ip
421 1 : .chain([192, 168, 0, 1].as_slice())
422 1 : // src port
423 1 : .chain([255, 255].as_slice())
424 1 : // dst port
425 1 : .chain([1, 1].as_slice())
426 1 : // TLV
427 1 : .chain([1, 2, 3].as_slice());
428 1 :
429 1 : let extra_data = [0x55; 256];
430 1 :
431 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
432 1 : .await
433 1 : .unwrap();
434 1 :
435 1 : let mut bytes = vec![];
436 1 : read.read_to_end(&mut bytes).await.unwrap();
437 1 :
438 1 : assert_eq!(bytes, extra_data);
439 1 :
440 1 : let ConnectHeader::Proxy(info) = info else {
441 1 : panic!()
442 1 : };
443 1 : assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
444 1 : }
445 :
446 : #[tokio::test]
447 1 : async fn test_ipv6() {
448 1 : let header = super::SIGNATURE
449 1 : // Proxy command, IPV6 | UDP
450 1 : .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
451 1 : // 36 + 3 bytes
452 1 : .chain([0, 39].as_slice())
453 1 : // src ip
454 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
455 1 : // dst ip
456 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
457 1 : // src port
458 1 : .chain([1, 1].as_slice())
459 1 : // dst port
460 1 : .chain([255, 255].as_slice())
461 1 : // TLV
462 1 : .chain([1, 2, 3].as_slice());
463 1 :
464 1 : let extra_data = [0x55; 256];
465 1 :
466 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
467 1 : .await
468 1 : .unwrap();
469 1 :
470 1 : let mut bytes = vec![];
471 1 : read.read_to_end(&mut bytes).await.unwrap();
472 1 :
473 1 : assert_eq!(bytes, extra_data);
474 1 :
475 1 : let ConnectHeader::Proxy(info) = info else {
476 1 : panic!()
477 1 : };
478 1 : assert_eq!(
479 1 : info.addr,
480 1 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
481 1 : );
482 1 : }
483 :
484 : #[tokio::test]
485 1 : async fn test_invalid() {
486 1 : let data = [0x55; 256];
487 1 :
488 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
489 1 :
490 1 : let mut bytes = vec![];
491 1 : read.read_to_end(&mut bytes).await.unwrap();
492 1 : assert_eq!(bytes, data);
493 1 : assert_eq!(info, ConnectHeader::Missing);
494 1 : }
495 :
496 : #[tokio::test]
497 1 : async fn test_short() {
498 1 : let data = [0x55; 10];
499 1 :
500 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
501 1 :
502 1 : let mut bytes = vec![];
503 1 : read.read_to_end(&mut bytes).await.unwrap();
504 1 : assert_eq!(bytes, data);
505 1 : assert_eq!(info, ConnectHeader::Missing);
506 1 : }
507 :
508 : #[tokio::test]
509 1 : async fn test_large_tlv() {
510 1 : let tlv = vec![0x55; 32768];
511 1 : let tlv_len = (tlv.len() as u16).to_be_bytes();
512 1 : let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
513 1 :
514 1 : let header = super::SIGNATURE
515 1 : // Proxy command, Inet << 4 | Stream
516 1 : .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
517 1 : // 12 + 3 bytes
518 1 : .chain(len.as_slice())
519 1 : // src ip
520 1 : .chain([55, 56, 57, 58].as_slice())
521 1 : // dst ip
522 1 : .chain([192, 168, 0, 1].as_slice())
523 1 : // src port
524 1 : .chain([255, 255].as_slice())
525 1 : // dst port
526 1 : .chain([1, 1].as_slice())
527 1 : // TLV
528 1 : .chain([255].as_slice())
529 1 : .chain(tlv_len.as_slice())
530 1 : .chain(tlv.as_slice());
531 1 :
532 1 : let extra_data = [0xaa; 256];
533 1 :
534 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
535 1 : .await
536 1 : .unwrap();
537 1 :
538 1 : let mut bytes = vec![];
539 1 : read.read_to_end(&mut bytes).await.unwrap();
540 1 :
541 1 : assert_eq!(bytes, extra_data);
542 1 :
543 1 : let ConnectHeader::Proxy(info) = info else {
544 1 : panic!()
545 1 : };
546 1 : assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
547 1 : }
548 :
549 : #[tokio::test]
550 1 : async fn test_local() {
551 1 : let len = 0u16.to_be_bytes();
552 1 : let header = super::SIGNATURE
553 1 : .chain([LOCAL_V2, 0x00].as_slice())
554 1 : .chain(len.as_slice());
555 1 :
556 1 : let extra_data = [0xaa; 256];
557 1 :
558 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
559 1 : .await
560 1 : .unwrap();
561 1 :
562 1 : let mut bytes = vec![];
563 1 : read.read_to_end(&mut bytes).await.unwrap();
564 1 :
565 1 : assert_eq!(bytes, extra_data);
566 1 :
567 1 : let ConnectHeader::Local = info else { panic!() };
568 1 : }
569 : }
|