LCOV - code coverage report
Current view: top level - proxy/src/redis - connection_with_credentials_provider.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 181 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 27 0

            Line data    Source code
       1              : use std::sync::Arc;
       2              : use std::sync::atomic::{AtomicBool, Ordering};
       3              : use std::time::Duration;
       4              : 
       5              : use futures::FutureExt;
       6              : use redis::aio::{ConnectionLike, MultiplexedConnection};
       7              : use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisError, RedisResult};
       8              : use tokio::task::AbortHandle;
       9              : use tracing::{error, info, warn};
      10              : 
      11              : use super::elasticache::CredentialsProvider;
      12              : use crate::redis::elasticache::CredentialsProviderError;
      13              : 
      14              : enum Credentials {
      15              :     Static(ConnectionInfo),
      16              :     Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
      17              : }
      18              : 
      19              : impl Clone for Credentials {
      20            0 :     fn clone(&self) -> Self {
      21            0 :         match self {
      22            0 :             Credentials::Static(info) => Credentials::Static(info.clone()),
      23            0 :             Credentials::Dynamic(provider, addr) => {
      24            0 :                 Credentials::Dynamic(Arc::clone(provider), addr.clone())
      25              :             }
      26              :         }
      27            0 :     }
      28              : }
      29              : 
      30              : #[derive(thiserror::Error, Debug)]
      31              : pub enum ConnectionProviderError {
      32              :     #[error(transparent)]
      33              :     Redis(#[from] RedisError),
      34              :     #[error(transparent)]
      35              :     CredentialsProvider(#[from] CredentialsProviderError),
      36              : }
      37              : 
      38              : /// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
      39              : /// Provides PubSub connection without credentials refresh.
      40              : pub struct ConnectionWithCredentialsProvider {
      41              :     credentials: Credentials,
      42              :     // TODO: with more load on the connection, we should consider using a connection pool
      43              :     con: Option<MultiplexedConnection>,
      44              :     refresh_token_task: Option<AbortHandle>,
      45              :     mutex: tokio::sync::Mutex<()>,
      46              :     credentials_refreshed: Arc<AtomicBool>,
      47              : }
      48              : 
      49              : impl Clone for ConnectionWithCredentialsProvider {
      50            0 :     fn clone(&self) -> Self {
      51            0 :         Self {
      52            0 :             credentials: self.credentials.clone(),
      53            0 :             con: None,
      54            0 :             refresh_token_task: None,
      55            0 :             mutex: tokio::sync::Mutex::new(()),
      56            0 :             credentials_refreshed: Arc::new(AtomicBool::new(false)),
      57            0 :         }
      58            0 :     }
      59              : }
      60              : 
      61              : impl ConnectionWithCredentialsProvider {
      62            0 :     pub fn new_with_credentials_provider(
      63            0 :         host: String,
      64            0 :         port: u16,
      65            0 :         credentials_provider: Arc<CredentialsProvider>,
      66            0 :     ) -> Self {
      67            0 :         Self {
      68            0 :             credentials: Credentials::Dynamic(
      69            0 :                 credentials_provider,
      70            0 :                 redis::ConnectionAddr::TcpTls {
      71            0 :                     host,
      72            0 :                     port,
      73            0 :                     insecure: false,
      74            0 :                     tls_params: None,
      75            0 :                 },
      76            0 :             ),
      77            0 :             con: None,
      78            0 :             refresh_token_task: None,
      79            0 :             mutex: tokio::sync::Mutex::new(()),
      80            0 :             credentials_refreshed: Arc::new(AtomicBool::new(false)),
      81            0 :         }
      82            0 :     }
      83              : 
      84            0 :     pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
      85            0 :         Self {
      86            0 :             credentials: Credentials::Static(
      87            0 :                 params
      88            0 :                     .into_connection_info()
      89            0 :                     .expect("static configured redis credentials should be a valid format"),
      90            0 :             ),
      91            0 :             con: None,
      92            0 :             refresh_token_task: None,
      93            0 :             mutex: tokio::sync::Mutex::new(()),
      94            0 :             credentials_refreshed: Arc::new(AtomicBool::new(true)),
      95            0 :         }
      96            0 :     }
      97              : 
      98            0 :     async fn ping(con: &mut MultiplexedConnection) -> Result<(), ConnectionProviderError> {
      99            0 :         redis::cmd("PING")
     100            0 :             .query_async(con)
     101            0 :             .await
     102            0 :             .map_err(Into::into)
     103            0 :     }
     104              : 
     105            0 :     pub(crate) fn credentials_refreshed(&self) -> bool {
     106            0 :         self.credentials_refreshed.load(Ordering::Relaxed)
     107            0 :     }
     108              : 
     109            0 :     pub(crate) async fn connect(&mut self) -> Result<(), ConnectionProviderError> {
     110            0 :         let _guard = self.mutex.lock().await;
     111            0 :         if let Some(con) = self.con.as_mut() {
     112            0 :             match Self::ping(con).await {
     113              :                 Ok(()) => {
     114            0 :                     return Ok(());
     115              :                 }
     116            0 :                 Err(e) => {
     117            0 :                     warn!("Error during PING: {e:?}");
     118              :                 }
     119              :             }
     120              :         } else {
     121            0 :             info!("Connection is not established");
     122              :         }
     123            0 :         info!("Establishing a new connection...");
     124            0 :         self.con = None;
     125            0 :         if let Some(f) = self.refresh_token_task.take() {
     126            0 :             f.abort();
     127            0 :         }
     128            0 :         let mut con = self
     129            0 :             .get_client()
     130            0 :             .await?
     131            0 :             .get_multiplexed_tokio_connection()
     132            0 :             .await?;
     133            0 :         if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
     134            0 :             let credentials_provider = credentials_provider.clone();
     135            0 :             let con2 = con.clone();
     136            0 :             let credentials_refreshed = self.credentials_refreshed.clone();
     137            0 :             let f = tokio::spawn(Self::keep_connection(
     138            0 :                 con2,
     139            0 :                 credentials_provider,
     140            0 :                 credentials_refreshed,
     141            0 :             ));
     142            0 :             self.refresh_token_task = Some(f.abort_handle());
     143            0 :         }
     144            0 :         match Self::ping(&mut con).await {
     145              :             Ok(()) => {
     146            0 :                 info!("Connection succesfully established");
     147              :             }
     148            0 :             Err(e) => {
     149            0 :                 warn!("Connection is broken. Error during PING: {e:?}");
     150              :             }
     151              :         }
     152            0 :         self.con = Some(con);
     153            0 :         Ok(())
     154            0 :     }
     155              : 
     156            0 :     async fn get_connection_info(&self) -> Result<ConnectionInfo, ConnectionProviderError> {
     157            0 :         match &self.credentials {
     158            0 :             Credentials::Static(info) => Ok(info.clone()),
     159            0 :             Credentials::Dynamic(provider, addr) => {
     160            0 :                 let (username, password) = provider.provide_credentials().await?;
     161            0 :                 Ok(ConnectionInfo {
     162            0 :                     addr: addr.clone(),
     163            0 :                     redis: RedisConnectionInfo {
     164            0 :                         db: 0,
     165            0 :                         username: Some(username),
     166            0 :                         password: Some(password.clone()),
     167            0 :                         // TODO: switch to RESP3 after testing new client version.
     168            0 :                         protocol: redis::ProtocolVersion::RESP2,
     169            0 :                     },
     170            0 :                 })
     171              :             }
     172              :         }
     173            0 :     }
     174              : 
     175            0 :     async fn get_client(&self) -> Result<redis::Client, ConnectionProviderError> {
     176            0 :         let client = redis::Client::open(self.get_connection_info().await?)?;
     177            0 :         self.credentials_refreshed.store(true, Ordering::Relaxed);
     178            0 :         Ok(client)
     179            0 :     }
     180              : 
     181              :     // PubSub does not support credentials refresh.
     182              :     // Requires manual reconnection every 12h.
     183            0 :     pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
     184            0 :         Ok(self.get_client().await?.get_async_pubsub().await?)
     185            0 :     }
     186              : 
     187              :     // The connection lives for 12h.
     188              :     // It can be prolonged with sending `AUTH` commands with the refreshed token.
     189              :     // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
     190            0 :     async fn keep_connection(
     191            0 :         mut con: MultiplexedConnection,
     192            0 :         credentials_provider: Arc<CredentialsProvider>,
     193            0 :         credentials_refreshed: Arc<AtomicBool>,
     194            0 :     ) -> ! {
     195              :         loop {
     196              :             // The connection lives for 12h, for the sanity check we refresh it every hour.
     197            0 :             tokio::time::sleep(Duration::from_secs(60 * 60)).await;
     198            0 :             match Self::refresh_token(&mut con, credentials_provider.clone()).await {
     199              :                 Ok(()) => {
     200            0 :                     info!("Token refreshed");
     201            0 :                     credentials_refreshed.store(true, Ordering::Relaxed);
     202              :                 }
     203            0 :                 Err(e) => {
     204            0 :                     error!("Error during token refresh: {e:?}");
     205            0 :                     credentials_refreshed.store(false, Ordering::Relaxed);
     206              :                 }
     207              :             }
     208              :         }
     209              :     }
     210            0 :     async fn refresh_token(
     211            0 :         con: &mut MultiplexedConnection,
     212            0 :         credentials_provider: Arc<CredentialsProvider>,
     213            0 :     ) -> anyhow::Result<()> {
     214            0 :         let (user, password) = credentials_provider.provide_credentials().await?;
     215            0 :         let _: () = redis::cmd("AUTH")
     216            0 :             .arg(user)
     217            0 :             .arg(password)
     218            0 :             .query_async(con)
     219            0 :             .await?;
     220            0 :         Ok(())
     221            0 :     }
     222              :     /// Sends an already encoded (packed) command into the TCP socket and
     223              :     /// reads the single response from it.
     224            0 :     pub(crate) async fn send_packed_command(
     225            0 :         &mut self,
     226            0 :         cmd: &redis::Cmd,
     227            0 :     ) -> RedisResult<redis::Value> {
     228              :         // Clone connection to avoid having to lock the ArcSwap in write mode
     229            0 :         let con = self.con.as_mut().ok_or(redis::RedisError::from((
     230            0 :             redis::ErrorKind::IoError,
     231            0 :             "Connection not established",
     232            0 :         )))?;
     233            0 :         con.send_packed_command(cmd).await
     234            0 :     }
     235              : 
     236              :     /// Sends multiple already encoded (packed) command into the TCP socket
     237              :     /// and reads `count` responses from it.  This is used to implement
     238              :     /// pipelining.
     239            0 :     pub(crate) async fn send_packed_commands(
     240            0 :         &mut self,
     241            0 :         cmd: &redis::Pipeline,
     242            0 :         offset: usize,
     243            0 :         count: usize,
     244            0 :     ) -> RedisResult<Vec<redis::Value>> {
     245              :         // Clone shared connection future to avoid having to lock the ArcSwap in write mode
     246            0 :         let con = self.con.as_mut().ok_or(redis::RedisError::from((
     247            0 :             redis::ErrorKind::IoError,
     248            0 :             "Connection not established",
     249            0 :         )))?;
     250            0 :         con.send_packed_commands(cmd, offset, count).await
     251            0 :     }
     252              : }
     253              : 
     254              : impl ConnectionLike for ConnectionWithCredentialsProvider {
     255            0 :     fn req_packed_command<'a>(
     256            0 :         &'a mut self,
     257            0 :         cmd: &'a redis::Cmd,
     258            0 :     ) -> redis::RedisFuture<'a, redis::Value> {
     259            0 :         self.send_packed_command(cmd).boxed()
     260            0 :     }
     261              : 
     262            0 :     fn req_packed_commands<'a>(
     263            0 :         &'a mut self,
     264            0 :         cmd: &'a redis::Pipeline,
     265            0 :         offset: usize,
     266            0 :         count: usize,
     267            0 :     ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
     268            0 :         self.send_packed_commands(cmd, offset, count).boxed()
     269            0 :     }
     270              : 
     271            0 :     fn get_db(&self) -> i64 {
     272            0 :         self.con.as_ref().map_or(0, |c| c.get_db())
     273            0 :     }
     274              : }
        

Generated by: LCOV version 2.1-beta