Line data Source code
1 : //! Proxy Protocol V2 implementation
2 :
3 : use std::io;
4 : use std::net::SocketAddr;
5 : use std::pin::Pin;
6 : use std::task::{Context, Poll};
7 :
8 : use bytes::BytesMut;
9 : use pin_project_lite::pin_project;
10 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
11 :
12 : pin_project! {
13 : /// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
14 : pub(crate) struct ChainRW<T> {
15 : #[pin]
16 : pub(crate) inner: T,
17 : buf: BytesMut,
18 : }
19 : }
20 :
21 : impl<T: AsyncWrite> AsyncWrite for ChainRW<T> {
22 : #[inline]
23 15 : fn poll_write(
24 15 : self: Pin<&mut Self>,
25 15 : cx: &mut Context<'_>,
26 15 : buf: &[u8],
27 15 : ) -> Poll<Result<usize, io::Error>> {
28 15 : self.project().inner.poll_write(cx, buf)
29 15 : }
30 :
31 : #[inline]
32 74 : fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
33 74 : self.project().inner.poll_flush(cx)
34 74 : }
35 :
36 : #[inline]
37 0 : fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
38 0 : self.project().inner.poll_shutdown(cx)
39 0 : }
40 :
41 : #[inline]
42 59 : fn poll_write_vectored(
43 59 : self: Pin<&mut Self>,
44 59 : cx: &mut Context<'_>,
45 59 : bufs: &[io::IoSlice<'_>],
46 59 : ) -> Poll<Result<usize, io::Error>> {
47 59 : self.project().inner.poll_write_vectored(cx, bufs)
48 59 : }
49 :
50 : #[inline]
51 0 : fn is_write_vectored(&self) -> bool {
52 0 : self.inner.is_write_vectored()
53 0 : }
54 : }
55 :
56 : /// Proxy Protocol Version 2 Header
57 : const HEADER: [u8; 12] = [
58 : 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
59 : ];
60 :
61 20 : pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
62 20 : mut read: T,
63 20 : ) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
64 20 : let mut buf = BytesMut::with_capacity(128);
65 29 : while buf.len() < 16 {
66 26 : let bytes_read = read.read_buf(&mut buf).await?;
67 :
68 : // exit for bad header
69 26 : let len = usize::min(buf.len(), HEADER.len());
70 26 : if buf[..len] != HEADER[..len] {
71 17 : return Ok((ChainRW { inner: read, buf }, None));
72 9 : }
73 9 :
74 9 : // if no more bytes available then exit
75 9 : if bytes_read == 0 {
76 0 : return Ok((ChainRW { inner: read, buf }, None));
77 9 : };
78 : }
79 :
80 3 : let header = buf.split_to(16);
81 3 :
82 3 : // The next byte (the 13th one) is the protocol version and command.
83 3 : // The highest four bits contains the version. As of this specification, it must
84 3 : // always be sent as \x2 and the receiver must only accept this value.
85 3 : let vc = header[12];
86 3 : let version = vc >> 4;
87 3 : let command = vc & 0b1111;
88 3 : if version != 2 {
89 0 : return Err(io::Error::new(
90 0 : io::ErrorKind::Other,
91 0 : "invalid proxy protocol version. expected version 2",
92 0 : ));
93 3 : }
94 3 : match command {
95 : // the connection was established on purpose by the proxy
96 : // without being relayed. The connection endpoints are the sender and the
97 : // receiver. Such connections exist when the proxy sends health-checks to the
98 : // server. The receiver must accept this connection as valid and must use the
99 : // real connection endpoints and discard the protocol block including the
100 : // family which is ignored.
101 0 : 0 => {}
102 : // the connection was established on behalf of another node,
103 : // and reflects the original connection endpoints. The receiver must then use
104 : // the information provided in the protocol block to get original the address.
105 3 : 1 => {}
106 : // other values are unassigned and must not be emitted by senders. Receivers
107 : // must drop connections presenting unexpected values here.
108 : _ => {
109 0 : return Err(io::Error::new(
110 0 : io::ErrorKind::Other,
111 0 : "invalid proxy protocol command. expected local (0) or proxy (1)",
112 0 : ))
113 : }
114 : };
115 :
116 : // The 14th byte contains the transport protocol and address family. The highest 4
117 : // bits contain the address family, the lowest 4 bits contain the protocol.
118 3 : let ft = header[13];
119 3 : let address_length = match ft {
120 : // - \x11 : TCP over IPv4 : the forwarded connection uses TCP over the AF_INET
121 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
122 : // - \x12 : UDP over IPv4 : the forwarded connection uses UDP over the AF_INET
123 : // protocol family. Address length is 2*4 + 2*2 = 12 bytes.
124 2 : 0x11 | 0x12 => 12,
125 : // - \x21 : TCP over IPv6 : the forwarded connection uses TCP over the AF_INET6
126 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
127 : // - \x22 : UDP over IPv6 : the forwarded connection uses UDP over the AF_INET6
128 : // protocol family. Address length is 2*16 + 2*2 = 36 bytes.
129 1 : 0x21 | 0x22 => 36,
130 : // unspecified or unix stream. ignore the addresses
131 0 : _ => 0,
132 : };
133 :
134 : // The 15th and 16th bytes is the address length in bytes in network endian order.
135 : // It is used so that the receiver knows how many address bytes to skip even when
136 : // it does not implement the presented protocol. Thus the length of the protocol
137 : // header in bytes is always exactly 16 + this value. When a sender presents a
138 : // LOCAL connection, it should not present any address so it sets this field to
139 : // zero. Receivers MUST always consider this field to skip the appropriate number
140 : // of bytes and must not assume zero is presented for LOCAL connections. When a
141 : // receiver accepts an incoming connection showing an UNSPEC address family or
142 : // protocol, it may or may not decide to log the address information if present.
143 3 : let remaining_length = u16::from_be_bytes(header[14..16].try_into().unwrap());
144 3 : if remaining_length < address_length {
145 0 : return Err(io::Error::new(
146 0 : io::ErrorKind::Other,
147 0 : "invalid proxy protocol length. not enough to fit requested IP addresses",
148 0 : ));
149 3 : }
150 3 : drop(header);
151 :
152 27 : while buf.len() < remaining_length as usize {
153 24 : if read.read_buf(&mut buf).await? == 0 {
154 0 : return Err(io::Error::new(
155 0 : io::ErrorKind::UnexpectedEof,
156 0 : "stream closed while waiting for proxy protocol addresses",
157 0 : ));
158 24 : }
159 : }
160 :
161 : // Starting from the 17th byte, addresses are presented in network byte order.
162 : // The address order is always the same :
163 : // - source layer 3 address in network byte order
164 : // - destination layer 3 address in network byte order
165 : // - source layer 4 address if any, in network byte order (port)
166 : // - destination layer 4 address if any, in network byte order (port)
167 3 : let addresses = buf.split_to(remaining_length as usize);
168 3 : let socket = match address_length {
169 : 12 => {
170 2 : let src_addr: [u8; 4] = addresses[0..4].try_into().unwrap();
171 2 : let src_port = u16::from_be_bytes(addresses[8..10].try_into().unwrap());
172 2 : Some(SocketAddr::from((src_addr, src_port)))
173 : }
174 : 36 => {
175 1 : let src_addr: [u8; 16] = addresses[0..16].try_into().unwrap();
176 1 : let src_port = u16::from_be_bytes(addresses[32..34].try_into().unwrap());
177 1 : Some(SocketAddr::from((src_addr, src_port)))
178 : }
179 0 : _ => None,
180 : };
181 :
182 3 : Ok((ChainRW { inner: read, buf }, socket))
183 20 : }
184 :
185 : impl<T: AsyncRead> AsyncRead for ChainRW<T> {
186 : #[inline]
187 167 : fn poll_read(
188 167 : self: Pin<&mut Self>,
189 167 : cx: &mut Context<'_>,
190 167 : buf: &mut ReadBuf<'_>,
191 167 : ) -> Poll<io::Result<()>> {
192 167 : if self.buf.is_empty() {
193 148 : self.project().inner.poll_read(cx, buf)
194 : } else {
195 19 : self.read_from_buf(buf)
196 : }
197 167 : }
198 : }
199 :
200 : impl<T: AsyncRead> ChainRW<T> {
201 : #[cold]
202 19 : fn read_from_buf(self: Pin<&mut Self>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
203 19 : debug_assert!(!self.buf.is_empty());
204 19 : let this = self.project();
205 19 :
206 19 : let write = usize::min(this.buf.len(), buf.remaining());
207 19 : let slice = this.buf.split_to(write).freeze();
208 19 : buf.put_slice(&slice);
209 19 :
210 19 : // reset the allocation so it can be freed
211 19 : if this.buf.is_empty() {
212 17 : *this.buf = BytesMut::new();
213 17 : }
214 :
215 19 : Poll::Ready(Ok(()))
216 19 : }
217 : }
218 :
219 : #[cfg(test)]
220 : mod tests {
221 : use tokio::io::AsyncReadExt;
222 :
223 : use crate::protocol2::read_proxy_protocol;
224 :
225 : #[tokio::test]
226 1 : async fn test_ipv4() {
227 1 : let header = super::HEADER
228 1 : // Proxy command, IPV4 | TCP
229 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
230 1 : // 12 + 3 bytes
231 1 : .chain([0, 15].as_slice())
232 1 : // src ip
233 1 : .chain([127, 0, 0, 1].as_slice())
234 1 : // dst ip
235 1 : .chain([192, 168, 0, 1].as_slice())
236 1 : // src port
237 1 : .chain([255, 255].as_slice())
238 1 : // dst port
239 1 : .chain([1, 1].as_slice())
240 1 : // TLV
241 1 : .chain([1, 2, 3].as_slice());
242 1 :
243 1 : let extra_data = [0x55; 256];
244 1 :
245 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
246 1 : .await
247 1 : .unwrap();
248 1 :
249 1 : let mut bytes = vec![];
250 1 : read.read_to_end(&mut bytes).await.unwrap();
251 1 :
252 1 : assert_eq!(bytes, extra_data);
253 1 : assert_eq!(addr, Some(([127, 0, 0, 1], 65535).into()));
254 1 : }
255 :
256 : #[tokio::test]
257 1 : async fn test_ipv6() {
258 1 : let header = super::HEADER
259 1 : // Proxy command, IPV6 | UDP
260 1 : .chain([(2 << 4) | 1, (2 << 4) | 2].as_slice())
261 1 : // 36 + 3 bytes
262 1 : .chain([0, 39].as_slice())
263 1 : // src ip
264 1 : .chain([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0].as_slice())
265 1 : // dst ip
266 1 : .chain([0, 15, 1, 14, 2, 13, 3, 12, 4, 11, 5, 10, 6, 9, 7, 8].as_slice())
267 1 : // src port
268 1 : .chain([1, 1].as_slice())
269 1 : // dst port
270 1 : .chain([255, 255].as_slice())
271 1 : // TLV
272 1 : .chain([1, 2, 3].as_slice());
273 1 :
274 1 : let extra_data = [0x55; 256];
275 1 :
276 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
277 1 : .await
278 1 : .unwrap();
279 1 :
280 1 : let mut bytes = vec![];
281 1 : read.read_to_end(&mut bytes).await.unwrap();
282 1 :
283 1 : assert_eq!(bytes, extra_data);
284 1 : assert_eq!(
285 1 : addr,
286 1 : Some(([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], 257).into())
287 1 : );
288 1 : }
289 :
290 : #[tokio::test]
291 1 : async fn test_invalid() {
292 1 : let data = [0x55; 256];
293 1 :
294 1 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
295 1 :
296 1 : let mut bytes = vec![];
297 1 : read.read_to_end(&mut bytes).await.unwrap();
298 1 : assert_eq!(bytes, data);
299 1 : assert_eq!(addr, None);
300 1 : }
301 :
302 : #[tokio::test]
303 1 : async fn test_short() {
304 1 : let data = [0x55; 10];
305 1 :
306 1 : let (mut read, addr) = read_proxy_protocol(data.as_slice()).await.unwrap();
307 1 :
308 1 : let mut bytes = vec![];
309 1 : read.read_to_end(&mut bytes).await.unwrap();
310 1 : assert_eq!(bytes, data);
311 1 : assert_eq!(addr, None);
312 1 : }
313 :
314 : #[tokio::test]
315 1 : async fn test_large_tlv() {
316 1 : let tlv = vec![0x55; 32768];
317 1 : let len = (12 + tlv.len() as u16).to_be_bytes();
318 1 :
319 1 : let header = super::HEADER
320 1 : // Proxy command, Inet << 4 | Stream
321 1 : .chain([(2 << 4) | 1, (1 << 4) | 1].as_slice())
322 1 : // 12 + 3 bytes
323 1 : .chain(len.as_slice())
324 1 : // src ip
325 1 : .chain([55, 56, 57, 58].as_slice())
326 1 : // dst ip
327 1 : .chain([192, 168, 0, 1].as_slice())
328 1 : // src port
329 1 : .chain([255, 255].as_slice())
330 1 : // dst port
331 1 : .chain([1, 1].as_slice())
332 1 : // TLV
333 1 : .chain(tlv.as_slice());
334 1 :
335 1 : let extra_data = [0xaa; 256];
336 1 :
337 1 : let (mut read, addr) = read_proxy_protocol(header.chain(extra_data.as_slice()))
338 1 : .await
339 1 : .unwrap();
340 1 :
341 1 : let mut bytes = vec![];
342 1 : read.read_to_end(&mut bytes).await.unwrap();
343 1 :
344 1 : assert_eq!(bytes, extra_data);
345 1 : assert_eq!(addr, Some(([55, 56, 57, 58], 65535).into()));
346 1 : }
347 : }
|