Line data Source code
1 : use std::collections::HashMap;
2 : use std::io;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll};
5 :
6 : use bytes::{Bytes, BytesMut};
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
9 : use postgres_protocol2::authentication::sasl;
10 : use postgres_protocol2::authentication::sasl::ScramSha256;
11 : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
12 : use postgres_protocol2::message::frontend;
13 : use tokio::io::{AsyncRead, AsyncWrite};
14 : use tokio_util::codec::Framed;
15 :
16 : use crate::Error;
17 : use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
18 : use crate::config::{self, AuthKeys, Config};
19 : use crate::maybe_tls_stream::MaybeTlsStream;
20 : use crate::tls::TlsStream;
21 :
22 : pub struct StartupStream<S, T> {
23 : inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
24 : buf: BackendMessages,
25 : delayed_notice: Vec<NoticeResponseBody>,
26 : }
27 :
28 : impl<S, T> Sink<Bytes> for StartupStream<S, T>
29 : where
30 : S: AsyncRead + AsyncWrite + Unpin,
31 : T: AsyncRead + AsyncWrite + Unpin,
32 : {
33 : type Error = io::Error;
34 :
35 36 : fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36 36 : Pin::new(&mut self.inner).poll_ready(cx)
37 0 : }
38 :
39 36 : fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
40 36 : Pin::new(&mut self.inner).start_send(item)
41 0 : }
42 :
43 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
44 36 : Pin::new(&mut self.inner).poll_flush(cx)
45 0 : }
46 :
47 0 : fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
48 0 : Pin::new(&mut self.inner).poll_close(cx)
49 0 : }
50 : }
51 :
52 : impl<S, T> Stream for StartupStream<S, T>
53 : where
54 : S: AsyncRead + AsyncWrite + Unpin,
55 : T: AsyncRead + AsyncWrite + Unpin,
56 : {
57 : type Item = io::Result<Message>;
58 :
59 91 : fn poll_next(
60 91 : mut self: Pin<&mut Self>,
61 91 : cx: &mut Context<'_>,
62 91 : ) -> Poll<Option<io::Result<Message>>> {
63 : loop {
64 129 : match self.buf.next() {
65 42 : Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
66 87 : Ok(None) => {}
67 0 : Err(e) => return Poll::Ready(Some(Err(e))),
68 : }
69 :
70 87 : match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
71 38 : Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
72 7 : Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
73 6 : Some(Err(e)) => return Poll::Ready(Some(Err(e))),
74 0 : None => return Poll::Ready(None),
75 : }
76 : }
77 0 : }
78 : }
79 :
80 : pub struct RawConnection<S, T> {
81 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
82 : pub parameters: HashMap<String, String>,
83 : pub delayed_notice: Vec<NoticeResponseBody>,
84 : pub process_id: i32,
85 : pub secret_key: i32,
86 : }
87 :
88 15 : pub async fn connect_raw<S, T>(
89 15 : stream: MaybeTlsStream<S, T>,
90 15 : config: &Config,
91 15 : ) -> Result<RawConnection<S, T>, Error>
92 15 : where
93 15 : S: AsyncRead + AsyncWrite + Unpin,
94 15 : T: TlsStream + Unpin,
95 0 : {
96 15 : let mut stream = StartupStream {
97 15 : inner: Framed::new(stream, PostgresCodec),
98 15 : buf: BackendMessages::empty(),
99 15 : delayed_notice: Vec::new(),
100 15 : };
101 :
102 15 : startup(&mut stream, config).await?;
103 15 : authenticate(&mut stream, config).await?;
104 7 : let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
105 :
106 7 : Ok(RawConnection {
107 7 : stream: stream.inner,
108 7 : parameters,
109 7 : delayed_notice: stream.delayed_notice,
110 7 : process_id,
111 7 : secret_key,
112 7 : })
113 0 : }
114 :
115 15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
116 15 : where
117 15 : S: AsyncRead + AsyncWrite + Unpin,
118 15 : T: AsyncRead + AsyncWrite + Unpin,
119 0 : {
120 15 : let mut buf = BytesMut::new();
121 15 : frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
122 :
123 15 : stream.send(buf.freeze()).await.map_err(Error::io)
124 0 : }
125 :
126 15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
127 15 : where
128 15 : S: AsyncRead + AsyncWrite + Unpin,
129 15 : T: TlsStream + Unpin,
130 0 : {
131 15 : match stream.try_next().await.map_err(Error::io)? {
132 : Some(Message::AuthenticationOk) => {
133 2 : can_skip_channel_binding(config)?;
134 2 : return Ok(());
135 : }
136 : Some(Message::AuthenticationCleartextPassword) => {
137 0 : can_skip_channel_binding(config)?;
138 :
139 0 : let pass = config
140 0 : .password
141 0 : .as_ref()
142 0 : .ok_or_else(|| Error::config("password missing".into()))?;
143 :
144 0 : authenticate_password(stream, pass).await?;
145 : }
146 12 : Some(Message::AuthenticationSasl(body)) => {
147 12 : authenticate_sasl(stream, body, config).await?;
148 : }
149 : Some(Message::AuthenticationMd5Password)
150 : | Some(Message::AuthenticationKerberosV5)
151 : | Some(Message::AuthenticationScmCredential)
152 : | Some(Message::AuthenticationGss)
153 : | Some(Message::AuthenticationSspi) => {
154 0 : return Err(Error::authentication(
155 0 : "unsupported authentication method".into(),
156 0 : ));
157 : }
158 1 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
159 0 : Some(_) => return Err(Error::unexpected_message()),
160 0 : None => return Err(Error::closed()),
161 : }
162 :
163 5 : match stream.try_next().await.map_err(Error::io)? {
164 5 : Some(Message::AuthenticationOk) => Ok(()),
165 0 : Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
166 0 : Some(_) => Err(Error::unexpected_message()),
167 0 : None => Err(Error::closed()),
168 : }
169 0 : }
170 :
171 6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
172 6 : match config.channel_binding {
173 5 : config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
174 1 : config::ChannelBinding::Require => Err(Error::authentication(
175 1 : "server did not use channel binding".into(),
176 1 : )),
177 : }
178 6 : }
179 :
180 0 : async fn authenticate_password<S, T>(
181 0 : stream: &mut StartupStream<S, T>,
182 0 : password: &[u8],
183 0 : ) -> Result<(), Error>
184 0 : where
185 0 : S: AsyncRead + AsyncWrite + Unpin,
186 0 : T: AsyncRead + AsyncWrite + Unpin,
187 0 : {
188 0 : let mut buf = BytesMut::new();
189 0 : frontend::password_message(password, &mut buf).map_err(Error::encode)?;
190 :
191 0 : stream.send(buf.freeze()).await.map_err(Error::io)
192 0 : }
193 :
194 12 : async fn authenticate_sasl<S, T>(
195 12 : stream: &mut StartupStream<S, T>,
196 12 : body: AuthenticationSaslBody,
197 12 : config: &Config,
198 12 : ) -> Result<(), Error>
199 12 : where
200 12 : S: AsyncRead + AsyncWrite + Unpin,
201 12 : T: TlsStream + Unpin,
202 0 : {
203 12 : let mut has_scram = false;
204 12 : let mut has_scram_plus = false;
205 12 : let mut mechanisms = body.mechanisms();
206 34 : while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
207 22 : match mechanism {
208 22 : sasl::SCRAM_SHA_256 => has_scram = true,
209 10 : sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
210 0 : _ => {}
211 : }
212 : }
213 :
214 12 : let channel_binding = stream
215 12 : .inner
216 12 : .get_ref()
217 12 : .channel_binding()
218 12 : .tls_server_end_point
219 12 : .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
220 12 : .map(sasl::ChannelBinding::tls_server_end_point);
221 :
222 12 : let (channel_binding, mechanism) = if has_scram_plus {
223 10 : match channel_binding {
224 8 : Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
225 2 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
226 : }
227 2 : } else if has_scram {
228 2 : match channel_binding {
229 2 : Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
230 0 : None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
231 : }
232 : } else {
233 0 : return Err(Error::authentication("unsupported SASL mechanism".into()));
234 : };
235 :
236 12 : if mechanism != sasl::SCRAM_SHA_256_PLUS {
237 4 : can_skip_channel_binding(config)?;
238 0 : }
239 :
240 11 : let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
241 0 : ScramSha256::new_with_keys(keys, channel_binding)
242 11 : } else if let Some(password) = config.get_password() {
243 11 : ScramSha256::new(password, channel_binding)
244 : } else {
245 0 : return Err(Error::config("password or auth keys missing".into()));
246 : };
247 :
248 11 : let mut buf = BytesMut::new();
249 11 : frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
250 11 : stream.send(buf.freeze()).await.map_err(Error::io)?;
251 :
252 11 : let body = match stream.try_next().await.map_err(Error::io)? {
253 10 : Some(Message::AuthenticationSaslContinue(body)) => body,
254 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
255 0 : Some(_) => return Err(Error::unexpected_message()),
256 0 : None => return Err(Error::closed()),
257 : };
258 :
259 10 : scram
260 10 : .update(body.data())
261 10 : .await
262 10 : .map_err(|e| Error::authentication(e.into()))?;
263 :
264 10 : let mut buf = BytesMut::new();
265 10 : frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
266 10 : stream.send(buf.freeze()).await.map_err(Error::io)?;
267 :
268 10 : let body = match stream.try_next().await.map_err(Error::io)? {
269 5 : Some(Message::AuthenticationSaslFinal(body)) => body,
270 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
271 0 : Some(_) => return Err(Error::unexpected_message()),
272 0 : None => return Err(Error::closed()),
273 : };
274 :
275 5 : scram
276 5 : .finish(body.data())
277 5 : .map_err(|e| Error::authentication(e.into()))?;
278 :
279 5 : Ok(())
280 0 : }
281 :
282 7 : async fn read_info<S, T>(
283 7 : stream: &mut StartupStream<S, T>,
284 7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
285 7 : where
286 7 : S: AsyncRead + AsyncWrite + Unpin,
287 7 : T: AsyncRead + AsyncWrite + Unpin,
288 0 : {
289 7 : let mut process_id = 0;
290 7 : let mut secret_key = 0;
291 7 : let mut parameters = HashMap::new();
292 :
293 : loop {
294 14 : match stream.try_next().await.map_err(Error::io)? {
295 0 : Some(Message::BackendKeyData(body)) => {
296 0 : process_id = body.process_id();
297 0 : secret_key = body.secret_key();
298 0 : }
299 7 : Some(Message::ParameterStatus(body)) => {
300 7 : parameters.insert(
301 7 : body.name().map_err(Error::parse)?.to_string(),
302 7 : body.value().map_err(Error::parse)?.to_string(),
303 : );
304 : }
305 0 : Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
306 7 : Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
307 0 : Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
308 0 : Some(_) => return Err(Error::unexpected_message()),
309 0 : None => return Err(Error::closed()),
310 : }
311 : }
312 0 : }
|