LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 67.0 % 197 132
Test Date: 2025-07-22 17:50:06 Functions: 23.0 % 122 28

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::io;
       3              : use std::pin::Pin;
       4              : use std::task::{Context, Poll};
       5              : 
       6              : use bytes::{Bytes, BytesMut};
       7              : use fallible_iterator::FallibleIterator;
       8              : use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
       9              : use postgres_protocol2::authentication::sasl;
      10              : use postgres_protocol2::authentication::sasl::ScramSha256;
      11              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
      12              : use postgres_protocol2::message::frontend;
      13              : use tokio::io::{AsyncRead, AsyncWrite};
      14              : use tokio_util::codec::Framed;
      15              : 
      16              : use crate::Error;
      17              : use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
      18              : use crate::config::{self, AuthKeys, Config};
      19              : use crate::maybe_tls_stream::MaybeTlsStream;
      20              : use crate::tls::TlsStream;
      21              : 
      22              : pub struct StartupStream<S, T> {
      23              :     inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      24              :     buf: BackendMessages,
      25              :     delayed_notice: Vec<NoticeResponseBody>,
      26              : }
      27              : 
      28              : impl<S, T> Sink<Bytes> for StartupStream<S, T>
      29              : where
      30              :     S: AsyncRead + AsyncWrite + Unpin,
      31              :     T: AsyncRead + AsyncWrite + Unpin,
      32              : {
      33              :     type Error = io::Error;
      34              : 
      35           36 :     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      36           36 :         Pin::new(&mut self.inner).poll_ready(cx)
      37            0 :     }
      38              : 
      39           36 :     fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
      40           36 :         Pin::new(&mut self.inner).start_send(item)
      41            0 :     }
      42              : 
      43           36 :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      44           36 :         Pin::new(&mut self.inner).poll_flush(cx)
      45            0 :     }
      46              : 
      47            0 :     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      48            0 :         Pin::new(&mut self.inner).poll_close(cx)
      49            0 :     }
      50              : }
      51              : 
      52              : impl<S, T> Stream for StartupStream<S, T>
      53              : where
      54              :     S: AsyncRead + AsyncWrite + Unpin,
      55              :     T: AsyncRead + AsyncWrite + Unpin,
      56              : {
      57              :     type Item = io::Result<Message>;
      58              : 
      59           91 :     fn poll_next(
      60           91 :         mut self: Pin<&mut Self>,
      61           91 :         cx: &mut Context<'_>,
      62           91 :     ) -> Poll<Option<io::Result<Message>>> {
      63              :         loop {
      64          129 :             match self.buf.next() {
      65           42 :                 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
      66           87 :                 Ok(None) => {}
      67            0 :                 Err(e) => return Poll::Ready(Some(Err(e))),
      68              :             }
      69              : 
      70           87 :             match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
      71           38 :                 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
      72            7 :                 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
      73            6 :                 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
      74            0 :                 None => return Poll::Ready(None),
      75              :             }
      76              :         }
      77            0 :     }
      78              : }
      79              : 
      80              : pub struct RawConnection<S, T> {
      81              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      82              :     pub parameters: HashMap<String, String>,
      83              :     pub delayed_notice: Vec<NoticeResponseBody>,
      84              :     pub process_id: i32,
      85              :     pub secret_key: i32,
      86              : }
      87              : 
      88           15 : pub async fn connect_raw<S, T>(
      89           15 :     stream: MaybeTlsStream<S, T>,
      90           15 :     config: &Config,
      91           15 : ) -> Result<RawConnection<S, T>, Error>
      92           15 : where
      93           15 :     S: AsyncRead + AsyncWrite + Unpin,
      94           15 :     T: TlsStream + Unpin,
      95            0 : {
      96           15 :     let mut stream = StartupStream {
      97           15 :         inner: Framed::new(stream, PostgresCodec),
      98           15 :         buf: BackendMessages::empty(),
      99           15 :         delayed_notice: Vec::new(),
     100           15 :     };
     101              : 
     102           15 :     startup(&mut stream, config).await?;
     103           15 :     authenticate(&mut stream, config).await?;
     104            7 :     let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
     105              : 
     106            7 :     Ok(RawConnection {
     107            7 :         stream: stream.inner,
     108            7 :         parameters,
     109            7 :         delayed_notice: stream.delayed_notice,
     110            7 :         process_id,
     111            7 :         secret_key,
     112            7 :     })
     113            0 : }
     114              : 
     115           15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     116           15 : where
     117           15 :     S: AsyncRead + AsyncWrite + Unpin,
     118           15 :     T: AsyncRead + AsyncWrite + Unpin,
     119            0 : {
     120           15 :     let mut buf = BytesMut::new();
     121           15 :     frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
     122              : 
     123           15 :     stream.send(buf.freeze()).await.map_err(Error::io)
     124            0 : }
     125              : 
     126           15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     127           15 : where
     128           15 :     S: AsyncRead + AsyncWrite + Unpin,
     129           15 :     T: TlsStream + Unpin,
     130            0 : {
     131           15 :     match stream.try_next().await.map_err(Error::io)? {
     132              :         Some(Message::AuthenticationOk) => {
     133            2 :             can_skip_channel_binding(config)?;
     134            2 :             return Ok(());
     135              :         }
     136              :         Some(Message::AuthenticationCleartextPassword) => {
     137            0 :             can_skip_channel_binding(config)?;
     138              : 
     139            0 :             let pass = config
     140            0 :                 .password
     141            0 :                 .as_ref()
     142            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     143              : 
     144            0 :             authenticate_password(stream, pass).await?;
     145              :         }
     146           12 :         Some(Message::AuthenticationSasl(body)) => {
     147           12 :             authenticate_sasl(stream, body, config).await?;
     148              :         }
     149              :         Some(Message::AuthenticationMd5Password)
     150              :         | Some(Message::AuthenticationKerberosV5)
     151              :         | Some(Message::AuthenticationScmCredential)
     152              :         | Some(Message::AuthenticationGss)
     153              :         | Some(Message::AuthenticationSspi) => {
     154            0 :             return Err(Error::authentication(
     155            0 :                 "unsupported authentication method".into(),
     156            0 :             ));
     157              :         }
     158            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     159            0 :         Some(_) => return Err(Error::unexpected_message()),
     160            0 :         None => return Err(Error::closed()),
     161              :     }
     162              : 
     163            5 :     match stream.try_next().await.map_err(Error::io)? {
     164            5 :         Some(Message::AuthenticationOk) => Ok(()),
     165            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     166            0 :         Some(_) => Err(Error::unexpected_message()),
     167            0 :         None => Err(Error::closed()),
     168              :     }
     169            0 : }
     170              : 
     171            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     172            6 :     match config.channel_binding {
     173            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     174            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     175            1 :             "server did not use channel binding".into(),
     176            1 :         )),
     177              :     }
     178            6 : }
     179              : 
     180            0 : async fn authenticate_password<S, T>(
     181            0 :     stream: &mut StartupStream<S, T>,
     182            0 :     password: &[u8],
     183            0 : ) -> Result<(), Error>
     184            0 : where
     185            0 :     S: AsyncRead + AsyncWrite + Unpin,
     186            0 :     T: AsyncRead + AsyncWrite + Unpin,
     187            0 : {
     188            0 :     let mut buf = BytesMut::new();
     189            0 :     frontend::password_message(password, &mut buf).map_err(Error::encode)?;
     190              : 
     191            0 :     stream.send(buf.freeze()).await.map_err(Error::io)
     192            0 : }
     193              : 
     194           12 : async fn authenticate_sasl<S, T>(
     195           12 :     stream: &mut StartupStream<S, T>,
     196           12 :     body: AuthenticationSaslBody,
     197           12 :     config: &Config,
     198           12 : ) -> Result<(), Error>
     199           12 : where
     200           12 :     S: AsyncRead + AsyncWrite + Unpin,
     201           12 :     T: TlsStream + Unpin,
     202            0 : {
     203           12 :     let mut has_scram = false;
     204           12 :     let mut has_scram_plus = false;
     205           12 :     let mut mechanisms = body.mechanisms();
     206           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     207           22 :         match mechanism {
     208           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     209           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     210            0 :             _ => {}
     211              :         }
     212              :     }
     213              : 
     214           12 :     let channel_binding = stream
     215           12 :         .inner
     216           12 :         .get_ref()
     217           12 :         .channel_binding()
     218           12 :         .tls_server_end_point
     219           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     220           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     221              : 
     222           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     223           10 :         match channel_binding {
     224            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     225            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     226              :         }
     227            2 :     } else if has_scram {
     228            2 :         match channel_binding {
     229            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     230            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     231              :         }
     232              :     } else {
     233            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     234              :     };
     235              : 
     236           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     237            4 :         can_skip_channel_binding(config)?;
     238            0 :     }
     239              : 
     240           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     241            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     242           11 :     } else if let Some(password) = config.get_password() {
     243           11 :         ScramSha256::new(password, channel_binding)
     244              :     } else {
     245            0 :         return Err(Error::config("password or auth keys missing".into()));
     246              :     };
     247              : 
     248           11 :     let mut buf = BytesMut::new();
     249           11 :     frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
     250           11 :     stream.send(buf.freeze()).await.map_err(Error::io)?;
     251              : 
     252           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     253           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     254            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     255            0 :         Some(_) => return Err(Error::unexpected_message()),
     256            0 :         None => return Err(Error::closed()),
     257              :     };
     258              : 
     259           10 :     scram
     260           10 :         .update(body.data())
     261           10 :         .await
     262           10 :         .map_err(|e| Error::authentication(e.into()))?;
     263              : 
     264           10 :     let mut buf = BytesMut::new();
     265           10 :     frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
     266           10 :     stream.send(buf.freeze()).await.map_err(Error::io)?;
     267              : 
     268           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     269            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     270            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     271            0 :         Some(_) => return Err(Error::unexpected_message()),
     272            0 :         None => return Err(Error::closed()),
     273              :     };
     274              : 
     275            5 :     scram
     276            5 :         .finish(body.data())
     277            5 :         .map_err(|e| Error::authentication(e.into()))?;
     278              : 
     279            5 :     Ok(())
     280            0 : }
     281              : 
     282            7 : async fn read_info<S, T>(
     283            7 :     stream: &mut StartupStream<S, T>,
     284            7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
     285            7 : where
     286            7 :     S: AsyncRead + AsyncWrite + Unpin,
     287            7 :     T: AsyncRead + AsyncWrite + Unpin,
     288            0 : {
     289            7 :     let mut process_id = 0;
     290            7 :     let mut secret_key = 0;
     291            7 :     let mut parameters = HashMap::new();
     292              : 
     293              :     loop {
     294           14 :         match stream.try_next().await.map_err(Error::io)? {
     295            0 :             Some(Message::BackendKeyData(body)) => {
     296            0 :                 process_id = body.process_id();
     297            0 :                 secret_key = body.secret_key();
     298            0 :             }
     299            7 :             Some(Message::ParameterStatus(body)) => {
     300            7 :                 parameters.insert(
     301            7 :                     body.name().map_err(Error::parse)?.to_string(),
     302            7 :                     body.value().map_err(Error::parse)?.to_string(),
     303              :                 );
     304              :             }
     305            0 :             Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
     306            7 :             Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
     307            0 :             Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     308            0 :             Some(_) => return Err(Error::unexpected_message()),
     309            0 :             None => return Err(Error::closed()),
     310              :         }
     311              :     }
     312            0 : }
        

Generated by: LCOV version 2.1-beta