LCOV - code coverage report
Current view: top level - proxy/src/scram - mod.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 93.7 % 79 74
Test Date: 2024-10-22 22:13:45 Functions: 100.0 % 13 13

            Line data    Source code
       1              : //! Salted Challenge Response Authentication Mechanism.
       2              : //!
       3              : //! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
       4              : //!
       5              : //! Reference implementation:
       6              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
       7              : //! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
       8              : 
       9              : mod countmin;
      10              : mod exchange;
      11              : mod key;
      12              : mod messages;
      13              : mod pbkdf2;
      14              : mod secret;
      15              : mod signature;
      16              : pub mod threadpool;
      17              : 
      18              : pub(crate) use exchange::{exchange, Exchange};
      19              : use hmac::{Hmac, Mac};
      20              : pub(crate) use key::ScramKey;
      21              : pub(crate) use secret::ServerSecret;
      22              : use sha2::{Digest, Sha256};
      23              : 
      24              : const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
      25              : const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
      26              : 
      27              : /// A list of supported SCRAM methods.
      28              : pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
      29              : pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
      30              : 
      31              : /// Decode base64 into array without any heap allocations
      32           49 : fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
      33           49 :     let mut bytes = [0u8; N];
      34              : 
      35           49 :     let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
      36           49 :     if size != N {
      37            0 :         return None;
      38           49 :     }
      39           49 : 
      40           49 :     Some(bytes)
      41           49 : }
      42              : 
      43              : /// This function essentially is `Hmac(sha256, key, input)`.
      44              : /// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
      45           15 : fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
      46           15 :     let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
      47           75 :     parts.into_iter().for_each(|s| mac.update(s));
      48           15 : 
      49           15 :     mac.finalize().into_bytes().into()
      50           15 : }
      51              : 
      52           12 : fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
      53           12 :     let mut hasher = Sha256::new();
      54           12 :     parts.into_iter().for_each(|s| hasher.update(s));
      55           12 : 
      56           12 :     hasher.finalize().into()
      57           12 : }
      58              : 
      59              : #[cfg(test)]
      60              : mod tests {
      61              :     use super::threadpool::ThreadPool;
      62              :     use super::{Exchange, ServerSecret};
      63              :     use crate::intern::EndpointIdInt;
      64              :     use crate::sasl::{Mechanism, Step};
      65              :     use crate::EndpointId;
      66              : 
      67              :     #[test]
      68            1 :     fn snapshot() {
      69            1 :         let iterations = 4096;
      70            1 :         let salt = "QSXCR+Q6sek8bf92";
      71            1 :         let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
      72            1 :         let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
      73            1 :         let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
      74            1 :         let secret = ServerSecret::parse(&secret).unwrap();
      75              : 
      76              :         const NONCE: [u8; 18] = [
      77              :             1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
      78              :         ];
      79            1 :         let mut exchange = Exchange::new(
      80            1 :             &secret,
      81            1 :             || NONCE,
      82            1 :             crate::config::TlsServerEndPoint::Undefined,
      83            1 :         );
      84            1 : 
      85            1 :         let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
      86            1 :         let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
      87            1 :         let server_first =
      88            1 :             "r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
      89            1 :         let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
      90              : 
      91            1 :         exchange = match exchange.exchange(client_first).unwrap() {
      92            1 :             Step::Continue(exchange, message) => {
      93            1 :                 assert_eq!(message, server_first);
      94            1 :                 exchange
      95              :             }
      96            0 :             Step::Success(_, _) => panic!("expected continue, got success"),
      97            0 :             Step::Failure(f) => panic!("{f}"),
      98              :         };
      99              : 
     100            1 :         let key = match exchange.exchange(client_final).unwrap() {
     101            1 :             Step::Success(key, message) => {
     102            1 :                 assert_eq!(message, server_final);
     103            1 :                 key
     104              :             }
     105            0 :             Step::Continue(_, _) => panic!("expected success, got continue"),
     106            0 :             Step::Failure(f) => panic!("{f}"),
     107              :         };
     108              : 
     109            1 :         assert_eq!(
     110            1 :             key.as_bytes(),
     111            1 :             [
     112            1 :                 74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
     113            1 :                 27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
     114            1 :             ]
     115            1 :         );
     116            1 :     }
     117              : 
     118            2 :     async fn run_round_trip_test(server_password: &str, client_password: &str) {
     119            2 :         let pool = ThreadPool::new(1);
     120            2 : 
     121            2 :         let ep = EndpointId::from("foo");
     122            2 :         let ep = EndpointIdInt::from(ep);
     123              : 
     124            6 :         let scram_secret = ServerSecret::build(server_password).await.unwrap();
     125            2 :         let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
     126            2 :             .await
     127            2 :             .unwrap();
     128            2 : 
     129            2 :         match outcome {
     130            1 :             crate::sasl::Outcome::Success(_) => {}
     131            1 :             crate::sasl::Outcome::Failure(r) => panic!("{r}"),
     132              :         }
     133            1 :     }
     134              : 
     135              :     #[tokio::test]
     136            1 :     async fn round_trip() {
     137            4 :         run_round_trip_test("pencil", "pencil").await;
     138            1 :     }
     139              : 
     140              :     #[tokio::test]
     141              :     #[should_panic(expected = "password doesn't match")]
     142            1 :     async fn failure() {
     143            4 :         run_round_trip_test("pencil", "eraser").await;
     144            1 :     }
     145              : }
        

Generated by: LCOV version 2.1-beta