Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::{
4 : io,
5 : net::SocketAddr,
6 : pin::Pin,
7 : task::{Context, Poll},
8 : };
9 :
10 : use bytes::BytesMut;
11 : use pin_project_lite::pin_project;
12 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
13 :
14 : pin_project! {
15 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
16 : pub(crate) struct ChainRW<T> {
17 : #[pin]
18 : pub(crate) inner: T,
19 : buf: BytesMut,
20 : }
21 : }
22 :
23 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
24 : #[inline]
25 15 : fn poll_write(
26 15 : self: Pin<&mut Self>,
27 15 : cx: &mut Context<'_>,
28 15 : buf: &[u8],
29 15 : ) -> Poll<Result<usize, io::Error>> {
30 15 : self.project().inner.poll_write(cx, buf)
31 15 : }
32 :
33 : #[inline]
34 74 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
35 74 : self.project().inner.poll_flush(cx)
36 74 : }
37 :
38 : #[inline]
39 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
40 0 : self.project().inner.poll_shutdown(cx)
41 0 : }
42 :
43 : #[inline]
44 59 : fn poll_write_vectored(
45 59 : self: Pin<&mut Self>,
46 59 : cx: &mut Context<'_>,
47 59 : bufs: &[io::IoSlice<'_>],
48 59 : ) -> Poll<Result<usize, io::Error>> {
49 59 : self.project().inner.poll_write_vectored(cx, bufs)
50 59 : }
51 :
52 : #[inline]
53 0 : fn is_write_vectored(&self) -> bool {
54 0 : self.inner.is_write_vectored()
55 0 : }
56 : }
57 :
58 : /// Proxy Protocol Version 2 Header
59 : const HEADER: [u8; 12] = [
60 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
61 : ];
62 :
63 20 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
64 20 : mut read: T,
65 20 : ) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
66 20 : let mut buf = BytesMut::with_capacity(128);
67 29 : while buf.len() < 16 {
68 26 : let bytes_read = read.read_buf(&mut buf).await?;
69 :
70 : // exit for bad header
71 26 : let len = usize::min(buf.len(), HEADER.len());
72 26 : if buf[..len] != HEADER[..len] {
73 17 : return Ok((ChainRW { inner: read, buf }, None));
74 9 : }
75 9 :
76 9 : // if no more bytes available then exit
77 9 : if bytes_read == 0 {
78 0 : return Ok((ChainRW { inner: read, buf }, None));
79 9 : };
80 : }
81 :
82 3 : let header = buf.split_to(16);
83 3 :
84 3 : // The next byte (the 13th one) is the protocol version and command.
85 3 : // The highest four bits contains the version. As of this specification, it must
86 3 : // always be sent as \x2 and the receiver must only accept this value.
87 3 : let vc = header[12];
88 3 : let version = vc >> 4;
89 3 : let command = vc & 0b1111;
90 3 : if version != 2 {
91 0 : return Err(io::Error::new(
92 0 : io::ErrorKind::Other,
93 0 : "invalid proxy protocol version. expected version 2",
94 0 : ));
95 3 : }
96 3 : match command {
97 : // the connection was established on purpose by the proxy
98 : // without being relayed. The connection endpoints are the sender and the
99 : // receiver. Such connections exist when the proxy sends health-checks to the
100 : // server. The receiver must accept this connection as valid and must use the
101 : // real connection endpoints and discard the protocol block including the
102 : // family which is ignored.
103 0 : 0 => {}
104 : // the connection was established on behalf of another node,
105 : // and reflects the original connection endpoints. The receiver must then use
106 : // the information provided in the protocol block to get original the address.
107 3 : 1 => {}
108 : // other values are unassigned and must not be emitted by senders. Receivers
109 : // must drop connections presenting unexpected values here.
110 : _ => {
111 0 : return Err(io::Error::new(
112 0 : io::ErrorKind::Other,
113 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
114 0 : ))
115 : }
116 : };
117 :
118 : // The 14th byte contains the transport protocol and address family. The highest 4
119 : // bits contain the address family, the lowest 4 bits contain the protocol.
120 3 : let ft = header[13];
121 3 : let address_length = match ft {
122 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
123 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
124 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
125 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
126 2 : 0x11 | 0x12 => 12,
127 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
128 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
129 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
130 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
131 1 : 0x21 | 0x22 => 36,
132 : // unspecified or unix stream. ignore the addresses
133 0 : _ => 0,
134 : };
135 :
136 : // The 15th and 16th bytes is the address length in bytes in network endian order.
137 : // It is used so that the receiver knows how many address bytes to skip even when
138 : // it does not implement the presented protocol. Thus the length of the protocol
139 : // header in bytes is always exactly 16 + this value. When a sender presents a
140 : // LOCAL connection, it should not present any address so it sets this field to
141 : // zero. Receivers MUST always consider this field to skip the appropriate number
142 : // of bytes and must not assume zero is presented for LOCAL connections. When a
143 : // receiver accepts an incoming connection showing an UNSPEC address family or
144 : // protocol, it may or may not decide to log the address information if present.
145 3 : let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
146 3 : if remaining_length < address_length {
147 0 : return Err(io::Error::new(
148 0 : io::ErrorKind::Other,
149 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
150 0 : ));
151 3 : }
152 3 : drop(header);
153 :
154 27 : while buf.len() < remaining_length as usize {
155 24 : if read.read_buf(&mut buf).await? == 0 {
156 0 : return Err(io::Error::new(
157 0 : io::ErrorKind::UnexpectedEof,
158 0 : "stream closed while waiting for proxy protocol addresses",
159 0 : ));
160 24 : }
161 : }
162 :
163 : // Starting from the 17th byte, addresses are presented in network byte order.
164 : // The address order is always the same :
165 : // - source layer 3 address in network byte order
166 : // - destination layer 3 address in network byte order
167 : // - source layer 4 address if any, in network byte order (port)
168 : // - destination layer 4 address if any, in network byte order (port)
169 3 : let addresses = buf.split_to(remaining_length as usize);
170 3 : let socket = match address_length {
171 : 12 => {
172 2 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
173 2 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
174 2 : Some(SocketAddr::from((src_addr, src_port)))
175 : }
176 : 36 => {
177 1 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
178 1 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
179 1 : Some(SocketAddr::from((src_addr, src_port)))
180 : }
181 0 : _ => None,
182 : };
183 :
184 3 : Ok((ChainRW { inner: read, buf }, socket))
185 20 : }
186 :
187 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
188 : #[inline]
189 167 : fn poll_read(
190 167 : self: Pin<&mut Self>,
191 167 : cx: &mut Context<'_>,
192 167 : buf: &mut ReadBuf<'_>,
193 167 : ) -> Poll<io::Result<()>> {
194 167 : if self.buf.is_empty() {
195 148 : self.project().inner.poll_read(cx, buf)
196 : } else {
197 19 : self.read_from_buf(buf)
198 : }
199 167 : }
200 : }
201 :
202 : impl<T: AsyncRead> ChainRW<T> {
203 : #[cold]
204 19 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
205 19 : debug_assert!(!self.buf.is_empty());
206 19 : let this = self.project();
207 19 :
208 19 : let write = usize::min(this.buf.len(), buf.remaining());
209 19 : let slice = this.buf.split_to(write).freeze();
210 19 : buf.put_slice(&slice);
211 19 :
212 19 : // reset the allocation so it can be freed
213 19 : if this.buf.is_empty() {
214 17 : *this.buf = BytesMut::new();
215 17 : }
216 :
217 19 : Poll::Ready(Ok(()))
218 19 : }
219 : }
220 :
221 : #[cfg(test)]
222 : mod tests {
223 : use tokio::io::AsyncReadExt;
224 :
225 : use crate::protocol2::read_proxy_protocol;
226 :
227 : #[tokio::test]
228 1 : async fn test_ipv4() {
229 1 : let header = super::HEADER
230 1 : // Proxy command, IPV4 | TCP
231 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
232 1 : // 12 + 3 bytes
233 1 : .chain([0, 15].as_slice())
234 1 : // src ip
235 1 : .chain([127, 0, 0, 1].as_slice())
236 1 : // dst ip
237 1 : .chain([192, 168, 0, 1].as_slice())
238 1 : // src port
239 1 : .chain([255, 255].as_slice())
240 1 : // dst port
241 1 : .chain([1, 1].as_slice())
242 1 : // TLV
243 1 : .chain([1, 2, 3].as_slice());
244 1 :
245 1 : let extra_data = [0x55; 256];
246 1 :
247 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
248 1 : .await
249 1 : .unwrap();
250 1 :
251 1 : let mut bytes = vec![];
252 1 : read.read_to_end(&mut bytes).await.unwrap();
253 1 :
254 1 : assert_eq!(bytes, extra_data);
255 1 : assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
256 1 : }
257 :
258 : #[tokio::test]
259 1 : async fn test_ipv6() {
260 1 : let header = super::HEADER
261 1 : // Proxy command, IPV6 | UDP
262 1 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
263 1 : // 36 + 3 bytes
264 1 : .chain([0, 39].as_slice())
265 1 : // src ip
266 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
267 1 : // dst ip
268 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
269 1 : // src port
270 1 : .chain([1, 1].as_slice())
271 1 : // dst port
272 1 : .chain([255, 255].as_slice())
273 1 : // TLV
274 1 : .chain([1, 2, 3].as_slice());
275 1 :
276 1 : let extra_data = [0x55; 256];
277 1 :
278 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
279 1 : .await
280 1 : .unwrap();
281 1 :
282 1 : let mut bytes = vec![];
283 1 : read.read_to_end(&mut bytes).await.unwrap();
284 1 :
285 1 : assert_eq!(bytes, extra_data);
286 1 : assert_eq!(
287 1 : addr,
288 1 : Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
289 1 : );
290 1 : }
291 :
292 : #[tokio::test]
293 1 : async fn test_invalid() {
294 1 : let data = [0x55; 256];
295 1 :
296 1 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
297 1 :
298 1 : let mut bytes = vec![];
299 1 : read.read_to_end(&mut bytes).await.unwrap();
300 1 : assert_eq!(bytes, data);
301 1 : assert_eq!(addr, None);
302 1 : }
303 :
304 : #[tokio::test]
305 1 : async fn test_short() {
306 1 : let data = [0x55; 10];
307 1 :
308 1 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
309 1 :
310 1 : let mut bytes = vec![];
311 1 : read.read_to_end(&mut bytes).await.unwrap();
312 1 : assert_eq!(bytes, data);
313 1 : assert_eq!(addr, None);
314 1 : }
315 :
316 : #[tokio::test]
317 1 : async fn test_large_tlv() {
318 1 : let tlv = vec![0x55; 32768];
319 1 : let len = (12 + tlv.len() as u16).to_be_bytes();
320 1 :
321 1 : let header = super::HEADER
322 1 : // Proxy command, Inet << 4 | Stream
323 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
324 1 : // 12 + 3 bytes
325 1 : .chain(len.as_slice())
326 1 : // src ip
327 1 : .chain([55, 56, 57, 58].as_slice())
328 1 : // dst ip
329 1 : .chain([192, 168, 0, 1].as_slice())
330 1 : // src port
331 1 : .chain([255, 255].as_slice())
332 1 : // dst port
333 1 : .chain([1, 1].as_slice())
334 1 : // TLV
335 1 : .chain(tlv.as_slice());
336 1 :
337 1 : let extra_data = [0xaa; 256];
338 1 :
339 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
340 1 : .await
341 1 : .unwrap();
342 1 :
343 1 : let mut bytes = vec![];
344 1 : read.read_to_end(&mut bytes).await.unwrap();
345 1 :
346 1 : assert_eq!(bytes, extra_data);
347 1 : assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
348 1 : }
349 : }
|