Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : io,
5 : net::SocketAddr,
6 : pin::Pin,
7 : task::{Context, Poll},
8 : };
9 :
10 : use bytes::BytesMut;
11 : use pin_project_lite::pin_project;
12 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13 :
14 : pin_project! {
15 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
16 : pub(crate) struct ChainRW<T> {
17 : #[pin]
18 : pub(crate) inner: T,
19 : buf: BytesMut,
20 : }
21 : }
22 :
23 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
24 : #[inline]
25 90 : fn poll_write(
26 90 : self: Pin<&mut Self>,
27 90 : cx: &mut Context<'_>,
28 90 : buf: &[u8],
29 90 : ) -> Poll<Result<usize, io::Error>> {
30 90 : self.project().inner.poll_write(cx, buf)
31 90 : }
32 :
33 : #[inline]
34 444 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
35 444 : self.project().inner.poll_flush(cx)
36 444 : }
37 :
38 : #[inline]
39 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
40 0 : self.project().inner.poll_shutdown(cx)
41 0 : }
42 :
43 : #[inline]
44 354 : fn poll_write_vectored(
45 354 : self: Pin<&mut Self>,
46 354 : cx: &mut Context<'_>,
47 354 : bufs: &[io::IoSlice<'_>],
48 354 : ) -> Poll<Result<usize, io::Error>> {
49 354 : self.project().inner.poll_write_vectored(cx, bufs)
50 354 : }
51 :
52 : #[inline]
53 0 : fn is_write_vectored(&self) -> bool {
54 0 : self.inner.is_write_vectored()
55 0 : }
56 : }
57 :
58 : /// Proxy Protocol Version 2 Header
59 : const HEADER: [u8; 12] = [
60 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
61 : ];
62 :
63 120 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
64 120 : mut read: T,
65 120 : ) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
66 120 : let mut buf = BytesMut::with_capacity(128);
67 174 : while buf.len() < 16 {
68 156 : let bytes_read = read.read_buf(&mut buf).await?;
69 :
70 : // exit for bad header
71 156 : let len = usize::min(buf.len(), HEADER.len());
72 156 : if buf[..len] != HEADER[..len] {
73 102 : return Ok((ChainRW { inner: read, buf }, None));
74 54 : }
75 54 :
76 54 : // if no more bytes available then exit
77 54 : if bytes_read == 0 {
78 0 : return Ok((ChainRW { inner: read, buf }, None));
79 54 : };
80 : }
81 :
82 18 : let header = buf.split_to(16);
83 18 :
84 18 : // The next byte (the 13th one) is the protocol version and command.
85 18 : // The highest four bits contains the version. As of this specification, it must
86 18 : // always be sent as \x2 and the receiver must only accept this value.
87 18 : let vc = header[12];
88 18 : let version = vc >> 4;
89 18 : let command = vc & 0b1111;
90 18 : if version != 2 {
91 0 : return Err(io::Error::new(
92 0 : io::ErrorKind::Other,
93 0 : "invalid proxy protocol version. expected version 2",
94 0 : ));
95 18 : }
96 18 : match command {
97 : // the connection was established on purpose by the proxy
98 : // without being relayed. The connection endpoints are the sender and the
99 : // receiver. Such connections exist when the proxy sends health-checks to the
100 : // server. The receiver must accept this connection as valid and must use the
101 : // real connection endpoints and discard the protocol block including the
102 : // family which is ignored.
103 0 : 0 => {}
104 : // the connection was established on behalf of another node,
105 : // and reflects the original connection endpoints. The receiver must then use
106 : // the information provided in the protocol block to get original the address.
107 18 : 1 => {}
108 : // other values are unassigned and must not be emitted by senders. Receivers
109 : // must drop connections presenting unexpected values here.
110 : _ => {
111 0 : return Err(io::Error::new(
112 0 : io::ErrorKind::Other,
113 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
114 0 : ))
115 : }
116 : };
117 :
118 : // The 14th byte contains the transport protocol and address family. The highest 4
119 : // bits contain the address family, the lowest 4 bits contain the protocol.
120 18 : let ft = header[13];
121 18 : let address_length = match ft {
122 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
123 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
124 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
125 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
126 12 : 0x11 | 0x12 => 12,
127 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
128 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
129 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
130 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
131 6 : 0x21 | 0x22 => 36,
132 : // unspecified or unix stream. ignore the addresses
133 0 : _ => 0,
134 : };
135 :
136 : // The 15th and 16th bytes is the address length in bytes in network endian order.
137 : // It is used so that the receiver knows how many address bytes to skip even when
138 : // it does not implement the presented protocol. Thus the length of the protocol
139 : // header in bytes is always exactly 16 + this value. When a sender presents a
140 : // LOCAL connection, it should not present any address so it sets this field to
141 : // zero. Receivers MUST always consider this field to skip the appropriate number
142 : // of bytes and must not assume zero is presented for LOCAL connections. When a
143 : // receiver accepts an incoming connection showing an UNSPEC address family or
144 : // protocol, it may or may not decide to log the address information if present.
145 18 : let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
146 18 : if remaining_length < address_length {
147 0 : return Err(io::Error::new(
148 0 : io::ErrorKind::Other,
149 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
150 0 : ));
151 18 : }
152 18 : drop(header);
153 :
154 162 : while buf.len() < remaining_length as usize {
155 144 : if read.read_buf(&mut buf).await? == 0 {
156 0 : return Err(io::Error::new(
157 0 : io::ErrorKind::UnexpectedEof,
158 0 : "stream closed while waiting for proxy protocol addresses",
159 0 : ));
160 144 : }
161 : }
162 :
163 : // Starting from the 17th byte, addresses are presented in network byte order.
164 : // The address order is always the same :
165 : // - source layer 3 address in network byte order
166 : // - destination layer 3 address in network byte order
167 : // - source layer 4 address if any, in network byte order (port)
168 : // - destination layer 4 address if any, in network byte order (port)
169 18 : let addresses = buf.split_to(remaining_length as usize);
170 18 : let socket = match address_length {
171 : 12 => {
172 12 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
173 12 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
174 12 : Some(SocketAddr::from((src_addr, src_port)))
175 : }
176 : 36 => {
177 6 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
178 6 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
179 6 : Some(SocketAddr::from((src_addr, src_port)))
180 : }
181 0 : _ => None,
182 : };
183 :
184 18 : Ok((ChainRW { inner: read, buf }, socket))
185 120 : }
186 :
187 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
188 : #[inline]
189 1002 : fn poll_read(
190 1002 : self: Pin<&mut Self>,
191 1002 : cx: &mut Context<'_>,
192 1002 : buf: &mut ReadBuf<'_>,
193 1002 : ) -> Poll<io::Result<()>> {
194 1002 : if self.buf.is_empty() {
195 888 : self.project().inner.poll_read(cx, buf)
196 : } else {
197 114 : self.read_from_buf(buf)
198 : }
199 1002 : }
200 : }
201 :
202 : impl<T: AsyncRead> ChainRW<T> {
203 : #[cold]
204 114 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
205 114 : debug_assert!(!self.buf.is_empty());
206 114 : let this = self.project();
207 114 :
208 114 : let write = usize::min(this.buf.len(), buf.remaining());
209 114 : let slice = this.buf.split_to(write).freeze();
210 114 : buf.put_slice(&slice);
211 114 :
212 114 : // reset the allocation so it can be freed
213 114 : if this.buf.is_empty() {
214 102 : *this.buf = BytesMut::new();
215 102 : }
216 :
217 114 : Poll::Ready(Ok(()))
218 114 : }
219 : }
220 :
221 : #[cfg(test)]
222 : mod tests {
223 : use tokio::io::AsyncReadExt;
224 :
225 : use crate::protocol2::read_proxy_protocol;
226 :
227 : #[tokio::test]
228 6 : async fn test_ipv4() {
229 6 : let header = super::HEADER
230 6 : // Proxy command, IPV4 | TCP
231 6 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
232 6 : // 12 + 3 bytes
233 6 : .chain([0, 15].as_slice())
234 6 : // src ip
235 6 : .chain([127, 0, 0, 1].as_slice())
236 6 : // dst ip
237 6 : .chain([192, 168, 0, 1].as_slice())
238 6 : // src port
239 6 : .chain([255, 255].as_slice())
240 6 : // dst port
241 6 : .chain([1, 1].as_slice())
242 6 : // TLV
243 6 : .chain([1, 2, 3].as_slice());
244 6 :
245 6 : let extra_data = [0x55; 256];
246 6 :
247 6 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
248 6 : .await
249 6 : .unwrap();
250 6 :
251 6 : let mut bytes = vec![];
252 6 : read.read_to_end(&mut bytes).await.unwrap();
253 6 :
254 6 : assert_eq!(bytes, extra_data);
255 6 : assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
256 6 : }
257 :
258 : #[tokio::test]
259 6 : async fn test_ipv6() {
260 6 : let header = super::HEADER
261 6 : // Proxy command, IPV6 | UDP
262 6 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
263 6 : // 36 + 3 bytes
264 6 : .chain([0, 39].as_slice())
265 6 : // src ip
266 6 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
267 6 : // dst ip
268 6 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
269 6 : // src port
270 6 : .chain([1, 1].as_slice())
271 6 : // dst port
272 6 : .chain([255, 255].as_slice())
273 6 : // TLV
274 6 : .chain([1, 2, 3].as_slice());
275 6 :
276 6 : let extra_data = [0x55; 256];
277 6 :
278 6 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
279 6 : .await
280 6 : .unwrap();
281 6 :
282 6 : let mut bytes = vec![];
283 6 : read.read_to_end(&mut bytes).await.unwrap();
284 6 :
285 6 : assert_eq!(bytes, extra_data);
286 6 : assert_eq!(
287 6 : addr,
288 6 : Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
289 6 : );
290 6 : }
291 :
292 : #[tokio::test]
293 6 : async fn test_invalid() {
294 6 : let data = [0x55; 256];
295 6 :
296 6 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
297 6 :
298 6 : let mut bytes = vec![];
299 6 : read.read_to_end(&mut bytes).await.unwrap();
300 6 : assert_eq!(bytes, data);
301 6 : assert_eq!(addr, None);
302 6 : }
303 :
304 : #[tokio::test]
305 6 : async fn test_short() {
306 6 : let data = [0x55; 10];
307 6 :
308 6 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
309 6 :
310 6 : let mut bytes = vec![];
311 6 : read.read_to_end(&mut bytes).await.unwrap();
312 6 : assert_eq!(bytes, data);
313 6 : assert_eq!(addr, None);
314 6 : }
315 :
316 : #[tokio::test]
317 6 : async fn test_large_tlv() {
318 6 : let tlv = vec![0x55; 32768];
319 6 : let len = (12 + tlv.len() as u16).to_be_bytes();
320 6 :
321 6 : let header = super::HEADER
322 6 : // Proxy command, Inet << 4 | Stream
323 6 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
324 6 : // 12 + 3 bytes
325 6 : .chain(len.as_slice())
326 6 : // src ip
327 6 : .chain([55, 56, 57, 58].as_slice())
328 6 : // dst ip
329 6 : .chain([192, 168, 0, 1].as_slice())
330 6 : // src port
331 6 : .chain([255, 255].as_slice())
332 6 : // dst port
333 6 : .chain([1, 1].as_slice())
334 6 : // TLV
335 6 : .chain(tlv.as_slice());
336 6 :
337 6 : let extra_data = [0xaa; 256];
338 6 :
339 6 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
340 6 : .await
341 6 : .unwrap();
342 6 :
343 6 : let mut bytes = vec![];
344 6 : read.read_to_end(&mut bytes).await.unwrap();
345 6 :
346 6 : assert_eq!(bytes, extra_data);
347 6 : assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
348 6 : }
349 : }
|