LCOV - code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit
Test: 691a4c28fe7169edd60b367c52d448a0a6605f1f.info Lines: 97.3 % 149 145
Test Date: 2024-05-10 13:18:37 Functions: 100.0 % 14 14

            Line data    Source code
       1              : //! Implementation of the SCRAM authentication algorithm.
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use hmac::{Hmac, Mac};
       6              : use sha2::Sha256;
       7              : use tokio::task::yield_now;
       8              : 
       9              : use super::messages::{
      10              :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
      11              : };
      12              : use super::secret::ServerSecret;
      13              : use super::signature::SignatureBuilder;
      14              : use super::ScramKey;
      15              : use crate::config;
      16              : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      17              : 
      18              : /// The only channel binding mode we currently support.
      19              : #[derive(Debug)]
      20              : struct TlsServerEndPoint;
      21              : 
      22              : impl std::fmt::Display for TlsServerEndPoint {
      23           12 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      24           12 :         write!(f, "tls-server-end-point")
      25           12 :     }
      26              : }
      27              : 
      28              : impl std::str::FromStr for TlsServerEndPoint {
      29              :     type Err = sasl::Error;
      30              : 
      31           12 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      32           12 :         match s {
      33           12 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      34            0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      35              :         }
      36           12 :     }
      37              : }
      38              : 
      39              : struct SaslSentInner {
      40              :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      41              :     client_first_message_bare: String,
      42              :     server_first_message: OwnedServerFirstMessage,
      43              : }
      44              : 
      45              : struct SaslInitial {
      46              :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      47              : }
      48              : 
      49              : enum ExchangeState {
      50              :     /// Waiting for [`ClientFirstMessage`].
      51              :     Initial(SaslInitial),
      52              :     /// Waiting for [`ClientFinalMessage`].
      53              :     SaltSent(SaslSentInner),
      54              : }
      55              : 
      56              : /// Server's side of SCRAM auth algorithm.
      57              : pub struct Exchange<'a> {
      58              :     state: ExchangeState,
      59              :     secret: &'a ServerSecret,
      60              :     tls_server_end_point: config::TlsServerEndPoint,
      61              : }
      62              : 
      63              : impl<'a> Exchange<'a> {
      64           26 :     pub fn new(
      65           26 :         secret: &'a ServerSecret,
      66           26 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      67           26 :         tls_server_end_point: config::TlsServerEndPoint,
      68           26 :     ) -> Self {
      69           26 :         Self {
      70           26 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      71           26 :             secret,
      72           26 :             tls_server_end_point,
      73           26 :         }
      74           26 :     }
      75              : }
      76              : 
      77              : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
      78            8 : async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
      79            8 :     let hmac = Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
      80            8 :     let mut prev = hmac
      81            8 :         .clone()
      82            8 :         .chain_update(salt)
      83            8 :         .chain_update(1u32.to_be_bytes())
      84            8 :         .finalize()
      85            8 :         .into_bytes();
      86            8 : 
      87            8 :     let mut hi = prev;
      88              : 
      89        32760 :     for i in 1..iterations {
      90        32760 :         prev = hmac.clone().chain_update(prev).finalize().into_bytes();
      91              : 
      92      1048320 :         for (hi, prev) in hi.iter_mut().zip(prev) {
      93      1048320 :             *hi ^= prev;
      94      1048320 :         }
      95              :         // yield every ~250us
      96              :         // hopefully reduces tail latencies
      97        32760 :         if i % 1024 == 0 {
      98           24 :             yield_now().await
      99        32736 :         }
     100              :     }
     101              : 
     102            8 :     hi.into()
     103            8 : }
     104              : 
     105              : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
     106            8 : async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
     107           24 :     let salted_password = pbkdf2(password, salt, iterations).await;
     108              : 
     109            8 :     let make_key = |name| {
     110            8 :         let key = Hmac::<Sha256>::new_from_slice(&salted_password)
     111            8 :             .expect("HMAC is able to accept all key sizes")
     112            8 :             .chain_update(name)
     113            8 :             .finalize();
     114            8 : 
     115            8 :         <[u8; 32]>::from(key.into_bytes())
     116            8 :     };
     117              : 
     118            8 :     make_key(b"Client Key").into()
     119            8 : }
     120              : 
     121            8 : pub async fn exchange(
     122            8 :     secret: &ServerSecret,
     123            8 :     password: &[u8],
     124            8 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
     125            8 :     let salt = base64::decode(&secret.salt_base64)?;
     126           24 :     let client_key = derive_client_key(password, &salt, secret.iterations).await;
     127              : 
     128            8 :     if secret.is_password_invalid(&client_key).into() {
     129            2 :         Ok(sasl::Outcome::Failure("password doesn't match"))
     130              :     } else {
     131            6 :         Ok(sasl::Outcome::Success(client_key))
     132              :     }
     133            8 : }
     134              : 
     135              : impl SaslInitial {
     136           26 :     fn transition(
     137           26 :         &self,
     138           26 :         secret: &ServerSecret,
     139           26 :         tls_server_end_point: &config::TlsServerEndPoint,
     140           26 :         input: &str,
     141           26 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     142           26 :         let client_first_message = ClientFirstMessage::parse(input)
     143           26 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     144              : 
     145              :         // If the flag is set to "y" and the server supports channel
     146              :         // binding, the server MUST fail authentication
     147           26 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     148            2 :             && tls_server_end_point.supported()
     149              :         {
     150            2 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     151           24 :         }
     152           24 : 
     153           24 :         let server_first_message = client_first_message.build_server_first_message(
     154           24 :             &(self.nonce)(),
     155           24 :             &secret.salt_base64,
     156           24 :             secret.iterations,
     157           24 :         );
     158           24 :         let msg = server_first_message.as_str().to_owned();
     159              : 
     160           24 :         let next = SaslSentInner {
     161           24 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     162           24 :             client_first_message_bare: client_first_message.bare.to_owned(),
     163           24 :             server_first_message,
     164           24 :         };
     165           24 : 
     166           24 :         Ok(sasl::Step::Continue(next, msg))
     167           26 :     }
     168              : }
     169              : 
     170              : impl SaslSentInner {
     171           24 :     fn transition(
     172           24 :         &self,
     173           24 :         secret: &ServerSecret,
     174           24 :         tls_server_end_point: &config::TlsServerEndPoint,
     175           24 :         input: &str,
     176           24 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     177           24 :         let Self {
     178           24 :             cbind_flag,
     179           24 :             client_first_message_bare,
     180           24 :             server_first_message,
     181           24 :         } = self;
     182              : 
     183           24 :         let client_final_message = ClientFinalMessage::parse(input)
     184           24 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     185              : 
     186           24 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     187           12 :             config::TlsServerEndPoint::Sha256(x) => Ok(x),
     188            0 :             config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     189           24 :         })?;
     190              : 
     191              :         // This might've been caused by a MITM attack
     192           24 :         if client_final_message.channel_binding != channel_binding {
     193            8 :             return Err(SaslError::ChannelBindingFailed(
     194            8 :                 "insecure connection: secure channel data mismatch",
     195            8 :             ));
     196           16 :         }
     197           16 : 
     198           16 :         if client_final_message.nonce != server_first_message.nonce() {
     199            0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     200           16 :         }
     201           16 : 
     202           16 :         let signature_builder = SignatureBuilder {
     203           16 :             client_first_message_bare,
     204           16 :             server_first_message: server_first_message.as_str(),
     205           16 :             client_final_message_without_proof: client_final_message.without_proof,
     206           16 :         };
     207           16 : 
     208           16 :         let client_key = signature_builder
     209           16 :             .build(&secret.stored_key)
     210           16 :             .derive_client_key(&client_final_message.proof);
     211           16 : 
     212           16 :         // Auth fails either if keys don't match or it's pre-determined to fail.
     213           16 :         if secret.is_password_invalid(&client_key).into() {
     214            2 :             return Ok(sasl::Step::Failure("password doesn't match"));
     215           14 :         }
     216           14 : 
     217           14 :         let msg =
     218           14 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     219           14 : 
     220           14 :         Ok(sasl::Step::Success(client_key, msg))
     221           24 :     }
     222              : }
     223              : 
     224              : impl sasl::Mechanism for Exchange<'_> {
     225              :     type Output = super::ScramKey;
     226              : 
     227           50 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     228           50 :         use {sasl::Step::*, ExchangeState::*};
     229           50 :         match &self.state {
     230           26 :             Initial(init) => {
     231           26 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     232           24 :                     Continue(sent, msg) => {
     233           24 :                         self.state = SaltSent(sent);
     234           24 :                         Ok(Continue(self, msg))
     235              :                     }
     236              :                     Success(x, _) => match x {},
     237            0 :                     Failure(msg) => Ok(Failure(msg)),
     238              :                 }
     239              :             }
     240           24 :             SaltSent(sent) => {
     241           24 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     242           14 :                     Success(keys, msg) => Ok(Success(keys, msg)),
     243              :                     Continue(x, _) => match x {},
     244            2 :                     Failure(msg) => Ok(Failure(msg)),
     245              :                 }
     246              :             }
     247              :         }
     248           50 :     }
     249              : }
        

Generated by: LCOV version 2.1-beta