Line data Source code
1 : //! Proxy Protocol V2 implementation
2 : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
3 :
4 : use core::fmt;
5 : use std::io;
6 : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
7 : use std::pin::Pin;
8 : use std::task::{Context, Poll};
9 :
10 : use bytes::{Buf, Bytes, BytesMut};
11 : use pin_project_lite::pin_project;
12 : use smol_str::SmolStr;
13 : use strum_macros::FromRepr;
14 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15 : use zerocopy::{FromBytes, FromZeroes};
16 :
17 : pin_project! {
18 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
19 : pub(crate) struct ChainRW<T> {
20 : #[pin]
21 : pub(crate) inner: T,
22 : buf: BytesMut,
23 : }
24 : }
25 :
26 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
27 : #[inline]
28 15 : fn poll_write(
29 15 : self: Pin<&mut Self>,
30 15 : cx: &mut Context<'_>,
31 15 : buf: &[u8],
32 15 : ) -> Poll<Result<usize, io::Error>> {
33 15 : self.project().inner.poll_write(cx, buf)
34 15 : }
35 :
36 : #[inline]
37 69 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
38 69 : self.project().inner.poll_flush(cx)
39 69 : }
40 :
41 : #[inline]
42 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
43 0 : self.project().inner.poll_shutdown(cx)
44 0 : }
45 :
46 : #[inline]
47 54 : fn poll_write_vectored(
48 54 : self: Pin<&mut Self>,
49 54 : cx: &mut Context<'_>,
50 54 : bufs: &[io::IoSlice<'_>],
51 54 : ) -> Poll<Result<usize, io::Error>> {
52 54 : self.project().inner.poll_write_vectored(cx, bufs)
53 54 : }
54 :
55 : #[inline]
56 0 : fn is_write_vectored(&self) -> bool {
57 0 : self.inner.is_write_vectored()
58 0 : }
59 : }
60 :
61 : /// Proxy Protocol Version 2 Header
62 : const SIGNATURE: [u8; 12] = [
63 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
64 : ];
65 :
66 : const LOCAL_V2: u8 = 0x20;
67 : const PROXY_V2: u8 = 0x21;
68 :
69 : const TCP_OVER_IPV4: u8 = 0x11;
70 : const UDP_OVER_IPV4: u8 = 0x12;
71 : const TCP_OVER_IPV6: u8 = 0x21;
72 : const UDP_OVER_IPV6: u8 = 0x22;
73 :
74 : #[derive(PartialEq, Eq, Clone, Debug)]
75 : pub struct ConnectionInfo {
76 : pub addr: SocketAddr,
77 : pub extra: Option<ConnectionInfoExtra>,
78 : }
79 :
80 : #[derive(PartialEq, Eq, Clone, Debug)]
81 : pub enum ConnectHeader {
82 : Missing,
83 : Local,
84 : Proxy(ConnectionInfo),
85 : }
86 :
87 : impl fmt::Display for ConnectionInfo {
88 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 0 : match &self.extra {
90 0 : None => self.addr.ip().fmt(f),
91 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => {
92 0 : write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
93 : }
94 0 : Some(ConnectionInfoExtra::Azure { link_id }) => {
95 0 : write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
96 : }
97 : }
98 0 : }
99 : }
100 :
101 : #[derive(PartialEq, Eq, Clone, Debug)]
102 : pub enum ConnectionInfoExtra {
103 : Aws { vpce_id: SmolStr },
104 : Azure { link_id: u32 },
105 : }
106 :
107 21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
108 21 : mut read: T,
109 21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
110 21 : let mut buf = BytesMut::with_capacity(128);
111 4 : let header = loop {
112 29 : let bytes_read = read.read_buf(&mut buf).await?;
113 :
114 : // exit for bad header signature
115 29 : let len = usize::min(buf.len(), SIGNATURE.len());
116 29 : if buf[..len] != SIGNATURE[..len] {
117 17 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
118 12 : }
119 12 :
120 12 : // if no more bytes available then exit
121 12 : if bytes_read == 0 {
122 0 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
123 12 : }
124 :
125 : // check if we have enough bytes to continue
126 12 : if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
127 4 : break header;
128 8 : }
129 : };
130 :
131 4 : let remaining_length = usize::from(header.len.get());
132 :
133 30 : while buf.len() < remaining_length {
134 26 : if read.read_buf(&mut buf).await? == 0 {
135 0 : return Err(io::Error::new(
136 0 : io::ErrorKind::UnexpectedEof,
137 0 : "stream closed while waiting for proxy protocol addresses",
138 0 : ));
139 26 : }
140 : }
141 4 : let payload = buf.split_to(remaining_length);
142 :
143 4 : let res = process_proxy_payload(header, payload)?;
144 4 : Ok((ChainRW { inner: read, buf }, res))
145 21 : }
146 :
147 4 : fn process_proxy_payload(
148 4 : header: ProxyProtocolV2Header,
149 4 : mut payload: BytesMut,
150 4 : ) -> std::io::Result<ConnectHeader> {
151 4 : match header.version_and_command {
152 : // the connection was established on purpose by the proxy
153 : // without being relayed. The connection endpoints are the sender and the
154 : // receiver. Such connections exist when the proxy sends health-checks to the
155 : // server. The receiver must accept this connection as valid and must use the
156 : // real connection endpoints and discard the protocol block including the
157 : // family which is ignored.
158 1 : LOCAL_V2 => return Ok(ConnectHeader::Local),
159 : // the connection was established on behalf of another node,
160 : // and reflects the original connection endpoints. The receiver must then use
161 : // the information provided in the protocol block to get original the address.
162 3 : PROXY_V2 => {}
163 : // other values are unassigned and must not be emitted by senders. Receivers
164 : // must drop connections presenting unexpected values here.
165 : #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
166 0 : _ => return Err(io::Error::new(
167 0 : io::ErrorKind::Other,
168 0 : format!(
169 0 : "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
170 0 : header.version_and_command
171 0 : ),
172 0 : )),
173 : }
174 :
175 3 : let size_err =
176 3 : "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
177 3 : let addr = match header.protocol_and_family {
178 : TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
179 2 : let addr = payload
180 2 : .try_get::<ProxyProtocolV2HeaderV4>()
181 2 : .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
182 :
183 2 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
184 : }
185 : TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
186 1 : let addr = payload
187 1 : .try_get::<ProxyProtocolV2HeaderV6>()
188 1 : .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
189 :
190 1 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
191 : }
192 : // unspecified or unix stream. ignore the addresses
193 : _ => {
194 0 : return Err(io::Error::new(
195 0 : io::ErrorKind::Other,
196 0 : "invalid proxy protocol address family/transport protocol.",
197 0 : ));
198 : }
199 : };
200 :
201 3 : let mut extra = None;
202 :
203 4 : while let Some(mut tlv) = read_tlv(&mut payload) {
204 1 : match Pp2Kind::from_repr(tlv.kind) {
205 : Some(Pp2Kind::Aws) => {
206 0 : if tlv.value.is_empty() {
207 0 : tracing::warn!("invalid aws tlv: no subtype");
208 0 : }
209 0 : let subtype = tlv.value.get_u8();
210 0 : match Pp2AwsType::from_repr(subtype) {
211 0 : Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
212 0 : Ok(s) => {
213 0 : extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
214 0 : }
215 0 : Err(e) => {
216 0 : tracing::warn!("invalid aws vpce id: {e}");
217 : }
218 : },
219 : None => {
220 0 : tracing::warn!("unknown aws tlv: subtype={subtype}");
221 : }
222 : }
223 : }
224 : Some(Pp2Kind::Azure) => {
225 0 : if tlv.value.is_empty() {
226 0 : tracing::warn!("invalid azure tlv: no subtype");
227 0 : }
228 0 : let subtype = tlv.value.get_u8();
229 0 : match Pp2AzureType::from_repr(subtype) {
230 : Some(Pp2AzureType::PrivateEndpointLinkId) => {
231 0 : if tlv.value.len() != 4 {
232 0 : tracing::warn!("invalid azure link_id: {:?}", tlv.value);
233 0 : }
234 0 : extra = Some(ConnectionInfoExtra::Azure {
235 0 : link_id: tlv.value.get_u32_le(),
236 0 : });
237 : }
238 : None => {
239 0 : tracing::warn!("unknown azure tlv: subtype={subtype}");
240 : }
241 : }
242 : }
243 0 : Some(kind) => {
244 0 : tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
245 : }
246 : None => {
247 1 : tracing::debug!("unknown tlv: {tlv:?}");
248 : }
249 : }
250 : }
251 :
252 3 : Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
253 4 : }
254 :
255 : #[derive(FromRepr, Debug, Copy, Clone)]
256 : #[repr(u8)]
257 : enum Pp2Kind {
258 : // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
259 : // we don't use these but it would be interesting to know what's available
260 : Alpn = 0x01,
261 : Authority = 0x02,
262 : Crc32C = 0x03,
263 : Noop = 0x04,
264 : UniqueId = 0x05,
265 : Ssl = 0x20,
266 : NetNs = 0x30,
267 :
268 : /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
269 : Aws = 0xEA,
270 :
271 : /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
272 : Azure = 0xEE,
273 : }
274 :
275 : #[derive(FromRepr, Debug, Copy, Clone)]
276 : #[repr(u8)]
277 : enum Pp2AwsType {
278 : VpceId = 0x01,
279 : }
280 :
281 : #[derive(FromRepr, Debug, Copy, Clone)]
282 : #[repr(u8)]
283 : enum Pp2AzureType {
284 : PrivateEndpointLinkId = 0x01,
285 : }
286 :
287 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
288 : #[inline]
289 172 : fn poll_read(
290 172 : self: Pin<&mut Self>,
291 172 : cx: &mut Context<'_>,
292 172 : buf: &mut ReadBuf<'_>,
293 172 : ) -> Poll<io::Result<()>> {
294 172 : if self.buf.is_empty() {
295 153 : self.project().inner.poll_read(cx, buf)
296 : } else {
297 19 : self.read_from_buf(buf)
298 : }
299 172 : }
300 : }
301 :
302 : impl<T: AsyncRead> ChainRW<T> {
303 : #[cold]
304 19 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
305 19 : debug_assert!(!self.buf.is_empty());
306 19 : let this = self.project();
307 19 :
308 19 : let write = usize::min(this.buf.len(), buf.remaining());
309 19 : let slice = this.buf.split_to(write).freeze();
310 19 : buf.put_slice(&slice);
311 19 :
312 19 : // reset the allocation so it can be freed
313 19 : if this.buf.is_empty() {
314 17 : *this.buf = BytesMut::new();
315 17 : }
316 :
317 19 : Poll::Ready(Ok(()))
318 19 : }
319 : }
320 :
321 : #[derive(Debug)]
322 : struct Tlv {
323 : kind: u8,
324 : value: Bytes,
325 : }
326 :
327 4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
328 4 : let tlv_header = b.try_get::<TlvHeader>()?;
329 3 : let len = usize::from(tlv_header.len.get());
330 3 : if b.len() < len {
331 2 : return None;
332 1 : }
333 1 : Some(Tlv {
334 1 : kind: tlv_header.kind,
335 1 : value: b.split_to(len).freeze(),
336 1 : })
337 4 : }
338 :
339 : trait BufExt: Sized {
340 : fn try_get<T: FromBytes>(&mut self) -> Option<T>;
341 : }
342 : impl BufExt for BytesMut {
343 19 : fn try_get<T: FromBytes>(&mut self) -> Option<T> {
344 19 : let res = T::read_from_prefix(self)?;
345 10 : self.advance(size_of::<T>());
346 10 : Some(res)
347 19 : }
348 : }
349 :
350 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
351 : #[repr(C)]
352 : struct ProxyProtocolV2Header {
353 : signature: [u8; 12],
354 : version_and_command: u8,
355 : protocol_and_family: u8,
356 : len: zerocopy::byteorder::network_endian::U16,
357 : }
358 :
359 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
360 : #[repr(C)]
361 : struct ProxyProtocolV2HeaderV4 {
362 : src_addr: NetworkEndianIpv4,
363 : dst_addr: NetworkEndianIpv4,
364 : src_port: zerocopy::byteorder::network_endian::U16,
365 : dst_port: zerocopy::byteorder::network_endian::U16,
366 : }
367 :
368 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
369 : #[repr(C)]
370 : struct ProxyProtocolV2HeaderV6 {
371 : src_addr: NetworkEndianIpv6,
372 : dst_addr: NetworkEndianIpv6,
373 : src_port: zerocopy::byteorder::network_endian::U16,
374 : dst_port: zerocopy::byteorder::network_endian::U16,
375 : }
376 :
377 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
378 : #[repr(C)]
379 : struct TlvHeader {
380 : kind: u8,
381 : len: zerocopy::byteorder::network_endian::U16,
382 : }
383 :
384 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
385 : #[repr(transparent)]
386 : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
387 : impl NetworkEndianIpv4 {
388 : #[inline]
389 2 : fn get(self) -> Ipv4Addr {
390 2 : Ipv4Addr::from_bits(self.0.get())
391 2 : }
392 : }
393 :
394 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
395 : #[repr(transparent)]
396 : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
397 : impl NetworkEndianIpv6 {
398 : #[inline]
399 1 : fn get(self) -> Ipv6Addr {
400 1 : Ipv6Addr::from_bits(self.0.get())
401 1 : }
402 : }
403 :
404 : #[cfg(test)]
405 : #[expect(clippy::unwrap_used)]
406 : mod tests {
407 : use tokio::io::AsyncReadExt;
408 :
409 : use crate::protocol2::{
410 : ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6, read_proxy_protocol,
411 : };
412 :
413 : #[tokio::test]
414 1 : async fn test_ipv4() {
415 1 : let header = super::SIGNATURE
416 1 : // Proxy command, IPV4 | TCP
417 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
418 1 : // 12 + 3 bytes
419 1 : .chain([0, 15].as_slice())
420 1 : // src ip
421 1 : .chain([127, 0, 0, 1].as_slice())
422 1 : // dst ip
423 1 : .chain([192, 168, 0, 1].as_slice())
424 1 : // src port
425 1 : .chain([255, 255].as_slice())
426 1 : // dst port
427 1 : .chain([1, 1].as_slice())
428 1 : // TLV
429 1 : .chain([1, 2, 3].as_slice());
430 1 :
431 1 : let extra_data = [0x55; 256];
432 1 :
433 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
434 1 : .await
435 1 : .unwrap();
436 1 :
437 1 : let mut bytes = vec![];
438 1 : read.read_to_end(&mut bytes).await.unwrap();
439 1 :
440 1 : assert_eq!(bytes, extra_data);
441 1 :
442 1 : let ConnectHeader::Proxy(info) = info else {
443 1 : panic!()
444 1 : };
445 1 : assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
446 1 : }
447 :
448 : #[tokio::test]
449 1 : async fn test_ipv6() {
450 1 : let header = super::SIGNATURE
451 1 : // Proxy command, IPV6 | UDP
452 1 : .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
453 1 : // 36 + 3 bytes
454 1 : .chain([0, 39].as_slice())
455 1 : // src ip
456 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
457 1 : // dst ip
458 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
459 1 : // src port
460 1 : .chain([1, 1].as_slice())
461 1 : // dst port
462 1 : .chain([255, 255].as_slice())
463 1 : // TLV
464 1 : .chain([1, 2, 3].as_slice());
465 1 :
466 1 : let extra_data = [0x55; 256];
467 1 :
468 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
469 1 : .await
470 1 : .unwrap();
471 1 :
472 1 : let mut bytes = vec![];
473 1 : read.read_to_end(&mut bytes).await.unwrap();
474 1 :
475 1 : assert_eq!(bytes, extra_data);
476 1 :
477 1 : let ConnectHeader::Proxy(info) = info else {
478 1 : panic!()
479 1 : };
480 1 : assert_eq!(
481 1 : info.addr,
482 1 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
483 1 : );
484 1 : }
485 :
486 : #[tokio::test]
487 1 : async fn test_invalid() {
488 1 : let data = [0x55; 256];
489 1 :
490 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
491 1 :
492 1 : let mut bytes = vec![];
493 1 : read.read_to_end(&mut bytes).await.unwrap();
494 1 : assert_eq!(bytes, data);
495 1 : assert_eq!(info, ConnectHeader::Missing);
496 1 : }
497 :
498 : #[tokio::test]
499 1 : async fn test_short() {
500 1 : let data = [0x55; 10];
501 1 :
502 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
503 1 :
504 1 : let mut bytes = vec![];
505 1 : read.read_to_end(&mut bytes).await.unwrap();
506 1 : assert_eq!(bytes, data);
507 1 : assert_eq!(info, ConnectHeader::Missing);
508 1 : }
509 :
510 : #[tokio::test]
511 1 : async fn test_large_tlv() {
512 1 : let tlv = vec![0x55; 32768];
513 1 : let tlv_len = (tlv.len() as u16).to_be_bytes();
514 1 : let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
515 1 :
516 1 : let header = super::SIGNATURE
517 1 : // Proxy command, Inet << 4 | Stream
518 1 : .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
519 1 : // 12 + 3 bytes
520 1 : .chain(len.as_slice())
521 1 : // src ip
522 1 : .chain([55, 56, 57, 58].as_slice())
523 1 : // dst ip
524 1 : .chain([192, 168, 0, 1].as_slice())
525 1 : // src port
526 1 : .chain([255, 255].as_slice())
527 1 : // dst port
528 1 : .chain([1, 1].as_slice())
529 1 : // TLV
530 1 : .chain([255].as_slice())
531 1 : .chain(tlv_len.as_slice())
532 1 : .chain(tlv.as_slice());
533 1 :
534 1 : let extra_data = [0xaa; 256];
535 1 :
536 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
537 1 : .await
538 1 : .unwrap();
539 1 :
540 1 : let mut bytes = vec![];
541 1 : read.read_to_end(&mut bytes).await.unwrap();
542 1 :
543 1 : assert_eq!(bytes, extra_data);
544 1 :
545 1 : let ConnectHeader::Proxy(info) = info else {
546 1 : panic!()
547 1 : };
548 1 : assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
549 1 : }
550 :
551 : #[tokio::test]
552 1 : async fn test_local() {
553 1 : let len = 0u16.to_be_bytes();
554 1 : let header = super::SIGNATURE
555 1 : .chain([LOCAL_V2, 0x00].as_slice())
556 1 : .chain(len.as_slice());
557 1 :
558 1 : let extra_data = [0xaa; 256];
559 1 :
560 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
561 1 : .await
562 1 : .unwrap();
563 1 :
564 1 : let mut bytes = vec![];
565 1 : read.read_to_end(&mut bytes).await.unwrap();
566 1 :
567 1 : assert_eq!(bytes, extra_data);
568 1 :
569 1 : let ConnectHeader::Local = info else { panic!() };
570 1 : }
571 : }
|