LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connect_raw.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 75.0 % 212 159
Test Date: 2025-02-20 13:11:02 Functions: 27.5 % 102 28

            Line data    Source code
       1              : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
       2              : use crate::config::{self, AuthKeys, Config};
       3              : use crate::connect_tls::connect_tls;
       4              : use crate::maybe_tls_stream::MaybeTlsStream;
       5              : use crate::tls::{TlsConnect, TlsStream};
       6              : use crate::Error;
       7              : use bytes::BytesMut;
       8              : use fallible_iterator::FallibleIterator;
       9              : use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
      10              : use postgres_protocol2::authentication::sasl;
      11              : use postgres_protocol2::authentication::sasl::ScramSha256;
      12              : use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
      13              : use postgres_protocol2::message::frontend;
      14              : use std::collections::HashMap;
      15              : use std::io;
      16              : use std::pin::Pin;
      17              : use std::task::{Context, Poll};
      18              : use tokio::io::{AsyncRead, AsyncWrite};
      19              : use tokio_util::codec::Framed;
      20              : 
      21              : pub struct StartupStream<S, T> {
      22              :     inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      23              :     buf: BackendMessages,
      24              :     delayed_notice: Vec<NoticeResponseBody>,
      25              : }
      26              : 
      27              : impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
      28              : where
      29              :     S: AsyncRead + AsyncWrite + Unpin,
      30              :     T: AsyncRead + AsyncWrite + Unpin,
      31              : {
      32              :     type Error = io::Error;
      33              : 
      34           36 :     fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      35           36 :         Pin::new(&mut self.inner).poll_ready(cx)
      36           36 :     }
      37              : 
      38           36 :     fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
      39           36 :         Pin::new(&mut self.inner).start_send(item)
      40           36 :     }
      41              : 
      42           36 :     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      43           36 :         Pin::new(&mut self.inner).poll_flush(cx)
      44           36 :     }
      45              : 
      46            0 :     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
      47            0 :         Pin::new(&mut self.inner).poll_close(cx)
      48            0 :     }
      49              : }
      50              : 
      51              : impl<S, T> Stream for StartupStream<S, T>
      52              : where
      53              :     S: AsyncRead + AsyncWrite + Unpin,
      54              :     T: AsyncRead + AsyncWrite + Unpin,
      55              : {
      56              :     type Item = io::Result<Message>;
      57              : 
      58           91 :     fn poll_next(
      59           91 :         mut self: Pin<&mut Self>,
      60           91 :         cx: &mut Context<'_>,
      61           91 :     ) -> Poll<Option<io::Result<Message>>> {
      62              :         loop {
      63          128 :             match self.buf.next() {
      64           42 :                 Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
      65           86 :                 Ok(None) => {}
      66            0 :                 Err(e) => return Poll::Ready(Some(Err(e))),
      67              :             }
      68              : 
      69           86 :             match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
      70           37 :                 Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
      71            7 :                 Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
      72            6 :                 Some(Err(e)) => return Poll::Ready(Some(Err(e))),
      73            0 :                 None => return Poll::Ready(None),
      74              :             }
      75              :         }
      76           91 :     }
      77              : }
      78              : 
      79              : pub struct RawConnection<S, T> {
      80              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      81              :     pub parameters: HashMap<String, String>,
      82              :     pub delayed_notice: Vec<NoticeResponseBody>,
      83              :     pub process_id: i32,
      84              :     pub secret_key: i32,
      85              : }
      86              : 
      87           15 : pub async fn connect_raw<S, T>(
      88           15 :     stream: S,
      89           15 :     tls: T,
      90           15 :     config: &Config,
      91           15 : ) -> Result<RawConnection<S, T::Stream>, Error>
      92           15 : where
      93           15 :     S: AsyncRead + AsyncWrite + Unpin,
      94           15 :     T: TlsConnect<S>,
      95           15 : {
      96           15 :     let stream = connect_tls(stream, config.ssl_mode, tls).await?;
      97              : 
      98           15 :     let mut stream = StartupStream {
      99           15 :         inner: Framed::new(stream, PostgresCodec),
     100           15 :         buf: BackendMessages::empty(),
     101           15 :         delayed_notice: Vec::new(),
     102           15 :     };
     103           15 : 
     104           15 :     startup(&mut stream, config).await?;
     105           15 :     authenticate(&mut stream, config).await?;
     106            7 :     let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
     107              : 
     108            7 :     Ok(RawConnection {
     109            7 :         stream: stream.inner,
     110            7 :         parameters,
     111            7 :         delayed_notice: stream.delayed_notice,
     112            7 :         process_id,
     113            7 :         secret_key,
     114            7 :     })
     115           15 : }
     116              : 
     117           15 : async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     118           15 : where
     119           15 :     S: AsyncRead + AsyncWrite + Unpin,
     120           15 :     T: AsyncRead + AsyncWrite + Unpin,
     121           15 : {
     122           15 :     let mut buf = BytesMut::new();
     123           15 :     frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
     124              : 
     125           15 :     stream
     126           15 :         .send(FrontendMessage::Raw(buf.freeze()))
     127           15 :         .await
     128           15 :         .map_err(Error::io)
     129           15 : }
     130              : 
     131           15 : async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
     132           15 : where
     133           15 :     S: AsyncRead + AsyncWrite + Unpin,
     134           15 :     T: TlsStream + Unpin,
     135           15 : {
     136           15 :     match stream.try_next().await.map_err(Error::io)? {
     137              :         Some(Message::AuthenticationOk) => {
     138            2 :             can_skip_channel_binding(config)?;
     139            2 :             return Ok(());
     140              :         }
     141              :         Some(Message::AuthenticationCleartextPassword) => {
     142            0 :             can_skip_channel_binding(config)?;
     143              : 
     144            0 :             let pass = config
     145            0 :                 .password
     146            0 :                 .as_ref()
     147            0 :                 .ok_or_else(|| Error::config("password missing".into()))?;
     148              : 
     149            0 :             authenticate_password(stream, pass).await?;
     150              :         }
     151           12 :         Some(Message::AuthenticationSasl(body)) => {
     152           12 :             authenticate_sasl(stream, body, config).await?;
     153              :         }
     154              :         Some(Message::AuthenticationMd5Password)
     155              :         | Some(Message::AuthenticationKerberosV5)
     156              :         | Some(Message::AuthenticationScmCredential)
     157              :         | Some(Message::AuthenticationGss)
     158              :         | Some(Message::AuthenticationSspi) => {
     159            0 :             return Err(Error::authentication(
     160            0 :                 "unsupported authentication method".into(),
     161            0 :             ))
     162              :         }
     163            1 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     164            0 :         Some(_) => return Err(Error::unexpected_message()),
     165            0 :         None => return Err(Error::closed()),
     166              :     }
     167              : 
     168            5 :     match stream.try_next().await.map_err(Error::io)? {
     169            5 :         Some(Message::AuthenticationOk) => Ok(()),
     170            0 :         Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
     171            0 :         Some(_) => Err(Error::unexpected_message()),
     172            0 :         None => Err(Error::closed()),
     173              :     }
     174           15 : }
     175              : 
     176            6 : fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
     177            6 :     match config.channel_binding {
     178            5 :         config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
     179            1 :         config::ChannelBinding::Require => Err(Error::authentication(
     180            1 :             "server did not use channel binding".into(),
     181            1 :         )),
     182              :     }
     183            6 : }
     184              : 
     185            0 : async fn authenticate_password<S, T>(
     186            0 :     stream: &mut StartupStream<S, T>,
     187            0 :     password: &[u8],
     188            0 : ) -> Result<(), Error>
     189            0 : where
     190            0 :     S: AsyncRead + AsyncWrite + Unpin,
     191            0 :     T: AsyncRead + AsyncWrite + Unpin,
     192            0 : {
     193            0 :     let mut buf = BytesMut::new();
     194            0 :     frontend::password_message(password, &mut buf).map_err(Error::encode)?;
     195              : 
     196            0 :     stream
     197            0 :         .send(FrontendMessage::Raw(buf.freeze()))
     198            0 :         .await
     199            0 :         .map_err(Error::io)
     200            0 : }
     201              : 
     202           12 : async fn authenticate_sasl<S, T>(
     203           12 :     stream: &mut StartupStream<S, T>,
     204           12 :     body: AuthenticationSaslBody,
     205           12 :     config: &Config,
     206           12 : ) -> Result<(), Error>
     207           12 : where
     208           12 :     S: AsyncRead + AsyncWrite + Unpin,
     209           12 :     T: TlsStream + Unpin,
     210           12 : {
     211           12 :     let mut has_scram = false;
     212           12 :     let mut has_scram_plus = false;
     213           12 :     let mut mechanisms = body.mechanisms();
     214           34 :     while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
     215           22 :         match mechanism {
     216           22 :             sasl::SCRAM_SHA_256 => has_scram = true,
     217           10 :             sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
     218            0 :             _ => {}
     219              :         }
     220              :     }
     221              : 
     222           12 :     let channel_binding = stream
     223           12 :         .inner
     224           12 :         .get_ref()
     225           12 :         .channel_binding()
     226           12 :         .tls_server_end_point
     227           12 :         .filter(|_| config.channel_binding != config::ChannelBinding::Disable)
     228           12 :         .map(sasl::ChannelBinding::tls_server_end_point);
     229              : 
     230           12 :     let (channel_binding, mechanism) = if has_scram_plus {
     231           10 :         match channel_binding {
     232            8 :             Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
     233            2 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     234              :         }
     235            2 :     } else if has_scram {
     236            2 :         match channel_binding {
     237            2 :             Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
     238            0 :             None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
     239              :         }
     240              :     } else {
     241            0 :         return Err(Error::authentication("unsupported SASL mechanism".into()));
     242              :     };
     243              : 
     244           12 :     if mechanism != sasl::SCRAM_SHA_256_PLUS {
     245            4 :         can_skip_channel_binding(config)?;
     246            8 :     }
     247              : 
     248           11 :     let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
     249            0 :         ScramSha256::new_with_keys(keys, channel_binding)
     250           11 :     } else if let Some(password) = config.get_password() {
     251           11 :         ScramSha256::new(password, channel_binding)
     252              :     } else {
     253            0 :         return Err(Error::config("password or auth keys missing".into()));
     254              :     };
     255              : 
     256           11 :     let mut buf = BytesMut::new();
     257           11 :     frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
     258           11 :     stream
     259           11 :         .send(FrontendMessage::Raw(buf.freeze()))
     260           11 :         .await
     261           11 :         .map_err(Error::io)?;
     262              : 
     263           11 :     let body = match stream.try_next().await.map_err(Error::io)? {
     264           10 :         Some(Message::AuthenticationSaslContinue(body)) => body,
     265            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     266            0 :         Some(_) => return Err(Error::unexpected_message()),
     267            0 :         None => return Err(Error::closed()),
     268              :     };
     269              : 
     270           10 :     scram
     271           10 :         .update(body.data())
     272           10 :         .await
     273           10 :         .map_err(|e| Error::authentication(e.into()))?;
     274              : 
     275           10 :     let mut buf = BytesMut::new();
     276           10 :     frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
     277           10 :     stream
     278           10 :         .send(FrontendMessage::Raw(buf.freeze()))
     279           10 :         .await
     280           10 :         .map_err(Error::io)?;
     281              : 
     282           10 :     let body = match stream.try_next().await.map_err(Error::io)? {
     283            5 :         Some(Message::AuthenticationSaslFinal(body)) => body,
     284            0 :         Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     285            0 :         Some(_) => return Err(Error::unexpected_message()),
     286            0 :         None => return Err(Error::closed()),
     287              :     };
     288              : 
     289            5 :     scram
     290            5 :         .finish(body.data())
     291            5 :         .map_err(|e| Error::authentication(e.into()))?;
     292              : 
     293            5 :     Ok(())
     294           12 : }
     295              : 
     296            7 : async fn read_info<S, T>(
     297            7 :     stream: &mut StartupStream<S, T>,
     298            7 : ) -> Result<(i32, i32, HashMap<String, String>), Error>
     299            7 : where
     300            7 :     S: AsyncRead + AsyncWrite + Unpin,
     301            7 :     T: AsyncRead + AsyncWrite + Unpin,
     302            7 : {
     303            7 :     let mut process_id = 0;
     304            7 :     let mut secret_key = 0;
     305            7 :     let mut parameters = HashMap::new();
     306              : 
     307              :     loop {
     308           14 :         match stream.try_next().await.map_err(Error::io)? {
     309            0 :             Some(Message::BackendKeyData(body)) => {
     310            0 :                 process_id = body.process_id();
     311            0 :                 secret_key = body.secret_key();
     312            0 :             }
     313            7 :             Some(Message::ParameterStatus(body)) => {
     314            7 :                 parameters.insert(
     315            7 :                     body.name().map_err(Error::parse)?.to_string(),
     316            7 :                     body.value().map_err(Error::parse)?.to_string(),
     317              :                 );
     318              :             }
     319            0 :             Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
     320            7 :             Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
     321            0 :             Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
     322            0 :             Some(_) => return Err(Error::unexpected_message()),
     323            0 :             None => return Err(Error::closed()),
     324              :         }
     325              :     }
     326            7 : }
        

Generated by: LCOV version 2.1-beta