Line data Source code
1 : //! Proxy Protocol V2 implementation
2 : //! Compatible with <https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt>
3 :
4 : use core::fmt;
5 : use std::io;
6 : use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
7 : use std::pin::Pin;
8 : use std::task::{Context, Poll};
9 :
10 : use bytes::{Buf, Bytes, BytesMut};
11 : use pin_project_lite::pin_project;
12 : use strum_macros::FromRepr;
13 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
14 : use zerocopy::{FromBytes, FromZeroes};
15 :
16 : pin_project! {
17 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
18 : pub(crate) struct ChainRW<T> {
19 : #[pin]
20 : pub(crate) inner: T,
21 : buf: BytesMut,
22 : }
23 : }
24 :
25 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
26 : #[inline]
27 15 : fn poll_write(
28 15 : self: Pin<&mut Self>,
29 15 : cx: &mut Context<'_>,
30 15 : buf: &[u8],
31 15 : ) -> Poll<Result<usize, io::Error>> {
32 15 : self.project().inner.poll_write(cx, buf)
33 15 : }
34 :
35 : #[inline]
36 74 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
37 74 : self.project().inner.poll_flush(cx)
38 74 : }
39 :
40 : #[inline]
41 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
42 0 : self.project().inner.poll_shutdown(cx)
43 0 : }
44 :
45 : #[inline]
46 59 : fn poll_write_vectored(
47 59 : self: Pin<&mut Self>,
48 59 : cx: &mut Context<'_>,
49 59 : bufs: &[io::IoSlice<'_>],
50 59 : ) -> Poll<Result<usize, io::Error>> {
51 59 : self.project().inner.poll_write_vectored(cx, bufs)
52 59 : }
53 :
54 : #[inline]
55 0 : fn is_write_vectored(&self) -> bool {
56 0 : self.inner.is_write_vectored()
57 0 : }
58 : }
59 :
60 : /// Proxy Protocol Version 2 Header
61 : const SIGNATURE: [u8; 12] = [
62 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
63 : ];
64 :
65 : const LOCAL_V2: u8 = 0x20;
66 : const PROXY_V2: u8 = 0x21;
67 :
68 : const TCP_OVER_IPV4: u8 = 0x11;
69 : const UDP_OVER_IPV4: u8 = 0x12;
70 : const TCP_OVER_IPV6: u8 = 0x21;
71 : const UDP_OVER_IPV6: u8 = 0x22;
72 :
73 : #[derive(PartialEq, Eq, Clone, Debug)]
74 : pub struct ConnectionInfo {
75 : pub addr: SocketAddr,
76 : pub extra: Option<ConnectionInfoExtra>,
77 : }
78 :
79 : #[derive(PartialEq, Eq, Clone, Debug)]
80 : pub enum ConnectHeader {
81 : Missing,
82 : Local,
83 : Proxy(ConnectionInfo),
84 : }
85 :
86 : impl fmt::Display for ConnectionInfo {
87 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 0 : match &self.extra {
89 0 : None => self.addr.ip().fmt(f),
90 0 : Some(ConnectionInfoExtra::Aws { vpce_id }) => {
91 0 : write!(f, "vpce_id[{vpce_id:?}]:addr[{}]", self.addr.ip())
92 : }
93 0 : Some(ConnectionInfoExtra::Azure { link_id }) => {
94 0 : write!(f, "link_id[{link_id}]:addr[{}]", self.addr.ip())
95 : }
96 : }
97 0 : }
98 : }
99 :
100 : #[derive(PartialEq, Eq, Clone, Debug)]
101 : pub enum ConnectionInfoExtra {
102 : Aws { vpce_id: Bytes },
103 : Azure { link_id: u32 },
104 : }
105 :
106 21 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
107 21 : mut read: T,
108 21 : ) -> std::io::Result<(ChainRW<T>, ConnectHeader)> {
109 21 : let mut buf = BytesMut::with_capacity(128);
110 4 : let header = loop {
111 29 : let bytes_read = read.read_buf(&mut buf).await?;
112 :
113 : // exit for bad header signature
114 29 : let len = usize::min(buf.len(), SIGNATURE.len());
115 29 : if buf[..len] != SIGNATURE[..len] {
116 17 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
117 12 : }
118 12 :
119 12 : // if no more bytes available then exit
120 12 : if bytes_read == 0 {
121 0 : return Ok((ChainRW { inner: read, buf }, ConnectHeader::Missing));
122 12 : };
123 :
124 : // check if we have enough bytes to continue
125 12 : if let Some(header) = buf.try_get::<ProxyProtocolV2Header>() {
126 4 : break header;
127 8 : }
128 : };
129 :
130 4 : let remaining_length = usize::from(header.len.get());
131 :
132 30 : while buf.len() < remaining_length {
133 26 : if read.read_buf(&mut buf).await? == 0 {
134 0 : return Err(io::Error::new(
135 0 : io::ErrorKind::UnexpectedEof,
136 0 : "stream closed while waiting for proxy protocol addresses",
137 0 : ));
138 26 : }
139 : }
140 4 : let payload = buf.split_to(remaining_length);
141 :
142 4 : let res = process_proxy_payload(header, payload)?;
143 4 : Ok((ChainRW { inner: read, buf }, res))
144 21 : }
145 :
146 4 : fn process_proxy_payload(
147 4 : header: ProxyProtocolV2Header,
148 4 : mut payload: BytesMut,
149 4 : ) -> std::io::Result<ConnectHeader> {
150 4 : match header.version_and_command {
151 : // the connection was established on purpose by the proxy
152 : // without being relayed. The connection endpoints are the sender and the
153 : // receiver. Such connections exist when the proxy sends health-checks to the
154 : // server. The receiver must accept this connection as valid and must use the
155 : // real connection endpoints and discard the protocol block including the
156 : // family which is ignored.
157 1 : LOCAL_V2 => return Ok(ConnectHeader::Local),
158 : // the connection was established on behalf of another node,
159 : // and reflects the original connection endpoints. The receiver must then use
160 : // the information provided in the protocol block to get original the address.
161 3 : PROXY_V2 => {}
162 : // other values are unassigned and must not be emitted by senders. Receivers
163 : // must drop connections presenting unexpected values here.
164 : #[rustfmt::skip] // https://github.com/rust-lang/rustfmt/issues/6384
165 0 : _ => return Err(io::Error::new(
166 0 : io::ErrorKind::Other,
167 0 : format!(
168 0 : "invalid proxy protocol command 0x{:02X}. expected local (0x20) or proxy (0x21)",
169 0 : header.version_and_command
170 0 : ),
171 0 : )),
172 : };
173 :
174 3 : let size_err =
175 3 : "invalid proxy protocol length. payload not large enough to fit requested IP addresses";
176 3 : let addr = match header.protocol_and_family {
177 : TCP_OVER_IPV4 | UDP_OVER_IPV4 => {
178 2 : let addr = payload
179 2 : .try_get::<ProxyProtocolV2HeaderV4>()
180 2 : .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
181 :
182 2 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
183 : }
184 : TCP_OVER_IPV6 | UDP_OVER_IPV6 => {
185 1 : let addr = payload
186 1 : .try_get::<ProxyProtocolV2HeaderV6>()
187 1 : .ok_or_else(|| io::Error::new(io::ErrorKind::Other, size_err))?;
188 :
189 1 : SocketAddr::from((addr.src_addr.get(), addr.src_port.get()))
190 : }
191 : // unspecified or unix stream. ignore the addresses
192 : _ => {
193 0 : return Err(io::Error::new(
194 0 : io::ErrorKind::Other,
195 0 : "invalid proxy protocol address family/transport protocol.",
196 0 : ))
197 : }
198 : };
199 :
200 3 : let mut extra = None;
201 :
202 4 : while let Some(mut tlv) = read_tlv(&mut payload) {
203 1 : match Pp2Kind::from_repr(tlv.kind) {
204 : Some(Pp2Kind::Aws) => {
205 0 : if tlv.value.is_empty() {
206 0 : tracing::warn!("invalid aws tlv: no subtype");
207 0 : }
208 0 : let subtype = tlv.value.get_u8();
209 0 : match Pp2AwsType::from_repr(subtype) {
210 0 : Some(Pp2AwsType::VpceId) => {
211 0 : extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
212 0 : }
213 : None => {
214 0 : tracing::warn!("unknown aws tlv: subtype={subtype}");
215 : }
216 : }
217 : }
218 : Some(Pp2Kind::Azure) => {
219 0 : if tlv.value.is_empty() {
220 0 : tracing::warn!("invalid azure tlv: no subtype");
221 0 : }
222 0 : let subtype = tlv.value.get_u8();
223 0 : match Pp2AzureType::from_repr(subtype) {
224 : Some(Pp2AzureType::PrivateEndpointLinkId) => {
225 0 : if tlv.value.len() != 4 {
226 0 : tracing::warn!("invalid azure link_id: {:?}", tlv.value);
227 0 : }
228 0 : extra = Some(ConnectionInfoExtra::Azure {
229 0 : link_id: tlv.value.get_u32_le(),
230 0 : });
231 : }
232 : None => {
233 0 : tracing::warn!("unknown azure tlv: subtype={subtype}");
234 : }
235 : }
236 : }
237 0 : Some(kind) => {
238 0 : tracing::debug!("unused tlv[{kind:?}]: {:?}", tlv.value);
239 : }
240 : None => {
241 1 : tracing::debug!("unknown tlv: {tlv:?}");
242 : }
243 : }
244 : }
245 :
246 3 : Ok(ConnectHeader::Proxy(ConnectionInfo { addr, extra }))
247 4 : }
248 :
249 1 : #[derive(FromRepr, Debug, Copy, Clone)]
250 : #[repr(u8)]
251 : enum Pp2Kind {
252 : // The following are defined by https://www.haproxy.org/download/3.1/doc/proxy-protocol.txt
253 : // we don't use these but it would be interesting to know what's available
254 : Alpn = 0x01,
255 : Authority = 0x02,
256 : Crc32C = 0x03,
257 : Noop = 0x04,
258 : UniqueId = 0x05,
259 : Ssl = 0x20,
260 : NetNs = 0x30,
261 :
262 : /// <https://docs.aws.amazon.com/elasticloadbalancing/latest/network/edit-target-group-attributes.html#proxy-protocol>
263 : Aws = 0xEA,
264 :
265 : /// <https://learn.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2>
266 : Azure = 0xEE,
267 : }
268 :
269 0 : #[derive(FromRepr, Debug, Copy, Clone)]
270 : #[repr(u8)]
271 : enum Pp2AwsType {
272 : VpceId = 0x01,
273 : }
274 :
275 0 : #[derive(FromRepr, Debug, Copy, Clone)]
276 : #[repr(u8)]
277 : enum Pp2AzureType {
278 : PrivateEndpointLinkId = 0x01,
279 : }
280 :
281 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
282 : #[inline]
283 172 : fn poll_read(
284 172 : self: Pin<&mut Self>,
285 172 : cx: &mut Context<'_>,
286 172 : buf: &mut ReadBuf<'_>,
287 172 : ) -> Poll<io::Result<()>> {
288 172 : if self.buf.is_empty() {
289 153 : self.project().inner.poll_read(cx, buf)
290 : } else {
291 19 : self.read_from_buf(buf)
292 : }
293 172 : }
294 : }
295 :
296 : impl<T: AsyncRead> ChainRW<T> {
297 : #[cold]
298 19 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
299 19 : debug_assert!(!self.buf.is_empty());
300 19 : let this = self.project();
301 19 :
302 19 : let write = usize::min(this.buf.len(), buf.remaining());
303 19 : let slice = this.buf.split_to(write).freeze();
304 19 : buf.put_slice(&slice);
305 19 :
306 19 : // reset the allocation so it can be freed
307 19 : if this.buf.is_empty() {
308 17 : *this.buf = BytesMut::new();
309 17 : }
310 :
311 19 : Poll::Ready(Ok(()))
312 19 : }
313 : }
314 :
315 : #[derive(Debug)]
316 : struct Tlv {
317 : kind: u8,
318 : value: Bytes,
319 : }
320 :
321 4 : fn read_tlv(b: &mut BytesMut) -> Option<Tlv> {
322 4 : let tlv_header = b.try_get::<TlvHeader>()?;
323 3 : let len = usize::from(tlv_header.len.get());
324 3 : if b.len() < len {
325 2 : return None;
326 1 : }
327 1 : Some(Tlv {
328 1 : kind: tlv_header.kind,
329 1 : value: b.split_to(len).freeze(),
330 1 : })
331 4 : }
332 :
333 : trait BufExt: Sized {
334 : fn try_get<T: FromBytes>(&mut self) -> Option<T>;
335 : }
336 : impl BufExt for BytesMut {
337 19 : fn try_get<T: FromBytes>(&mut self) -> Option<T> {
338 19 : let res = T::read_from_prefix(self)?;
339 10 : self.advance(size_of::<T>());
340 10 : Some(res)
341 19 : }
342 : }
343 :
344 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
345 : #[repr(C)]
346 : struct ProxyProtocolV2Header {
347 : signature: [u8; 12],
348 : version_and_command: u8,
349 : protocol_and_family: u8,
350 : len: zerocopy::byteorder::network_endian::U16,
351 : }
352 :
353 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
354 : #[repr(C)]
355 : struct ProxyProtocolV2HeaderV4 {
356 : src_addr: NetworkEndianIpv4,
357 : dst_addr: NetworkEndianIpv4,
358 : src_port: zerocopy::byteorder::network_endian::U16,
359 : dst_port: zerocopy::byteorder::network_endian::U16,
360 : }
361 :
362 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
363 : #[repr(C)]
364 : struct ProxyProtocolV2HeaderV6 {
365 : src_addr: NetworkEndianIpv6,
366 : dst_addr: NetworkEndianIpv6,
367 : src_port: zerocopy::byteorder::network_endian::U16,
368 : dst_port: zerocopy::byteorder::network_endian::U16,
369 : }
370 :
371 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
372 : #[repr(C)]
373 : struct TlvHeader {
374 : kind: u8,
375 : len: zerocopy::byteorder::network_endian::U16,
376 : }
377 :
378 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
379 : #[repr(transparent)]
380 : struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
381 : impl NetworkEndianIpv4 {
382 : #[inline]
383 2 : fn get(self) -> Ipv4Addr {
384 2 : Ipv4Addr::from_bits(self.0.get())
385 2 : }
386 : }
387 :
388 0 : #[derive(FromBytes, FromZeroes, Copy, Clone)]
389 : #[repr(transparent)]
390 : struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
391 : impl NetworkEndianIpv6 {
392 : #[inline]
393 1 : fn get(self) -> Ipv6Addr {
394 1 : Ipv6Addr::from_bits(self.0.get())
395 1 : }
396 : }
397 :
398 : #[cfg(test)]
399 : mod tests {
400 : use tokio::io::AsyncReadExt;
401 :
402 : use crate::protocol2::{
403 : read_proxy_protocol, ConnectHeader, LOCAL_V2, PROXY_V2, TCP_OVER_IPV4, UDP_OVER_IPV6,
404 : };
405 :
406 : #[tokio::test]
407 1 : async fn test_ipv4() {
408 1 : let header = super::SIGNATURE
409 1 : // Proxy command, IPV4 | TCP
410 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
411 1 : // 12 + 3 bytes
412 1 : .chain([0, 15].as_slice())
413 1 : // src ip
414 1 : .chain([127, 0, 0, 1].as_slice())
415 1 : // dst ip
416 1 : .chain([192, 168, 0, 1].as_slice())
417 1 : // src port
418 1 : .chain([255, 255].as_slice())
419 1 : // dst port
420 1 : .chain([1, 1].as_slice())
421 1 : // TLV
422 1 : .chain([1, 2, 3].as_slice());
423 1 :
424 1 : let extra_data = [0x55; 256];
425 1 :
426 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
427 1 : .await
428 1 : .unwrap();
429 1 :
430 1 : let mut bytes = vec![];
431 1 : read.read_to_end(&mut bytes).await.unwrap();
432 1 :
433 1 : assert_eq!(bytes, extra_data);
434 1 :
435 1 : let ConnectHeader::Proxy(info) = info else {
436 1 : panic!()
437 1 : };
438 1 : assert_eq!(info.addr, ([127, 0, 0, 1], 65535).into());
439 1 : }
440 :
441 : #[tokio::test]
442 1 : async fn test_ipv6() {
443 1 : let header = super::SIGNATURE
444 1 : // Proxy command, IPV6 | UDP
445 1 : .chain([PROXY_V2, UDP_OVER_IPV6].as_slice())
446 1 : // 36 + 3 bytes
447 1 : .chain([0, 39].as_slice())
448 1 : // src ip
449 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
450 1 : // dst ip
451 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
452 1 : // src port
453 1 : .chain([1, 1].as_slice())
454 1 : // dst port
455 1 : .chain([255, 255].as_slice())
456 1 : // TLV
457 1 : .chain([1, 2, 3].as_slice());
458 1 :
459 1 : let extra_data = [0x55; 256];
460 1 :
461 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
462 1 : .await
463 1 : .unwrap();
464 1 :
465 1 : let mut bytes = vec![];
466 1 : read.read_to_end(&mut bytes).await.unwrap();
467 1 :
468 1 : assert_eq!(bytes, extra_data);
469 1 :
470 1 : let ConnectHeader::Proxy(info) = info else {
471 1 : panic!()
472 1 : };
473 1 : assert_eq!(
474 1 : info.addr,
475 1 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
476 1 : );
477 1 : }
478 :
479 : #[tokio::test]
480 1 : async fn test_invalid() {
481 1 : let data = [0x55; 256];
482 1 :
483 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
484 1 :
485 1 : let mut bytes = vec![];
486 1 : read.read_to_end(&mut bytes).await.unwrap();
487 1 : assert_eq!(bytes, data);
488 1 : assert_eq!(info, ConnectHeader::Missing);
489 1 : }
490 :
491 : #[tokio::test]
492 1 : async fn test_short() {
493 1 : let data = [0x55; 10];
494 1 :
495 1 : let (mut read, info) = read_proxy_protocol(data.as_slice()).await.unwrap();
496 1 :
497 1 : let mut bytes = vec![];
498 1 : read.read_to_end(&mut bytes).await.unwrap();
499 1 : assert_eq!(bytes, data);
500 1 : assert_eq!(info, ConnectHeader::Missing);
501 1 : }
502 :
503 : #[tokio::test]
504 1 : async fn test_large_tlv() {
505 1 : let tlv = vec![0x55; 32768];
506 1 : let tlv_len = (tlv.len() as u16).to_be_bytes();
507 1 : let len = (12 + 3 + tlv.len() as u16).to_be_bytes();
508 1 :
509 1 : let header = super::SIGNATURE
510 1 : // Proxy command, Inet << 4 | Stream
511 1 : .chain([PROXY_V2, TCP_OVER_IPV4].as_slice())
512 1 : // 12 + 3 bytes
513 1 : .chain(len.as_slice())
514 1 : // src ip
515 1 : .chain([55, 56, 57, 58].as_slice())
516 1 : // dst ip
517 1 : .chain([192, 168, 0, 1].as_slice())
518 1 : // src port
519 1 : .chain([255, 255].as_slice())
520 1 : // dst port
521 1 : .chain([1, 1].as_slice())
522 1 : // TLV
523 1 : .chain([255].as_slice())
524 1 : .chain(tlv_len.as_slice())
525 1 : .chain(tlv.as_slice());
526 1 :
527 1 : let extra_data = [0xaa; 256];
528 1 :
529 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
530 1 : .await
531 1 : .unwrap();
532 1 :
533 1 : let mut bytes = vec![];
534 1 : read.read_to_end(&mut bytes).await.unwrap();
535 1 :
536 1 : assert_eq!(bytes, extra_data);
537 1 :
538 1 : let ConnectHeader::Proxy(info) = info else {
539 1 : panic!()
540 1 : };
541 1 : assert_eq!(info.addr, ([55, 56, 57, 58], 65535).into());
542 1 : }
543 :
544 : #[tokio::test]
545 1 : async fn test_local() {
546 1 : let len = 0u16.to_be_bytes();
547 1 : let header = super::SIGNATURE
548 1 : .chain([LOCAL_V2, 0x00].as_slice())
549 1 : .chain(len.as_slice());
550 1 :
551 1 : let extra_data = [0xaa; 256];
552 1 :
553 1 : let (mut read, info) = read_proxy_protocol(header.chain(extra_data.as_slice()))
554 1 : .await
555 1 : .unwrap();
556 1 :
557 1 : let mut bytes = vec![];
558 1 : read.read_to_end(&mut bytes).await.unwrap();
559 1 :
560 1 : assert_eq!(bytes, extra_data);
561 1 :
562 1 : let ConnectHeader::Local = info else { panic!() };
563 1 : }
564 : }
|