LCOV - code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit
Test: 322b88762cba8ea666f63cda880cccab6936bf37.info Lines: 95.5 % 134 128
Test Date: 2024-02-29 11:57:12 Functions: 72.7 % 11 8

            Line data    Source code
       1              : //! Implementation of the SCRAM authentication algorithm.
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use postgres_protocol::authentication::sasl::ScramSha256;
       6              : 
       7              : use super::messages::{
       8              :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
       9              : };
      10              : use super::secret::ServerSecret;
      11              : use super::signature::SignatureBuilder;
      12              : use crate::config;
      13              : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      14              : 
      15              : /// The only channel binding mode we currently support.
      16            0 : #[derive(Debug)]
      17              : struct TlsServerEndPoint;
      18              : 
      19              : impl std::fmt::Display for TlsServerEndPoint {
      20           12 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      21           12 :         write!(f, "tls-server-end-point")
      22           12 :     }
      23              : }
      24              : 
      25              : impl std::str::FromStr for TlsServerEndPoint {
      26              :     type Err = sasl::Error;
      27              : 
      28           12 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      29           12 :         match s {
      30           12 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      31            0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      32              :         }
      33           12 :     }
      34              : }
      35              : 
      36              : struct SaslSentInner {
      37              :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      38              :     client_first_message_bare: String,
      39              :     server_first_message: OwnedServerFirstMessage,
      40              : }
      41              : 
      42              : struct SaslInitial {
      43              :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      44              : }
      45              : 
      46              : enum ExchangeState {
      47              :     /// Waiting for [`ClientFirstMessage`].
      48              :     Initial(SaslInitial),
      49              :     /// Waiting for [`ClientFinalMessage`].
      50              :     SaltSent(SaslSentInner),
      51              : }
      52              : 
      53              : /// Server's side of SCRAM auth algorithm.
      54              : pub struct Exchange<'a> {
      55              :     state: ExchangeState,
      56              :     secret: &'a ServerSecret,
      57              :     tls_server_end_point: config::TlsServerEndPoint,
      58              : }
      59              : 
      60              : impl<'a> Exchange<'a> {
      61           24 :     pub fn new(
      62           24 :         secret: &'a ServerSecret,
      63           24 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      64           24 :         tls_server_end_point: config::TlsServerEndPoint,
      65           24 :     ) -> Self {
      66           24 :         Self {
      67           24 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      68           24 :             secret,
      69           24 :             tls_server_end_point,
      70           24 :         }
      71           24 :     }
      72              : }
      73              : 
      74            4 : pub fn exchange(
      75            4 :     secret: &ServerSecret,
      76            4 :     mut client: ScramSha256,
      77            4 :     tls_server_end_point: config::TlsServerEndPoint,
      78            4 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
      79            4 :     use sasl::Step::*;
      80            4 : 
      81            4 :     let init = SaslInitial {
      82            4 :         nonce: rand::random,
      83            4 :     };
      84              : 
      85            4 :     let client_first = std::str::from_utf8(client.message())
      86            4 :         .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
      87            4 :     let sent = match init.transition(secret, &tls_server_end_point, client_first)? {
      88            4 :         Continue(sent, server_first) => {
      89            4 :             client.update(server_first.as_bytes())?;
      90            4 :             sent
      91              :         }
      92              :         Success(x, _) => match x {},
      93            0 :         Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
      94              :     };
      95              : 
      96            4 :     let client_final = std::str::from_utf8(client.message())
      97            4 :         .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
      98            4 :     let keys = match sent.transition(secret, &tls_server_end_point, client_final)? {
      99            2 :         Success(keys, server_final) => {
     100            2 :             client.finish(server_final.as_bytes())?;
     101            2 :             keys
     102              :         }
     103              :         Continue(x, _) => match x {},
     104            2 :         Failure(msg) => return Ok(sasl::Outcome::Failure(msg)),
     105              :     };
     106              : 
     107            2 :     Ok(sasl::Outcome::Success(keys))
     108            4 : }
     109              : 
     110              : impl SaslInitial {
     111           28 :     fn transition(
     112           28 :         &self,
     113           28 :         secret: &ServerSecret,
     114           28 :         tls_server_end_point: &config::TlsServerEndPoint,
     115           28 :         input: &str,
     116           28 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     117           28 :         let client_first_message = ClientFirstMessage::parse(input)
     118           28 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     119              : 
     120              :         // If the flag is set to "y" and the server supports channel
     121              :         // binding, the server MUST fail authentication
     122           28 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     123            2 :             && tls_server_end_point.supported()
     124              :         {
     125            2 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     126           26 :         }
     127           26 : 
     128           26 :         let server_first_message = client_first_message.build_server_first_message(
     129           26 :             &(self.nonce)(),
     130           26 :             &secret.salt_base64,
     131           26 :             secret.iterations,
     132           26 :         );
     133           26 :         let msg = server_first_message.as_str().to_owned();
     134              : 
     135           26 :         let next = SaslSentInner {
     136           26 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     137           26 :             client_first_message_bare: client_first_message.bare.to_owned(),
     138           26 :             server_first_message,
     139           26 :         };
     140           26 : 
     141           26 :         Ok(sasl::Step::Continue(next, msg))
     142           28 :     }
     143              : }
     144              : 
     145              : impl SaslSentInner {
     146           26 :     fn transition(
     147           26 :         &self,
     148           26 :         secret: &ServerSecret,
     149           26 :         tls_server_end_point: &config::TlsServerEndPoint,
     150           26 :         input: &str,
     151           26 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     152           26 :         let Self {
     153           26 :             cbind_flag,
     154           26 :             client_first_message_bare,
     155           26 :             server_first_message,
     156           26 :         } = self;
     157              : 
     158           26 :         let client_final_message = ClientFinalMessage::parse(input)
     159           26 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     160              : 
     161           26 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     162           12 :             config::TlsServerEndPoint::Sha256(x) => Ok(x),
     163            0 :             config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     164           26 :         })?;
     165              : 
     166              :         // This might've been caused by a MITM attack
     167           26 :         if client_final_message.channel_binding != channel_binding {
     168            8 :             return Err(SaslError::ChannelBindingFailed(
     169            8 :                 "insecure connection: secure channel data mismatch",
     170            8 :             ));
     171           18 :         }
     172           18 : 
     173           18 :         if client_final_message.nonce != server_first_message.nonce() {
     174            0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     175           18 :         }
     176           18 : 
     177           18 :         let signature_builder = SignatureBuilder {
     178           18 :             client_first_message_bare,
     179           18 :             server_first_message: server_first_message.as_str(),
     180           18 :             client_final_message_without_proof: client_final_message.without_proof,
     181           18 :         };
     182           18 : 
     183           18 :         let client_key = signature_builder
     184           18 :             .build(&secret.stored_key)
     185           18 :             .derive_client_key(&client_final_message.proof);
     186           18 : 
     187           18 :         // Auth fails either if keys don't match or it's pre-determined to fail.
     188           18 :         if client_key.sha256() != secret.stored_key || secret.doomed {
     189            4 :             return Ok(sasl::Step::Failure("password doesn't match"));
     190           14 :         }
     191           14 : 
     192           14 :         let msg =
     193           14 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     194           14 : 
     195           14 :         Ok(sasl::Step::Success(client_key, msg))
     196           26 :     }
     197              : }
     198              : 
     199              : impl sasl::Mechanism for Exchange<'_> {
     200              :     type Output = super::ScramKey;
     201              : 
     202           46 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     203           46 :         use {sasl::Step::*, ExchangeState::*};
     204           46 :         match &self.state {
     205           24 :             Initial(init) => {
     206           24 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     207           22 :                     Continue(sent, msg) => {
     208           22 :                         self.state = SaltSent(sent);
     209           22 :                         Ok(Continue(self, msg))
     210              :                     }
     211              :                     Success(x, _) => match x {},
     212            0 :                     Failure(msg) => Ok(Failure(msg)),
     213              :                 }
     214              :             }
     215           22 :             SaltSent(sent) => {
     216           22 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     217           12 :                     Success(keys, msg) => Ok(Success(keys, msg)),
     218              :                     Continue(x, _) => match x {},
     219            2 :                     Failure(msg) => Ok(Failure(msg)),
     220              :                 }
     221              :             }
     222              :         }
     223           46 :     }
     224              : }
        

Generated by: LCOV version 2.1-beta