LCOV - code coverage report
Current view: top level - proxy/src/scram - exchange.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 97.1 % 138 134
Test Date: 2025-02-20 13:11:02 Functions: 100.0 % 12 12

            Line data    Source code
       1              : //! Implementation of the SCRAM authentication algorithm.
       2              : 
       3              : use std::convert::Infallible;
       4              : 
       5              : use hmac::{Hmac, Mac};
       6              : use sha2::Sha256;
       7              : 
       8              : use super::messages::{
       9              :     ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
      10              : };
      11              : use super::pbkdf2::Pbkdf2;
      12              : use super::secret::ServerSecret;
      13              : use super::signature::SignatureBuilder;
      14              : use super::threadpool::ThreadPool;
      15              : use super::ScramKey;
      16              : use crate::intern::EndpointIdInt;
      17              : use crate::sasl::{self, ChannelBinding, Error as SaslError};
      18              : 
      19              : /// The only channel binding mode we currently support.
      20              : #[derive(Debug)]
      21              : struct TlsServerEndPoint;
      22              : 
      23              : impl std::fmt::Display for TlsServerEndPoint {
      24            6 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      25            6 :         write!(f, "tls-server-end-point")
      26            6 :     }
      27              : }
      28              : 
      29              : impl std::str::FromStr for TlsServerEndPoint {
      30              :     type Err = sasl::Error;
      31              : 
      32            6 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
      33            6 :         match s {
      34            6 :             "tls-server-end-point" => Ok(TlsServerEndPoint),
      35            0 :             _ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
      36              :         }
      37            6 :     }
      38              : }
      39              : 
      40              : struct SaslSentInner {
      41              :     cbind_flag: ChannelBinding<TlsServerEndPoint>,
      42              :     client_first_message_bare: String,
      43              :     server_first_message: OwnedServerFirstMessage,
      44              : }
      45              : 
      46              : struct SaslInitial {
      47              :     nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      48              : }
      49              : 
      50              : enum ExchangeState {
      51              :     /// Waiting for [`ClientFirstMessage`].
      52              :     Initial(SaslInitial),
      53              :     /// Waiting for [`ClientFinalMessage`].
      54              :     SaltSent(SaslSentInner),
      55              : }
      56              : 
      57              : /// Server's side of SCRAM auth algorithm.
      58              : pub(crate) struct Exchange<'a> {
      59              :     state: ExchangeState,
      60              :     secret: &'a ServerSecret,
      61              :     tls_server_end_point: crate::tls::TlsServerEndPoint,
      62              : }
      63              : 
      64              : impl<'a> Exchange<'a> {
      65           13 :     pub(crate) fn new(
      66           13 :         secret: &'a ServerSecret,
      67           13 :         nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
      68           13 :         tls_server_end_point: crate::tls::TlsServerEndPoint,
      69           13 :     ) -> Self {
      70           13 :         Self {
      71           13 :             state: ExchangeState::Initial(SaslInitial { nonce }),
      72           13 :             secret,
      73           13 :             tls_server_end_point,
      74           13 :         }
      75           13 :     }
      76              : }
      77              : 
      78              : // copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
      79            4 : async fn derive_client_key(
      80            4 :     pool: &ThreadPool,
      81            4 :     endpoint: EndpointIdInt,
      82            4 :     password: &[u8],
      83            4 :     salt: &[u8],
      84            4 :     iterations: u32,
      85            4 : ) -> ScramKey {
      86            4 :     let salted_password = pool
      87            4 :         .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
      88            4 :         .await;
      89              : 
      90            4 :     let make_key = |name| {
      91            4 :         let key = Hmac::<Sha256>::new_from_slice(&salted_password)
      92            4 :             .expect("HMAC is able to accept all key sizes")
      93            4 :             .chain_update(name)
      94            4 :             .finalize();
      95            4 : 
      96            4 :         <[u8; 32]>::from(key.into_bytes())
      97            4 :     };
      98              : 
      99            4 :     make_key(b"Client Key").into()
     100            4 : }
     101              : 
     102            4 : pub(crate) async fn exchange(
     103            4 :     pool: &ThreadPool,
     104            4 :     endpoint: EndpointIdInt,
     105            4 :     secret: &ServerSecret,
     106            4 :     password: &[u8],
     107            4 : ) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
     108            4 :     let salt = base64::decode(&secret.salt_base64)?;
     109            4 :     let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
     110              : 
     111            4 :     if secret.is_password_invalid(&client_key).into() {
     112            1 :         Ok(sasl::Outcome::Failure("password doesn't match"))
     113              :     } else {
     114            3 :         Ok(sasl::Outcome::Success(client_key))
     115              :     }
     116            4 : }
     117              : 
     118              : impl SaslInitial {
     119           13 :     fn transition(
     120           13 :         &self,
     121           13 :         secret: &ServerSecret,
     122           13 :         tls_server_end_point: &crate::tls::TlsServerEndPoint,
     123           13 :         input: &str,
     124           13 :     ) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
     125           13 :         let client_first_message = ClientFirstMessage::parse(input)
     126           13 :             .ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
     127              : 
     128              :         // If the flag is set to "y" and the server supports channel
     129              :         // binding, the server MUST fail authentication
     130           13 :         if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
     131            1 :             && tls_server_end_point.supported()
     132              :         {
     133            1 :             return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
     134           12 :         }
     135           12 : 
     136           12 :         let server_first_message = client_first_message.build_server_first_message(
     137           12 :             &(self.nonce)(),
     138           12 :             &secret.salt_base64,
     139           12 :             secret.iterations,
     140           12 :         );
     141           12 :         let msg = server_first_message.as_str().to_owned();
     142              : 
     143           12 :         let next = SaslSentInner {
     144           12 :             cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
     145           12 :             client_first_message_bare: client_first_message.bare.to_owned(),
     146           12 :             server_first_message,
     147           12 :         };
     148           12 : 
     149           12 :         Ok(sasl::Step::Continue(next, msg))
     150           13 :     }
     151              : }
     152              : 
     153              : impl SaslSentInner {
     154           12 :     fn transition(
     155           12 :         &self,
     156           12 :         secret: &ServerSecret,
     157           12 :         tls_server_end_point: &crate::tls::TlsServerEndPoint,
     158           12 :         input: &str,
     159           12 :     ) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
     160           12 :         let Self {
     161           12 :             cbind_flag,
     162           12 :             client_first_message_bare,
     163           12 :             server_first_message,
     164           12 :         } = self;
     165              : 
     166           12 :         let client_final_message = ClientFinalMessage::parse(input)
     167           12 :             .ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
     168              : 
     169           12 :         let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
     170            6 :             crate::tls::TlsServerEndPoint::Sha256(x) => Ok(x),
     171            0 :             crate::tls::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
     172           12 :         })?;
     173              : 
     174              :         // This might've been caused by a MITM attack
     175           12 :         if client_final_message.channel_binding != channel_binding {
     176            4 :             return Err(SaslError::ChannelBindingFailed(
     177            4 :                 "insecure connection: secure channel data mismatch",
     178            4 :             ));
     179            8 :         }
     180            8 : 
     181            8 :         if client_final_message.nonce != server_first_message.nonce() {
     182            0 :             return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
     183            8 :         }
     184            8 : 
     185            8 :         let signature_builder = SignatureBuilder {
     186            8 :             client_first_message_bare,
     187            8 :             server_first_message: server_first_message.as_str(),
     188            8 :             client_final_message_without_proof: client_final_message.without_proof,
     189            8 :         };
     190            8 : 
     191            8 :         let client_key = signature_builder
     192            8 :             .build(&secret.stored_key)
     193            8 :             .derive_client_key(&client_final_message.proof);
     194            8 : 
     195            8 :         // Auth fails either if keys don't match or it's pre-determined to fail.
     196            8 :         if secret.is_password_invalid(&client_key).into() {
     197            1 :             return Ok(sasl::Step::Failure("password doesn't match"));
     198            7 :         }
     199            7 : 
     200            7 :         let msg =
     201            7 :             client_final_message.build_server_final_message(signature_builder, &secret.server_key);
     202            7 : 
     203            7 :         Ok(sasl::Step::Success(client_key, msg))
     204           12 :     }
     205              : }
     206              : 
     207              : impl sasl::Mechanism for Exchange<'_> {
     208              :     type Output = super::ScramKey;
     209              : 
     210           25 :     fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
     211              :         use sasl::Step;
     212              :         use ExchangeState;
     213           25 :         match &self.state {
     214           13 :             ExchangeState::Initial(init) => {
     215           13 :                 match init.transition(self.secret, &self.tls_server_end_point, input)? {
     216           12 :                     Step::Continue(sent, msg) => {
     217           12 :                         self.state = ExchangeState::SaltSent(sent);
     218           12 :                         Ok(Step::Continue(self, msg))
     219              :                     }
     220            0 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     221              :                 }
     222              :             }
     223           12 :             ExchangeState::SaltSent(sent) => {
     224           12 :                 match sent.transition(self.secret, &self.tls_server_end_point, input)? {
     225            7 :                     Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
     226            1 :                     Step::Failure(msg) => Ok(Step::Failure(msg)),
     227              :                 }
     228              :             }
     229              :         }
     230           25 :     }
     231              : }
        

Generated by: LCOV version 2.1-beta