Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : future::{poll_fn, Future},
5 : io,
6 : net::SocketAddr,
7 : pin::{pin, Pin},
8 : sync::Mutex,
9 : task::{ready, Context, Poll},
10 : };
11 :
12 : use bytes::{Buf, BytesMut};
13 : use hyper::server::accept::Accept;
14 : use hyper::server::conn::{AddrIncoming, AddrStream};
15 : use metrics::IntCounterPairGuard;
16 : use pin_project_lite::pin_project;
17 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
18 : use uuid::Uuid;
19 :
20 : use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
21 :
22 : pub struct ProxyProtocolAccept {
23 : pub incoming: AddrIncoming,
24 : pub protocol: &'static str,
25 : }
26 :
27 : pin_project! {
28 : pub struct WithClientIp<T> {
29 : #[pin]
30 : pub inner: T,
31 : buf: BytesMut,
32 : tlv_bytes: u16,
33 : state: ProxyParse,
34 : }
35 : }
36 :
37 : #[derive(Clone, PartialEq, Debug)]
38 : enum ProxyParse {
39 : NotStarted,
40 :
41 : Finished(SocketAddr),
42 : None,
43 : }
44 :
45 : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
46 : #[inline]
47 30 : fn poll_write(
48 30 : self: Pin<&mut Self>,
49 30 : cx: &mut Context<'_>,
50 30 : buf: &[u8],
51 30 : ) -> Poll<Result<usize, io::Error>> {
52 30 : self.project().inner.poll_write(cx, buf)
53 30 : }
54 :
55 : #[inline]
56 148 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
57 148 : self.project().inner.poll_flush(cx)
58 148 : }
59 :
60 : #[inline]
61 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
62 0 : self.project().inner.poll_shutdown(cx)
63 0 : }
64 :
65 : #[inline]
66 118 : fn poll_write_vectored(
67 118 : self: Pin<&mut Self>,
68 118 : cx: &mut Context<'_>,
69 118 : bufs: &[io::IoSlice<'_>],
70 118 : ) -> Poll<Result<usize, io::Error>> {
71 118 : self.project().inner.poll_write_vectored(cx, bufs)
72 118 : }
73 :
74 : #[inline]
75 0 : fn is_write_vectored(&self) -> bool {
76 0 : self.inner.is_write_vectored()
77 0 : }
78 : }
79 :
80 : impl<T> WithClientIp<T> {
81 40 : pub fn new(inner: T) -> Self {
82 40 : WithClientIp {
83 40 : inner,
84 40 : buf: BytesMut::with_capacity(128),
85 40 : tlv_bytes: 0,
86 40 : state: ProxyParse::NotStarted,
87 40 : }
88 40 : }
89 :
90 0 : pub fn client_addr(&self) -> Option<SocketAddr> {
91 0 : match self.state {
92 0 : ProxyParse::Finished(socket) => Some(socket),
93 0 : _ => None,
94 : }
95 0 : }
96 : }
97 :
98 : impl<T: AsyncRead + Unpin> WithClientIp<T> {
99 0 : pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
100 0 : match self.state {
101 : ProxyParse::NotStarted => {
102 0 : let mut pin = Pin::new(&mut *self);
103 0 : let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
104 0 : match addr {
105 0 : Some(addr) => self.state = ProxyParse::Finished(addr),
106 0 : None => self.state = ProxyParse::None,
107 : }
108 0 : Ok(addr)
109 : }
110 0 : ProxyParse::Finished(addr) => Ok(Some(addr)),
111 0 : ProxyParse::None => Ok(None),
112 : }
113 0 : }
114 : }
115 :
116 : /// Proxy Protocol Version 2 Header
117 : const HEADER: [u8; 12] = [
118 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
119 : ];
120 :
121 : impl<T: AsyncRead> WithClientIp<T> {
122 : /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
123 : /// Version 2 (Binary Format)
124 40 : fn poll_client_ip(
125 40 : mut self: Pin<&mut Self>,
126 40 : cx: &mut Context<'_>,
127 40 : ) -> Poll<io::Result<Option<SocketAddr>>> {
128 : // The binary header format starts with a constant 12 bytes block containing the protocol signature :
129 : // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
130 58 : while self.buf.len() < 16 {
131 52 : let mut this = self.as_mut().project();
132 52 : let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
133 :
134 : // exit for bad header
135 52 : let len = usize::min(self.buf.len(), HEADER.len());
136 52 : if self.buf[..len] != HEADER[..len] {
137 34 : return Poll::Ready(Ok(None));
138 18 : }
139 :
140 : // if no more bytes available then exit
141 18 : if ready!(bytes_read) == 0 {
142 0 : return Poll::Ready(Ok(None));
143 18 : };
144 : }
145 :
146 : // The next byte (the 13th one) is the protocol version and command.
147 : // The highest four bits contains the version. As of this specification, it must
148 : // always be sent as \x2 and the receiver must only accept this value.
149 6 : let vc = self.buf[12];
150 6 : let version = vc >> 4;
151 6 : let command = vc & 0b1111;
152 6 : if version != 2 {
153 0 : return Poll::Ready(Err(io::Error::new(
154 0 : io::ErrorKind::Other,
155 0 : "invalid proxy protocol version. expected version 2",
156 0 : )));
157 6 : }
158 6 : match command {
159 : // the connection was established on purpose by the proxy
160 : // without being relayed. The connection endpoints are the sender and the
161 : // receiver. Such connections exist when the proxy sends health-checks to the
162 : // server. The receiver must accept this connection as valid and must use the
163 : // real connection endpoints and discard the protocol block including the
164 : // family which is ignored.
165 0 : 0 => {}
166 : // the connection was established on behalf of another node,
167 : // and reflects the original connection endpoints. The receiver must then use
168 : // the information provided in the protocol block to get original the address.
169 6 : 1 => {}
170 : // other values are unassigned and must not be emitted by senders. Receivers
171 : // must drop connections presenting unexpected values here.
172 : _ => {
173 0 : return Poll::Ready(Err(io::Error::new(
174 0 : io::ErrorKind::Other,
175 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
176 0 : )))
177 : }
178 : };
179 :
180 : // The 14th byte contains the transport protocol and address family. The highest 4
181 : // bits contain the address family, the lowest 4 bits contain the protocol.
182 6 : let ft = self.buf[13];
183 6 : let address_length = match ft {
184 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
185 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
186 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
187 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
188 4 : 0x11 | 0x12 => 12,
189 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
190 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
191 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
192 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
193 2 : 0x21 | 0x22 => 36,
194 : // unspecified or unix stream. ignore the addresses
195 0 : _ => 0,
196 : };
197 :
198 : // The 15th and 16th bytes is the address length in bytes in network endian order.
199 : // It is used so that the receiver knows how many address bytes to skip even when
200 : // it does not implement the presented protocol. Thus the length of the protocol
201 : // header in bytes is always exactly 16 + this value. When a sender presents a
202 : // LOCAL connection, it should not present any address so it sets this field to
203 : // zero. Receivers MUST always consider this field to skip the appropriate number
204 : // of bytes and must not assume zero is presented for LOCAL connections. When a
205 : // receiver accepts an incoming connection showing an UNSPEC address family or
206 : // protocol, it may or may not decide to log the address information if present.
207 6 : let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
208 6 : if remaining_length < address_length {
209 0 : return Poll::Ready(Err(io::Error::new(
210 0 : io::ErrorKind::Other,
211 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
212 0 : )));
213 6 : }
214 :
215 30 : while self.buf.len() < 16 + address_length as usize {
216 24 : let mut this = self.as_mut().project();
217 24 : if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
218 0 : return Poll::Ready(Err(io::Error::new(
219 0 : io::ErrorKind::UnexpectedEof,
220 0 : "stream closed while waiting for proxy protocol addresses",
221 0 : )));
222 24 : }
223 : }
224 :
225 6 : let this = self.as_mut().project();
226 6 :
227 6 : // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
228 6 : // discard the header we have parsed
229 6 : this.buf.advance(16);
230 6 :
231 6 : // Starting from the 17th byte, addresses are presented in network byte order.
232 6 : // The address order is always the same :
233 6 : // - source layer 3 address in network byte order
234 6 : // - destination layer 3 address in network byte order
235 6 : // - source layer 4 address if any, in network byte order (port)
236 6 : // - destination layer 4 address if any, in network byte order (port)
237 6 : let addresses = this.buf.split_to(address_length as usize);
238 6 : let socket = match address_length {
239 : 12 => {
240 4 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
241 4 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
242 4 : Some(SocketAddr::from((src_addr, src_port)))
243 : }
244 : 36 => {
245 2 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
246 2 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
247 2 : Some(SocketAddr::from((src_addr, src_port)))
248 : }
249 0 : _ => None,
250 : };
251 :
252 6 : *this.tlv_bytes = remaining_length - address_length;
253 6 : self.as_mut().skip_tlv_inner();
254 6 :
255 6 : Poll::Ready(Ok(socket))
256 40 : }
257 :
258 : #[cold]
259 40 : fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
260 40 : let ip = ready!(self.as_mut().poll_client_ip(cx)?);
261 40 : match ip {
262 6 : Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
263 34 : None => *self.as_mut().project().state = ProxyParse::None,
264 : }
265 40 : Poll::Ready(Ok(()))
266 40 : }
267 :
268 : #[cold]
269 68 : fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
270 68 : let mut this = self.as_mut().project();
271 68 : // we know that this.buf is empty
272 68 : debug_assert_eq!(this.buf.len(), 0);
273 :
274 68 : this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
275 68 : ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
276 68 : self.skip_tlv_inner();
277 68 :
278 68 : Poll::Ready(Ok(()))
279 68 : }
280 :
281 74 : fn skip_tlv_inner(self: Pin<&mut Self>) {
282 74 : let tlv_bytes_read = match u16::try_from(self.buf.len()) {
283 : // we read more than u16::MAX therefore we must have read the full tlv_bytes
284 0 : Err(_) => self.tlv_bytes,
285 : // we might not have read the full tlv bytes yet
286 74 : Ok(n) => u16::min(n, self.tlv_bytes),
287 : };
288 74 : let this = self.project();
289 74 : *this.tlv_bytes -= tlv_bytes_read;
290 74 : this.buf.advance(tlv_bytes_read as usize);
291 74 : }
292 : }
293 :
294 : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
295 : #[inline]
296 334 : fn poll_read(
297 334 : mut self: Pin<&mut Self>,
298 334 : cx: &mut Context<'_>,
299 334 : buf: &mut ReadBuf<'_>,
300 334 : ) -> Poll<io::Result<()>> {
301 334 : // I'm assuming these 3 comparisons will be easy to branch predict.
302 334 : // especially with the cold attributes
303 334 : // which should make this read wrapper almost invisible
304 334 :
305 334 : if let ProxyParse::NotStarted = self.state {
306 40 : ready!(self.as_mut().read_ip(cx)?);
307 294 : }
308 :
309 402 : while self.tlv_bytes > 0 {
310 68 : ready!(self.as_mut().skip_tlv(cx)?)
311 : }
312 :
313 334 : let this = self.project();
314 334 : if this.buf.is_empty() {
315 296 : this.inner.poll_read(cx, buf)
316 : } else {
317 : // we know that tlv_bytes is 0
318 38 : debug_assert_eq!(*this.tlv_bytes, 0);
319 :
320 38 : let write = usize::min(this.buf.len(), buf.remaining());
321 38 : let slice = this.buf.split_to(write).freeze();
322 38 : buf.put_slice(&slice);
323 38 :
324 38 : // reset the allocation so it can be freed
325 38 : if this.buf.is_empty() {
326 34 : *this.buf = BytesMut::new();
327 34 : }
328 :
329 38 : Poll::Ready(Ok(()))
330 : }
331 334 : }
332 : }
333 :
334 : impl Accept for ProxyProtocolAccept {
335 : type Conn = WithConnectionGuard<WithClientIp<AddrStream>>;
336 :
337 : type Error = io::Error;
338 :
339 0 : fn poll_accept(
340 0 : mut self: Pin<&mut Self>,
341 0 : cx: &mut Context<'_>,
342 0 : ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
343 0 : let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
344 :
345 0 : let conn_id = uuid::Uuid::new_v4();
346 0 : let span = tracing::info_span!("http_conn", ?conn_id);
347 : {
348 0 : let _enter = span.enter();
349 0 : tracing::info!("accepted new TCP connection");
350 : }
351 :
352 0 : let Some(conn) = conn else {
353 0 : return Poll::Ready(None);
354 : };
355 :
356 0 : Poll::Ready(Some(Ok(WithConnectionGuard {
357 0 : inner: WithClientIp::new(conn),
358 0 : connection_id: Uuid::new_v4(),
359 0 : gauge: Mutex::new(Some(
360 0 : NUM_CLIENT_CONNECTION_GAUGE
361 0 : .with_label_values(&[self.protocol])
362 0 : .guard(),
363 0 : )),
364 0 : span,
365 0 : })))
366 0 : }
367 : }
368 :
369 : pin_project! {
370 : pub struct WithConnectionGuard<T> {
371 : #[pin]
372 : pub inner: T,
373 : pub connection_id: Uuid,
374 : pub gauge: Mutex<Option<IntCounterPairGuard>>,
375 : pub span: tracing::Span,
376 : }
377 :
378 : impl<S> PinnedDrop for WithConnectionGuard<S> {
379 : fn drop(this: Pin<&mut Self>) {
380 : let _enter = this.span.enter();
381 0 : tracing::info!("HTTP connection closed")
382 : }
383 : }
384 : }
385 :
386 : impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
387 : #[inline]
388 0 : fn poll_write(
389 0 : self: Pin<&mut Self>,
390 0 : cx: &mut Context<'_>,
391 0 : buf: &[u8],
392 0 : ) -> Poll<Result<usize, io::Error>> {
393 0 : self.project().inner.poll_write(cx, buf)
394 0 : }
395 :
396 : #[inline]
397 0 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
398 0 : self.project().inner.poll_flush(cx)
399 0 : }
400 :
401 : #[inline]
402 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
403 0 : self.project().inner.poll_shutdown(cx)
404 0 : }
405 :
406 : #[inline]
407 0 : fn poll_write_vectored(
408 0 : self: Pin<&mut Self>,
409 0 : cx: &mut Context<'_>,
410 0 : bufs: &[io::IoSlice<'_>],
411 0 : ) -> Poll<Result<usize, io::Error>> {
412 0 : self.project().inner.poll_write_vectored(cx, bufs)
413 0 : }
414 :
415 : #[inline]
416 0 : fn is_write_vectored(&self) -> bool {
417 0 : self.inner.is_write_vectored()
418 0 : }
419 : }
420 :
421 : impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
422 0 : fn poll_read(
423 0 : self: Pin<&mut Self>,
424 0 : cx: &mut Context<'_>,
425 0 : buf: &mut ReadBuf<'_>,
426 0 : ) -> Poll<io::Result<()>> {
427 0 : self.project().inner.poll_read(cx, buf)
428 0 : }
429 : }
430 :
431 : #[cfg(test)]
432 : mod tests {
433 : use std::pin::pin;
434 :
435 : use tokio::io::AsyncReadExt;
436 :
437 : use crate::protocol2::{ProxyParse, WithClientIp};
438 :
439 : #[tokio::test]
440 2 : async fn test_ipv4() {
441 2 : let header = super::HEADER
442 2 : // Proxy command, IPV4 | TCP
443 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
444 2 : // 12 + 3 bytes
445 2 : .chain([0, 15].as_slice())
446 2 : // src ip
447 2 : .chain([127, 0, 0, 1].as_slice())
448 2 : // dst ip
449 2 : .chain([192, 168, 0, 1].as_slice())
450 2 : // src port
451 2 : .chain([255, 255].as_slice())
452 2 : // dst port
453 2 : .chain([1, 1].as_slice())
454 2 : // TLV
455 2 : .chain([1, 2, 3].as_slice());
456 2 :
457 2 : let extra_data = [0x55; 256];
458 2 :
459 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
460 2 :
461 2 : let mut bytes = vec![];
462 2 : read.read_to_end(&mut bytes).await.unwrap();
463 2 :
464 2 : assert_eq!(bytes, extra_data);
465 2 : assert_eq!(
466 2 : read.state,
467 2 : ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
468 2 : );
469 2 : }
470 :
471 : #[tokio::test]
472 2 : async fn test_ipv6() {
473 2 : let header = super::HEADER
474 2 : // Proxy command, IPV6 | UDP
475 2 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
476 2 : // 36 + 3 bytes
477 2 : .chain([0, 39].as_slice())
478 2 : // src ip
479 2 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
480 2 : // dst ip
481 2 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
482 2 : // src port
483 2 : .chain([1, 1].as_slice())
484 2 : // dst port
485 2 : .chain([255, 255].as_slice())
486 2 : // TLV
487 2 : .chain([1, 2, 3].as_slice());
488 2 :
489 2 : let extra_data = [0x55; 256];
490 2 :
491 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
492 2 :
493 2 : let mut bytes = vec![];
494 2 : read.read_to_end(&mut bytes).await.unwrap();
495 2 :
496 2 : assert_eq!(bytes, extra_data);
497 2 : assert_eq!(
498 2 : read.state,
499 2 : ProxyParse::Finished(
500 2 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
501 2 : )
502 2 : );
503 2 : }
504 :
505 : #[tokio::test]
506 2 : async fn test_invalid() {
507 2 : let data = [0x55; 256];
508 2 :
509 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
510 2 :
511 2 : let mut bytes = vec![];
512 2 : read.read_to_end(&mut bytes).await.unwrap();
513 2 : assert_eq!(bytes, data);
514 2 : assert_eq!(read.state, ProxyParse::None);
515 2 : }
516 :
517 : #[tokio::test]
518 2 : async fn test_short() {
519 2 : let data = [0x55; 10];
520 2 :
521 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
522 2 :
523 2 : let mut bytes = vec![];
524 2 : read.read_to_end(&mut bytes).await.unwrap();
525 2 : assert_eq!(bytes, data);
526 2 : assert_eq!(read.state, ProxyParse::None);
527 2 : }
528 :
529 : #[tokio::test]
530 2 : async fn test_large_tlv() {
531 2 : let tlv = vec![0x55; 32768];
532 2 : let len = (12 + tlv.len() as u16).to_be_bytes();
533 2 :
534 2 : let header = super::HEADER
535 2 : // Proxy command, Inet << 4 | Stream
536 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
537 2 : // 12 + 3 bytes
538 2 : .chain(len.as_slice())
539 2 : // src ip
540 2 : .chain([55, 56, 57, 58].as_slice())
541 2 : // dst ip
542 2 : .chain([192, 168, 0, 1].as_slice())
543 2 : // src port
544 2 : .chain([255, 255].as_slice())
545 2 : // dst port
546 2 : .chain([1, 1].as_slice())
547 2 : // TLV
548 2 : .chain(tlv.as_slice());
549 2 :
550 2 : let extra_data = [0xaa; 256];
551 2 :
552 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
553 2 :
554 2 : let mut bytes = vec![];
555 2 : read.read_to_end(&mut bytes).await.unwrap();
556 2 :
557 2 : assert_eq!(bytes, extra_data);
558 2 : assert_eq!(
559 2 : read.state,
560 2 : ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
561 2 : );
562 2 : }
563 : }
|