Line data Source code
1 : use std::collections::HashMap;
2 : use std::io;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll};
5 :
6 : use bytes::BytesMut;
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
9 : use postgres_protocol2::authentication::sasl;
10 : use postgres_protocol2::authentication::sasl::ScramSha256;
11 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
12 : use postgres_protocol2::message::frontend;
13 : use tokio::io::{AsyncRead, AsyncWrite};
14 : use tokio_util::codec::Framed;
15 :
16 : use crate::Error;
17 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
18 : use crate::config::{self, AuthKeys, Config};
19 : use crate::maybe_tls_stream::MaybeTlsStream;
20 : use crate::tls::TlsStream;
21 :
22 : pub struct StartupStream<S, T> {
23 : inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
24 : buf: BackendMessages,
25 : delayed_notice: Vec<NoticeResponseBody>,
26 : }
27 :
28 : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
29 : where
30 : S: AsyncRead + AsyncWrite + Unpin,
31 : T: AsyncRead + AsyncWrite + Unpin,
32 : {
33 : type Error = io::Error;
34 :
35 36 : fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36 36 : Pin::new(&mut self.inner).poll_ready(cx)
37 0 : }
38 :
39 36 : fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
40 36 : Pin::new(&mut self.inner).start_send(item)
41 0 : }
42 :
43 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
44 36 : Pin::new(&mut self.inner).poll_flush(cx)
45 0 : }
46 :
47 0 : fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
48 0 : Pin::new(&mut self.inner).poll_close(cx)
49 0 : }
50 : }
51 :
52 : impl<S, T> Stream for StartupStream<S, T>
53 : where
54 : S: AsyncRead + AsyncWrite + Unpin,
55 : T: AsyncRead + AsyncWrite + Unpin,
56 : {
57 : type Item = io::Result<Message>;
58 :
59 91 : fn poll_next(
60 91 : mut self: Pin<&mut Self>,
61 91 : cx: &mut Context<'_>,
62 91 : ) -> Poll<Option<io::Result<Message>>> {
63 : loop {
64 129 : match self.buf.next() {
65 42 : Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
66 87 : Ok(None) => {}
67 0 : Err(e) => return Poll::Ready(Some(Err(e))),
68 : }
69 :
70 87 : match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
71 38 : Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
72 7 : Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
73 6 : Some(Err(e)) => return Poll::Ready(Some(Err(e))),
74 0 : None => return Poll::Ready(None),
75 : }
76 : }
77 0 : }
78 : }
79 :
80 : pub struct RawConnection<S, T> {
81 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
82 : pub parameters: HashMap<String, String>,
83 : pub delayed_notice: Vec<NoticeResponseBody>,
84 : pub process_id: i32,
85 : pub secret_key: i32,
86 : }
87 :
88 15 : pub async fn connect_raw<S, T>(
89 15 : stream: MaybeTlsStream<S, T>,
90 15 : config: &Config,
91 15 : ) -> Result<RawConnection<S, T>, Error>
92 15 : where
93 15 : S: AsyncRead + AsyncWrite + Unpin,
94 15 : T: TlsStream + Unpin,
95 0 : {
96 15 : let mut stream = StartupStream {
97 15 : inner: Framed::new(stream, PostgresCodec),
98 15 : buf: BackendMessages::empty(),
99 15 : delayed_notice: Vec::new(),
100 15 : };
101 :
102 15 : startup(&mut stream, config).await?;
103 15 : authenticate(&mut stream, config).await?;
104 7 : let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
105 :
106 7 : Ok(RawConnection {
107 7 : stream: stream.inner,
108 7 : parameters,
109 7 : delayed_notice: stream.delayed_notice,
110 7 : process_id,
111 7 : secret_key,
112 7 : })
113 0 : }
114 :
115 15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
116 15 : where
117 15 : S: AsyncRead + AsyncWrite + Unpin,
118 15 : T: AsyncRead + AsyncWrite + Unpin,
119 0 : {
120 15 : let mut buf = BytesMut::new();
121 15 : frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
122 :
123 15 : stream
124 15 : .send(FrontendMessage::Raw(buf.freeze()))
125 15 : .await
126 15 : .map_err(Error::io)
127 0 : }
128 :
129 15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
130 15 : where
131 15 : S: AsyncRead + AsyncWrite + Unpin,
132 15 : T: TlsStream + Unpin,
133 0 : {
134 15 : match stream.try_next().await.map_err(Error::io)? {
135 : Some(Message::AuthenticationOk) => {
136 2 : can_skip_channel_binding(config)?;
137 2 : return Ok(());
138 : }
139 : Some(Message::AuthenticationCleartextPassword) => {
140 0 : can_skip_channel_binding(config)?;
141 :
142 0 : let pass = config
143 0 : .password
144 0 : .as_ref()
145 0 : .ok_or_else(|| Error::config("password missing".into()))?;
146 :
147 0 : authenticate_password(stream, pass).await?;
148 : }
149 12 : Some(Message::AuthenticationSasl(body)) => {
150 12 : authenticate_sasl(stream, body, config).await?;
151 : }
152 : Some(Message::AuthenticationMd5Password)
153 : | Some(Message::AuthenticationKerberosV5)
154 : | Some(Message::AuthenticationScmCredential)
155 : | Some(Message::AuthenticationGss)
156 : | Some(Message::AuthenticationSspi) => {
157 0 : return Err(Error::authentication(
158 0 : "unsupported authentication method".into(),
159 0 : ));
160 : }
161 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
162 0 : Some(_) => return Err(Error::unexpected_message()),
163 0 : None => return Err(Error::closed()),
164 : }
165 :
166 5 : match stream.try_next().await.map_err(Error::io)? {
167 5 : Some(Message::AuthenticationOk) => Ok(()),
168 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
169 0 : Some(_) => Err(Error::unexpected_message()),
170 0 : None => Err(Error::closed()),
171 : }
172 0 : }
173 :
174 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
175 6 : match config.channel_binding {
176 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
177 1 : config::ChannelBinding::Require => Err(Error::authentication(
178 1 : "server did not use channel binding".into(),
179 1 : )),
180 : }
181 6 : }
182 :
183 0 : async fn authenticate_password<S, T>(
184 0 : stream: &mut StartupStream<S, T>,
185 0 : password: &[u8],
186 0 : ) -> Result<(), Error>
187 0 : where
188 0 : S: AsyncRead + AsyncWrite + Unpin,
189 0 : T: AsyncRead + AsyncWrite + Unpin,
190 0 : {
191 0 : let mut buf = BytesMut::new();
192 0 : frontend::password_message(password, &mut buf).map_err(Error::encode)?;
193 :
194 0 : stream
195 0 : .send(FrontendMessage::Raw(buf.freeze()))
196 0 : .await
197 0 : .map_err(Error::io)
198 0 : }
199 :
200 12 : async fn authenticate_sasl<S, T>(
201 12 : stream: &mut StartupStream<S, T>,
202 12 : body: AuthenticationSaslBody,
203 12 : config: &Config,
204 12 : ) -> Result<(), Error>
205 12 : where
206 12 : S: AsyncRead + AsyncWrite + Unpin,
207 12 : T: TlsStream + Unpin,
208 0 : {
209 12 : let mut has_scram = false;
210 12 : let mut has_scram_plus = false;
211 12 : let mut mechanisms = body.mechanisms();
212 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
213 22 : match mechanism {
214 22 : sasl::SCRAM_SHA_256 => has_scram = true,
215 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
216 0 : _ => {}
217 : }
218 : }
219 :
220 12 : let channel_binding = stream
221 12 : .inner
222 12 : .get_ref()
223 12 : .channel_binding()
224 12 : .tls_server_end_point
225 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
226 12 : .map(sasl::ChannelBinding::tls_server_end_point);
227 :
228 12 : let (channel_binding, mechanism) = if has_scram_plus {
229 10 : match channel_binding {
230 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
231 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
232 : }
233 2 : } else if has_scram {
234 2 : match channel_binding {
235 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
236 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
237 : }
238 : } else {
239 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
240 : };
241 :
242 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
243 4 : can_skip_channel_binding(config)?;
244 0 : }
245 :
246 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
247 0 : ScramSha256::new_with_keys(keys, channel_binding)
248 11 : } else if let Some(password) = config.get_password() {
249 11 : ScramSha256::new(password, channel_binding)
250 : } else {
251 0 : return Err(Error::config("password or auth keys missing".into()));
252 : };
253 :
254 11 : let mut buf = BytesMut::new();
255 11 : frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
256 11 : stream
257 11 : .send(FrontendMessage::Raw(buf.freeze()))
258 11 : .await
259 11 : .map_err(Error::io)?;
260 :
261 11 : let body = match stream.try_next().await.map_err(Error::io)? {
262 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
263 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
264 0 : Some(_) => return Err(Error::unexpected_message()),
265 0 : None => return Err(Error::closed()),
266 : };
267 :
268 10 : scram
269 10 : .update(body.data())
270 10 : .await
271 10 : .map_err(|e| Error::authentication(e.into()))?;
272 :
273 10 : let mut buf = BytesMut::new();
274 10 : frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
275 10 : stream
276 10 : .send(FrontendMessage::Raw(buf.freeze()))
277 10 : .await
278 10 : .map_err(Error::io)?;
279 :
280 10 : let body = match stream.try_next().await.map_err(Error::io)? {
281 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
282 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
283 0 : Some(_) => return Err(Error::unexpected_message()),
284 0 : None => return Err(Error::closed()),
285 : };
286 :
287 5 : scram
288 5 : .finish(body.data())
289 5 : .map_err(|e| Error::authentication(e.into()))?;
290 :
291 5 : Ok(())
292 0 : }
293 :
294 7 : async fn read_info<S, T>(
295 7 : stream: &mut StartupStream<S, T>,
296 7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
297 7 : where
298 7 : S: AsyncRead + AsyncWrite + Unpin,
299 7 : T: AsyncRead + AsyncWrite + Unpin,
300 0 : {
301 7 : let mut process_id = 0;
302 7 : let mut secret_key = 0;
303 7 : let mut parameters = HashMap::new();
304 :
305 : loop {
306 14 : match stream.try_next().await.map_err(Error::io)? {
307 0 : Some(Message::BackendKeyData(body)) => {
308 0 : process_id = body.process_id();
309 0 : secret_key = body.secret_key();
310 0 : }
311 7 : Some(Message::ParameterStatus(body)) => {
312 7 : parameters.insert(
313 7 : body.name().map_err(Error::parse)?.to_string(),
314 7 : body.value().map_err(Error::parse)?.to_string(),
315 : );
316 : }
317 0 : Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
318 7 : Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
319 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
320 0 : Some(_) => return Err(Error::unexpected_message()),
321 0 : None => return Err(Error::closed()),
322 : }
323 : }
324 0 : }
|