Line data Source code
1 : use std::sync::Arc;
2 : use std::time::{Duration, SystemTime};
3 :
4 : use aws_config::environment::EnvironmentVariableCredentialsProvider;
5 : use aws_config::imds::credentials::ImdsCredentialsProvider;
6 : use aws_config::meta::credentials::CredentialsProviderChain;
7 : use aws_config::meta::region::RegionProviderChain;
8 : use aws_config::profile::ProfileFileCredentialsProvider;
9 : use aws_config::provider_config::ProviderConfig;
10 : use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider;
11 : use aws_config::Region;
12 : use aws_sdk_iam::config::ProvideCredentials;
13 : use aws_sigv4::http_request::{
14 : self, SignableBody, SignableRequest, SignatureLocation, SigningSettings,
15 : };
16 : use tracing::info;
17 :
18 : #[derive(Debug)]
19 : pub struct AWSIRSAConfig {
20 : region: String,
21 : service_name: String,
22 : cluster_name: String,
23 : user_id: String,
24 : token_ttl: Duration,
25 : action: String,
26 : }
27 :
28 : impl AWSIRSAConfig {
29 0 : pub fn new(region: String, cluster_name: Option<String>, user_id: Option<String>) -> Self {
30 0 : AWSIRSAConfig {
31 0 : region,
32 0 : service_name: "elasticache".to_string(),
33 0 : cluster_name: cluster_name.unwrap_or_default(),
34 0 : user_id: user_id.unwrap_or_default(),
35 0 : // "The IAM authentication token is valid for 15 minutes"
36 0 : // https://docs.aws.amazon.com/memorydb/latest/devguide/auth-iam.html#auth-iam-limits
37 0 : token_ttl: Duration::from_secs(15 * 60),
38 0 : action: "connect".to_string(),
39 0 : }
40 0 : }
41 : }
42 :
43 : /// Credentials provider for AWS elasticache authentication.
44 : ///
45 : /// Official documentation:
46 : /// <https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html>
47 : ///
48 : /// Useful resources:
49 : /// <https://aws.amazon.com/blogs/database/simplify-managing-access-to-amazon-elasticache-for-redis-clusters-with-iam/>
50 : pub struct CredentialsProvider {
51 : config: AWSIRSAConfig,
52 : credentials_provider: CredentialsProviderChain,
53 : }
54 :
55 : impl CredentialsProvider {
56 0 : pub async fn new(
57 0 : aws_region: String,
58 0 : redis_cluster_name: Option<String>,
59 0 : redis_user_id: Option<String>,
60 0 : ) -> Arc<CredentialsProvider> {
61 0 : let region_provider =
62 0 : RegionProviderChain::default_provider().or_else(Region::new(aws_region.clone()));
63 0 : let provider_conf =
64 0 : ProviderConfig::without_region().with_region(region_provider.region().await);
65 0 : let aws_credentials_provider = {
66 0 : // uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
67 0 : CredentialsProviderChain::first_try(
68 0 : "env",
69 0 : EnvironmentVariableCredentialsProvider::new(),
70 0 : )
71 0 : // uses "AWS_PROFILE" / `aws sso login --profile <profile>`
72 0 : .or_else(
73 0 : "profile-sso",
74 0 : ProfileFileCredentialsProvider::builder()
75 0 : .configure(&provider_conf)
76 0 : .build(),
77 0 : )
78 0 : // uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
79 0 : // needed to access remote extensions bucket
80 0 : .or_else(
81 0 : "token",
82 0 : WebIdentityTokenCredentialsProvider::builder()
83 0 : .configure(&provider_conf)
84 0 : .build(),
85 0 : )
86 0 : // uses imds v2
87 0 : .or_else("imds", ImdsCredentialsProvider::builder().build())
88 0 : };
89 0 : Arc::new(CredentialsProvider {
90 0 : config: AWSIRSAConfig::new(aws_region, redis_cluster_name, redis_user_id),
91 0 : credentials_provider: aws_credentials_provider,
92 0 : })
93 0 : }
94 :
95 0 : pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
96 0 : let aws_credentials = self
97 0 : .credentials_provider
98 0 : .provide_credentials()
99 0 : .await?
100 0 : .into();
101 0 : info!("AWS credentials successfully obtained");
102 0 : info!("Connecting to Redis with configuration: {:?}", self.config);
103 0 : let mut settings = SigningSettings::default();
104 0 : settings.signature_location = SignatureLocation::QueryParams;
105 0 : settings.expires_in = Some(self.config.token_ttl);
106 0 : let signing_params = aws_sigv4::sign::v4::SigningParams::builder()
107 0 : .identity(&aws_credentials)
108 0 : .region(&self.config.region)
109 0 : .name(&self.config.service_name)
110 0 : .time(SystemTime::now())
111 0 : .settings(settings)
112 0 : .build()?
113 0 : .into();
114 0 : let auth_params = [
115 0 : ("Action", &self.config.action),
116 0 : ("User", &self.config.user_id),
117 0 : ];
118 0 : let auth_params = url::form_urlencoded::Serializer::new(String::new())
119 0 : .extend_pairs(auth_params)
120 0 : .finish();
121 0 : let auth_uri = http::Uri::builder()
122 0 : .scheme("http")
123 0 : .authority(self.config.cluster_name.as_bytes())
124 0 : .path_and_query(format!("/?{auth_params}"))
125 0 : .build()?;
126 0 : info!("{}", auth_uri);
127 :
128 : // Convert the HTTP request into a signable request
129 0 : let signable_request = SignableRequest::new(
130 0 : "GET",
131 0 : auth_uri.to_string(),
132 0 : std::iter::empty(),
133 0 : SignableBody::Bytes(&[]),
134 0 : )?;
135 :
136 : // Sign and then apply the signature to the request
137 0 : let (si, _) = http_request::sign(signable_request, &signing_params)?.into_parts();
138 0 : let mut signable_request = http::Request::builder()
139 0 : .method("GET")
140 0 : .uri(auth_uri)
141 0 : .body(())?;
142 0 : si.apply_to_request_http1x(&mut signable_request);
143 0 : Ok((
144 0 : self.config.user_id.clone(),
145 0 : signable_request
146 0 : .uri()
147 0 : .to_string()
148 0 : .replacen("http://", "", 1),
149 0 : ))
150 0 : }
151 : }
|