LCOV - code coverage report
Current view: top level - proxy/src/redis - elasticache.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 0.0 % 102 0
Test Date: 2025-02-20 13:11:02 Functions: 0.0 % 5 0

            Line data    Source code
       1              : use std::sync::Arc;
       2              : use std::time::{Duration, SystemTime};
       3              : 
       4              : use aws_config::environment::EnvironmentVariableCredentialsProvider;
       5              : use aws_config::imds::credentials::ImdsCredentialsProvider;
       6              : use aws_config::meta::credentials::CredentialsProviderChain;
       7              : use aws_config::meta::region::RegionProviderChain;
       8              : use aws_config::profile::ProfileFileCredentialsProvider;
       9              : use aws_config::provider_config::ProviderConfig;
      10              : use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
      11              : use aws_config::Region;
      12              : use aws_sdk_iam::config::ProvideCredentials;
      13              : use aws_sigv4::http_request::{
      14              :     self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
      15              : };
      16              : use tracing::info;
      17              : 
      18              : #[derive(Debug)]
      19              : pub struct AWSIRSAConfig {
      20              :     region: String,
      21              :     service_name: String,
      22              :     cluster_name: String,
      23              :     user_id: String,
      24              :     token_ttl: Duration,
      25              :     action: String,
      26              : }
      27              : 
      28              : impl AWSIRSAConfig {
      29            0 :     pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
      30            0 :         AWSIRSAConfig {
      31            0 :             region,
      32            0 :             service_name: "elasticache".to_string(),
      33            0 :             cluster_name: cluster_name.unwrap_or_default(),
      34            0 :             user_id: user_id.unwrap_or_default(),
      35            0 :             // "The IAM authentication token is valid for 15 minutes"
      36            0 :             // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
      37            0 :             token_ttl: Duration::from_secs(15 * 60),
      38            0 :             action: "connect".to_string(),
      39            0 :         }
      40            0 :     }
      41              : }
      42              : 
      43              : /// Credentials provider for AWS elasticache authentication.
      44              : ///
      45              : /// Official documentation:
      46              : /// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
      47              : ///
      48              : /// Useful resources:
      49              : /// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
      50              : pub struct CredentialsProvider {
      51              :     config: AWSIRSAConfig,
      52              :     credentials_provider: CredentialsProviderChain,
      53              : }
      54              : 
      55              : impl CredentialsProvider {
      56            0 :     pub async fn new(
      57            0 :         aws_region: String,
      58            0 :         redis_cluster_name: Option<String>,
      59            0 :         redis_user_id: Option<String>,
      60            0 :     ) -> Arc<CredentialsProvider> {
      61            0 :         let region_provider =
      62            0 :             RegionProviderChain::default_provider().or_else(Region::new(aws_region.clone()));
      63            0 :         let provider_conf =
      64            0 :             ProviderConfig::without_region().with_region(region_provider.region().await);
      65            0 :         let aws_credentials_provider = {
      66            0 :             // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
      67            0 :             CredentialsProviderChain::first_try(
      68            0 :                 "env",
      69            0 :                 EnvironmentVariableCredentialsProvider::new(),
      70            0 :             )
      71            0 :             // uses "AWS_PROFILE" / `aws sso login --profile <profile>`
      72            0 :             .or_else(
      73            0 :                 "profile-sso",
      74            0 :                 ProfileFileCredentialsProvider::builder()
      75            0 :                     .configure(&provider_conf)
      76            0 :                     .build(),
      77            0 :             )
      78            0 :             // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
      79            0 :             // needed to access remote extensions bucket
      80            0 :             .or_else(
      81            0 :                 "token",
      82            0 :                 WebIdentityTokenCredentialsProvider::builder()
      83            0 :                     .configure(&provider_conf)
      84            0 :                     .build(),
      85            0 :             )
      86            0 :             // uses imds v2
      87            0 :             .or_else("imds", ImdsCredentialsProvider::builder().build())
      88            0 :         };
      89            0 :         Arc::new(CredentialsProvider {
      90            0 :             config: AWSIRSAConfig::new(aws_region, redis_cluster_name, redis_user_id),
      91            0 :             credentials_provider: aws_credentials_provider,
      92            0 :         })
      93            0 :     }
      94              : 
      95            0 :     pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
      96            0 :         let aws_credentials = self
      97            0 :             .credentials_provider
      98            0 :             .provide_credentials()
      99            0 :             .await?
     100            0 :             .into();
     101            0 :         info!("AWS credentials successfully obtained");
     102            0 :         info!("Connecting to Redis with configuration: {:?}", self.config);
     103            0 :         let mut settings = SigningSettings::default();
     104            0 :         settings.signature_location = SignatureLocation::QueryParams;
     105            0 :         settings.expires_in = Some(self.config.token_ttl);
     106            0 :         let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
     107            0 :             .identity(&aws_credentials)
     108            0 :             .region(&self.config.region)
     109            0 :             .name(&self.config.service_name)
     110            0 :             .time(SystemTime::now())
     111            0 :             .settings(settings)
     112            0 :             .build()?
     113            0 :             .into();
     114            0 :         let auth_params = [
     115            0 :             ("Action", &self.config.action),
     116            0 :             ("User", &self.config.user_id),
     117            0 :         ];
     118            0 :         let auth_params = url::form_urlencoded::Serializer::new(String::new())
     119            0 :             .extend_pairs(auth_params)
     120            0 :             .finish();
     121            0 :         let auth_uri = http::Uri::builder()
     122            0 :             .scheme("http")
     123            0 :             .authority(self.config.cluster_name.as_bytes())
     124            0 :             .path_and_query(format!("/?{auth_params}"))
     125            0 :             .build()?;
     126            0 :         info!("{}", auth_uri);
     127              : 
     128              :         // Convert the HTTP request into a signable request
     129            0 :         let signable_request = SignableRequest::new(
     130            0 :             "GET",
     131            0 :             auth_uri.to_string(),
     132            0 :             std::iter::empty(),
     133            0 :             SignableBody::Bytes(&[]),
     134            0 :         )?;
     135              : 
     136              :         // Sign and then apply the signature to the request
     137            0 :         let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
     138            0 :         let mut signable_request = http::Request::builder()
     139            0 :             .method("GET")
     140            0 :             .uri(auth_uri)
     141            0 :             .body(())?;
     142            0 :         si.apply_to_request_http1x(&mut signable_request);
     143            0 :         Ok((
     144            0 :             self.config.user_id.clone(),
     145            0 :             signable_request
     146            0 :                 .uri()
     147            0 :                 .to_string()
     148            0 :                 .replacen("http://", "", 1),
     149            0 :         ))
     150            0 :     }
     151              : }
        

Generated by: LCOV version 2.1-beta