Line data Source code
1 : use std::io;
2 : use std::pin::Pin;
3 : use std::task::{Context, Poll, ready};
4 :
5 : use bytes::BytesMut;
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{SinkExt, Stream, TryStreamExt};
8 : use postgres_protocol2::authentication::sasl;
9 : use postgres_protocol2::authentication::sasl::ScramSha256;
10 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
11 : use postgres_protocol2::message::frontend;
12 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 : use tokio_util::codec::{Framed, FramedParts};
14 :
15 : use crate::Error;
16 : use crate::codec::PostgresCodec;
17 : use crate::config::{self, AuthKeys, Config};
18 : use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY};
19 : use crate::maybe_tls_stream::MaybeTlsStream;
20 : use crate::tls::TlsStream;
21 :
22 : pub struct StartupStream<S, T> {
23 : inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
24 : read_buf: BytesMut,
25 : }
26 :
27 : impl<S, T> Stream for StartupStream<S, T>
28 : where
29 : S: AsyncRead + AsyncWrite + Unpin,
30 : T: AsyncRead + AsyncWrite + Unpin,
31 : {
32 : type Item = io::Result<Message>;
33 :
34 77 : fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
35 : // We don't use `self.inner.poll_next()` as that might over-read into the read buffer.
36 :
37 : // read 1 byte tag, 4 bytes length.
38 77 : let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
39 :
40 35 : let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
41 35 : if len < 4 {
42 0 : return Poll::Ready(Some(Err(std::io::Error::other(
43 0 : "postgres message too small",
44 0 : ))));
45 0 : }
46 35 : if len >= 65536 {
47 0 : return Poll::Ready(Some(Err(std::io::Error::other(
48 0 : "postgres message too large",
49 0 : ))));
50 0 : }
51 :
52 : // the tag is an additional byte.
53 35 : let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
54 :
55 : // Message::parse will remove the all the bytes from the buffer.
56 35 : Poll::Ready(Message::parse(&mut self.read_buf).transpose())
57 0 : }
58 : }
59 :
60 : impl<S, T> StartupStream<S, T>
61 : where
62 : S: AsyncRead + AsyncWrite + Unpin,
63 : T: AsyncRead + AsyncWrite + Unpin,
64 : {
65 : /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
66 : ///
67 : /// If the current buffer length is greater, nothing happens.
68 112 : fn poll_fill_buf_exact(
69 112 : self: Pin<&mut Self>,
70 112 : cx: &mut Context<'_>,
71 112 : len: usize,
72 112 : ) -> Poll<Result<&[u8], std::io::Error>> {
73 112 : let this = self.get_mut();
74 112 : let mut stream = Pin::new(this.inner.get_mut());
75 :
76 112 : let mut n = this.read_buf.len();
77 182 : while n < len {
78 112 : this.read_buf.resize(len, 0);
79 :
80 112 : let mut buf = ReadBuf::new(&mut this.read_buf[..]);
81 112 : buf.set_filled(n);
82 :
83 112 : if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
84 36 : this.read_buf.truncate(n);
85 36 : return Poll::Pending;
86 0 : }
87 :
88 70 : if buf.filled().len() == n {
89 0 : return Poll::Ready(Err(std::io::Error::new(
90 0 : std::io::ErrorKind::UnexpectedEof,
91 0 : "early eof",
92 0 : )));
93 0 : }
94 70 : n = buf.filled().len();
95 :
96 70 : this.read_buf.truncate(n);
97 : }
98 :
99 70 : Poll::Ready(Ok(&this.read_buf[..len]))
100 0 : }
101 :
102 0 : pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
103 0 : *self.inner.read_buffer_mut() = self.read_buf;
104 0 : self.inner
105 0 : }
106 :
107 15 : pub fn new(io: MaybeTlsStream<S, T>) -> Self {
108 15 : let mut parts = FramedParts::new(io, PostgresCodec);
109 15 : parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY);
110 :
111 15 : let mut inner = Framed::from_parts(parts);
112 :
113 : // This is the default already, but nice to be explicit.
114 : // We divide by two because writes will overshoot the boundary.
115 : // We don't want constant overshoots to cause us to constantly re-shrink the buffer.
116 15 : inner.set_backpressure_boundary(GC_THRESHOLD / 2);
117 :
118 15 : Self {
119 15 : inner,
120 15 : read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
121 15 : }
122 0 : }
123 : }
124 :
125 15 : pub(crate) async fn authenticate<S, T>(
126 15 : stream: &mut StartupStream<S, T>,
127 15 : config: &Config,
128 15 : ) -> Result<(), Error>
129 15 : where
130 15 : S: AsyncRead + AsyncWrite + Unpin,
131 15 : T: TlsStream + Unpin,
132 0 : {
133 15 : frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut())
134 15 : .map_err(Error::encode)?;
135 :
136 15 : stream.inner.flush().await.map_err(Error::io)?;
137 15 : match stream.try_next().await.map_err(Error::io)? {
138 : Some(Message::AuthenticationOk) => {
139 2 : can_skip_channel_binding(config)?;
140 2 : return Ok(());
141 : }
142 : Some(Message::AuthenticationCleartextPassword) => {
143 0 : can_skip_channel_binding(config)?;
144 :
145 0 : let pass = config
146 0 : .password
147 0 : .as_ref()
148 0 : .ok_or_else(|| Error::config("password missing".into()))?;
149 :
150 0 : frontend::password_message(pass, stream.inner.write_buffer_mut())
151 0 : .map_err(Error::encode)?;
152 : }
153 12 : Some(Message::AuthenticationSasl(body)) => {
154 12 : authenticate_sasl(stream, body, config).await?;
155 : }
156 : Some(Message::AuthenticationMd5Password)
157 : | Some(Message::AuthenticationKerberosV5)
158 : | Some(Message::AuthenticationScmCredential)
159 : | Some(Message::AuthenticationGss)
160 : | Some(Message::AuthenticationSspi) => {
161 0 : return Err(Error::authentication(
162 0 : "unsupported authentication method".into(),
163 0 : ));
164 : }
165 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
166 0 : Some(_) => return Err(Error::unexpected_message()),
167 0 : None => return Err(Error::closed()),
168 : }
169 :
170 5 : stream.inner.flush().await.map_err(Error::io)?;
171 5 : match stream.try_next().await.map_err(Error::io)? {
172 5 : Some(Message::AuthenticationOk) => Ok(()),
173 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
174 0 : Some(_) => Err(Error::unexpected_message()),
175 0 : None => Err(Error::closed()),
176 : }
177 0 : }
178 :
179 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
180 6 : match config.channel_binding {
181 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
182 1 : config::ChannelBinding::Require => Err(Error::authentication(
183 1 : "server did not use channel binding".into(),
184 1 : )),
185 : }
186 6 : }
187 :
188 12 : async fn authenticate_sasl<S, T>(
189 12 : stream: &mut StartupStream<S, T>,
190 12 : body: AuthenticationSaslBody,
191 12 : config: &Config,
192 12 : ) -> Result<(), Error>
193 12 : where
194 12 : S: AsyncRead + AsyncWrite + Unpin,
195 12 : T: TlsStream + Unpin,
196 0 : {
197 12 : let mut has_scram = false;
198 12 : let mut has_scram_plus = false;
199 12 : let mut mechanisms = body.mechanisms();
200 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
201 22 : match mechanism {
202 22 : sasl::SCRAM_SHA_256 => has_scram = true,
203 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
204 0 : _ => {}
205 : }
206 : }
207 :
208 12 : let channel_binding = stream
209 12 : .inner
210 12 : .get_ref()
211 12 : .channel_binding()
212 12 : .tls_server_end_point
213 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
214 12 : .map(sasl::ChannelBinding::tls_server_end_point);
215 :
216 12 : let (channel_binding, mechanism) = if has_scram_plus {
217 10 : match channel_binding {
218 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
219 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
220 : }
221 2 : } else if has_scram {
222 2 : match channel_binding {
223 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
224 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
225 : }
226 : } else {
227 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
228 : };
229 :
230 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
231 4 : can_skip_channel_binding(config)?;
232 0 : }
233 :
234 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
235 0 : ScramSha256::new_with_keys(keys, channel_binding)
236 11 : } else if let Some(password) = config.get_password() {
237 11 : ScramSha256::new(password, channel_binding)
238 : } else {
239 0 : return Err(Error::config("password or auth keys missing".into()));
240 : };
241 :
242 11 : frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut())
243 11 : .map_err(Error::encode)?;
244 :
245 11 : stream.inner.flush().await.map_err(Error::io)?;
246 11 : let body = match stream.try_next().await.map_err(Error::io)? {
247 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
248 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
249 0 : Some(_) => return Err(Error::unexpected_message()),
250 0 : None => return Err(Error::closed()),
251 : };
252 :
253 10 : scram
254 10 : .update(body.data())
255 10 : .await
256 10 : .map_err(|e| Error::authentication(e.into()))?;
257 :
258 10 : frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut())
259 10 : .map_err(Error::encode)?;
260 :
261 10 : stream.inner.flush().await.map_err(Error::io)?;
262 10 : let body = match stream.try_next().await.map_err(Error::io)? {
263 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
264 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
265 0 : Some(_) => return Err(Error::unexpected_message()),
266 0 : None => return Err(Error::closed()),
267 : };
268 :
269 5 : scram
270 5 : .finish(body.data())
271 5 : .map_err(|e| Error::authentication(e.into()))?;
272 :
273 5 : Ok(())
274 0 : }
|