Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : future::poll_fn,
5 : future::Future,
6 : io,
7 : net::SocketAddr,
8 : pin::{pin, Pin},
9 : task::{ready, Context, Poll},
10 : };
11 :
12 : use bytes::{Buf, BytesMut};
13 : use hyper::server::conn::{AddrIncoming, AddrStream};
14 : use pin_project_lite::pin_project;
15 : use tls_listener::AsyncAccept;
16 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
17 :
18 : pub struct ProxyProtocolAccept {
19 : pub incoming: AddrIncoming,
20 : }
21 :
22 : pin_project! {
23 : pub struct WithClientIp<T> {
24 : #[pin]
25 : pub inner: T,
26 : buf: BytesMut,
27 : tlv_bytes: u16,
28 : state: ProxyParse,
29 : }
30 : }
31 :
32 10 : #[derive(Clone, PartialEq, Debug)]
33 : enum ProxyParse {
34 : NotStarted,
35 :
36 : Finished(SocketAddr),
37 : None,
38 : }
39 :
40 : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
41 : #[inline]
42 91 : fn poll_write(
43 91 : self: Pin<&mut Self>,
44 91 : cx: &mut Context<'_>,
45 91 : buf: &[u8],
46 91 : ) -> Poll<Result<usize, io::Error>> {
47 91 : self.project().inner.poll_write(cx, buf)
48 91 : }
49 :
50 : #[inline]
51 1599 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
52 1599 : self.project().inner.poll_flush(cx)
53 1599 : }
54 :
55 : #[inline]
56 85 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
57 85 : self.project().inner.poll_shutdown(cx)
58 85 : }
59 :
60 : #[inline]
61 643 : fn poll_write_vectored(
62 643 : self: Pin<&mut Self>,
63 643 : cx: &mut Context<'_>,
64 643 : bufs: &[io::IoSlice<'_>],
65 643 : ) -> Poll<Result<usize, io::Error>> {
66 643 : self.project().inner.poll_write_vectored(cx, bufs)
67 643 : }
68 :
69 : #[inline]
70 0 : fn is_write_vectored(&self) -> bool {
71 0 : self.inner.is_write_vectored()
72 0 : }
73 : }
74 :
75 : impl<T> WithClientIp<T> {
76 147 : pub fn new(inner: T) -> Self {
77 147 : WithClientIp {
78 147 : inner,
79 147 : buf: BytesMut::with_capacity(128),
80 147 : tlv_bytes: 0,
81 147 : state: ProxyParse::NotStarted,
82 147 : }
83 147 : }
84 :
85 46 : pub fn client_addr(&self) -> Option<SocketAddr> {
86 46 : match self.state {
87 0 : ProxyParse::Finished(socket) => Some(socket),
88 46 : _ => None,
89 : }
90 46 : }
91 : }
92 :
93 : impl<T: AsyncRead + Unpin> WithClientIp<T> {
94 61 : pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
95 61 : match self.state {
96 : ProxyParse::NotStarted => {
97 61 : let mut pin = Pin::new(&mut *self);
98 119 : let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
99 61 : match addr {
100 0 : Some(addr) => self.state = ProxyParse::Finished(addr),
101 61 : None => self.state = ProxyParse::None,
102 : }
103 61 : Ok(addr)
104 : }
105 0 : ProxyParse::Finished(addr) => Ok(Some(addr)),
106 0 : ProxyParse::None => Ok(None),
107 : }
108 61 : }
109 : }
110 :
111 : /// Proxy Protocol Version 2 Header
112 : const HEADER: [u8; 12] = [
113 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
114 : ];
115 :
116 : impl<T: AsyncRead> WithClientIp<T> {
117 : /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
118 : /// Version 2 (Binary Format)
119 247 : fn poll_client_ip(
120 247 : mut self: Pin<&mut Self>,
121 247 : cx: &mut Context<'_>,
122 247 : ) -> Poll<io::Result<Option<SocketAddr>>> {
123 : // The binary header format starts with a constant 12 bytes block containing the protocol signature :
124 : // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
125 265 : while self.buf.len() < 16 {
126 259 : let mut this = self.as_mut().project();
127 259 : let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
128 :
129 : // exit for bad header
130 259 : let len = usize::min(self.buf.len(), HEADER.len());
131 259 : if self.buf[..len] != HEADER[..len] {
132 141 : return Poll::Ready(Ok(None));
133 118 : }
134 :
135 : // if no more bytes available then exit
136 118 : if ready!(bytes_read) == 0 {
137 0 : return Poll::Ready(Ok(None));
138 18 : };
139 : }
140 :
141 : // The next byte (the 13th one) is the protocol version and command.
142 : // The highest four bits contains the version. As of this specification, it must
143 : // always be sent as \x2 and the receiver must only accept this value.
144 6 : let vc = self.buf[12];
145 6 : let version = vc >> 4;
146 6 : let command = vc & 0b1111;
147 6 : if version != 2 {
148 0 : return Poll::Ready(Err(io::Error::new(
149 0 : io::ErrorKind::Other,
150 0 : "invalid proxy protocol version. expected version 2",
151 0 : )));
152 6 : }
153 6 : match command {
154 : // the connection was established on purpose by the proxy
155 : // without being relayed. The connection endpoints are the sender and the
156 : // receiver. Such connections exist when the proxy sends health-checks to the
157 : // server. The receiver must accept this connection as valid and must use the
158 : // real connection endpoints and discard the protocol block including the
159 : // family which is ignored.
160 0 : 0 => {}
161 : // the connection was established on behalf of another node,
162 : // and reflects the original connection endpoints. The receiver must then use
163 : // the information provided in the protocol block to get original the address.
164 6 : 1 => {}
165 : // other values are unassigned and must not be emitted by senders. Receivers
166 : // must drop connections presenting unexpected values here.
167 : _ => {
168 0 : return Poll::Ready(Err(io::Error::new(
169 0 : io::ErrorKind::Other,
170 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
171 0 : )))
172 : }
173 : };
174 :
175 : // The 14th byte contains the transport protocol and address family. The highest 4
176 : // bits contain the address family, the lowest 4 bits contain the protocol.
177 6 : let ft = self.buf[13];
178 6 : let address_length = match ft {
179 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
180 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
181 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
182 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
183 4 : 0x11 | 0x12 => 12,
184 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
185 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
186 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
187 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
188 2 : 0x21 | 0x22 => 36,
189 : // unspecified or unix stream. ignore the addresses
190 0 : _ => 0,
191 : };
192 :
193 : // The 15th and 16th bytes is the address length in bytes in network endian order.
194 : // It is used so that the receiver knows how many address bytes to skip even when
195 : // it does not implement the presented protocol. Thus the length of the protocol
196 : // header in bytes is always exactly 16 + this value. When a sender presents a
197 : // LOCAL connection, it should not present any address so it sets this field to
198 : // zero. Receivers MUST always consider this field to skip the appropriate number
199 : // of bytes and must not assume zero is presented for LOCAL connections. When a
200 : // receiver accepts an incoming connection showing an UNSPEC address family or
201 : // protocol, it may or may not decide to log the address information if present.
202 6 : let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
203 6 : if remaining_length < address_length {
204 0 : return Poll::Ready(Err(io::Error::new(
205 0 : io::ErrorKind::Other,
206 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
207 0 : )));
208 6 : }
209 :
210 30 : while self.buf.len() < 16 + address_length as usize {
211 24 : let mut this = self.as_mut().project();
212 24 : if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
213 0 : return Poll::Ready(Err(io::Error::new(
214 0 : io::ErrorKind::UnexpectedEof,
215 0 : "stream closed while waiting for proxy protocol addresses",
216 0 : )));
217 24 : }
218 : }
219 :
220 6 : let this = self.as_mut().project();
221 6 :
222 6 : // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
223 6 : // discard the header we have parsed
224 6 : this.buf.advance(16);
225 6 :
226 6 : // Starting from the 17th byte, addresses are presented in network byte order.
227 6 : // The address order is always the same :
228 6 : // - source layer 3 address in network byte order
229 6 : // - destination layer 3 address in network byte order
230 6 : // - source layer 4 address if any, in network byte order (port)
231 6 : // - destination layer 4 address if any, in network byte order (port)
232 6 : let addresses = this.buf.split_to(address_length as usize);
233 6 : let socket = match address_length {
234 : 12 => {
235 4 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
236 4 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
237 4 : Some(SocketAddr::from((src_addr, src_port)))
238 : }
239 : 36 => {
240 2 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
241 2 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
242 2 : Some(SocketAddr::from((src_addr, src_port)))
243 : }
244 0 : _ => None,
245 : };
246 :
247 6 : *this.tlv_bytes = remaining_length - address_length;
248 6 : self.as_mut().skip_tlv_inner();
249 6 :
250 6 : Poll::Ready(Ok(socket))
251 247 : }
252 :
253 : #[cold]
254 128 : fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
255 128 : let ip = ready!(self.as_mut().poll_client_ip(cx)?);
256 86 : match ip {
257 6 : Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
258 80 : None => *self.as_mut().project().state = ProxyParse::None,
259 : }
260 86 : Poll::Ready(Ok(()))
261 128 : }
262 :
263 : #[cold]
264 68 : fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
265 68 : let mut this = self.as_mut().project();
266 : // we know that this.buf is empty
267 68 : debug_assert_eq!(this.buf.len(), 0);
268 :
269 68 : this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
270 68 : ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
271 68 : self.skip_tlv_inner();
272 68 :
273 68 : Poll::Ready(Ok(()))
274 68 : }
275 :
276 74 : fn skip_tlv_inner(self: Pin<&mut Self>) {
277 74 : let tlv_bytes_read = match u16::try_from(self.buf.len()) {
278 : // we read more than u16::MAX therefore we must have read the full tlv_bytes
279 0 : Err(_) => self.tlv_bytes,
280 : // we might not have read the full tlv bytes yet
281 74 : Ok(n) => u16::min(n, self.tlv_bytes),
282 : };
283 74 : let this = self.project();
284 74 : *this.tlv_bytes -= tlv_bytes_read;
285 74 : this.buf.advance(tlv_bytes_read as usize);
286 74 : }
287 : }
288 :
289 : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
290 : #[inline]
291 2308 : fn poll_read(
292 2308 : mut self: Pin<&mut Self>,
293 2308 : cx: &mut Context<'_>,
294 2308 : buf: &mut ReadBuf<'_>,
295 2308 : ) -> Poll<io::Result<()>> {
296 2308 : // I'm assuming these 3 comparisons will be easy to branch predict.
297 2308 : // especially with the cold attributes
298 2308 : // which should make this read wrapper almost invisible
299 2308 :
300 2308 : if let ProxyParse::NotStarted = self.state {
301 128 : ready!(self.as_mut().read_ip(cx)?);
302 2180 : }
303 :
304 2334 : while self.tlv_bytes > 0 {
305 68 : ready!(self.as_mut().skip_tlv(cx)?)
306 : }
307 :
308 2266 : let this = self.project();
309 2266 : if this.buf.is_empty() {
310 2121 : this.inner.poll_read(cx, buf)
311 : } else {
312 : // we know that tlv_bytes is 0
313 145 : debug_assert_eq!(*this.tlv_bytes, 0);
314 :
315 145 : let write = usize::min(this.buf.len(), buf.remaining());
316 145 : let slice = this.buf.split_to(write).freeze();
317 145 : buf.put_slice(&slice);
318 145 :
319 145 : // reset the allocation so it can be freed
320 145 : if this.buf.is_empty() {
321 141 : *this.buf = BytesMut::new();
322 141 : }
323 :
324 145 : Poll::Ready(Ok(()))
325 : }
326 2308 : }
327 : }
328 :
329 : impl AsyncAccept for ProxyProtocolAccept {
330 : type Connection = WithClientIp<AddrStream>;
331 :
332 : type Error = io::Error;
333 :
334 305 : fn poll_accept(
335 305 : mut self: Pin<&mut Self>,
336 305 : cx: &mut Context<'_>,
337 305 : ) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
338 305 : let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
339 46 : let Some(conn) = conn else {
340 0 : return Poll::Ready(None);
341 : };
342 :
343 46 : Poll::Ready(Some(Ok(WithClientIp::new(conn))))
344 305 : }
345 : }
346 :
347 : #[cfg(test)]
348 : mod tests {
349 : use std::pin::pin;
350 :
351 : use tokio::io::AsyncReadExt;
352 :
353 : use crate::protocol2::{ProxyParse, WithClientIp};
354 :
355 2 : #[tokio::test]
356 2 : async fn test_ipv4() {
357 2 : let header = super::HEADER
358 2 : // Proxy command, IPV4 | TCP
359 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
360 2 : // 12 + 3 bytes
361 2 : .chain([0, 15].as_slice())
362 2 : // src ip
363 2 : .chain([127, 0, 0, 1].as_slice())
364 2 : // dst ip
365 2 : .chain([192, 168, 0, 1].as_slice())
366 2 : // src port
367 2 : .chain([255, 255].as_slice())
368 2 : // dst port
369 2 : .chain([1, 1].as_slice())
370 2 : // TLV
371 2 : .chain([1, 2, 3].as_slice());
372 2 :
373 2 : let extra_data = [0x55; 256];
374 2 :
375 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
376 2 :
377 2 : let mut bytes = vec![];
378 2 : read.read_to_end(&mut bytes).await.unwrap();
379 2 :
380 2 : assert_eq!(bytes, extra_data);
381 2 : assert_eq!(
382 2 : read.state,
383 2 : ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
384 2 : );
385 : }
386 :
387 2 : #[tokio::test]
388 2 : async fn test_ipv6() {
389 2 : let header = super::HEADER
390 2 : // Proxy command, IPV6 | UDP
391 2 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
392 2 : // 36 + 3 bytes
393 2 : .chain([0, 39].as_slice())
394 2 : // src ip
395 2 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
396 2 : // dst ip
397 2 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
398 2 : // src port
399 2 : .chain([1, 1].as_slice())
400 2 : // dst port
401 2 : .chain([255, 255].as_slice())
402 2 : // TLV
403 2 : .chain([1, 2, 3].as_slice());
404 2 :
405 2 : let extra_data = [0x55; 256];
406 2 :
407 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
408 2 :
409 2 : let mut bytes = vec![];
410 2 : read.read_to_end(&mut bytes).await.unwrap();
411 2 :
412 2 : assert_eq!(bytes, extra_data);
413 2 : assert_eq!(
414 2 : read.state,
415 2 : ProxyParse::Finished(
416 2 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
417 2 : )
418 2 : );
419 : }
420 :
421 2 : #[tokio::test]
422 2 : async fn test_invalid() {
423 2 : let data = [0x55; 256];
424 2 :
425 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
426 2 :
427 2 : let mut bytes = vec![];
428 2 : read.read_to_end(&mut bytes).await.unwrap();
429 2 : assert_eq!(bytes, data);
430 2 : assert_eq!(read.state, ProxyParse::None);
431 : }
432 :
433 2 : #[tokio::test]
434 2 : async fn test_short() {
435 2 : let data = [0x55; 10];
436 2 :
437 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
438 2 :
439 2 : let mut bytes = vec![];
440 2 : read.read_to_end(&mut bytes).await.unwrap();
441 2 : assert_eq!(bytes, data);
442 2 : assert_eq!(read.state, ProxyParse::None);
443 : }
444 :
445 2 : #[tokio::test]
446 2 : async fn test_large_tlv() {
447 2 : let tlv = vec![0x55; 32768];
448 2 : let len = (12 + tlv.len() as u16).to_be_bytes();
449 2 :
450 2 : let header = super::HEADER
451 2 : // Proxy command, Inet << 4 | Stream
452 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
453 2 : // 12 + 3 bytes
454 2 : .chain(len.as_slice())
455 2 : // src ip
456 2 : .chain([55, 56, 57, 58].as_slice())
457 2 : // dst ip
458 2 : .chain([192, 168, 0, 1].as_slice())
459 2 : // src port
460 2 : .chain([255, 255].as_slice())
461 2 : // dst port
462 2 : .chain([1, 1].as_slice())
463 2 : // TLV
464 2 : .chain(tlv.as_slice());
465 2 :
466 2 : let extra_data = [0xaa; 256];
467 2 :
468 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
469 2 :
470 2 : let mut bytes = vec![];
471 2 : read.read_to_end(&mut bytes).await.unwrap();
472 2 :
473 2 : assert_eq!(bytes, extra_data);
474 2 : assert_eq!(
475 2 : read.state,
476 2 : ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
477 2 : );
478 : }
479 : }
|