LCOV - code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit
Test: 2aa98e37cd3250b9a68c97ef6050b16fe702ab33.info Lines: 97.2 % 141 137
Test Date: 2024-08-29 11:33:10 Functions: 100.0 % 12 12

            Line data    Source code
       1              : //! Implementation of the SCRAM authentication algorithm.
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use hmac::{Hmac, Mac};
       6              : use sha2::Sha256;
       7              : 
       8              : use super::messages::{
       9              :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
      10              : };
      11              : use super::pbkdf2::Pbkdf2;
      12              : use super::secret::ServerSecret;
      13              : use super::signature::SignatureBuilder;
      14              : use super::threadpool::ThreadPool;
      15              : use super::ScramKey;
      16              : use crate::config;
      17              : use crate::intern::EndpointIdInt;
      18              : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      19              : 
      20              : /// The only channel binding mode we currently support.
      21              : #[derive(Debug)]
      22              : struct TlsServerEndPoint;
      23              : 
      24              : impl std::fmt::Display for TlsServerEndPoint {
      25           36 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      26           36 :         write!(f, "tls-server-end-point")
      27           36 :     }
      28              : }
      29              : 
      30              : impl std::str::FromStr for TlsServerEndPoint {
      31              :     type Err = sasl::Error;
      32              : 
      33           36 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      34           36 :         match s {
      35           36 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      36            0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      37              :         }
      38           36 :     }
      39              : }
      40              : 
      41              : struct SaslSentInner {
      42              :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      43              :     client_first_message_bare: String,
      44              :     server_first_message: OwnedServerFirstMessage,
      45              : }
      46              : 
      47              : struct SaslInitial {
      48              :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      49              : }
      50              : 
      51              : enum ExchangeState {
      52              :     /// Waiting for [`ClientFirstMessage`].
      53              :     Initial(SaslInitial),
      54              :     /// Waiting for [`ClientFinalMessage`].
      55              :     SaltSent(SaslSentInner),
      56              : }
      57              : 
      58              : /// Server's side of SCRAM auth algorithm.
      59              : pub(crate) struct Exchange<'a> {
      60              :     state: ExchangeState,
      61              :     secret: &'a ServerSecret,
      62              :     tls_server_end_point: config::TlsServerEndPoint,
      63              : }
      64              : 
      65              : impl<'a> Exchange<'a> {
      66           78 :     pub(crate) fn new(
      67           78 :         secret: &'a ServerSecret,
      68           78 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      69           78 :         tls_server_end_point: config::TlsServerEndPoint,
      70           78 :     ) -> Self {
      71           78 :         Self {
      72           78 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      73           78 :             secret,
      74           78 :             tls_server_end_point,
      75           78 :         }
      76           78 :     }
      77              : }
      78              : 
      79              : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
      80           24 : async fn derive_client_key(
      81           24 :     pool: &ThreadPool,
      82           24 :     endpoint: EndpointIdInt,
      83           24 :     password: &[u8],
      84           24 :     salt: &[u8],
      85           24 :     iterations: u32,
      86           24 : ) -> ScramKey {
      87           24 :     let salted_password = pool
      88           24 :         .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
      89           24 :         .await
      90           24 :         .expect("job should not be cancelled");
      91           24 : 
      92           24 :     let make_key = |name| {
      93           24 :         let key = Hmac::<Sha256>::new_from_slice(&salted_password)
      94           24 :             .expect("HMAC is able to accept all key sizes")
      95           24 :             .chain_update(name)
      96           24 :             .finalize();
      97           24 : 
      98           24 :         <[u8; 32]>::from(key.into_bytes())
      99           24 :     };
     100              : 
     101           24 :     make_key(b"Client Key").into()
     102           24 : }
     103              : 
     104           24 : pub(crate) async fn exchange(
     105           24 :     pool: &ThreadPool,
     106           24 :     endpoint: EndpointIdInt,
     107           24 :     secret: &ServerSecret,
     108           24 :     password: &[u8],
     109           24 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
     110           24 :     let salt = base64::decode(&secret.salt_base64)?;
     111           24 :     let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
     112              : 
     113           24 :     if secret.is_password_invalid(&client_key).into() {
     114            6 :         Ok(sasl::Outcome::Failure("password doesn't match"))
     115              :     } else {
     116           18 :         Ok(sasl::Outcome::Success(client_key))
     117              :     }
     118           24 : }
     119              : 
     120              : impl SaslInitial {
     121           78 :     fn transition(
     122           78 :         &self,
     123           78 :         secret: &ServerSecret,
     124           78 :         tls_server_end_point: &config::TlsServerEndPoint,
     125           78 :         input: &str,
     126           78 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     127           78 :         let client_first_message = ClientFirstMessage::parse(input)
     128           78 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     129              : 
     130              :         // If the flag is set to "y" and the server supports channel
     131              :         // binding, the server MUST fail authentication
     132           78 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     133            6 :             && tls_server_end_point.supported()
     134              :         {
     135            6 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     136           72 :         }
     137           72 : 
     138           72 :         let server_first_message = client_first_message.build_server_first_message(
     139           72 :             &(self.nonce)(),
     140           72 :             &secret.salt_base64,
     141           72 :             secret.iterations,
     142           72 :         );
     143           72 :         let msg = server_first_message.as_str().to_owned();
     144              : 
     145           72 :         let next = SaslSentInner {
     146           72 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     147           72 :             client_first_message_bare: client_first_message.bare.to_owned(),
     148           72 :             server_first_message,
     149           72 :         };
     150           72 : 
     151           72 :         Ok(sasl::Step::Continue(next, msg))
     152           78 :     }
     153              : }
     154              : 
     155              : impl SaslSentInner {
     156           72 :     fn transition(
     157           72 :         &self,
     158           72 :         secret: &ServerSecret,
     159           72 :         tls_server_end_point: &config::TlsServerEndPoint,
     160           72 :         input: &str,
     161           72 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     162           72 :         let Self {
     163           72 :             cbind_flag,
     164           72 :             client_first_message_bare,
     165           72 :             server_first_message,
     166           72 :         } = self;
     167              : 
     168           72 :         let client_final_message = ClientFinalMessage::parse(input)
     169           72 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     170              : 
     171           72 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     172           36 :             config::TlsServerEndPoint::Sha256(x) => Ok(x),
     173            0 :             config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     174           72 :         })?;
     175              : 
     176              :         // This might've been caused by a MITM attack
     177           72 :         if client_final_message.channel_binding != channel_binding {
     178           24 :             return Err(SaslError::ChannelBindingFailed(
     179           24 :                 "insecure connection: secure channel data mismatch",
     180           24 :             ));
     181           48 :         }
     182           48 : 
     183           48 :         if client_final_message.nonce != server_first_message.nonce() {
     184            0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     185           48 :         }
     186           48 : 
     187           48 :         let signature_builder = SignatureBuilder {
     188           48 :             client_first_message_bare,
     189           48 :             server_first_message: server_first_message.as_str(),
     190           48 :             client_final_message_without_proof: client_final_message.without_proof,
     191           48 :         };
     192           48 : 
     193           48 :         let client_key = signature_builder
     194           48 :             .build(&secret.stored_key)
     195           48 :             .derive_client_key(&client_final_message.proof);
     196           48 : 
     197           48 :         // Auth fails either if keys don't match or it's pre-determined to fail.
     198           48 :         if secret.is_password_invalid(&client_key).into() {
     199            6 :             return Ok(sasl::Step::Failure("password doesn't match"));
     200           42 :         }
     201           42 : 
     202           42 :         let msg =
     203           42 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     204           42 : 
     205           42 :         Ok(sasl::Step::Success(client_key, msg))
     206           72 :     }
     207              : }
     208              : 
     209              : impl sasl::Mechanism for Exchange<'_> {
     210              :     type Output = super::ScramKey;
     211              : 
     212          150 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     213          150 :         use {sasl::Step, ExchangeState};
     214          150 :         match &self.state {
     215           78 :             ExchangeState::Initial(init) => {
     216           78 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     217           72 :                     Step::Continue(sent, msg) => {
     218           72 :                         self.state = ExchangeState::SaltSent(sent);
     219           72 :                         Ok(Step::Continue(self, msg))
     220              :                     }
     221              :                     Step::Success(x, _) => match x {},
     222            0 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     223              :                 }
     224              :             }
     225           72 :             ExchangeState::SaltSent(sent) => {
     226           72 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     227           42 :                     Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
     228              :                     Step::Continue(x, _) => match x {},
     229            6 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     230              :                 }
     231              :             }
     232              :         }
     233          150 :     }
     234              : }
        

Generated by: LCOV version 2.1-beta