Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : future::{poll_fn, Future},
5 : io,
6 : net::SocketAddr,
7 : pin::{pin, Pin},
8 : task::{ready, Context, Poll},
9 : };
10 :
11 : use bytes::{Buf, BytesMut};
12 : use hyper::server::conn::AddrIncoming;
13 : use pin_project_lite::pin_project;
14 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
15 :
16 : pub struct ProxyProtocolAccept {
17 : pub incoming: AddrIncoming,
18 : pub protocol: &'static str,
19 : }
20 :
21 : pin_project! {
22 : pub struct WithClientIp<T> {
23 : #[pin]
24 : pub inner: T,
25 : buf: BytesMut,
26 : tlv_bytes: u16,
27 : state: ProxyParse,
28 : }
29 : }
30 :
31 : #[derive(Clone, PartialEq, Debug)]
32 : enum ProxyParse {
33 : NotStarted,
34 :
35 : Finished(SocketAddr),
36 : None,
37 : }
38 :
39 : impl<T: AsyncWrite> AsyncWrite for WithClientIp<T> {
40 : #[inline]
41 30 : fn poll_write(
42 30 : self: Pin<&mut Self>,
43 30 : cx: &mut Context<'_>,
44 30 : buf: &[u8],
45 30 : ) -> Poll<Result<usize, io::Error>> {
46 30 : self.project().inner.poll_write(cx, buf)
47 30 : }
48 :
49 : #[inline]
50 148 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
51 148 : self.project().inner.poll_flush(cx)
52 148 : }
53 :
54 : #[inline]
55 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
56 0 : self.project().inner.poll_shutdown(cx)
57 0 : }
58 :
59 : #[inline]
60 118 : fn poll_write_vectored(
61 118 : self: Pin<&mut Self>,
62 118 : cx: &mut Context<'_>,
63 118 : bufs: &[io::IoSlice<'_>],
64 118 : ) -> Poll<Result<usize, io::Error>> {
65 118 : self.project().inner.poll_write_vectored(cx, bufs)
66 118 : }
67 :
68 : #[inline]
69 0 : fn is_write_vectored(&self) -> bool {
70 0 : self.inner.is_write_vectored()
71 0 : }
72 : }
73 :
74 : impl<T> WithClientIp<T> {
75 40 : pub fn new(inner: T) -> Self {
76 40 : WithClientIp {
77 40 : inner,
78 40 : buf: BytesMut::with_capacity(128),
79 40 : tlv_bytes: 0,
80 40 : state: ProxyParse::NotStarted,
81 40 : }
82 40 : }
83 :
84 0 : pub fn client_addr(&self) -> Option<SocketAddr> {
85 0 : match self.state {
86 0 : ProxyParse::Finished(socket) => Some(socket),
87 0 : _ => None,
88 : }
89 0 : }
90 : }
91 :
92 : impl<T: AsyncRead + Unpin> WithClientIp<T> {
93 0 : pub async fn wait_for_addr(&mut self) -> io::Result<Option<SocketAddr>> {
94 0 : match self.state {
95 : ProxyParse::NotStarted => {
96 0 : let mut pin = Pin::new(&mut *self);
97 0 : let addr = poll_fn(|cx| pin.as_mut().poll_client_ip(cx)).await?;
98 0 : match addr {
99 0 : Some(addr) => self.state = ProxyParse::Finished(addr),
100 0 : None => self.state = ProxyParse::None,
101 : }
102 0 : Ok(addr)
103 : }
104 0 : ProxyParse::Finished(addr) => Ok(Some(addr)),
105 0 : ProxyParse::None => Ok(None),
106 : }
107 0 : }
108 : }
109 :
110 : /// Proxy Protocol Version 2 Header
111 : const HEADER: [u8; 12] = [
112 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
113 : ];
114 :
115 : impl<T: AsyncRead> WithClientIp<T> {
116 : /// implementation of <https://www.haproxy.org/download/2.4/doc/proxy-protocol.txt>
117 : /// Version 2 (Binary Format)
118 40 : fn poll_client_ip(
119 40 : mut self: Pin<&mut Self>,
120 40 : cx: &mut Context<'_>,
121 40 : ) -> Poll<io::Result<Option<SocketAddr>>> {
122 : // The binary header format starts with a constant 12 bytes block containing the protocol signature :
123 : // \x0D \x0A \x0D \x0A \x00 \x0D \x0A \x51 \x55 \x49 \x54 \x0A
124 58 : while self.buf.len() < 16 {
125 52 : let mut this = self.as_mut().project();
126 52 : let bytes_read = pin!(this.inner.read_buf(this.buf)).poll(cx)?;
127 :
128 : // exit for bad header
129 52 : let len = usize::min(self.buf.len(), HEADER.len());
130 52 : if self.buf[..len] != HEADER[..len] {
131 34 : return Poll::Ready(Ok(None));
132 18 : }
133 :
134 : // if no more bytes available then exit
135 18 : if ready!(bytes_read) == 0 {
136 0 : return Poll::Ready(Ok(None));
137 18 : };
138 : }
139 :
140 : // The next byte (the 13th one) is the protocol version and command.
141 : // The highest four bits contains the version. As of this specification, it must
142 : // always be sent as \x2 and the receiver must only accept this value.
143 6 : let vc = self.buf[12];
144 6 : let version = vc >> 4;
145 6 : let command = vc & 0b1111;
146 6 : if version != 2 {
147 0 : return Poll::Ready(Err(io::Error::new(
148 0 : io::ErrorKind::Other,
149 0 : "invalid proxy protocol version. expected version 2",
150 0 : )));
151 6 : }
152 6 : match command {
153 : // the connection was established on purpose by the proxy
154 : // without being relayed. The connection endpoints are the sender and the
155 : // receiver. Such connections exist when the proxy sends health-checks to the
156 : // server. The receiver must accept this connection as valid and must use the
157 : // real connection endpoints and discard the protocol block including the
158 : // family which is ignored.
159 0 : 0 => {}
160 : // the connection was established on behalf of another node,
161 : // and reflects the original connection endpoints. The receiver must then use
162 : // the information provided in the protocol block to get original the address.
163 6 : 1 => {}
164 : // other values are unassigned and must not be emitted by senders. Receivers
165 : // must drop connections presenting unexpected values here.
166 : _ => {
167 0 : return Poll::Ready(Err(io::Error::new(
168 0 : io::ErrorKind::Other,
169 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
170 0 : )))
171 : }
172 : };
173 :
174 : // The 14th byte contains the transport protocol and address family. The highest 4
175 : // bits contain the address family, the lowest 4 bits contain the protocol.
176 6 : let ft = self.buf[13];
177 6 : let address_length = match ft {
178 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
179 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
180 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
181 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
182 4 : 0x11 | 0x12 => 12,
183 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
184 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
185 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
186 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
187 2 : 0x21 | 0x22 => 36,
188 : // unspecified or unix stream. ignore the addresses
189 0 : _ => 0,
190 : };
191 :
192 : // The 15th and 16th bytes is the address length in bytes in network endian order.
193 : // It is used so that the receiver knows how many address bytes to skip even when
194 : // it does not implement the presented protocol. Thus the length of the protocol
195 : // header in bytes is always exactly 16 + this value. When a sender presents a
196 : // LOCAL connection, it should not present any address so it sets this field to
197 : // zero. Receivers MUST always consider this field to skip the appropriate number
198 : // of bytes and must not assume zero is presented for LOCAL connections. When a
199 : // receiver accepts an incoming connection showing an UNSPEC address family or
200 : // protocol, it may or may not decide to log the address information if present.
201 6 : let remaining_length = u16::from_be_bytes(self.buf[14..16].try_into().unwrap());
202 6 : if remaining_length < address_length {
203 0 : return Poll::Ready(Err(io::Error::new(
204 0 : io::ErrorKind::Other,
205 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
206 0 : )));
207 6 : }
208 :
209 30 : while self.buf.len() < 16 + address_length as usize {
210 24 : let mut this = self.as_mut().project();
211 24 : if ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?) == 0 {
212 0 : return Poll::Ready(Err(io::Error::new(
213 0 : io::ErrorKind::UnexpectedEof,
214 0 : "stream closed while waiting for proxy protocol addresses",
215 0 : )));
216 24 : }
217 : }
218 :
219 6 : let this = self.as_mut().project();
220 6 :
221 6 : // we are sure this is a proxy protocol v2 entry and we have read all the bytes we need
222 6 : // discard the header we have parsed
223 6 : this.buf.advance(16);
224 6 :
225 6 : // Starting from the 17th byte, addresses are presented in network byte order.
226 6 : // The address order is always the same :
227 6 : // - source layer 3 address in network byte order
228 6 : // - destination layer 3 address in network byte order
229 6 : // - source layer 4 address if any, in network byte order (port)
230 6 : // - destination layer 4 address if any, in network byte order (port)
231 6 : let addresses = this.buf.split_to(address_length as usize);
232 6 : let socket = match address_length {
233 : 12 => {
234 4 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
235 4 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
236 4 : Some(SocketAddr::from((src_addr, src_port)))
237 : }
238 : 36 => {
239 2 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
240 2 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
241 2 : Some(SocketAddr::from((src_addr, src_port)))
242 : }
243 0 : _ => None,
244 : };
245 :
246 6 : *this.tlv_bytes = remaining_length - address_length;
247 6 : self.as_mut().skip_tlv_inner();
248 6 :
249 6 : Poll::Ready(Ok(socket))
250 40 : }
251 :
252 : #[cold]
253 40 : fn read_ip(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
254 40 : let ip = ready!(self.as_mut().poll_client_ip(cx)?);
255 40 : match ip {
256 6 : Some(x) => *self.as_mut().project().state = ProxyParse::Finished(x),
257 34 : None => *self.as_mut().project().state = ProxyParse::None,
258 : }
259 40 : Poll::Ready(Ok(()))
260 40 : }
261 :
262 : #[cold]
263 68 : fn skip_tlv(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
264 68 : let mut this = self.as_mut().project();
265 68 : // we know that this.buf is empty
266 68 : debug_assert_eq!(this.buf.len(), 0);
267 :
268 68 : this.buf.reserve((*this.tlv_bytes).clamp(0, 1024) as usize);
269 68 : ready!(pin!(this.inner.read_buf(this.buf)).poll(cx)?);
270 68 : self.skip_tlv_inner();
271 68 :
272 68 : Poll::Ready(Ok(()))
273 68 : }
274 :
275 74 : fn skip_tlv_inner(self: Pin<&mut Self>) {
276 74 : let tlv_bytes_read = match u16::try_from(self.buf.len()) {
277 : // we read more than u16::MAX therefore we must have read the full tlv_bytes
278 0 : Err(_) => self.tlv_bytes,
279 : // we might not have read the full tlv bytes yet
280 74 : Ok(n) => u16::min(n, self.tlv_bytes),
281 : };
282 74 : let this = self.project();
283 74 : *this.tlv_bytes -= tlv_bytes_read;
284 74 : this.buf.advance(tlv_bytes_read as usize);
285 74 : }
286 : }
287 :
288 : impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
289 : #[inline]
290 334 : fn poll_read(
291 334 : mut self: Pin<&mut Self>,
292 334 : cx: &mut Context<'_>,
293 334 : buf: &mut ReadBuf<'_>,
294 334 : ) -> Poll<io::Result<()>> {
295 334 : // I'm assuming these 3 comparisons will be easy to branch predict.
296 334 : // especially with the cold attributes
297 334 : // which should make this read wrapper almost invisible
298 334 :
299 334 : if let ProxyParse::NotStarted = self.state {
300 40 : ready!(self.as_mut().read_ip(cx)?);
301 294 : }
302 :
303 402 : while self.tlv_bytes > 0 {
304 68 : ready!(self.as_mut().skip_tlv(cx)?)
305 : }
306 :
307 334 : let this = self.project();
308 334 : if this.buf.is_empty() {
309 296 : this.inner.poll_read(cx, buf)
310 : } else {
311 : // we know that tlv_bytes is 0
312 38 : debug_assert_eq!(*this.tlv_bytes, 0);
313 :
314 38 : let write = usize::min(this.buf.len(), buf.remaining());
315 38 : let slice = this.buf.split_to(write).freeze();
316 38 : buf.put_slice(&slice);
317 38 :
318 38 : // reset the allocation so it can be freed
319 38 : if this.buf.is_empty() {
320 34 : *this.buf = BytesMut::new();
321 34 : }
322 :
323 38 : Poll::Ready(Ok(()))
324 : }
325 334 : }
326 : }
327 :
328 : #[cfg(test)]
329 : mod tests {
330 : use std::pin::pin;
331 :
332 : use tokio::io::AsyncReadExt;
333 :
334 : use crate::protocol2::{ProxyParse, WithClientIp};
335 :
336 : #[tokio::test]
337 2 : async fn test_ipv4() {
338 2 : let header = super::HEADER
339 2 : // Proxy command, IPV4 | TCP
340 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
341 2 : // 12 + 3 bytes
342 2 : .chain([0, 15].as_slice())
343 2 : // src ip
344 2 : .chain([127, 0, 0, 1].as_slice())
345 2 : // dst ip
346 2 : .chain([192, 168, 0, 1].as_slice())
347 2 : // src port
348 2 : .chain([255, 255].as_slice())
349 2 : // dst port
350 2 : .chain([1, 1].as_slice())
351 2 : // TLV
352 2 : .chain([1, 2, 3].as_slice());
353 2 :
354 2 : let extra_data = [0x55; 256];
355 2 :
356 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
357 2 :
358 2 : let mut bytes = vec![];
359 2 : read.read_to_end(&mut bytes).await.unwrap();
360 2 :
361 2 : assert_eq!(bytes, extra_data);
362 2 : assert_eq!(
363 2 : read.state,
364 2 : ProxyParse::Finished(([127, 0, 0, 1], 65535).into())
365 2 : );
366 2 : }
367 :
368 : #[tokio::test]
369 2 : async fn test_ipv6() {
370 2 : let header = super::HEADER
371 2 : // Proxy command, IPV6 | UDP
372 2 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
373 2 : // 36 + 3 bytes
374 2 : .chain([0, 39].as_slice())
375 2 : // src ip
376 2 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
377 2 : // dst ip
378 2 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
379 2 : // src port
380 2 : .chain([1, 1].as_slice())
381 2 : // dst port
382 2 : .chain([255, 255].as_slice())
383 2 : // TLV
384 2 : .chain([1, 2, 3].as_slice());
385 2 :
386 2 : let extra_data = [0x55; 256];
387 2 :
388 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
389 2 :
390 2 : let mut bytes = vec![];
391 2 : read.read_to_end(&mut bytes).await.unwrap();
392 2 :
393 2 : assert_eq!(bytes, extra_data);
394 2 : assert_eq!(
395 2 : read.state,
396 2 : ProxyParse::Finished(
397 2 : ([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into()
398 2 : )
399 2 : );
400 2 : }
401 :
402 : #[tokio::test]
403 2 : async fn test_invalid() {
404 2 : let data = [0x55; 256];
405 2 :
406 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
407 2 :
408 2 : let mut bytes = vec![];
409 2 : read.read_to_end(&mut bytes).await.unwrap();
410 2 : assert_eq!(bytes, data);
411 2 : assert_eq!(read.state, ProxyParse::None);
412 2 : }
413 :
414 : #[tokio::test]
415 2 : async fn test_short() {
416 2 : let data = [0x55; 10];
417 2 :
418 2 : let mut read = pin!(WithClientIp::new(data.as_slice()));
419 2 :
420 2 : let mut bytes = vec![];
421 2 : read.read_to_end(&mut bytes).await.unwrap();
422 2 : assert_eq!(bytes, data);
423 2 : assert_eq!(read.state, ProxyParse::None);
424 2 : }
425 :
426 : #[tokio::test]
427 2 : async fn test_large_tlv() {
428 2 : let tlv = vec![0x55; 32768];
429 2 : let len = (12 + tlv.len() as u16).to_be_bytes();
430 2 :
431 2 : let header = super::HEADER
432 2 : // Proxy command, Inet << 4 | Stream
433 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
434 2 : // 12 + 3 bytes
435 2 : .chain(len.as_slice())
436 2 : // src ip
437 2 : .chain([55, 56, 57, 58].as_slice())
438 2 : // dst ip
439 2 : .chain([192, 168, 0, 1].as_slice())
440 2 : // src port
441 2 : .chain([255, 255].as_slice())
442 2 : // dst port
443 2 : .chain([1, 1].as_slice())
444 2 : // TLV
445 2 : .chain(tlv.as_slice());
446 2 :
447 2 : let extra_data = [0xaa; 256];
448 2 :
449 2 : let mut read = pin!(WithClientIp::new(header.chain(extra_data.as_slice())));
450 2 :
451 2 : let mut bytes = vec![];
452 2 : read.read_to_end(&mut bytes).await.unwrap();
453 2 :
454 2 : assert_eq!(bytes, extra_data);
455 2 : assert_eq!(
456 2 : read.state,
457 2 : ProxyParse::Finished(([55, 56, 57, 58], 65535).into())
458 2 : );
459 2 : }
460 : }
|