Line data Source code
1 : use std::collections::HashMap;
2 : use std::io;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll};
5 :
6 : use bytes::BytesMut;
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
9 : use postgres_protocol2::authentication::sasl;
10 : use postgres_protocol2::authentication::sasl::ScramSha256;
11 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
12 : use postgres_protocol2::message::frontend;
13 : use tokio::io::{AsyncRead, AsyncWrite};
14 : use tokio_util::codec::Framed;
15 :
16 : use crate::Error;
17 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
18 : use crate::config::{self, AuthKeys, Config};
19 : use crate::connect_tls::connect_tls;
20 : use crate::maybe_tls_stream::MaybeTlsStream;
21 : use crate::tls::{TlsConnect, TlsStream};
22 :
23 : pub struct StartupStream<S, T> {
24 : inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
25 : buf: BackendMessages,
26 : delayed_notice: Vec<NoticeResponseBody>,
27 : }
28 :
29 : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
30 : where
31 : S: AsyncRead + AsyncWrite + Unpin,
32 : T: AsyncRead + AsyncWrite + Unpin,
33 : {
34 : type Error = io::Error;
35 :
36 36 : fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
37 36 : Pin::new(&mut self.inner).poll_ready(cx)
38 36 : }
39 :
40 36 : fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
41 36 : Pin::new(&mut self.inner).start_send(item)
42 36 : }
43 :
44 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
45 36 : Pin::new(&mut self.inner).poll_flush(cx)
46 36 : }
47 :
48 0 : fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
49 0 : Pin::new(&mut self.inner).poll_close(cx)
50 0 : }
51 : }
52 :
53 : impl<S, T> Stream for StartupStream<S, T>
54 : where
55 : S: AsyncRead + AsyncWrite + Unpin,
56 : T: AsyncRead + AsyncWrite + Unpin,
57 : {
58 : type Item = io::Result<Message>;
59 :
60 91 : fn poll_next(
61 91 : mut self: Pin<&mut Self>,
62 91 : cx: &mut Context<'_>,
63 91 : ) -> Poll<Option<io::Result<Message>>> {
64 : loop {
65 128 : match self.buf.next() {
66 42 : Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
67 86 : Ok(None) => {}
68 0 : Err(e) => return Poll::Ready(Some(Err(e))),
69 : }
70 :
71 86 : match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
72 37 : Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
73 7 : Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
74 6 : Some(Err(e)) => return Poll::Ready(Some(Err(e))),
75 0 : None => return Poll::Ready(None),
76 : }
77 : }
78 0 : }
79 : }
80 :
81 : pub struct RawConnection<S, T> {
82 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
83 : pub parameters: HashMap<String, String>,
84 : pub delayed_notice: Vec<NoticeResponseBody>,
85 : pub process_id: i32,
86 : pub secret_key: i32,
87 : }
88 :
89 15 : pub async fn connect_raw<S, T>(
90 15 : stream: S,
91 15 : tls: T,
92 15 : config: &Config,
93 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
94 15 : where
95 15 : S: AsyncRead + AsyncWrite + Unpin,
96 15 : T: TlsConnect<S>,
97 15 : {
98 15 : let stream = connect_tls(stream, config.ssl_mode, tls).await?;
99 :
100 15 : let mut stream = StartupStream {
101 15 : inner: Framed::new(stream, PostgresCodec),
102 15 : buf: BackendMessages::empty(),
103 15 : delayed_notice: Vec::new(),
104 15 : };
105 15 :
106 15 : startup(&mut stream, config).await?;
107 15 : authenticate(&mut stream, config).await?;
108 7 : let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
109 :
110 7 : Ok(RawConnection {
111 7 : stream: stream.inner,
112 7 : parameters,
113 7 : delayed_notice: stream.delayed_notice,
114 7 : process_id,
115 7 : secret_key,
116 7 : })
117 0 : }
118 :
119 15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
120 15 : where
121 15 : S: AsyncRead + AsyncWrite + Unpin,
122 15 : T: AsyncRead + AsyncWrite + Unpin,
123 15 : {
124 15 : let mut buf = BytesMut::new();
125 15 : frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
126 :
127 15 : stream
128 15 : .send(FrontendMessage::Raw(buf.freeze()))
129 15 : .await
130 15 : .map_err(Error::io)
131 0 : }
132 :
133 15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
134 15 : where
135 15 : S: AsyncRead + AsyncWrite + Unpin,
136 15 : T: TlsStream + Unpin,
137 15 : {
138 15 : match stream.try_next().await.map_err(Error::io)? {
139 : Some(Message::AuthenticationOk) => {
140 2 : can_skip_channel_binding(config)?;
141 2 : return Ok(());
142 : }
143 : Some(Message::AuthenticationCleartextPassword) => {
144 0 : can_skip_channel_binding(config)?;
145 :
146 0 : let pass = config
147 0 : .password
148 0 : .as_ref()
149 0 : .ok_or_else(|| Error::config("password missing".into()))?;
150 :
151 0 : authenticate_password(stream, pass).await?;
152 : }
153 12 : Some(Message::AuthenticationSasl(body)) => {
154 12 : authenticate_sasl(stream, body, config).await?;
155 : }
156 : Some(Message::AuthenticationMd5Password)
157 : | Some(Message::AuthenticationKerberosV5)
158 : | Some(Message::AuthenticationScmCredential)
159 : | Some(Message::AuthenticationGss)
160 : | Some(Message::AuthenticationSspi) => {
161 0 : return Err(Error::authentication(
162 0 : "unsupported authentication method".into(),
163 0 : ));
164 : }
165 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
166 0 : Some(_) => return Err(Error::unexpected_message()),
167 0 : None => return Err(Error::closed()),
168 : }
169 :
170 5 : match stream.try_next().await.map_err(Error::io)? {
171 5 : Some(Message::AuthenticationOk) => Ok(()),
172 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
173 0 : Some(_) => Err(Error::unexpected_message()),
174 0 : None => Err(Error::closed()),
175 : }
176 0 : }
177 :
178 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
179 6 : match config.channel_binding {
180 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
181 1 : config::ChannelBinding::Require => Err(Error::authentication(
182 1 : "server did not use channel binding".into(),
183 1 : )),
184 : }
185 6 : }
186 :
187 0 : async fn authenticate_password<S, T>(
188 0 : stream: &mut StartupStream<S, T>,
189 0 : password: &[u8],
190 0 : ) -> Result<(), Error>
191 0 : where
192 0 : S: AsyncRead + AsyncWrite + Unpin,
193 0 : T: AsyncRead + AsyncWrite + Unpin,
194 0 : {
195 0 : let mut buf = BytesMut::new();
196 0 : frontend::password_message(password, &mut buf).map_err(Error::encode)?;
197 :
198 0 : stream
199 0 : .send(FrontendMessage::Raw(buf.freeze()))
200 0 : .await
201 0 : .map_err(Error::io)
202 0 : }
203 :
204 12 : async fn authenticate_sasl<S, T>(
205 12 : stream: &mut StartupStream<S, T>,
206 12 : body: AuthenticationSaslBody,
207 12 : config: &Config,
208 12 : ) -> Result<(), Error>
209 12 : where
210 12 : S: AsyncRead + AsyncWrite + Unpin,
211 12 : T: TlsStream + Unpin,
212 12 : {
213 12 : let mut has_scram = false;
214 12 : let mut has_scram_plus = false;
215 12 : let mut mechanisms = body.mechanisms();
216 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
217 22 : match mechanism {
218 22 : sasl::SCRAM_SHA_256 => has_scram = true,
219 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
220 0 : _ => {}
221 : }
222 : }
223 :
224 12 : let channel_binding = stream
225 12 : .inner
226 12 : .get_ref()
227 12 : .channel_binding()
228 12 : .tls_server_end_point
229 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
230 12 : .map(sasl::ChannelBinding::tls_server_end_point);
231 :
232 12 : let (channel_binding, mechanism) = if has_scram_plus {
233 10 : match channel_binding {
234 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
235 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
236 : }
237 2 : } else if has_scram {
238 2 : match channel_binding {
239 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
240 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
241 : }
242 : } else {
243 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
244 : };
245 :
246 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
247 4 : can_skip_channel_binding(config)?;
248 0 : }
249 :
250 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
251 0 : ScramSha256::new_with_keys(keys, channel_binding)
252 11 : } else if let Some(password) = config.get_password() {
253 11 : ScramSha256::new(password, channel_binding)
254 : } else {
255 0 : return Err(Error::config("password or auth keys missing".into()));
256 : };
257 :
258 11 : let mut buf = BytesMut::new();
259 11 : frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
260 11 : stream
261 11 : .send(FrontendMessage::Raw(buf.freeze()))
262 11 : .await
263 11 : .map_err(Error::io)?;
264 :
265 11 : let body = match stream.try_next().await.map_err(Error::io)? {
266 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
267 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
268 0 : Some(_) => return Err(Error::unexpected_message()),
269 0 : None => return Err(Error::closed()),
270 : };
271 :
272 10 : scram
273 10 : .update(body.data())
274 10 : .await
275 10 : .map_err(|e| Error::authentication(e.into()))?;
276 :
277 10 : let mut buf = BytesMut::new();
278 10 : frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
279 10 : stream
280 10 : .send(FrontendMessage::Raw(buf.freeze()))
281 10 : .await
282 10 : .map_err(Error::io)?;
283 :
284 10 : let body = match stream.try_next().await.map_err(Error::io)? {
285 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
286 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
287 0 : Some(_) => return Err(Error::unexpected_message()),
288 0 : None => return Err(Error::closed()),
289 : };
290 :
291 5 : scram
292 5 : .finish(body.data())
293 5 : .map_err(|e| Error::authentication(e.into()))?;
294 :
295 5 : Ok(())
296 0 : }
297 :
298 7 : async fn read_info<S, T>(
299 7 : stream: &mut StartupStream<S, T>,
300 7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
301 7 : where
302 7 : S: AsyncRead + AsyncWrite + Unpin,
303 7 : T: AsyncRead + AsyncWrite + Unpin,
304 7 : {
305 7 : let mut process_id = 0;
306 7 : let mut secret_key = 0;
307 7 : let mut parameters = HashMap::new();
308 :
309 : loop {
310 14 : match stream.try_next().await.map_err(Error::io)? {
311 0 : Some(Message::BackendKeyData(body)) => {
312 0 : process_id = body.process_id();
313 0 : secret_key = body.secret_key();
314 0 : }
315 7 : Some(Message::ParameterStatus(body)) => {
316 7 : parameters.insert(
317 7 : body.name().map_err(Error::parse)?.to_string(),
318 7 : body.value().map_err(Error::parse)?.to_string(),
319 : );
320 : }
321 0 : Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
322 7 : Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
323 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
324 0 : Some(_) => return Err(Error::unexpected_message()),
325 0 : None => return Err(Error::closed()),
326 : }
327 : }
328 0 : }
|