LCOV - differential code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 94.8 % 134 127 7 127
Current Date: 2024-01-09 02:06:09 Functions: 72.7 % 11 8 3 8
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Implementation of the SCRAM authentication algorithm.
       2                 : 
       3                 : use std::convert::Infallible;
       4                 : 
       5                 : use postgres_protocol::authentication::sasl::ScramSha256;
       6                 : 
       7                 : use super::messages::{
       8                 :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
       9                 : };
      10                 : use super::secret::ServerSecret;
      11                 : use super::signature::SignatureBuilder;
      12                 : use crate::config;
      13                 : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      14                 : 
      15                 : /// The only channel binding mode we currently support.
      16 UBC           0 : #[derive(Debug)]
      17                 : struct TlsServerEndPoint;
      18                 : 
      19                 : impl std::fmt::Display for TlsServerEndPoint {
      20 CBC          42 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      21              42 :         write!(f, "tls-server-end-point")
      22              42 :     }
      23                 : }
      24                 : 
      25                 : impl std::str::FromStr for TlsServerEndPoint {
      26                 :     type Err = sasl::Error;
      27                 : 
      28              42 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      29              42 :         match s {
      30              42 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      31 UBC           0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      32                 :         }
      33 CBC          42 :     }
      34                 : }
      35                 : 
      36                 : struct SaslSentInner {
      37                 :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      38                 :     client_first_message_bare: String,
      39                 :     server_first_message: OwnedServerFirstMessage,
      40                 : }
      41                 : 
      42                 : struct SaslInitial {
      43                 :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      44                 : }
      45                 : 
      46                 : enum ExchangeState {
      47                 :     /// Waiting for [`ClientFirstMessage`].
      48                 :     Initial(SaslInitial),
      49                 :     /// Waiting for [`ClientFinalMessage`].
      50                 :     SaltSent(SaslSentInner),
      51                 : }
      52                 : 
      53                 : /// Server's side of SCRAM auth algorithm.
      54                 : pub struct Exchange<'a> {
      55                 :     state: ExchangeState,
      56                 :     secret: &'a ServerSecret,
      57                 :     tls_server_end_point: config::TlsServerEndPoint,
      58                 : }
      59                 : 
      60                 : impl<'a> Exchange<'a> {
      61              48 :     pub fn new(
      62              48 :         secret: &'a ServerSecret,
      63              48 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      64              48 :         tls_server_end_point: config::TlsServerEndPoint,
      65              48 :     ) -> Self {
      66              48 :         Self {
      67              48 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      68              48 :             secret,
      69              48 :             tls_server_end_point,
      70              48 :         }
      71              48 :     }
      72                 : }
      73                 : 
      74               2 : pub fn exchange(
      75               2 :     secret: &ServerSecret,
      76               2 :     mut client: ScramSha256,
      77               2 :     tls_server_end_point: config::TlsServerEndPoint,
      78               2 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
      79               2 :     use sasl::Step::*;
      80               2 : 
      81               2 :     let init = SaslInitial {
      82               2 :         nonce: rand::random,
      83               2 :     };
      84                 : 
      85               2 :     let client_first = std::str::from_utf8(client.message())
      86               2 :         .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
      87               2 :     let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
      88               2 :         Continue(sent, server_first) => {
      89               2 :             client.update(server_first.as_bytes())?;
      90               2 :             sent
      91                 :         }
      92                 :         Success(x, _) => match x {},
      93 UBC           0 :         Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
      94                 :     };
      95                 : 
      96 CBC           2 :     let client_final = std::str::from_utf8(client.message())
      97               2 :         .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
      98               2 :     let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
      99               2 :         Success(keys, server_final) => {
     100               2 :             client.finish(server_final.as_bytes())?;
     101               2 :             keys
     102                 :         }
     103                 :         Continue(x, _) => match x {},
     104 UBC           0 :         Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
     105                 :     };
     106                 : 
     107 CBC           2 :     Ok(sasl::Outcome::Success(keys))
     108               2 : }
     109                 : 
     110                 : impl SaslInitial {
     111              50 :     fn transition(
     112              50 :         &self,
     113              50 :         secret: &ServerSecret,
     114              50 :         tls_server_end_point: &config::TlsServerEndPoint,
     115              50 :         input: &str,
     116              50 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     117              50 :         let client_first_message = ClientFirstMessage::parse(input)
     118              50 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     119                 : 
     120                 :         // If the flag is set to "y" and the server supports channel
     121                 :         // binding, the server MUST fail authentication
     122              50 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     123               1 :             && tls_server_end_point.supported()
     124                 :         {
     125               1 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     126              49 :         }
     127              49 : 
     128              49 :         let server_first_message = client_first_message.build_server_first_message(
     129              49 :             &(self.nonce)(),
     130              49 :             &secret.salt_base64,
     131              49 :             secret.iterations,
     132              49 :         );
     133              49 :         let msg = server_first_message.as_str().to_owned();
     134                 : 
     135              49 :         let next = SaslSentInner {
     136              49 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     137              49 :             client_first_message_bare: client_first_message.bare.to_owned(),
     138              49 :             server_first_message,
     139              49 :         };
     140              49 : 
     141              49 :         Ok(sasl::Step::Continue(next, msg))
     142              50 :     }
     143                 : }
     144                 : 
     145                 : impl SaslSentInner {
     146              49 :     fn transition(
     147              49 :         &self,
     148              49 :         secret: &ServerSecret,
     149              49 :         tls_server_end_point: &config::TlsServerEndPoint,
     150              49 :         input: &str,
     151              49 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     152              49 :         let Self {
     153              49 :             cbind_flag,
     154              49 :             client_first_message_bare,
     155              49 :             server_first_message,
     156              49 :         } = self;
     157                 : 
     158              49 :         let client_final_message = ClientFinalMessage::parse(input)
     159              49 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     160                 : 
     161              49 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     162              42 :             config::TlsServerEndPoint::Sha256(x) => Ok(x),
     163 UBC           0 :             config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     164 CBC          49 :         })?;
     165                 : 
     166                 :         // This might've been caused by a MITM attack
     167              49 :         if client_final_message.channel_binding != channel_binding {
     168               4 :             return Err(SaslError::ChannelBindingFailed(
     169               4 :                 "insecure connection: secure channel data mismatch",
     170               4 :             ));
     171              45 :         }
     172              45 : 
     173              45 :         if client_final_message.nonce != server_first_message.nonce() {
     174 UBC           0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     175 CBC          45 :         }
     176              45 : 
     177              45 :         let signature_builder = SignatureBuilder {
     178              45 :             client_first_message_bare,
     179              45 :             server_first_message: server_first_message.as_str(),
     180              45 :             client_final_message_without_proof: client_final_message.without_proof,
     181              45 :         };
     182              45 : 
     183              45 :         let client_key = signature_builder
     184              45 :             .build(&secret.stored_key)
     185              45 :             .derive_client_key(&client_final_message.proof);
     186              45 : 
     187              45 :         // Auth fails either if keys don't match or it's pre-determined to fail.
     188              45 :         if client_key.sha256() != secret.stored_key || secret.doomed {
     189               4 :             return Ok(sasl::Step::Failure("password doesn't match"));
     190              41 :         }
     191              41 : 
     192              41 :         let msg =
     193              41 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     194              41 : 
     195              41 :         Ok(sasl::Step::Success(client_key, msg))
     196              49 :     }
     197                 : }
     198                 : 
     199                 : impl sasl::Mechanism for Exchange<'_> {
     200                 :     type Output = super::ScramKey;
     201                 : 
     202              95 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     203              95 :         use {sasl::Step::*, ExchangeState::*};
     204              95 :         match &self.state {
     205              48 :             Initial(init) => {
     206              48 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     207              47 :                     Continue(sent, msg) => {
     208              47 :                         self.state = SaltSent(sent);
     209              47 :                         Ok(Continue(self, msg))
     210                 :                     }
     211                 :                     Success(x, _) => match x {},
     212 UBC           0 :                     Failure(msg) => Ok(Failure(msg)),
     213                 :                 }
     214                 :             }
     215 CBC          47 :             SaltSent(sent) => {
     216              47 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     217              39 :                     Success(keys, msg) => Ok(Success(keys, msg)),
     218                 :                     Continue(x, _) => match x {},
     219               4 :                     Failure(msg) => Ok(Failure(msg)),
     220                 :                 }
     221                 :             }
     222                 :         }
     223              95 :     }
     224                 : }
        

Generated by: LCOV version 2.1-beta