LCOV - code coverage report
Current view: top level - proxy/src/redis - elasticache.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 92 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 5 0

            Line data    Source code
       1              : use std::sync::Arc;
       2              : use std::time::{Duration, SystemTime};
       3              : 
       4              : use aws_config::Region;
       5              : use aws_config::environment::EnvironmentVariableCredentialsProvider;
       6              : use aws_config::imds::credentials::ImdsCredentialsProvider;
       7              : use aws_config::meta::credentials::CredentialsProviderChain;
       8              : use aws_config::meta::region::RegionProviderChain;
       9              : use aws_config::profile::ProfileFileCredentialsProvider;
      10              : use aws_config::provider_config::ProviderConfig;
      11              : use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
      12              : use aws_credential_types::provider::error::CredentialsError;
      13              : use aws_sdk_iam::config::ProvideCredentials;
      14              : use aws_sigv4::http_request::{
      15              :     self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
      16              : };
      17              : use aws_sigv4::sign::v4::signing_params::BuildError;
      18              : use tracing::info;
      19              : 
      20              : #[derive(Debug)]
      21              : pub struct AWSIRSAConfig {
      22              :     region: String,
      23              :     service_name: String,
      24              :     cluster_name: String,
      25              :     user_id: String,
      26              :     token_ttl: Duration,
      27              :     action: String,
      28              : }
      29              : 
      30              : impl AWSIRSAConfig {
      31            0 :     pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
      32            0 :         AWSIRSAConfig {
      33            0 :             region,
      34            0 :             service_name: "elasticache".to_string(),
      35            0 :             cluster_name: cluster_name.unwrap_or_default(),
      36            0 :             user_id: user_id.unwrap_or_default(),
      37            0 :             // "The IAM authentication token is valid for 15 minutes"
      38            0 :             // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
      39            0 :             token_ttl: Duration::from_secs(15 * 60),
      40            0 :             action: "connect".to_string(),
      41            0 :         }
      42            0 :     }
      43              : }
      44              : 
      45              : #[derive(thiserror::Error, Debug)]
      46              : pub enum CredentialsProviderError {
      47              :     #[error(transparent)]
      48              :     AwsCredentials(#[from] CredentialsError),
      49              :     #[error(transparent)]
      50              :     AwsSigv4Build(#[from] BuildError),
      51              :     #[error(transparent)]
      52              :     AwsSigv4Singing(#[from] SigningError),
      53              :     #[error(transparent)]
      54              :     Http(#[from] http::Error),
      55              : }
      56              : 
      57              : /// Credentials provider for AWS elasticache authentication.
      58              : ///
      59              : /// Official documentation:
      60              : /// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
      61              : ///
      62              : /// Useful resources:
      63              : /// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
      64              : pub struct CredentialsProvider {
      65              :     config: AWSIRSAConfig,
      66              :     credentials_provider: CredentialsProviderChain,
      67              : }
      68              : 
      69              : impl CredentialsProvider {
      70            0 :     pub async fn new(
      71            0 :         aws_region: String,
      72            0 :         redis_cluster_name: Option<String>,
      73            0 :         redis_user_id: Option<String>,
      74            0 :     ) -> Arc<CredentialsProvider> {
      75            0 :         let region_provider =
      76            0 :             RegionProviderChain::default_provider().or_else(Region::new(aws_region.clone()));
      77            0 :         let provider_conf =
      78            0 :             ProviderConfig::without_region().with_region(region_provider.region().await);
      79            0 :         let aws_credentials_provider = {
      80              :             // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
      81            0 :             CredentialsProviderChain::first_try(
      82              :                 "env",
      83            0 :                 EnvironmentVariableCredentialsProvider::new(),
      84              :             )
      85              :             // uses "AWS_PROFILE" / `aws sso login --profile <profile>`
      86            0 :             .or_else(
      87              :                 "profile-sso",
      88            0 :                 ProfileFileCredentialsProvider::builder()
      89            0 :                     .configure(&provider_conf)
      90            0 :                     .build(),
      91              :             )
      92              :             // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
      93              :             // needed to access remote extensions bucket
      94            0 :             .or_else(
      95              :                 "token",
      96            0 :                 WebIdentityTokenCredentialsProvider::builder()
      97            0 :                     .configure(&provider_conf)
      98            0 :                     .build(),
      99              :             )
     100              :             // uses imds v2
     101            0 :             .or_else("imds", ImdsCredentialsProvider::builder().build())
     102              :         };
     103            0 :         Arc::new(CredentialsProvider {
     104            0 :             config: AWSIRSAConfig::new(aws_region, redis_cluster_name, redis_user_id),
     105            0 :             credentials_provider: aws_credentials_provider,
     106            0 :         })
     107            0 :     }
     108              : 
     109            0 :     pub(crate) async fn provide_credentials(
     110            0 :         &self,
     111            0 :     ) -> Result<(String, String), CredentialsProviderError> {
     112            0 :         let aws_credentials = self
     113            0 :             .credentials_provider
     114            0 :             .provide_credentials()
     115            0 :             .await?
     116            0 :             .into();
     117            0 :         info!("AWS credentials successfully obtained");
     118            0 :         info!("Connecting to Redis with configuration: {:?}", self.config);
     119            0 :         let mut settings = SigningSettings::default();
     120            0 :         settings.signature_location = SignatureLocation::QueryParams;
     121            0 :         settings.expires_in = Some(self.config.token_ttl);
     122            0 :         let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
     123            0 :             .identity(&aws_credentials)
     124            0 :             .region(&self.config.region)
     125            0 :             .name(&self.config.service_name)
     126            0 :             .time(SystemTime::now())
     127            0 :             .settings(settings)
     128            0 :             .build()?
     129            0 :             .into();
     130            0 :         let auth_params = [
     131            0 :             ("Action", &self.config.action),
     132            0 :             ("User", &self.config.user_id),
     133            0 :         ];
     134            0 :         let auth_params = url::form_urlencoded::Serializer::new(String::new())
     135            0 :             .extend_pairs(auth_params)
     136            0 :             .finish();
     137            0 :         let auth_uri = http::Uri::builder()
     138            0 :             .scheme("http")
     139            0 :             .authority(self.config.cluster_name.as_bytes())
     140            0 :             .path_and_query(format!("/?{auth_params}"))
     141            0 :             .build()?;
     142            0 :         info!("{}", auth_uri);
     143              : 
     144              :         // Convert the HTTP request into a signable request
     145            0 :         let signable_request = SignableRequest::new(
     146            0 :             "GET",
     147            0 :             auth_uri.to_string(),
     148            0 :             std::iter::empty(),
     149            0 :             SignableBody::Bytes(&[]),
     150            0 :         )?;
     151              : 
     152              :         // Sign and then apply the signature to the request
     153            0 :         let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
     154            0 :         let mut signable_request = http::Request::builder()
     155            0 :             .method("GET")
     156            0 :             .uri(auth_uri)
     157            0 :             .body(())?;
     158            0 :         si.apply_to_request_http1x(&mut signable_request);
     159            0 :         Ok((
     160            0 :             self.config.user_id.clone(),
     161            0 :             signable_request
     162            0 :                 .uri()
     163            0 :                 .to_string()
     164            0 :                 .replacen("http://", "", 1),
     165            0 :         ))
     166            0 :     }
     167              : }
        

Generated by: LCOV version 2.1-beta