Line data Source code
1 : use std::sync::Arc;
2 : use std::sync::atomic::{AtomicBool, Ordering};
3 : use std::time::Duration;
4 :
5 : use futures::FutureExt;
6 : use redis::aio::{ConnectionLike, MultiplexedConnection};
7 : use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisError, RedisResult};
8 : use tokio::task::AbortHandle;
9 : use tracing::{error, info, warn};
10 :
11 : use super::elasticache::CredentialsProvider;
12 : use crate::redis::elasticache::CredentialsProviderError;
13 :
14 : enum Credentials {
15 : Static(ConnectionInfo),
16 : Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
17 : }
18 :
19 : impl Clone for Credentials {
20 0 : fn clone(&self) -> Self {
21 0 : match self {
22 0 : Credentials::Static(info) => Credentials::Static(info.clone()),
23 0 : Credentials::Dynamic(provider, addr) => {
24 0 : Credentials::Dynamic(Arc::clone(provider), addr.clone())
25 : }
26 : }
27 0 : }
28 : }
29 :
30 : #[derive(thiserror::Error, Debug)]
31 : pub enum ConnectionProviderError {
32 : #[error(transparent)]
33 : Redis(#[from] RedisError),
34 : #[error(transparent)]
35 : CredentialsProvider(#[from] CredentialsProviderError),
36 : }
37 :
38 : /// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
39 : /// Provides PubSub connection without credentials refresh.
40 : pub struct ConnectionWithCredentialsProvider {
41 : credentials: Credentials,
42 : // TODO: with more load on the connection, we should consider using a connection pool
43 : con: Option<MultiplexedConnection>,
44 : refresh_token_task: Option<AbortHandle>,
45 : mutex: tokio::sync::Mutex<()>,
46 : credentials_refreshed: Arc<AtomicBool>,
47 : }
48 :
49 : impl Clone for ConnectionWithCredentialsProvider {
50 0 : fn clone(&self) -> Self {
51 0 : Self {
52 0 : credentials: self.credentials.clone(),
53 0 : con: None,
54 0 : refresh_token_task: None,
55 0 : mutex: tokio::sync::Mutex::new(()),
56 0 : credentials_refreshed: Arc::new(AtomicBool::new(false)),
57 0 : }
58 0 : }
59 : }
60 :
61 : impl ConnectionWithCredentialsProvider {
62 0 : pub fn new_with_credentials_provider(
63 0 : host: String,
64 0 : port: u16,
65 0 : credentials_provider: Arc<CredentialsProvider>,
66 0 : ) -> Self {
67 0 : Self {
68 0 : credentials: Credentials::Dynamic(
69 0 : credentials_provider,
70 0 : redis::ConnectionAddr::TcpTls {
71 0 : host,
72 0 : port,
73 0 : insecure: false,
74 0 : tls_params: None,
75 0 : },
76 0 : ),
77 0 : con: None,
78 0 : refresh_token_task: None,
79 0 : mutex: tokio::sync::Mutex::new(()),
80 0 : credentials_refreshed: Arc::new(AtomicBool::new(false)),
81 0 : }
82 0 : }
83 :
84 0 : pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
85 0 : Self {
86 0 : credentials: Credentials::Static(
87 0 : params
88 0 : .into_connection_info()
89 0 : .expect("static configured redis credentials should be a valid format"),
90 0 : ),
91 0 : con: None,
92 0 : refresh_token_task: None,
93 0 : mutex: tokio::sync::Mutex::new(()),
94 0 : credentials_refreshed: Arc::new(AtomicBool::new(true)),
95 0 : }
96 0 : }
97 :
98 0 : async fn ping(con: &mut MultiplexedConnection) -> Result<(), ConnectionProviderError> {
99 0 : redis::cmd("PING")
100 0 : .query_async(con)
101 0 : .await
102 0 : .map_err(Into::into)
103 0 : }
104 :
105 0 : pub(crate) fn credentials_refreshed(&self) -> bool {
106 0 : self.credentials_refreshed.load(Ordering::Relaxed)
107 0 : }
108 :
109 0 : pub(crate) async fn connect(&mut self) -> Result<(), ConnectionProviderError> {
110 0 : let _guard = self.mutex.lock().await;
111 0 : if let Some(con) = self.con.as_mut() {
112 0 : match Self::ping(con).await {
113 : Ok(()) => {
114 0 : return Ok(());
115 : }
116 0 : Err(e) => {
117 0 : warn!("Error during PING: {e:?}");
118 : }
119 : }
120 : } else {
121 0 : info!("Connection is not established");
122 : }
123 0 : info!("Establishing a new connection...");
124 0 : self.con = None;
125 0 : if let Some(f) = self.refresh_token_task.take() {
126 0 : f.abort();
127 0 : }
128 0 : let mut con = self
129 0 : .get_client()
130 0 : .await?
131 0 : .get_multiplexed_tokio_connection()
132 0 : .await?;
133 0 : if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
134 0 : let credentials_provider = credentials_provider.clone();
135 0 : let con2 = con.clone();
136 0 : let credentials_refreshed = self.credentials_refreshed.clone();
137 0 : let f = tokio::spawn(Self::keep_connection(
138 0 : con2,
139 0 : credentials_provider,
140 0 : credentials_refreshed,
141 0 : ));
142 0 : self.refresh_token_task = Some(f.abort_handle());
143 0 : }
144 0 : match Self::ping(&mut con).await {
145 : Ok(()) => {
146 0 : info!("Connection succesfully established");
147 : }
148 0 : Err(e) => {
149 0 : warn!("Connection is broken. Error during PING: {e:?}");
150 : }
151 : }
152 0 : self.con = Some(con);
153 0 : Ok(())
154 0 : }
155 :
156 0 : async fn get_connection_info(&self) -> Result<ConnectionInfo, ConnectionProviderError> {
157 0 : match &self.credentials {
158 0 : Credentials::Static(info) => Ok(info.clone()),
159 0 : Credentials::Dynamic(provider, addr) => {
160 0 : let (username, password) = provider.provide_credentials().await?;
161 0 : Ok(ConnectionInfo {
162 0 : addr: addr.clone(),
163 0 : redis: RedisConnectionInfo {
164 0 : db: 0,
165 0 : username: Some(username),
166 0 : password: Some(password.clone()),
167 0 : // TODO: switch to RESP3 after testing new client version.
168 0 : protocol: redis::ProtocolVersion::RESP2,
169 0 : },
170 0 : })
171 : }
172 : }
173 0 : }
174 :
175 0 : async fn get_client(&self) -> Result<redis::Client, ConnectionProviderError> {
176 0 : let client = redis::Client::open(self.get_connection_info().await?)?;
177 0 : self.credentials_refreshed.store(true, Ordering::Relaxed);
178 0 : Ok(client)
179 0 : }
180 :
181 : // PubSub does not support credentials refresh.
182 : // Requires manual reconnection every 12h.
183 0 : pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
184 0 : Ok(self.get_client().await?.get_async_pubsub().await?)
185 0 : }
186 :
187 : // The connection lives for 12h.
188 : // It can be prolonged with sending `AUTH` commands with the refreshed token.
189 : // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
190 0 : async fn keep_connection(
191 0 : mut con: MultiplexedConnection,
192 0 : credentials_provider: Arc<CredentialsProvider>,
193 0 : credentials_refreshed: Arc<AtomicBool>,
194 0 : ) -> ! {
195 : loop {
196 : // The connection lives for 12h, for the sanity check we refresh it every hour.
197 0 : tokio::time::sleep(Duration::from_secs(60 * 60)).await;
198 0 : match Self::refresh_token(&mut con, credentials_provider.clone()).await {
199 : Ok(()) => {
200 0 : info!("Token refreshed");
201 0 : credentials_refreshed.store(true, Ordering::Relaxed);
202 : }
203 0 : Err(e) => {
204 0 : error!("Error during token refresh: {e:?}");
205 0 : credentials_refreshed.store(false, Ordering::Relaxed);
206 : }
207 : }
208 : }
209 : }
210 0 : async fn refresh_token(
211 0 : con: &mut MultiplexedConnection,
212 0 : credentials_provider: Arc<CredentialsProvider>,
213 0 : ) -> anyhow::Result<()> {
214 0 : let (user, password) = credentials_provider.provide_credentials().await?;
215 0 : let _: () = redis::cmd("AUTH")
216 0 : .arg(user)
217 0 : .arg(password)
218 0 : .query_async(con)
219 0 : .await?;
220 0 : Ok(())
221 0 : }
222 : /// Sends an already encoded (packed) command into the TCP socket and
223 : /// reads the single response from it.
224 0 : pub(crate) async fn send_packed_command(
225 0 : &mut self,
226 0 : cmd: &redis::Cmd,
227 0 : ) -> RedisResult<redis::Value> {
228 : // Clone connection to avoid having to lock the ArcSwap in write mode
229 0 : let con = self.con.as_mut().ok_or(redis::RedisError::from((
230 0 : redis::ErrorKind::IoError,
231 0 : "Connection not established",
232 0 : )))?;
233 0 : con.send_packed_command(cmd).await
234 0 : }
235 :
236 : /// Sends multiple already encoded (packed) command into the TCP socket
237 : /// and reads `count` responses from it. This is used to implement
238 : /// pipelining.
239 0 : pub(crate) async fn send_packed_commands(
240 0 : &mut self,
241 0 : cmd: &redis::Pipeline,
242 0 : offset: usize,
243 0 : count: usize,
244 0 : ) -> RedisResult<Vec<redis::Value>> {
245 : // Clone shared connection future to avoid having to lock the ArcSwap in write mode
246 0 : let con = self.con.as_mut().ok_or(redis::RedisError::from((
247 0 : redis::ErrorKind::IoError,
248 0 : "Connection not established",
249 0 : )))?;
250 0 : con.send_packed_commands(cmd, offset, count).await
251 0 : }
252 : }
253 :
254 : impl ConnectionLike for ConnectionWithCredentialsProvider {
255 0 : fn req_packed_command<'a>(
256 0 : &'a mut self,
257 0 : cmd: &'a redis::Cmd,
258 0 : ) -> redis::RedisFuture<'a, redis::Value> {
259 0 : self.send_packed_command(cmd).boxed()
260 0 : }
261 :
262 0 : fn req_packed_commands<'a>(
263 0 : &'a mut self,
264 0 : cmd: &'a redis::Pipeline,
265 0 : offset: usize,
266 0 : count: usize,
267 0 : ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
268 0 : self.send_packed_commands(cmd, offset, count).boxed()
269 0 : }
270 :
271 0 : fn get_db(&self) -> i64 {
272 0 : self.con.as_ref().map_or(0, |c| c.get_db())
273 0 : }
274 : }
|