Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : io,
5 : net::SocketAddr,
6 : pin::Pin,
7 : task::{Context, Poll},
8 : };
9 :
10 : use bytes::BytesMut;
11 : use pin_project_lite::pin_project;
12 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13 :
14 : pin_project! {
15 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
16 : pub struct ChainRW<T> {
17 : #[pin]
18 : pub inner: T,
19 : buf: BytesMut,
20 : }
21 : }
22 :
23 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
24 : #[inline]
25 30 : fn poll_write(
26 30 : self: Pin<&mut Self>,
27 30 : cx: &mut Context<'_>,
28 30 : buf: &[u8],
29 30 : ) -> Poll<Result<usize, io::Error>> {
30 30 : self.project().inner.poll_write(cx, buf)
31 30 : }
32 :
33 : #[inline]
34 148 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
35 148 : self.project().inner.poll_flush(cx)
36 148 : }
37 :
38 : #[inline]
39 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
40 0 : self.project().inner.poll_shutdown(cx)
41 0 : }
42 :
43 : #[inline]
44 118 : fn poll_write_vectored(
45 118 : self: Pin<&mut Self>,
46 118 : cx: &mut Context<'_>,
47 118 : bufs: &[io::IoSlice<'_>],
48 118 : ) -> Poll<Result<usize, io::Error>> {
49 118 : self.project().inner.poll_write_vectored(cx, bufs)
50 118 : }
51 :
52 : #[inline]
53 0 : fn is_write_vectored(&self) -> bool {
54 0 : self.inner.is_write_vectored()
55 0 : }
56 : }
57 :
58 : /// Proxy Protocol Version 2 Header
59 : const HEADER: [u8; 12] = [
60 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
61 : ];
62 :
63 40 : pub async fn read_proxy_protocol<T: AsyncRead + Unpin>(
64 40 : mut read: T,
65 40 : ) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
66 40 : let mut buf = BytesMut::with_capacity(128);
67 58 : while buf.len() < 16 {
68 52 : let bytes_read = read.read_buf(&mut buf).await?;
69 :
70 : // exit for bad header
71 52 : let len = usize::min(buf.len(), HEADER.len());
72 52 : if buf[..len] != HEADER[..len] {
73 34 : return Ok((ChainRW { inner: read, buf }, None));
74 18 : }
75 18 :
76 18 : // if no more bytes available then exit
77 18 : if bytes_read == 0 {
78 0 : return Ok((ChainRW { inner: read, buf }, None));
79 18 : };
80 : }
81 :
82 6 : let header = buf.split_to(16);
83 6 :
84 6 : // The next byte (the 13th one) is the protocol version and command.
85 6 : // The highest four bits contains the version. As of this specification, it must
86 6 : // always be sent as \x2 and the receiver must only accept this value.
87 6 : let vc = header[12];
88 6 : let version = vc >> 4;
89 6 : let command = vc & 0b1111;
90 6 : if version != 2 {
91 0 : return Err(io::Error::new(
92 0 : io::ErrorKind::Other,
93 0 : "invalid proxy protocol version. expected version 2",
94 0 : ));
95 6 : }
96 6 : match command {
97 : // the connection was established on purpose by the proxy
98 : // without being relayed. The connection endpoints are the sender and the
99 : // receiver. Such connections exist when the proxy sends health-checks to the
100 : // server. The receiver must accept this connection as valid and must use the
101 : // real connection endpoints and discard the protocol block including the
102 : // family which is ignored.
103 0 : 0 => {}
104 : // the connection was established on behalf of another node,
105 : // and reflects the original connection endpoints. The receiver must then use
106 : // the information provided in the protocol block to get original the address.
107 6 : 1 => {}
108 : // other values are unassigned and must not be emitted by senders. Receivers
109 : // must drop connections presenting unexpected values here.
110 : _ => {
111 0 : return Err(io::Error::new(
112 0 : io::ErrorKind::Other,
113 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
114 0 : ))
115 : }
116 : };
117 :
118 : // The 14th byte contains the transport protocol and address family. The highest 4
119 : // bits contain the address family, the lowest 4 bits contain the protocol.
120 6 : let ft = header[13];
121 6 : let address_length = match ft {
122 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
123 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
124 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
125 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
126 4 : 0x11 | 0x12 => 12,
127 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
128 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
129 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
130 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
131 2 : 0x21 | 0x22 => 36,
132 : // unspecified or unix stream. ignore the addresses
133 0 : _ => 0,
134 : };
135 :
136 : // The 15th and 16th bytes is the address length in bytes in network endian order.
137 : // It is used so that the receiver knows how many address bytes to skip even when
138 : // it does not implement the presented protocol. Thus the length of the protocol
139 : // header in bytes is always exactly 16 + this value. When a sender presents a
140 : // LOCAL connection, it should not present any address so it sets this field to
141 : // zero. Receivers MUST always consider this field to skip the appropriate number
142 : // of bytes and must not assume zero is presented for LOCAL connections. When a
143 : // receiver accepts an incoming connection showing an UNSPEC address family or
144 : // protocol, it may or may not decide to log the address information if present.
145 6 : let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
146 6 : if remaining_length < address_length {
147 0 : return Err(io::Error::new(
148 0 : io::ErrorKind::Other,
149 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
150 0 : ));
151 6 : }
152 6 : drop(header);
153 :
154 54 : while buf.len() < remaining_length as usize {
155 48 : if read.read_buf(&mut buf).await? == 0 {
156 0 : return Err(io::Error::new(
157 0 : io::ErrorKind::UnexpectedEof,
158 0 : "stream closed while waiting for proxy protocol addresses",
159 0 : ));
160 48 : }
161 : }
162 :
163 : // Starting from the 17th byte, addresses are presented in network byte order.
164 : // The address order is always the same :
165 : // - source layer 3 address in network byte order
166 : // - destination layer 3 address in network byte order
167 : // - source layer 4 address if any, in network byte order (port)
168 : // - destination layer 4 address if any, in network byte order (port)
169 6 : let addresses = buf.split_to(remaining_length as usize);
170 6 : let socket = match address_length {
171 : 12 => {
172 4 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
173 4 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
174 4 : Some(SocketAddr::from((src_addr, src_port)))
175 : }
176 : 36 => {
177 2 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
178 2 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
179 2 : Some(SocketAddr::from((src_addr, src_port)))
180 : }
181 0 : _ => None,
182 : };
183 :
184 6 : Ok((ChainRW { inner: read, buf }, socket))
185 40 : }
186 :
187 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
188 : #[inline]
189 334 : fn poll_read(
190 334 : self: Pin<&mut Self>,
191 334 : cx: &mut Context<'_>,
192 334 : buf: &mut ReadBuf<'_>,
193 334 : ) -> Poll<io::Result<()>> {
194 334 : if self.buf.is_empty() {
195 296 : self.project().inner.poll_read(cx, buf)
196 : } else {
197 38 : self.read_from_buf(buf)
198 : }
199 334 : }
200 : }
201 :
202 : impl<T: AsyncRead> ChainRW<T> {
203 : #[cold]
204 38 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
205 38 : debug_assert!(!self.buf.is_empty());
206 38 : let this = self.project();
207 38 :
208 38 : let write = usize::min(this.buf.len(), buf.remaining());
209 38 : let slice = this.buf.split_to(write).freeze();
210 38 : buf.put_slice(&slice);
211 38 :
212 38 : // reset the allocation so it can be freed
213 38 : if this.buf.is_empty() {
214 34 : *this.buf = BytesMut::new();
215 34 : }
216 :
217 38 : Poll::Ready(Ok(()))
218 38 : }
219 : }
220 :
221 : #[cfg(test)]
222 : mod tests {
223 : use tokio::io::AsyncReadExt;
224 :
225 : use crate::protocol2::read_proxy_protocol;
226 :
227 : #[tokio::test]
228 2 : async fn test_ipv4() {
229 2 : let header = super::HEADER
230 2 : // Proxy command, IPV4 | TCP
231 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
232 2 : // 12 + 3 bytes
233 2 : .chain([0, 15].as_slice())
234 2 : // src ip
235 2 : .chain([127, 0, 0, 1].as_slice())
236 2 : // dst ip
237 2 : .chain([192, 168, 0, 1].as_slice())
238 2 : // src port
239 2 : .chain([255, 255].as_slice())
240 2 : // dst port
241 2 : .chain([1, 1].as_slice())
242 2 : // TLV
243 2 : .chain([1, 2, 3].as_slice());
244 2 :
245 2 : let extra_data = [0x55; 256];
246 2 :
247 2 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
248 2 : .await
249 2 : .unwrap();
250 2 :
251 2 : let mut bytes = vec![];
252 2 : read.read_to_end(&mut bytes).await.unwrap();
253 2 :
254 2 : assert_eq!(bytes, extra_data);
255 2 : assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
256 2 : }
257 :
258 : #[tokio::test]
259 2 : async fn test_ipv6() {
260 2 : let header = super::HEADER
261 2 : // Proxy command, IPV6 | UDP
262 2 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
263 2 : // 36 + 3 bytes
264 2 : .chain([0, 39].as_slice())
265 2 : // src ip
266 2 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
267 2 : // dst ip
268 2 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
269 2 : // src port
270 2 : .chain([1, 1].as_slice())
271 2 : // dst port
272 2 : .chain([255, 255].as_slice())
273 2 : // TLV
274 2 : .chain([1, 2, 3].as_slice());
275 2 :
276 2 : let extra_data = [0x55; 256];
277 2 :
278 2 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
279 2 : .await
280 2 : .unwrap();
281 2 :
282 2 : let mut bytes = vec![];
283 2 : read.read_to_end(&mut bytes).await.unwrap();
284 2 :
285 2 : assert_eq!(bytes, extra_data);
286 2 : assert_eq!(
287 2 : addr,
288 2 : Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
289 2 : );
290 2 : }
291 :
292 : #[tokio::test]
293 2 : async fn test_invalid() {
294 2 : let data = [0x55; 256];
295 2 :
296 2 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
297 2 :
298 2 : let mut bytes = vec![];
299 2 : read.read_to_end(&mut bytes).await.unwrap();
300 2 : assert_eq!(bytes, data);
301 2 : assert_eq!(addr, None);
302 2 : }
303 :
304 : #[tokio::test]
305 2 : async fn test_short() {
306 2 : let data = [0x55; 10];
307 2 :
308 2 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
309 2 :
310 2 : let mut bytes = vec![];
311 2 : read.read_to_end(&mut bytes).await.unwrap();
312 2 : assert_eq!(bytes, data);
313 2 : assert_eq!(addr, None);
314 2 : }
315 :
316 : #[tokio::test]
317 2 : async fn test_large_tlv() {
318 2 : let tlv = vec![0x55; 32768];
319 2 : let len = (12 + tlv.len() as u16).to_be_bytes();
320 2 :
321 2 : let header = super::HEADER
322 2 : // Proxy command, Inet << 4 | Stream
323 2 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
324 2 : // 12 + 3 bytes
325 2 : .chain(len.as_slice())
326 2 : // src ip
327 2 : .chain([55, 56, 57, 58].as_slice())
328 2 : // dst ip
329 2 : .chain([192, 168, 0, 1].as_slice())
330 2 : // src port
331 2 : .chain([255, 255].as_slice())
332 2 : // dst port
333 2 : .chain([1, 1].as_slice())
334 2 : // TLV
335 2 : .chain(tlv.as_slice());
336 2 :
337 2 : let extra_data = [0xaa; 256];
338 2 :
339 2 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
340 2 : .await
341 2 : .unwrap();
342 2 :
343 2 : let mut bytes = vec![];
344 2 : read.read_to_end(&mut bytes).await.unwrap();
345 2 :
346 2 : assert_eq!(bytes, extra_data);
347 2 : assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
348 2 : }
349 : }
|