LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: 915229b2d22dd355ad718d9afbb773e7f2fba970.info Lines: 61.7 % 196 121
Test Date: 2025-07-24 10:33:41 Functions: 25.3 % 95 24

            Line data    Source code
       1              : use std::io;
       2              : use std::pin::Pin;
       3              : use std::task::{Context, Poll, ready};
       4              : 
       5              : use bytes::{Bytes, BytesMut};
       6              : use fallible_iterator::FallibleIterator;
       7              : use futures_util::{Sink, SinkExt, Stream, TryStreamExt};
       8              : use postgres_protocol2::authentication::sasl;
       9              : use postgres_protocol2::authentication::sasl::ScramSha256;
      10              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
      11              : use postgres_protocol2::message::frontend;
      12              : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
      13              : use tokio_util::codec::{Framed, FramedParts, FramedWrite};
      14              : 
      15              : use crate::Error;
      16              : use crate::codec::PostgresCodec;
      17              : use crate::config::{self, AuthKeys, Config};
      18              : use crate::maybe_tls_stream::MaybeTlsStream;
      19              : use crate::tls::TlsStream;
      20              : 
      21              : pub struct StartupStream<S, T> {
      22              :     inner: FramedWrite<MaybeTlsStream<S, T>, PostgresCodec>,
      23              :     read_buf: BytesMut,
      24              : }
      25              : 
      26              : impl<S, T> Sink<Bytes> for StartupStream<S, T>
      27              : where
      28              :     S: AsyncRead + AsyncWrite + Unpin,
      29              :     T: AsyncRead + AsyncWrite + Unpin,
      30              : {
      31              :     type Error = io::Error;
      32              : 
      33           36 :     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      34           36 :         Pin::new(&mut self.inner).poll_ready(cx)
      35            0 :     }
      36              : 
      37           36 :     fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
      38           36 :         Pin::new(&mut self.inner).start_send(item)
      39            0 :     }
      40              : 
      41           36 :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      42           36 :         Pin::new(&mut self.inner).poll_flush(cx)
      43            0 :     }
      44              : 
      45            0 :     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      46            0 :         Pin::new(&mut self.inner).poll_close(cx)
      47            0 :     }
      48              : }
      49              : 
      50              : impl<S, T> Stream for StartupStream<S, T>
      51              : where
      52              :     S: AsyncRead + AsyncWrite + Unpin,
      53              :     T: AsyncRead + AsyncWrite + Unpin,
      54              : {
      55              :     type Item = io::Result<Message>;
      56              : 
      57           77 :     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
      58              :         // read 1 byte tag, 4 bytes length.
      59           77 :         let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
      60              : 
      61           35 :         let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
      62           35 :         if len < 4 {
      63            0 :             return Poll::Ready(Some(Err(std::io::Error::other(
      64            0 :                 "postgres message too small",
      65            0 :             ))));
      66            0 :         }
      67           35 :         if len >= 65536 {
      68            0 :             return Poll::Ready(Some(Err(std::io::Error::other(
      69            0 :                 "postgres message too large",
      70            0 :             ))));
      71            0 :         }
      72              : 
      73              :         // the tag is an additional byte.
      74           35 :         let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
      75              : 
      76              :         // Message::parse will remove the all the bytes from the buffer.
      77           35 :         Poll::Ready(Message::parse(&mut self.read_buf).transpose())
      78            0 :     }
      79              : }
      80              : 
      81              : impl<S, T> StartupStream<S, T>
      82              : where
      83              :     S: AsyncRead + AsyncWrite + Unpin,
      84              :     T: AsyncRead + AsyncWrite + Unpin,
      85              : {
      86              :     /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
      87              :     ///
      88              :     /// If the current buffer length is greater, nothing happens.
      89          112 :     fn poll_fill_buf_exact(
      90          112 :         self: Pin<&mut Self>,
      91          112 :         cx: &mut Context<'_>,
      92          112 :         len: usize,
      93          112 :     ) -> Poll<Result<&[u8], std::io::Error>> {
      94          112 :         let this = self.get_mut();
      95          112 :         let mut stream = Pin::new(this.inner.get_mut());
      96              : 
      97          112 :         let mut n = this.read_buf.len();
      98          182 :         while n < len {
      99          112 :             this.read_buf.resize(len, 0);
     100              : 
     101          112 :             let mut buf = ReadBuf::new(&mut this.read_buf[..]);
     102          112 :             buf.set_filled(n);
     103              : 
     104          112 :             if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
     105           36 :                 this.read_buf.truncate(n);
     106           36 :                 return Poll::Pending;
     107            0 :             }
     108              : 
     109           70 :             if buf.filled().len() == n {
     110            0 :                 return Poll::Ready(Err(std::io::Error::new(
     111            0 :                     std::io::ErrorKind::UnexpectedEof,
     112            0 :                     "early eof",
     113            0 :                 )));
     114            0 :             }
     115           70 :             n = buf.filled().len();
     116              : 
     117           70 :             this.read_buf.truncate(n);
     118              :         }
     119              : 
     120           70 :         Poll::Ready(Ok(&this.read_buf[..len]))
     121            0 :     }
     122              : 
     123            0 :     pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
     124            0 :         let write_buf = std::mem::take(self.inner.write_buffer_mut());
     125            0 :         let io = self.inner.into_inner();
     126            0 :         let mut parts = FramedParts::new(io, PostgresCodec);
     127            0 :         parts.read_buf = self.read_buf;
     128            0 :         parts.write_buf = write_buf;
     129            0 :         Framed::from_parts(parts)
     130            0 :     }
     131              : 
     132           15 :     pub fn new(io: MaybeTlsStream<S, T>) -> Self {
     133           15 :         Self {
     134           15 :             inner: FramedWrite::new(io, PostgresCodec),
     135           15 :             read_buf: BytesMut::new(),
     136           15 :         }
     137            0 :     }
     138              : }
     139              : 
     140           15 : pub(crate) async fn startup<S, T>(
     141           15 :     stream: &mut StartupStream<S, T>,
     142           15 :     config: &Config,
     143           15 : ) -> Result<(), Error>
     144           15 : where
     145           15 :     S: AsyncRead + AsyncWrite + Unpin,
     146           15 :     T: AsyncRead + AsyncWrite + Unpin,
     147            0 : {
     148           15 :     let mut buf = BytesMut::new();
     149           15 :     frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
     150              : 
     151           15 :     stream.send(buf.freeze()).await.map_err(Error::io)
     152            0 : }
     153              : 
     154           15 : pub(crate) async fn authenticate<S, T>(
     155           15 :     stream: &mut StartupStream<S, T>,
     156           15 :     config: &Config,
     157           15 : ) -> Result<(), Error>
     158           15 : where
     159           15 :     S: AsyncRead + AsyncWrite + Unpin,
     160           15 :     T: TlsStream + Unpin,
     161            0 : {
     162           15 :     match stream.try_next().await.map_err(Error::io)? {
     163              :         Some(Message::AuthenticationOk) => {
     164            2 :             can_skip_channel_binding(config)?;
     165            2 :             return Ok(());
     166              :         }
     167              :         Some(Message::AuthenticationCleartextPassword) => {
     168            0 :             can_skip_channel_binding(config)?;
     169              : 
     170            0 :             let pass = config
     171            0 :                 .password
     172            0 :                 .as_ref()
     173            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     174              : 
     175            0 :             authenticate_password(stream, pass).await?;
     176              :         }
     177           12 :         Some(Message::AuthenticationSasl(body)) => {
     178           12 :             authenticate_sasl(stream, body, config).await?;
     179              :         }
     180              :         Some(Message::AuthenticationMd5Password)
     181              :         | Some(Message::AuthenticationKerberosV5)
     182              :         | Some(Message::AuthenticationScmCredential)
     183              :         | Some(Message::AuthenticationGss)
     184              :         | Some(Message::AuthenticationSspi) => {
     185            0 :             return Err(Error::authentication(
     186            0 :                 "unsupported authentication method".into(),
     187            0 :             ));
     188              :         }
     189            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     190            0 :         Some(_) => return Err(Error::unexpected_message()),
     191            0 :         None => return Err(Error::closed()),
     192              :     }
     193              : 
     194            5 :     match stream.try_next().await.map_err(Error::io)? {
     195            5 :         Some(Message::AuthenticationOk) => Ok(()),
     196            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     197            0 :         Some(_) => Err(Error::unexpected_message()),
     198            0 :         None => Err(Error::closed()),
     199              :     }
     200            0 : }
     201              : 
     202            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     203            6 :     match config.channel_binding {
     204            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     205            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     206            1 :             "server did not use channel binding".into(),
     207            1 :         )),
     208              :     }
     209            6 : }
     210              : 
     211            0 : async fn authenticate_password<S, T>(
     212            0 :     stream: &mut StartupStream<S, T>,
     213            0 :     password: &[u8],
     214            0 : ) -> Result<(), Error>
     215            0 : where
     216            0 :     S: AsyncRead + AsyncWrite + Unpin,
     217            0 :     T: AsyncRead + AsyncWrite + Unpin,
     218            0 : {
     219            0 :     let mut buf = BytesMut::new();
     220            0 :     frontend::password_message(password, &mut buf).map_err(Error::encode)?;
     221              : 
     222            0 :     stream.send(buf.freeze()).await.map_err(Error::io)
     223            0 : }
     224              : 
     225           12 : async fn authenticate_sasl<S, T>(
     226           12 :     stream: &mut StartupStream<S, T>,
     227           12 :     body: AuthenticationSaslBody,
     228           12 :     config: &Config,
     229           12 : ) -> Result<(), Error>
     230           12 : where
     231           12 :     S: AsyncRead + AsyncWrite + Unpin,
     232           12 :     T: TlsStream + Unpin,
     233            0 : {
     234           12 :     let mut has_scram = false;
     235           12 :     let mut has_scram_plus = false;
     236           12 :     let mut mechanisms = body.mechanisms();
     237           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     238           22 :         match mechanism {
     239           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     240           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     241            0 :             _ => {}
     242              :         }
     243              :     }
     244              : 
     245           12 :     let channel_binding = stream
     246           12 :         .inner
     247           12 :         .get_ref()
     248           12 :         .channel_binding()
     249           12 :         .tls_server_end_point
     250           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     251           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     252              : 
     253           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     254           10 :         match channel_binding {
     255            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     256            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     257              :         }
     258            2 :     } else if has_scram {
     259            2 :         match channel_binding {
     260            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     261            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     262              :         }
     263              :     } else {
     264            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     265              :     };
     266              : 
     267           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     268            4 :         can_skip_channel_binding(config)?;
     269            0 :     }
     270              : 
     271           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     272            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     273           11 :     } else if let Some(password) = config.get_password() {
     274           11 :         ScramSha256::new(password, channel_binding)
     275              :     } else {
     276            0 :         return Err(Error::config("password or auth keys missing".into()));
     277              :     };
     278              : 
     279           11 :     let mut buf = BytesMut::new();
     280           11 :     frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
     281           11 :     stream.send(buf.freeze()).await.map_err(Error::io)?;
     282              : 
     283           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     284           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     285            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     286            0 :         Some(_) => return Err(Error::unexpected_message()),
     287            0 :         None => return Err(Error::closed()),
     288              :     };
     289              : 
     290           10 :     scram
     291           10 :         .update(body.data())
     292           10 :         .await
     293           10 :         .map_err(|e| Error::authentication(e.into()))?;
     294              : 
     295           10 :     let mut buf = BytesMut::new();
     296           10 :     frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
     297           10 :     stream.send(buf.freeze()).await.map_err(Error::io)?;
     298              : 
     299           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     300            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     301            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     302            0 :         Some(_) => return Err(Error::unexpected_message()),
     303            0 :         None => return Err(Error::closed()),
     304              :     };
     305              : 
     306            5 :     scram
     307            5 :         .finish(body.data())
     308            5 :         .map_err(|e| Error::authentication(e.into()))?;
     309              : 
     310            5 :     Ok(())
     311            0 : }
        

Generated by: LCOV version 2.1-beta