LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 68.5 % 165 113
Test Date: 2025-07-26 17:20:05 Functions: 23.7 % 59 14

            Line data    Source code
       1              : use std::io;
       2              : use std::pin::Pin;
       3              : use std::task::{Context, Poll, ready};
       4              : 
       5              : use bytes::BytesMut;
       6              : use fallible_iterator::FallibleIterator;
       7              : use futures_util::{SinkExt, Stream, TryStreamExt};
       8              : use postgres_protocol2::authentication::sasl;
       9              : use postgres_protocol2::authentication::sasl::ScramSha256;
      10              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
      11              : use postgres_protocol2::message::frontend;
      12              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      13              : use tokio_util::codec::{Framed, FramedParts};
      14              : 
      15              : use crate::Error;
      16              : use crate::codec::PostgresCodec;
      17              : use crate::config::{self, AuthKeys, Config};
      18              : use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY};
      19              : use crate::maybe_tls_stream::MaybeTlsStream;
      20              : use crate::tls::TlsStream;
      21              : 
      22              : pub struct StartupStream<S, T> {
      23              :     inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      24              :     read_buf: BytesMut,
      25              : }
      26              : 
      27              : impl<S, T> Stream for StartupStream<S, T>
      28              : where
      29              :     S: AsyncRead + AsyncWrite + Unpin,
      30              :     T: AsyncRead + AsyncWrite + Unpin,
      31              : {
      32              :     type Item = io::Result<Message>;
      33              : 
      34           77 :     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      35              :         // We don't use `self.inner.poll_next()` as that might over-read into the read buffer.
      36              : 
      37              :         // read 1 byte tag, 4 bytes length.
      38           77 :         let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
      39              : 
      40           35 :         let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
      41           35 :         if len < 4 {
      42            0 :             return Poll::Ready(Some(Err(std::io::Error::other(
      43            0 :                 "postgres message too small",
      44            0 :             ))));
      45            0 :         }
      46           35 :         if len >= 65536 {
      47            0 :             return Poll::Ready(Some(Err(std::io::Error::other(
      48            0 :                 "postgres message too large",
      49            0 :             ))));
      50            0 :         }
      51              : 
      52              :         // the tag is an additional byte.
      53           35 :         let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
      54              : 
      55              :         // Message::parse will remove the all the bytes from the buffer.
      56           35 :         Poll::Ready(Message::parse(&mut self.read_buf).transpose())
      57            0 :     }
      58              : }
      59              : 
      60              : impl<S, T> StartupStream<S, T>
      61              : where
      62              :     S: AsyncRead + AsyncWrite + Unpin,
      63              :     T: AsyncRead + AsyncWrite + Unpin,
      64              : {
      65              :     /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
      66              :     ///
      67              :     /// If the current buffer length is greater, nothing happens.
      68          112 :     fn poll_fill_buf_exact(
      69          112 :         self: Pin<&mut Self>,
      70          112 :         cx: &mut Context<'_>,
      71          112 :         len: usize,
      72          112 :     ) -> Poll<Result<&[u8], std::io::Error>> {
      73          112 :         let this = self.get_mut();
      74          112 :         let mut stream = Pin::new(this.inner.get_mut());
      75              : 
      76          112 :         let mut n = this.read_buf.len();
      77          182 :         while n < len {
      78          112 :             this.read_buf.resize(len, 0);
      79              : 
      80          112 :             let mut buf = ReadBuf::new(&mut this.read_buf[..]);
      81          112 :             buf.set_filled(n);
      82              : 
      83          112 :             if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
      84           36 :                 this.read_buf.truncate(n);
      85           36 :                 return Poll::Pending;
      86            0 :             }
      87              : 
      88           70 :             if buf.filled().len() == n {
      89            0 :                 return Poll::Ready(Err(std::io::Error::new(
      90            0 :                     std::io::ErrorKind::UnexpectedEof,
      91            0 :                     "early eof",
      92            0 :                 )));
      93            0 :             }
      94           70 :             n = buf.filled().len();
      95              : 
      96           70 :             this.read_buf.truncate(n);
      97              :         }
      98              : 
      99           70 :         Poll::Ready(Ok(&this.read_buf[..len]))
     100            0 :     }
     101              : 
     102            0 :     pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
     103            0 :         *self.inner.read_buffer_mut() = self.read_buf;
     104            0 :         self.inner
     105            0 :     }
     106              : 
     107           15 :     pub fn new(io: MaybeTlsStream<S, T>) -> Self {
     108           15 :         let mut parts = FramedParts::new(io, PostgresCodec);
     109           15 :         parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY);
     110              : 
     111           15 :         let mut inner = Framed::from_parts(parts);
     112              : 
     113              :         // This is the default already, but nice to be explicit.
     114              :         // We divide by two because writes will overshoot the boundary.
     115              :         // We don't want constant overshoots to cause us to constantly re-shrink the buffer.
     116           15 :         inner.set_backpressure_boundary(GC_THRESHOLD / 2);
     117              : 
     118           15 :         Self {
     119           15 :             inner,
     120           15 :             read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
     121           15 :         }
     122            0 :     }
     123              : }
     124              : 
     125           15 : pub(crate) async fn authenticate<S, T>(
     126           15 :     stream: &mut StartupStream<S, T>,
     127           15 :     config: &Config,
     128           15 : ) -> Result<(), Error>
     129           15 : where
     130           15 :     S: AsyncRead + AsyncWrite + Unpin,
     131           15 :     T: TlsStream + Unpin,
     132            0 : {
     133           15 :     frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut())
     134           15 :         .map_err(Error::encode)?;
     135              : 
     136           15 :     stream.inner.flush().await.map_err(Error::io)?;
     137           15 :     match stream.try_next().await.map_err(Error::io)? {
     138              :         Some(Message::AuthenticationOk) => {
     139            2 :             can_skip_channel_binding(config)?;
     140            2 :             return Ok(());
     141              :         }
     142              :         Some(Message::AuthenticationCleartextPassword) => {
     143            0 :             can_skip_channel_binding(config)?;
     144              : 
     145            0 :             let pass = config
     146            0 :                 .password
     147            0 :                 .as_ref()
     148            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     149              : 
     150            0 :             frontend::password_message(pass, stream.inner.write_buffer_mut())
     151            0 :                 .map_err(Error::encode)?;
     152              :         }
     153           12 :         Some(Message::AuthenticationSasl(body)) => {
     154           12 :             authenticate_sasl(stream, body, config).await?;
     155              :         }
     156              :         Some(Message::AuthenticationMd5Password)
     157              :         | Some(Message::AuthenticationKerberosV5)
     158              :         | Some(Message::AuthenticationScmCredential)
     159              :         | Some(Message::AuthenticationGss)
     160              :         | Some(Message::AuthenticationSspi) => {
     161            0 :             return Err(Error::authentication(
     162            0 :                 "unsupported authentication method".into(),
     163            0 :             ));
     164              :         }
     165            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     166            0 :         Some(_) => return Err(Error::unexpected_message()),
     167            0 :         None => return Err(Error::closed()),
     168              :     }
     169              : 
     170            5 :     stream.inner.flush().await.map_err(Error::io)?;
     171            5 :     match stream.try_next().await.map_err(Error::io)? {
     172            5 :         Some(Message::AuthenticationOk) => Ok(()),
     173            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     174            0 :         Some(_) => Err(Error::unexpected_message()),
     175            0 :         None => Err(Error::closed()),
     176              :     }
     177            0 : }
     178              : 
     179            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     180            6 :     match config.channel_binding {
     181            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     182            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     183            1 :             "server did not use channel binding".into(),
     184            1 :         )),
     185              :     }
     186            6 : }
     187              : 
     188           12 : async fn authenticate_sasl<S, T>(
     189           12 :     stream: &mut StartupStream<S, T>,
     190           12 :     body: AuthenticationSaslBody,
     191           12 :     config: &Config,
     192           12 : ) -> Result<(), Error>
     193           12 : where
     194           12 :     S: AsyncRead + AsyncWrite + Unpin,
     195           12 :     T: TlsStream + Unpin,
     196            0 : {
     197           12 :     let mut has_scram = false;
     198           12 :     let mut has_scram_plus = false;
     199           12 :     let mut mechanisms = body.mechanisms();
     200           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     201           22 :         match mechanism {
     202           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     203           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     204            0 :             _ => {}
     205              :         }
     206              :     }
     207              : 
     208           12 :     let channel_binding = stream
     209           12 :         .inner
     210           12 :         .get_ref()
     211           12 :         .channel_binding()
     212           12 :         .tls_server_end_point
     213           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     214           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     215              : 
     216           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     217           10 :         match channel_binding {
     218            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     219            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     220              :         }
     221            2 :     } else if has_scram {
     222            2 :         match channel_binding {
     223            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     224            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     225              :         }
     226              :     } else {
     227            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     228              :     };
     229              : 
     230           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     231            4 :         can_skip_channel_binding(config)?;
     232            0 :     }
     233              : 
     234           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     235            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     236           11 :     } else if let Some(password) = config.get_password() {
     237           11 :         ScramSha256::new(password, channel_binding)
     238              :     } else {
     239            0 :         return Err(Error::config("password or auth keys missing".into()));
     240              :     };
     241              : 
     242           11 :     frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut())
     243           11 :         .map_err(Error::encode)?;
     244              : 
     245           11 :     stream.inner.flush().await.map_err(Error::io)?;
     246           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     247           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     248            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     249            0 :         Some(_) => return Err(Error::unexpected_message()),
     250            0 :         None => return Err(Error::closed()),
     251              :     };
     252              : 
     253           10 :     scram
     254           10 :         .update(body.data())
     255           10 :         .await
     256           10 :         .map_err(|e| Error::authentication(e.into()))?;
     257              : 
     258           10 :     frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut())
     259           10 :         .map_err(Error::encode)?;
     260              : 
     261           10 :     stream.inner.flush().await.map_err(Error::io)?;
     262           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     263            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     264            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     265            0 :         Some(_) => return Err(Error::unexpected_message()),
     266            0 :         None => return Err(Error::closed()),
     267              :     };
     268              : 
     269            5 :     scram
     270            5 :         .finish(body.data())
     271            5 :         .map_err(|e| Error::authentication(e.into()))?;
     272              : 
     273            5 :     Ok(())
     274            0 : }
        

Generated by: LCOV version 2.1-beta