Line data Source code
1 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2 : use crate::config::{self, AuthKeys, Config};
3 : use crate::connect_tls::connect_tls;
4 : use crate::maybe_tls_stream::MaybeTlsStream;
5 : use crate::tls::{TlsConnect, TlsStream};
6 : use crate::Error;
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
10 : use postgres_protocol2::authentication::sasl;
11 : use postgres_protocol2::authentication::sasl::ScramSha256;
12 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
13 : use postgres_protocol2::message::frontend;
14 : use std::collections::HashMap;
15 : use std::io;
16 : use std::pin::Pin;
17 : use std::task::{Context, Poll};
18 : use tokio::io::{AsyncRead, AsyncWrite};
19 : use tokio_util::codec::Framed;
20 :
21 : pub struct StartupStream<S, T> {
22 : inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
23 : buf: BackendMessages,
24 : delayed_notice: Vec<NoticeResponseBody>,
25 : }
26 :
27 : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
28 : where
29 : S: AsyncRead + AsyncWrite + Unpin,
30 : T: AsyncRead + AsyncWrite + Unpin,
31 : {
32 : type Error = io::Error;
33 :
34 36 : fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
35 36 : Pin::new(&mut self.inner).poll_ready(cx)
36 36 : }
37 :
38 36 : fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
39 36 : Pin::new(&mut self.inner).start_send(item)
40 36 : }
41 :
42 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
43 36 : Pin::new(&mut self.inner).poll_flush(cx)
44 36 : }
45 :
46 0 : fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
47 0 : Pin::new(&mut self.inner).poll_close(cx)
48 0 : }
49 : }
50 :
51 : impl<S, T> Stream for StartupStream<S, T>
52 : where
53 : S: AsyncRead + AsyncWrite + Unpin,
54 : T: AsyncRead + AsyncWrite + Unpin,
55 : {
56 : type Item = io::Result<Message>;
57 :
58 91 : fn poll_next(
59 91 : mut self: Pin<&mut Self>,
60 91 : cx: &mut Context<'_>,
61 91 : ) -> Poll<Option<io::Result<Message>>> {
62 : loop {
63 128 : match self.buf.next() {
64 42 : Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
65 86 : Ok(None) => {}
66 0 : Err(e) => return Poll::Ready(Some(Err(e))),
67 : }
68 :
69 86 : match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
70 37 : Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
71 7 : Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
72 6 : Some(Err(e)) => return Poll::Ready(Some(Err(e))),
73 0 : None => return Poll::Ready(None),
74 : }
75 : }
76 91 : }
77 : }
78 :
79 : pub struct RawConnection<S, T> {
80 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
81 : pub parameters: HashMap<String, String>,
82 : pub delayed_notice: Vec<NoticeResponseBody>,
83 : pub process_id: i32,
84 : pub secret_key: i32,
85 : }
86 :
87 15 : pub async fn connect_raw<S, T>(
88 15 : stream: S,
89 15 : tls: T,
90 15 : config: &Config,
91 15 : ) -> Result<RawConnection<S, T::Stream>, Error>
92 15 : where
93 15 : S: AsyncRead + AsyncWrite + Unpin,
94 15 : T: TlsConnect<S>,
95 15 : {
96 15 : let stream = connect_tls(stream, config.ssl_mode, tls).await?;
97 :
98 15 : let mut stream = StartupStream {
99 15 : inner: Framed::new(stream, PostgresCodec),
100 15 : buf: BackendMessages::empty(),
101 15 : delayed_notice: Vec::new(),
102 15 : };
103 15 :
104 15 : startup(&mut stream, config).await?;
105 15 : authenticate(&mut stream, config).await?;
106 7 : let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
107 :
108 7 : Ok(RawConnection {
109 7 : stream: stream.inner,
110 7 : parameters,
111 7 : delayed_notice: stream.delayed_notice,
112 7 : process_id,
113 7 : secret_key,
114 7 : })
115 15 : }
116 :
117 15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
118 15 : where
119 15 : S: AsyncRead + AsyncWrite + Unpin,
120 15 : T: AsyncRead + AsyncWrite + Unpin,
121 15 : {
122 15 : let mut buf = BytesMut::new();
123 15 : frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
124 :
125 15 : stream
126 15 : .send(FrontendMessage::Raw(buf.freeze()))
127 15 : .await
128 15 : .map_err(Error::io)
129 15 : }
130 :
131 15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
132 15 : where
133 15 : S: AsyncRead + AsyncWrite + Unpin,
134 15 : T: TlsStream + Unpin,
135 15 : {
136 15 : match stream.try_next().await.map_err(Error::io)? {
137 : Some(Message::AuthenticationOk) => {
138 2 : can_skip_channel_binding(config)?;
139 2 : return Ok(());
140 : }
141 : Some(Message::AuthenticationCleartextPassword) => {
142 0 : can_skip_channel_binding(config)?;
143 :
144 0 : let pass = config
145 0 : .password
146 0 : .as_ref()
147 0 : .ok_or_else(|| Error::config("password missing".into()))?;
148 :
149 0 : authenticate_password(stream, pass).await?;
150 : }
151 12 : Some(Message::AuthenticationSasl(body)) => {
152 12 : authenticate_sasl(stream, body, config).await?;
153 : }
154 : Some(Message::AuthenticationMd5Password)
155 : | Some(Message::AuthenticationKerberosV5)
156 : | Some(Message::AuthenticationScmCredential)
157 : | Some(Message::AuthenticationGss)
158 : | Some(Message::AuthenticationSspi) => {
159 0 : return Err(Error::authentication(
160 0 : "unsupported authentication method".into(),
161 0 : ))
162 : }
163 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
164 0 : Some(_) => return Err(Error::unexpected_message()),
165 0 : None => return Err(Error::closed()),
166 : }
167 :
168 5 : match stream.try_next().await.map_err(Error::io)? {
169 5 : Some(Message::AuthenticationOk) => Ok(()),
170 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
171 0 : Some(_) => Err(Error::unexpected_message()),
172 0 : None => Err(Error::closed()),
173 : }
174 15 : }
175 :
176 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
177 6 : match config.channel_binding {
178 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
179 1 : config::ChannelBinding::Require => Err(Error::authentication(
180 1 : "server did not use channel binding".into(),
181 1 : )),
182 : }
183 6 : }
184 :
185 0 : async fn authenticate_password<S, T>(
186 0 : stream: &mut StartupStream<S, T>,
187 0 : password: &[u8],
188 0 : ) -> Result<(), Error>
189 0 : where
190 0 : S: AsyncRead + AsyncWrite + Unpin,
191 0 : T: AsyncRead + AsyncWrite + Unpin,
192 0 : {
193 0 : let mut buf = BytesMut::new();
194 0 : frontend::password_message(password, &mut buf).map_err(Error::encode)?;
195 :
196 0 : stream
197 0 : .send(FrontendMessage::Raw(buf.freeze()))
198 0 : .await
199 0 : .map_err(Error::io)
200 0 : }
201 :
202 12 : async fn authenticate_sasl<S, T>(
203 12 : stream: &mut StartupStream<S, T>,
204 12 : body: AuthenticationSaslBody,
205 12 : config: &Config,
206 12 : ) -> Result<(), Error>
207 12 : where
208 12 : S: AsyncRead + AsyncWrite + Unpin,
209 12 : T: TlsStream + Unpin,
210 12 : {
211 12 : let mut has_scram = false;
212 12 : let mut has_scram_plus = false;
213 12 : let mut mechanisms = body.mechanisms();
214 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
215 22 : match mechanism {
216 22 : sasl::SCRAM_SHA_256 => has_scram = true,
217 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
218 0 : _ => {}
219 : }
220 : }
221 :
222 12 : let channel_binding = stream
223 12 : .inner
224 12 : .get_ref()
225 12 : .channel_binding()
226 12 : .tls_server_end_point
227 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
228 12 : .map(sasl::ChannelBinding::tls_server_end_point);
229 :
230 12 : let (channel_binding, mechanism) = if has_scram_plus {
231 10 : match channel_binding {
232 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
233 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
234 : }
235 2 : } else if has_scram {
236 2 : match channel_binding {
237 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
238 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
239 : }
240 : } else {
241 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
242 : };
243 :
244 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
245 4 : can_skip_channel_binding(config)?;
246 8 : }
247 :
248 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
249 0 : ScramSha256::new_with_keys(keys, channel_binding)
250 11 : } else if let Some(password) = config.get_password() {
251 11 : ScramSha256::new(password, channel_binding)
252 : } else {
253 0 : return Err(Error::config("password or auth keys missing".into()));
254 : };
255 :
256 11 : let mut buf = BytesMut::new();
257 11 : frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
258 11 : stream
259 11 : .send(FrontendMessage::Raw(buf.freeze()))
260 11 : .await
261 11 : .map_err(Error::io)?;
262 :
263 11 : let body = match stream.try_next().await.map_err(Error::io)? {
264 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
265 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
266 0 : Some(_) => return Err(Error::unexpected_message()),
267 0 : None => return Err(Error::closed()),
268 : };
269 :
270 10 : scram
271 10 : .update(body.data())
272 10 : .await
273 10 : .map_err(|e| Error::authentication(e.into()))?;
274 :
275 10 : let mut buf = BytesMut::new();
276 10 : frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
277 10 : stream
278 10 : .send(FrontendMessage::Raw(buf.freeze()))
279 10 : .await
280 10 : .map_err(Error::io)?;
281 :
282 10 : let body = match stream.try_next().await.map_err(Error::io)? {
283 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
284 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
285 0 : Some(_) => return Err(Error::unexpected_message()),
286 0 : None => return Err(Error::closed()),
287 : };
288 :
289 5 : scram
290 5 : .finish(body.data())
291 5 : .map_err(|e| Error::authentication(e.into()))?;
292 :
293 5 : Ok(())
294 12 : }
295 :
296 7 : async fn read_info<S, T>(
297 7 : stream: &mut StartupStream<S, T>,
298 7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
299 7 : where
300 7 : S: AsyncRead + AsyncWrite + Unpin,
301 7 : T: AsyncRead + AsyncWrite + Unpin,
302 7 : {
303 7 : let mut process_id = 0;
304 7 : let mut secret_key = 0;
305 7 : let mut parameters = HashMap::new();
306 :
307 : loop {
308 14 : match stream.try_next().await.map_err(Error::io)? {
309 0 : Some(Message::BackendKeyData(body)) => {
310 0 : process_id = body.process_id();
311 0 : secret_key = body.secret_key();
312 0 : }
313 7 : Some(Message::ParameterStatus(body)) => {
314 7 : parameters.insert(
315 7 : body.name().map_err(Error::parse)?.to_string(),
316 7 : body.value().map_err(Error::parse)?.to_string(),
317 : );
318 : }
319 0 : Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
320 7 : Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
321 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
322 0 : Some(_) => return Err(Error::unexpected_message()),
323 0 : None => return Err(Error::closed()),
324 : }
325 : }
326 7 : }
|