LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 67.5 % 209 141
Test Date: 2025-07-16 12:29:03 Functions: 23.0 % 122 28

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::io;
       3              : use std::pin::Pin;
       4              : use std::task::{Context, Poll};
       5              : 
       6              : use bytes::BytesMut;
       7              : use fallible_iterator::FallibleIterator;
       8              : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
       9              : use postgres_protocol2::authentication::sasl;
      10              : use postgres_protocol2::authentication::sasl::ScramSha256;
      11              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
      12              : use postgres_protocol2::message::frontend;
      13              : use tokio::io::{AsyncRead, AsyncWrite};
      14              : use tokio_util::codec::Framed;
      15              : 
      16              : use crate::Error;
      17              : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
      18              : use crate::config::{self, AuthKeys, Config};
      19              : use crate::maybe_tls_stream::MaybeTlsStream;
      20              : use crate::tls::TlsStream;
      21              : 
      22              : pub struct StartupStream<S, T> {
      23              :     inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      24              :     buf: BackendMessages,
      25              :     delayed_notice: Vec<NoticeResponseBody>,
      26              : }
      27              : 
      28              : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
      29              : where
      30              :     S: AsyncRead + AsyncWrite + Unpin,
      31              :     T: AsyncRead + AsyncWrite + Unpin,
      32              : {
      33              :     type Error = io::Error;
      34              : 
      35           36 :     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      36           36 :         Pin::new(&mut self.inner).poll_ready(cx)
      37            0 :     }
      38              : 
      39           36 :     fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
      40           36 :         Pin::new(&mut self.inner).start_send(item)
      41            0 :     }
      42              : 
      43           36 :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      44           36 :         Pin::new(&mut self.inner).poll_flush(cx)
      45            0 :     }
      46              : 
      47            0 :     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      48            0 :         Pin::new(&mut self.inner).poll_close(cx)
      49            0 :     }
      50              : }
      51              : 
      52              : impl<S, T> Stream for StartupStream<S, T>
      53              : where
      54              :     S: AsyncRead + AsyncWrite + Unpin,
      55              :     T: AsyncRead + AsyncWrite + Unpin,
      56              : {
      57              :     type Item = io::Result<Message>;
      58              : 
      59           91 :     fn poll_next(
      60           91 :         mut self: Pin<&mut Self>,
      61           91 :         cx: &mut Context<'_>,
      62           91 :     ) -> Poll<Option<io::Result<Message>>> {
      63              :         loop {
      64          129 :             match self.buf.next() {
      65           42 :                 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
      66           87 :                 Ok(None) => {}
      67            0 :                 Err(e) => return Poll::Ready(Some(Err(e))),
      68              :             }
      69              : 
      70           87 :             match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
      71           38 :                 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
      72            7 :                 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
      73            6 :                 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
      74            0 :                 None => return Poll::Ready(None),
      75              :             }
      76              :         }
      77            0 :     }
      78              : }
      79              : 
      80              : pub struct RawConnection<S, T> {
      81              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      82              :     pub parameters: HashMap<String, String>,
      83              :     pub delayed_notice: Vec<NoticeResponseBody>,
      84              :     pub process_id: i32,
      85              :     pub secret_key: i32,
      86              : }
      87              : 
      88           15 : pub async fn connect_raw<S, T>(
      89           15 :     stream: MaybeTlsStream<S, T>,
      90           15 :     config: &Config,
      91           15 : ) -> Result<RawConnection<S, T>, Error>
      92           15 : where
      93           15 :     S: AsyncRead + AsyncWrite + Unpin,
      94           15 :     T: TlsStream + Unpin,
      95            0 : {
      96           15 :     let mut stream = StartupStream {
      97           15 :         inner: Framed::new(stream, PostgresCodec),
      98           15 :         buf: BackendMessages::empty(),
      99           15 :         delayed_notice: Vec::new(),
     100           15 :     };
     101              : 
     102           15 :     startup(&mut stream, config).await?;
     103           15 :     authenticate(&mut stream, config).await?;
     104            7 :     let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
     105              : 
     106            7 :     Ok(RawConnection {
     107            7 :         stream: stream.inner,
     108            7 :         parameters,
     109            7 :         delayed_notice: stream.delayed_notice,
     110            7 :         process_id,
     111            7 :         secret_key,
     112            7 :     })
     113            0 : }
     114              : 
     115           15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     116           15 : where
     117           15 :     S: AsyncRead + AsyncWrite + Unpin,
     118           15 :     T: AsyncRead + AsyncWrite + Unpin,
     119            0 : {
     120           15 :     let mut buf = BytesMut::new();
     121           15 :     frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
     122              : 
     123           15 :     stream
     124           15 :         .send(FrontendMessage::Raw(buf.freeze()))
     125           15 :         .await
     126           15 :         .map_err(Error::io)
     127            0 : }
     128              : 
     129           15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     130           15 : where
     131           15 :     S: AsyncRead + AsyncWrite + Unpin,
     132           15 :     T: TlsStream + Unpin,
     133            0 : {
     134           15 :     match stream.try_next().await.map_err(Error::io)? {
     135              :         Some(Message::AuthenticationOk) => {
     136            2 :             can_skip_channel_binding(config)?;
     137            2 :             return Ok(());
     138              :         }
     139              :         Some(Message::AuthenticationCleartextPassword) => {
     140            0 :             can_skip_channel_binding(config)?;
     141              : 
     142            0 :             let pass = config
     143            0 :                 .password
     144            0 :                 .as_ref()
     145            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     146              : 
     147            0 :             authenticate_password(stream, pass).await?;
     148              :         }
     149           12 :         Some(Message::AuthenticationSasl(body)) => {
     150           12 :             authenticate_sasl(stream, body, config).await?;
     151              :         }
     152              :         Some(Message::AuthenticationMd5Password)
     153              :         | Some(Message::AuthenticationKerberosV5)
     154              :         | Some(Message::AuthenticationScmCredential)
     155              :         | Some(Message::AuthenticationGss)
     156              :         | Some(Message::AuthenticationSspi) => {
     157            0 :             return Err(Error::authentication(
     158            0 :                 "unsupported authentication method".into(),
     159            0 :             ));
     160              :         }
     161            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     162            0 :         Some(_) => return Err(Error::unexpected_message()),
     163            0 :         None => return Err(Error::closed()),
     164              :     }
     165              : 
     166            5 :     match stream.try_next().await.map_err(Error::io)? {
     167            5 :         Some(Message::AuthenticationOk) => Ok(()),
     168            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     169            0 :         Some(_) => Err(Error::unexpected_message()),
     170            0 :         None => Err(Error::closed()),
     171              :     }
     172            0 : }
     173              : 
     174            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     175            6 :     match config.channel_binding {
     176            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     177            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     178            1 :             "server did not use channel binding".into(),
     179            1 :         )),
     180              :     }
     181            6 : }
     182              : 
     183            0 : async fn authenticate_password<S, T>(
     184            0 :     stream: &mut StartupStream<S, T>,
     185            0 :     password: &[u8],
     186            0 : ) -> Result<(), Error>
     187            0 : where
     188            0 :     S: AsyncRead + AsyncWrite + Unpin,
     189            0 :     T: AsyncRead + AsyncWrite + Unpin,
     190            0 : {
     191            0 :     let mut buf = BytesMut::new();
     192            0 :     frontend::password_message(password, &mut buf).map_err(Error::encode)?;
     193              : 
     194            0 :     stream
     195            0 :         .send(FrontendMessage::Raw(buf.freeze()))
     196            0 :         .await
     197            0 :         .map_err(Error::io)
     198            0 : }
     199              : 
     200           12 : async fn authenticate_sasl<S, T>(
     201           12 :     stream: &mut StartupStream<S, T>,
     202           12 :     body: AuthenticationSaslBody,
     203           12 :     config: &Config,
     204           12 : ) -> Result<(), Error>
     205           12 : where
     206           12 :     S: AsyncRead + AsyncWrite + Unpin,
     207           12 :     T: TlsStream + Unpin,
     208            0 : {
     209           12 :     let mut has_scram = false;
     210           12 :     let mut has_scram_plus = false;
     211           12 :     let mut mechanisms = body.mechanisms();
     212           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     213           22 :         match mechanism {
     214           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     215           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     216            0 :             _ => {}
     217              :         }
     218              :     }
     219              : 
     220           12 :     let channel_binding = stream
     221           12 :         .inner
     222           12 :         .get_ref()
     223           12 :         .channel_binding()
     224           12 :         .tls_server_end_point
     225           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     226           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     227              : 
     228           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     229           10 :         match channel_binding {
     230            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     231            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     232              :         }
     233            2 :     } else if has_scram {
     234            2 :         match channel_binding {
     235            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     236            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     237              :         }
     238              :     } else {
     239            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     240              :     };
     241              : 
     242           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     243            4 :         can_skip_channel_binding(config)?;
     244            0 :     }
     245              : 
     246           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     247            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     248           11 :     } else if let Some(password) = config.get_password() {
     249           11 :         ScramSha256::new(password, channel_binding)
     250              :     } else {
     251            0 :         return Err(Error::config("password or auth keys missing".into()));
     252              :     };
     253              : 
     254           11 :     let mut buf = BytesMut::new();
     255           11 :     frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
     256           11 :     stream
     257           11 :         .send(FrontendMessage::Raw(buf.freeze()))
     258           11 :         .await
     259           11 :         .map_err(Error::io)?;
     260              : 
     261           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     262           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     263            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     264            0 :         Some(_) => return Err(Error::unexpected_message()),
     265            0 :         None => return Err(Error::closed()),
     266              :     };
     267              : 
     268           10 :     scram
     269           10 :         .update(body.data())
     270           10 :         .await
     271           10 :         .map_err(|e| Error::authentication(e.into()))?;
     272              : 
     273           10 :     let mut buf = BytesMut::new();
     274           10 :     frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
     275           10 :     stream
     276           10 :         .send(FrontendMessage::Raw(buf.freeze()))
     277           10 :         .await
     278           10 :         .map_err(Error::io)?;
     279              : 
     280           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     281            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     282            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     283            0 :         Some(_) => return Err(Error::unexpected_message()),
     284            0 :         None => return Err(Error::closed()),
     285              :     };
     286              : 
     287            5 :     scram
     288            5 :         .finish(body.data())
     289            5 :         .map_err(|e| Error::authentication(e.into()))?;
     290              : 
     291            5 :     Ok(())
     292            0 : }
     293              : 
     294            7 : async fn read_info<S, T>(
     295            7 :     stream: &mut StartupStream<S, T>,
     296            7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
     297            7 : where
     298            7 :     S: AsyncRead + AsyncWrite + Unpin,
     299            7 :     T: AsyncRead + AsyncWrite + Unpin,
     300            0 : {
     301            7 :     let mut process_id = 0;
     302            7 :     let mut secret_key = 0;
     303            7 :     let mut parameters = HashMap::new();
     304              : 
     305              :     loop {
     306           14 :         match stream.try_next().await.map_err(Error::io)? {
     307            0 :             Some(Message::BackendKeyData(body)) => {
     308            0 :                 process_id = body.process_id();
     309            0 :                 secret_key = body.secret_key();
     310            0 :             }
     311            7 :             Some(Message::ParameterStatus(body)) => {
     312            7 :                 parameters.insert(
     313            7 :                     body.name().map_err(Error::parse)?.to_string(),
     314            7 :                     body.value().map_err(Error::parse)?.to_string(),
     315              :                 );
     316              :             }
     317            0 :             Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
     318            7 :             Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
     319            0 :             Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     320            0 :             Some(_) => return Err(Error::unexpected_message()),
     321            0 :             None => return Err(Error::closed()),
     322              :         }
     323              :     }
     324            0 : }
        

Generated by: LCOV version 2.1-beta