LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: 5fe7fa8d483b39476409aee736d6d5e32728bfac.info Lines: 71.7 % 212 152
Test Date: 2025-03-12 16:10:49 Functions: 27.5 % 102 28

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::io;
       3              : use std::pin::Pin;
       4              : use std::task::{Context, Poll};
       5              : 
       6              : use bytes::BytesMut;
       7              : use fallible_iterator::FallibleIterator;
       8              : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
       9              : use postgres_protocol2::authentication::sasl;
      10              : use postgres_protocol2::authentication::sasl::ScramSha256;
      11              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
      12              : use postgres_protocol2::message::frontend;
      13              : use tokio::io::{AsyncRead, AsyncWrite};
      14              : use tokio_util::codec::Framed;
      15              : 
      16              : use crate::Error;
      17              : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
      18              : use crate::config::{self, AuthKeys, Config};
      19              : use crate::connect_tls::connect_tls;
      20              : use crate::maybe_tls_stream::MaybeTlsStream;
      21              : use crate::tls::{TlsConnect, TlsStream};
      22              : 
      23              : pub struct StartupStream<S, T> {
      24              :     inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      25              :     buf: BackendMessages,
      26              :     delayed_notice: Vec<NoticeResponseBody>,
      27              : }
      28              : 
      29              : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
      30              : where
      31              :     S: AsyncRead + AsyncWrite + Unpin,
      32              :     T: AsyncRead + AsyncWrite + Unpin,
      33              : {
      34              :     type Error = io::Error;
      35              : 
      36           36 :     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      37           36 :         Pin::new(&mut self.inner).poll_ready(cx)
      38           36 :     }
      39              : 
      40           36 :     fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
      41           36 :         Pin::new(&mut self.inner).start_send(item)
      42           36 :     }
      43              : 
      44           36 :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      45           36 :         Pin::new(&mut self.inner).poll_flush(cx)
      46           36 :     }
      47              : 
      48            0 :     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      49            0 :         Pin::new(&mut self.inner).poll_close(cx)
      50            0 :     }
      51              : }
      52              : 
      53              : impl<S, T> Stream for StartupStream<S, T>
      54              : where
      55              :     S: AsyncRead + AsyncWrite + Unpin,
      56              :     T: AsyncRead + AsyncWrite + Unpin,
      57              : {
      58              :     type Item = io::Result<Message>;
      59              : 
      60           91 :     fn poll_next(
      61           91 :         mut self: Pin<&mut Self>,
      62           91 :         cx: &mut Context<'_>,
      63           91 :     ) -> Poll<Option<io::Result<Message>>> {
      64              :         loop {
      65          128 :             match self.buf.next() {
      66           42 :                 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
      67           86 :                 Ok(None) => {}
      68            0 :                 Err(e) => return Poll::Ready(Some(Err(e))),
      69              :             }
      70              : 
      71           86 :             match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
      72           37 :                 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
      73            7 :                 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
      74            6 :                 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
      75            0 :                 None => return Poll::Ready(None),
      76              :             }
      77              :         }
      78            0 :     }
      79              : }
      80              : 
      81              : pub struct RawConnection<S, T> {
      82              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      83              :     pub parameters: HashMap<String, String>,
      84              :     pub delayed_notice: Vec<NoticeResponseBody>,
      85              :     pub process_id: i32,
      86              :     pub secret_key: i32,
      87              : }
      88              : 
      89           15 : pub async fn connect_raw<S, T>(
      90           15 :     stream: S,
      91           15 :     tls: T,
      92           15 :     config: &Config,
      93           15 : ) -> Result<RawConnection<S, T::Stream>, Error>
      94           15 : where
      95           15 :     S: AsyncRead + AsyncWrite + Unpin,
      96           15 :     T: TlsConnect<S>,
      97           15 : {
      98           15 :     let stream = connect_tls(stream, config.ssl_mode, tls).await?;
      99              : 
     100           15 :     let mut stream = StartupStream {
     101           15 :         inner: Framed::new(stream, PostgresCodec),
     102           15 :         buf: BackendMessages::empty(),
     103           15 :         delayed_notice: Vec::new(),
     104           15 :     };
     105           15 : 
     106           15 :     startup(&mut stream, config).await?;
     107           15 :     authenticate(&mut stream, config).await?;
     108            7 :     let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
     109              : 
     110            7 :     Ok(RawConnection {
     111            7 :         stream: stream.inner,
     112            7 :         parameters,
     113            7 :         delayed_notice: stream.delayed_notice,
     114            7 :         process_id,
     115            7 :         secret_key,
     116            7 :     })
     117            0 : }
     118              : 
     119           15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     120           15 : where
     121           15 :     S: AsyncRead + AsyncWrite + Unpin,
     122           15 :     T: AsyncRead + AsyncWrite + Unpin,
     123           15 : {
     124           15 :     let mut buf = BytesMut::new();
     125           15 :     frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
     126              : 
     127           15 :     stream
     128           15 :         .send(FrontendMessage::Raw(buf.freeze()))
     129           15 :         .await
     130           15 :         .map_err(Error::io)
     131            0 : }
     132              : 
     133           15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     134           15 : where
     135           15 :     S: AsyncRead + AsyncWrite + Unpin,
     136           15 :     T: TlsStream + Unpin,
     137           15 : {
     138           15 :     match stream.try_next().await.map_err(Error::io)? {
     139              :         Some(Message::AuthenticationOk) => {
     140            2 :             can_skip_channel_binding(config)?;
     141            2 :             return Ok(());
     142              :         }
     143              :         Some(Message::AuthenticationCleartextPassword) => {
     144            0 :             can_skip_channel_binding(config)?;
     145              : 
     146            0 :             let pass = config
     147            0 :                 .password
     148            0 :                 .as_ref()
     149            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     150              : 
     151            0 :             authenticate_password(stream, pass).await?;
     152              :         }
     153           12 :         Some(Message::AuthenticationSasl(body)) => {
     154           12 :             authenticate_sasl(stream, body, config).await?;
     155              :         }
     156              :         Some(Message::AuthenticationMd5Password)
     157              :         | Some(Message::AuthenticationKerberosV5)
     158              :         | Some(Message::AuthenticationScmCredential)
     159              :         | Some(Message::AuthenticationGss)
     160              :         | Some(Message::AuthenticationSspi) => {
     161            0 :             return Err(Error::authentication(
     162            0 :                 "unsupported authentication method".into(),
     163            0 :             ));
     164              :         }
     165            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     166            0 :         Some(_) => return Err(Error::unexpected_message()),
     167            0 :         None => return Err(Error::closed()),
     168              :     }
     169              : 
     170            5 :     match stream.try_next().await.map_err(Error::io)? {
     171            5 :         Some(Message::AuthenticationOk) => Ok(()),
     172            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     173            0 :         Some(_) => Err(Error::unexpected_message()),
     174            0 :         None => Err(Error::closed()),
     175              :     }
     176            0 : }
     177              : 
     178            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     179            6 :     match config.channel_binding {
     180            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     181            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     182            1 :             "server did not use channel binding".into(),
     183            1 :         )),
     184              :     }
     185            6 : }
     186              : 
     187            0 : async fn authenticate_password<S, T>(
     188            0 :     stream: &mut StartupStream<S, T>,
     189            0 :     password: &[u8],
     190            0 : ) -> Result<(), Error>
     191            0 : where
     192            0 :     S: AsyncRead + AsyncWrite + Unpin,
     193            0 :     T: AsyncRead + AsyncWrite + Unpin,
     194            0 : {
     195            0 :     let mut buf = BytesMut::new();
     196            0 :     frontend::password_message(password, &mut buf).map_err(Error::encode)?;
     197              : 
     198            0 :     stream
     199            0 :         .send(FrontendMessage::Raw(buf.freeze()))
     200            0 :         .await
     201            0 :         .map_err(Error::io)
     202            0 : }
     203              : 
     204           12 : async fn authenticate_sasl<S, T>(
     205           12 :     stream: &mut StartupStream<S, T>,
     206           12 :     body: AuthenticationSaslBody,
     207           12 :     config: &Config,
     208           12 : ) -> Result<(), Error>
     209           12 : where
     210           12 :     S: AsyncRead + AsyncWrite + Unpin,
     211           12 :     T: TlsStream + Unpin,
     212           12 : {
     213           12 :     let mut has_scram = false;
     214           12 :     let mut has_scram_plus = false;
     215           12 :     let mut mechanisms = body.mechanisms();
     216           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     217           22 :         match mechanism {
     218           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     219           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     220            0 :             _ => {}
     221              :         }
     222              :     }
     223              : 
     224           12 :     let channel_binding = stream
     225           12 :         .inner
     226           12 :         .get_ref()
     227           12 :         .channel_binding()
     228           12 :         .tls_server_end_point
     229           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     230           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     231              : 
     232           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     233           10 :         match channel_binding {
     234            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     235            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     236              :         }
     237            2 :     } else if has_scram {
     238            2 :         match channel_binding {
     239            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     240            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     241              :         }
     242              :     } else {
     243            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     244              :     };
     245              : 
     246           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     247            4 :         can_skip_channel_binding(config)?;
     248            0 :     }
     249              : 
     250           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     251            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     252           11 :     } else if let Some(password) = config.get_password() {
     253           11 :         ScramSha256::new(password, channel_binding)
     254              :     } else {
     255            0 :         return Err(Error::config("password or auth keys missing".into()));
     256              :     };
     257              : 
     258           11 :     let mut buf = BytesMut::new();
     259           11 :     frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
     260           11 :     stream
     261           11 :         .send(FrontendMessage::Raw(buf.freeze()))
     262           11 :         .await
     263           11 :         .map_err(Error::io)?;
     264              : 
     265           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     266           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     267            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     268            0 :         Some(_) => return Err(Error::unexpected_message()),
     269            0 :         None => return Err(Error::closed()),
     270              :     };
     271              : 
     272           10 :     scram
     273           10 :         .update(body.data())
     274           10 :         .await
     275           10 :         .map_err(|e| Error::authentication(e.into()))?;
     276              : 
     277           10 :     let mut buf = BytesMut::new();
     278           10 :     frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
     279           10 :     stream
     280           10 :         .send(FrontendMessage::Raw(buf.freeze()))
     281           10 :         .await
     282           10 :         .map_err(Error::io)?;
     283              : 
     284           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     285            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     286            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     287            0 :         Some(_) => return Err(Error::unexpected_message()),
     288            0 :         None => return Err(Error::closed()),
     289              :     };
     290              : 
     291            5 :     scram
     292            5 :         .finish(body.data())
     293            5 :         .map_err(|e| Error::authentication(e.into()))?;
     294              : 
     295            5 :     Ok(())
     296            0 : }
     297              : 
     298            7 : async fn read_info<S, T>(
     299            7 :     stream: &mut StartupStream<S, T>,
     300            7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
     301            7 : where
     302            7 :     S: AsyncRead + AsyncWrite + Unpin,
     303            7 :     T: AsyncRead + AsyncWrite + Unpin,
     304            7 : {
     305            7 :     let mut process_id = 0;
     306            7 :     let mut secret_key = 0;
     307            7 :     let mut parameters = HashMap::new();
     308              : 
     309              :     loop {
     310           14 :         match stream.try_next().await.map_err(Error::io)? {
     311            0 :             Some(Message::BackendKeyData(body)) => {
     312            0 :                 process_id = body.process_id();
     313            0 :                 secret_key = body.secret_key();
     314            0 :             }
     315            7 :             Some(Message::ParameterStatus(body)) => {
     316            7 :                 parameters.insert(
     317            7 :                     body.name().map_err(Error::parse)?.to_string(),
     318            7 :                     body.value().map_err(Error::parse)?.to_string(),
     319              :                 );
     320              :             }
     321            0 :             Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
     322            7 :             Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
     323            0 :             Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     324            0 :             Some(_) => return Err(Error::unexpected_message()),
     325            0 :             None => return Err(Error::closed()),
     326              :         }
     327              :     }
     328            0 : }
        

Generated by: LCOV version 2.1-beta