LCOV - code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 93.3 % 165 154
Test Date: 2025-07-31 15:59:03 Functions: 100.0 % 14 14

            Line data    Source code
       1              : //! Implementation of the SCRAM authentication algorithm.
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use base64::Engine as _;
       6              : use base64::prelude::BASE64_STANDARD;
       7              : use tracing::{debug, trace};
       8              : 
       9              : use super::messages::{
      10              :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
      11              : };
      12              : use super::pbkdf2::Pbkdf2;
      13              : use super::secret::ServerSecret;
      14              : use super::signature::SignatureBuilder;
      15              : use super::threadpool::ThreadPool;
      16              : use super::{ScramKey, pbkdf2};
      17              : use crate::intern::{EndpointIdInt, RoleNameInt};
      18              : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      19              : use crate::scram::cache::Pbkdf2CacheEntry;
      20              : 
      21              : /// The only channel binding mode we currently support.
      22              : #[derive(Debug)]
      23              : struct TlsServerEndPoint;
      24              : 
      25              : impl std::fmt::Display for TlsServerEndPoint {
      26            6 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      27            6 :         write!(f, "tls-server-end-point")
      28            6 :     }
      29              : }
      30              : 
      31              : impl std::str::FromStr for TlsServerEndPoint {
      32              :     type Err = sasl::Error;
      33              : 
      34            6 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      35            6 :         match s {
      36            6 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      37            0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      38              :         }
      39            6 :     }
      40              : }
      41              : 
      42              : struct SaslSentInner {
      43              :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      44              :     client_first_message_bare: String,
      45              :     server_first_message: OwnedServerFirstMessage,
      46              : }
      47              : 
      48              : struct SaslInitial {
      49              :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      50              : }
      51              : 
      52              : enum ExchangeState {
      53              :     /// Waiting for [`ClientFirstMessage`].
      54              :     Initial(SaslInitial),
      55              :     /// Waiting for [`ClientFinalMessage`].
      56              :     SaltSent(SaslSentInner),
      57              : }
      58              : 
      59              : /// Server's side of SCRAM auth algorithm.
      60              : pub(crate) struct Exchange<'a> {
      61              :     state: ExchangeState,
      62              :     secret: &'a ServerSecret,
      63              :     tls_server_end_point: crate::tls::TlsServerEndPoint,
      64              : }
      65              : 
      66              : impl<'a> Exchange<'a> {
      67           13 :     pub(crate) fn new(
      68           13 :         secret: &'a ServerSecret,
      69           13 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      70           13 :         tls_server_end_point: crate::tls::TlsServerEndPoint,
      71           13 :     ) -> Self {
      72           13 :         Self {
      73           13 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      74           13 :             secret,
      75           13 :             tls_server_end_point,
      76           13 :         }
      77           13 :     }
      78              : }
      79              : 
      80           14 : async fn derive_client_key(
      81           14 :     pool: &ThreadPool,
      82           14 :     endpoint: EndpointIdInt,
      83           14 :     password: &[u8],
      84           14 :     salt: &[u8],
      85           14 :     iterations: u32,
      86           14 : ) -> pbkdf2::Block {
      87           14 :     pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
      88           14 :         .await
      89           14 : }
      90              : 
      91              : /// For cleartext flow, we need to derive the client key to
      92              : /// 1. authenticate the client.
      93              : /// 2. authenticate with compute.
      94            8 : pub(crate) async fn exchange(
      95            8 :     pool: &ThreadPool,
      96            8 :     endpoint: EndpointIdInt,
      97            8 :     role: RoleNameInt,
      98            8 :     secret: &ServerSecret,
      99            8 :     password: &[u8],
     100            8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
     101            8 :     if secret.iterations > CACHED_ROUNDS {
     102            8 :         exchange_with_cache(pool, endpoint, role, secret, password).await
     103              :     } else {
     104            0 :         let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
     105            0 :         let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
     106            0 :         Ok(validate_pbkdf2(secret, &hash))
     107              :     }
     108            8 : }
     109              : 
     110              : /// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
     111              : /// which is not enough by itself to perform an offline brute force.
     112            8 : async fn exchange_with_cache(
     113            8 :     pool: &ThreadPool,
     114            8 :     endpoint: EndpointIdInt,
     115            8 :     role: RoleNameInt,
     116            8 :     secret: &ServerSecret,
     117            8 :     password: &[u8],
     118            8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
     119            8 :     let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
     120              : 
     121            8 :     debug_assert!(
     122            8 :         secret.iterations > CACHED_ROUNDS,
     123            0 :         "we should not cache password data if there isn't enough rounds needed"
     124              :     );
     125              : 
     126              :     // compute the prefix of the pbkdf2 output.
     127            8 :     let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
     128              : 
     129            8 :     if let Some(entry) = pool.cache.get_entry(endpoint, role) {
     130              :         // hot path: let's check the threadpool cache
     131            2 :         if secret.cached_at == entry.cached_from {
     132              :             // cache is valid. compute the full hash by adding the prefix to the suffix.
     133            2 :             let mut hash = prefix;
     134            2 :             pbkdf2::xor_assign(&mut hash, &entry.suffix);
     135            2 :             let outcome = validate_pbkdf2(secret, &hash);
     136              : 
     137            2 :             if matches!(outcome, sasl::Outcome::Success(_)) {
     138            1 :                 trace!("password validated from cache");
     139            1 :             }
     140              : 
     141            2 :             return Ok(outcome);
     142            0 :         }
     143              : 
     144              :         // cached key is no longer valid.
     145            0 :         debug!("invalidating cached password");
     146            0 :         entry.invalidate();
     147            6 :     }
     148              : 
     149              :     // slow path: full password hash.
     150            6 :     let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
     151            6 :     let outcome = validate_pbkdf2(secret, &hash);
     152              : 
     153            6 :     let client_key = match outcome {
     154            4 :         sasl::Outcome::Success(client_key) => client_key,
     155            2 :         sasl::Outcome::Failure(_) => return Ok(outcome),
     156              :     };
     157              : 
     158            4 :     trace!("storing cached password");
     159              : 
     160              :     // time to cache, compute the suffix by subtracting the prefix from the hash.
     161            4 :     let mut suffix = hash;
     162            4 :     pbkdf2::xor_assign(&mut suffix, &prefix);
     163              : 
     164            4 :     pool.cache.insert(
     165            4 :         endpoint,
     166            4 :         role,
     167            4 :         Pbkdf2CacheEntry {
     168            4 :             cached_from: secret.cached_at,
     169            4 :             suffix,
     170            4 :         },
     171              :     );
     172              : 
     173            4 :     Ok(sasl::Outcome::Success(client_key))
     174            8 : }
     175              : 
     176            8 : fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
     177            8 :     let client_key = super::ScramKey::client_key(&(*hash).into());
     178            8 :     if secret.is_password_invalid(&client_key).into() {
     179            3 :         sasl::Outcome::Failure("password doesn't match")
     180              :     } else {
     181            5 :         sasl::Outcome::Success(client_key)
     182              :     }
     183            8 : }
     184              : 
     185              : const CACHED_ROUNDS: u32 = 16;
     186              : 
     187              : impl SaslInitial {
     188           13 :     fn transition(
     189           13 :         &self,
     190           13 :         secret: &ServerSecret,
     191           13 :         tls_server_end_point: &crate::tls::TlsServerEndPoint,
     192           13 :         input: &str,
     193           13 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     194           13 :         let client_first_message = ClientFirstMessage::parse(input)
     195           13 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     196              : 
     197              :         // If the flag is set to "y" and the server supports channel
     198              :         // binding, the server MUST fail authentication
     199           13 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     200            1 :             && tls_server_end_point.supported()
     201              :         {
     202            1 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     203           12 :         }
     204              : 
     205           12 :         let server_first_message = client_first_message.build_server_first_message(
     206           12 :             &(self.nonce)(),
     207           12 :             &secret.salt_base64,
     208           12 :             secret.iterations,
     209              :         );
     210           12 :         let msg = server_first_message.as_str().to_owned();
     211              : 
     212           12 :         let next = SaslSentInner {
     213           12 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     214           12 :             client_first_message_bare: client_first_message.bare.to_owned(),
     215           12 :             server_first_message,
     216              :         };
     217              : 
     218           12 :         Ok(sasl::Step::Continue(next, msg))
     219           13 :     }
     220              : }
     221              : 
     222              : impl SaslSentInner {
     223           12 :     fn transition(
     224           12 :         &self,
     225           12 :         secret: &ServerSecret,
     226           12 :         tls_server_end_point: &crate::tls::TlsServerEndPoint,
     227           12 :         input: &str,
     228           12 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     229              :         let Self {
     230           12 :             cbind_flag,
     231           12 :             client_first_message_bare,
     232           12 :             server_first_message,
     233           12 :         } = self;
     234              : 
     235           12 :         let client_final_message = ClientFinalMessage::parse(input)
     236           12 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     237              : 
     238           12 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     239            6 :             crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
     240            0 :             crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     241            6 :         })?;
     242              : 
     243              :         // This might've been caused by a MITM attack
     244           12 :         if client_final_message.channel_binding != channel_binding {
     245            4 :             return Err(SaslError::ChannelBindingFailed(
     246            4 :                 "insecure connection: secure channel data mismatch",
     247            4 :             ));
     248            8 :         }
     249              : 
     250            8 :         if client_final_message.nonce != server_first_message.nonce() {
     251            0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     252            8 :         }
     253              : 
     254            8 :         let signature_builder = SignatureBuilder {
     255            8 :             client_first_message_bare,
     256            8 :             server_first_message: server_first_message.as_str(),
     257            8 :             client_final_message_without_proof: client_final_message.without_proof,
     258            8 :         };
     259              : 
     260            8 :         let client_key = signature_builder
     261            8 :             .build(&secret.stored_key)
     262            8 :             .derive_client_key(&client_final_message.proof);
     263              : 
     264              :         // Auth fails either if keys don't match or it's pre-determined to fail.
     265            8 :         if secret.is_password_invalid(&client_key).into() {
     266            1 :             return Ok(sasl::Step::Failure("password doesn't match"));
     267            7 :         }
     268              : 
     269            7 :         let msg =
     270            7 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     271              : 
     272            7 :         Ok(sasl::Step::Success(client_key, msg))
     273           12 :     }
     274              : }
     275              : 
     276              : impl sasl::Mechanism for Exchange<'_> {
     277              :     type Output = super::ScramKey;
     278              : 
     279           25 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     280              :         use ExchangeState;
     281              :         use sasl::Step;
     282           25 :         match &self.state {
     283           13 :             ExchangeState::Initial(init) => {
     284           13 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     285           12 :                     Step::Continue(sent, msg) => {
     286           12 :                         self.state = ExchangeState::SaltSent(sent);
     287           12 :                         Ok(Step::Continue(self, msg))
     288              :                     }
     289            0 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     290              :                 }
     291              :             }
     292           12 :             ExchangeState::SaltSent(sent) => {
     293           12 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     294            7 :                     Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
     295            1 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     296              :                 }
     297              :             }
     298              :         }
     299           25 :     }
     300              : }
        

Generated by: LCOV version 2.1-beta