Line data Source code
1 : use std::io;
2 : use std::pin::Pin;
3 : use std::task::{Context, Poll, ready};
4 :
5 : use bytes::{Bytes, BytesMut};
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
8 : use postgres_protocol2::authentication::sasl;
9 : use postgres_protocol2::authentication::sasl::ScramSha256;
10 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
11 : use postgres_protocol2::message::frontend;
12 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
13 : use tokio_util::codec::{Framed, FramedParts, FramedWrite};
14 :
15 : use crate::Error;
16 : use crate::codec::PostgresCodec;
17 : use crate::config::{self, AuthKeys, Config};
18 : use crate::maybe_tls_stream::MaybeTlsStream;
19 : use crate::tls::TlsStream;
20 :
21 : pub struct StartupStream<S, T> {
22 : inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
23 : read_buf: BytesMut,
24 : }
25 :
26 : impl<S, T> Sink<Bytes> for StartupStream<S, T>
27 : where
28 : S: AsyncRead + AsyncWrite + Unpin,
29 : T: AsyncRead + AsyncWrite + Unpin,
30 : {
31 : type Error = io::Error;
32 :
33 36 : fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
34 36 : Pin::new(&mut self.inner).poll_ready(cx)
35 0 : }
36 :
37 36 : fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
38 36 : Pin::new(&mut self.inner).start_send(item)
39 0 : }
40 :
41 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
42 36 : Pin::new(&mut self.inner).poll_flush(cx)
43 0 : }
44 :
45 0 : fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46 0 : Pin::new(&mut self.inner).poll_close(cx)
47 0 : }
48 : }
49 :
50 : impl<S, T> Stream for StartupStream<S, T>
51 : where
52 : S: AsyncRead + AsyncWrite + Unpin,
53 : T: AsyncRead + AsyncWrite + Unpin,
54 : {
55 : type Item = io::Result<Message>;
56 :
57 77 : fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
58 : // read 1 byte tag, 4 bytes length.
59 77 : let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
60 :
61 35 : let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
62 35 : if len < 4 {
63 0 : return Poll::Ready(Some(Err(std::io::Error::other(
64 0 : "postgres message too small",
65 0 : ))));
66 0 : }
67 35 : if len >= 65536 {
68 0 : return Poll::Ready(Some(Err(std::io::Error::other(
69 0 : "postgres message too large",
70 0 : ))));
71 0 : }
72 :
73 : // the tag is an additional byte.
74 35 : let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
75 :
76 : // Message::parse will remove the all the bytes from the buffer.
77 35 : Poll::Ready(Message::parse(&mut self.read_buf).transpose())
78 0 : }
79 : }
80 :
81 : impl<S, T> StartupStream<S, T>
82 : where
83 : S: AsyncRead + AsyncWrite + Unpin,
84 : T: AsyncRead + AsyncWrite + Unpin,
85 : {
86 : /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
87 : ///
88 : /// If the current buffer length is greater, nothing happens.
89 112 : fn poll_fill_buf_exact(
90 112 : self: Pin<&mut Self>,
91 112 : cx: &mut Context<'_>,
92 112 : len: usize,
93 112 : ) -> Poll<Result<&[u8], std::io::Error>> {
94 112 : let this = self.get_mut();
95 112 : let mut stream = Pin::new(this.inner.get_mut());
96 :
97 112 : let mut n = this.read_buf.len();
98 182 : while n < len {
99 112 : this.read_buf.resize(len, 0);
100 :
101 112 : let mut buf = ReadBuf::new(&mut this.read_buf[..]);
102 112 : buf.set_filled(n);
103 :
104 112 : if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
105 36 : this.read_buf.truncate(n);
106 36 : return Poll::Pending;
107 0 : }
108 :
109 70 : if buf.filled().len() == n {
110 0 : return Poll::Ready(Err(std::io::Error::new(
111 0 : std::io::ErrorKind::UnexpectedEof,
112 0 : "early eof",
113 0 : )));
114 0 : }
115 70 : n = buf.filled().len();
116 :
117 70 : this.read_buf.truncate(n);
118 : }
119 :
120 70 : Poll::Ready(Ok(&this.read_buf[..len]))
121 0 : }
122 :
123 0 : pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
124 0 : let write_buf = std::mem::take(self.inner.write_buffer_mut());
125 0 : let io = self.inner.into_inner();
126 0 : let mut parts = FramedParts::new(io, PostgresCodec);
127 0 : parts.read_buf = self.read_buf;
128 0 : parts.write_buf = write_buf;
129 0 : Framed::from_parts(parts)
130 0 : }
131 :
132 15 : pub fn new(io: MaybeTlsStream<S, T>) -> Self {
133 15 : Self {
134 15 : inner: FramedWrite::new(io, PostgresCodec),
135 15 : read_buf: BytesMut::new(),
136 15 : }
137 0 : }
138 : }
139 :
140 15 : pub(crate) async fn startup<S, T>(
141 15 : stream: &mut StartupStream<S, T>,
142 15 : config: &Config,
143 15 : ) -> Result<(), Error>
144 15 : where
145 15 : S: AsyncRead + AsyncWrite + Unpin,
146 15 : T: AsyncRead + AsyncWrite + Unpin,
147 0 : {
148 15 : let mut buf = BytesMut::new();
149 15 : frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
150 :
151 15 : stream.send(buf.freeze()).await.map_err(Error::io)
152 0 : }
153 :
154 15 : pub(crate) async fn authenticate<S, T>(
155 15 : stream: &mut StartupStream<S, T>,
156 15 : config: &Config,
157 15 : ) -> Result<(), Error>
158 15 : where
159 15 : S: AsyncRead + AsyncWrite + Unpin,
160 15 : T: TlsStream + Unpin,
161 0 : {
162 15 : match stream.try_next().await.map_err(Error::io)? {
163 : Some(Message::AuthenticationOk) => {
164 2 : can_skip_channel_binding(config)?;
165 2 : return Ok(());
166 : }
167 : Some(Message::AuthenticationCleartextPassword) => {
168 0 : can_skip_channel_binding(config)?;
169 :
170 0 : let pass = config
171 0 : .password
172 0 : .as_ref()
173 0 : .ok_or_else(|| Error::config("password missing".into()))?;
174 :
175 0 : authenticate_password(stream, pass).await?;
176 : }
177 12 : Some(Message::AuthenticationSasl(body)) => {
178 12 : authenticate_sasl(stream, body, config).await?;
179 : }
180 : Some(Message::AuthenticationMd5Password)
181 : | Some(Message::AuthenticationKerberosV5)
182 : | Some(Message::AuthenticationScmCredential)
183 : | Some(Message::AuthenticationGss)
184 : | Some(Message::AuthenticationSspi) => {
185 0 : return Err(Error::authentication(
186 0 : "unsupported authentication method".into(),
187 0 : ));
188 : }
189 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
190 0 : Some(_) => return Err(Error::unexpected_message()),
191 0 : None => return Err(Error::closed()),
192 : }
193 :
194 5 : match stream.try_next().await.map_err(Error::io)? {
195 5 : Some(Message::AuthenticationOk) => Ok(()),
196 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
197 0 : Some(_) => Err(Error::unexpected_message()),
198 0 : None => Err(Error::closed()),
199 : }
200 0 : }
201 :
202 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
203 6 : match config.channel_binding {
204 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
205 1 : config::ChannelBinding::Require => Err(Error::authentication(
206 1 : "server did not use channel binding".into(),
207 1 : )),
208 : }
209 6 : }
210 :
211 0 : async fn authenticate_password<S, T>(
212 0 : stream: &mut StartupStream<S, T>,
213 0 : password: &[u8],
214 0 : ) -> Result<(), Error>
215 0 : where
216 0 : S: AsyncRead + AsyncWrite + Unpin,
217 0 : T: AsyncRead + AsyncWrite + Unpin,
218 0 : {
219 0 : let mut buf = BytesMut::new();
220 0 : frontend::password_message(password, &mut buf).map_err(Error::encode)?;
221 :
222 0 : stream.send(buf.freeze()).await.map_err(Error::io)
223 0 : }
224 :
225 12 : async fn authenticate_sasl<S, T>(
226 12 : stream: &mut StartupStream<S, T>,
227 12 : body: AuthenticationSaslBody,
228 12 : config: &Config,
229 12 : ) -> Result<(), Error>
230 12 : where
231 12 : S: AsyncRead + AsyncWrite + Unpin,
232 12 : T: TlsStream + Unpin,
233 0 : {
234 12 : let mut has_scram = false;
235 12 : let mut has_scram_plus = false;
236 12 : let mut mechanisms = body.mechanisms();
237 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
238 22 : match mechanism {
239 22 : sasl::SCRAM_SHA_256 => has_scram = true,
240 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
241 0 : _ => {}
242 : }
243 : }
244 :
245 12 : let channel_binding = stream
246 12 : .inner
247 12 : .get_ref()
248 12 : .channel_binding()
249 12 : .tls_server_end_point
250 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
251 12 : .map(sasl::ChannelBinding::tls_server_end_point);
252 :
253 12 : let (channel_binding, mechanism) = if has_scram_plus {
254 10 : match channel_binding {
255 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
256 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
257 : }
258 2 : } else if has_scram {
259 2 : match channel_binding {
260 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
261 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
262 : }
263 : } else {
264 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
265 : };
266 :
267 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
268 4 : can_skip_channel_binding(config)?;
269 0 : }
270 :
271 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
272 0 : ScramSha256::new_with_keys(keys, channel_binding)
273 11 : } else if let Some(password) = config.get_password() {
274 11 : ScramSha256::new(password, channel_binding)
275 : } else {
276 0 : return Err(Error::config("password or auth keys missing".into()));
277 : };
278 :
279 11 : let mut buf = BytesMut::new();
280 11 : frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
281 11 : stream.send(buf.freeze()).await.map_err(Error::io)?;
282 :
283 11 : let body = match stream.try_next().await.map_err(Error::io)? {
284 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
285 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
286 0 : Some(_) => return Err(Error::unexpected_message()),
287 0 : None => return Err(Error::closed()),
288 : };
289 :
290 10 : scram
291 10 : .update(body.data())
292 10 : .await
293 10 : .map_err(|e| Error::authentication(e.into()))?;
294 :
295 10 : let mut buf = BytesMut::new();
296 10 : frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
297 10 : stream.send(buf.freeze()).await.map_err(Error::io)?;
298 :
299 10 : let body = match stream.try_next().await.map_err(Error::io)? {
300 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
301 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
302 0 : Some(_) => return Err(Error::unexpected_message()),
303 0 : None => return Err(Error::closed()),
304 : };
305 :
306 5 : scram
307 5 : .finish(body.data())
308 5 : .map_err(|e| Error::authentication(e.into()))?;
309 :
310 5 : Ok(())
311 0 : }
|