Line data Source code
1 : use std::sync::Arc;
2 : use std::time::Duration;
3 :
4 : use futures::FutureExt;
5 : use redis::aio::{ConnectionLike, MultiplexedConnection};
6 : use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
7 : use tokio::task::JoinHandle;
8 : use tracing::{debug, error, info, warn};
9 :
10 : use super::elasticache::CredentialsProvider;
11 :
12 : enum Credentials {
13 : Static(ConnectionInfo),
14 : Dynamic(Arc<CredentialsProvider>, redis::ConnectionAddr),
15 : }
16 :
17 : impl Clone for Credentials {
18 0 : fn clone(&self) -> Self {
19 0 : match self {
20 0 : Credentials::Static(info) => Credentials::Static(info.clone()),
21 0 : Credentials::Dynamic(provider, addr) => {
22 0 : Credentials::Dynamic(Arc::clone(provider), addr.clone())
23 : }
24 : }
25 0 : }
26 : }
27 :
28 : /// A wrapper around `redis::MultiplexedConnection` that automatically refreshes the token.
29 : /// Provides PubSub connection without credentials refresh.
30 : pub struct ConnectionWithCredentialsProvider {
31 : credentials: Credentials,
32 : con: Option<MultiplexedConnection>,
33 : refresh_token_task: Option<JoinHandle<()>>,
34 : mutex: tokio::sync::Mutex<()>,
35 : }
36 :
37 : impl Clone for ConnectionWithCredentialsProvider {
38 0 : fn clone(&self) -> Self {
39 0 : Self {
40 0 : credentials: self.credentials.clone(),
41 0 : con: None,
42 0 : refresh_token_task: None,
43 0 : mutex: tokio::sync::Mutex::new(()),
44 0 : }
45 0 : }
46 : }
47 :
48 : impl ConnectionWithCredentialsProvider {
49 0 : pub fn new_with_credentials_provider(
50 0 : host: String,
51 0 : port: u16,
52 0 : credentials_provider: Arc<CredentialsProvider>,
53 0 : ) -> Self {
54 0 : Self {
55 0 : credentials: Credentials::Dynamic(
56 0 : credentials_provider,
57 0 : redis::ConnectionAddr::TcpTls {
58 0 : host,
59 0 : port,
60 0 : insecure: false,
61 0 : tls_params: None,
62 0 : },
63 0 : ),
64 0 : con: None,
65 0 : refresh_token_task: None,
66 0 : mutex: tokio::sync::Mutex::new(()),
67 0 : }
68 0 : }
69 :
70 0 : pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
71 0 : Self {
72 0 : credentials: Credentials::Static(
73 0 : params
74 0 : .into_connection_info()
75 0 : .expect("static configured redis credentials should be a valid format"),
76 0 : ),
77 0 : con: None,
78 0 : refresh_token_task: None,
79 0 : mutex: tokio::sync::Mutex::new(()),
80 0 : }
81 0 : }
82 :
83 0 : async fn ping(con: &mut MultiplexedConnection) -> RedisResult<()> {
84 0 : redis::cmd("PING").query_async(con).await
85 0 : }
86 :
87 0 : pub(crate) async fn connect(&mut self) -> anyhow::Result<()> {
88 0 : let _guard = self.mutex.lock().await;
89 0 : if let Some(con) = self.con.as_mut() {
90 0 : match Self::ping(con).await {
91 : Ok(()) => {
92 0 : return Ok(());
93 : }
94 0 : Err(e) => {
95 0 : warn!("Error during PING: {e:?}");
96 : }
97 : }
98 : } else {
99 0 : info!("Connection is not established");
100 : }
101 0 : info!("Establishing a new connection...");
102 0 : self.con = None;
103 0 : if let Some(f) = self.refresh_token_task.take() {
104 0 : f.abort();
105 0 : }
106 0 : let mut con = self
107 0 : .get_client()
108 0 : .await?
109 0 : .get_multiplexed_tokio_connection()
110 0 : .await?;
111 0 : if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
112 0 : let credentials_provider = credentials_provider.clone();
113 0 : let con2 = con.clone();
114 0 : let f = tokio::spawn(async move {
115 0 : Self::keep_connection(con2, credentials_provider)
116 0 : .await
117 0 : .inspect_err(|e| debug!("keep_connection failed: {e}"))
118 0 : .ok();
119 0 : });
120 0 : self.refresh_token_task = Some(f);
121 0 : }
122 0 : match Self::ping(&mut con).await {
123 : Ok(()) => {
124 0 : info!("Connection succesfully established");
125 : }
126 0 : Err(e) => {
127 0 : warn!("Connection is broken. Error during PING: {e:?}");
128 : }
129 : }
130 0 : self.con = Some(con);
131 0 : Ok(())
132 0 : }
133 :
134 0 : async fn get_connection_info(&self) -> anyhow::Result<ConnectionInfo> {
135 0 : match &self.credentials {
136 0 : Credentials::Static(info) => Ok(info.clone()),
137 0 : Credentials::Dynamic(provider, addr) => {
138 0 : let (username, password) = provider.provide_credentials().await?;
139 0 : Ok(ConnectionInfo {
140 0 : addr: addr.clone(),
141 0 : redis: RedisConnectionInfo {
142 0 : db: 0,
143 0 : username: Some(username),
144 0 : password: Some(password.clone()),
145 0 : },
146 0 : })
147 : }
148 : }
149 0 : }
150 :
151 0 : async fn get_client(&self) -> anyhow::Result<redis::Client> {
152 0 : let client = redis::Client::open(self.get_connection_info().await?)?;
153 0 : Ok(client)
154 0 : }
155 :
156 : // PubSub does not support credentials refresh.
157 : // Requires manual reconnection every 12h.
158 0 : pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
159 0 : Ok(self.get_client().await?.get_async_pubsub().await?)
160 0 : }
161 :
162 : // The connection lives for 12h.
163 : // It can be prolonged with sending `AUTH` commands with the refreshed token.
164 : // https://docs.aws.amazon.com/AmazonElastiCache/latest/red-ug/auth-iam.html#auth-iam-limits
165 0 : async fn keep_connection(
166 0 : mut con: MultiplexedConnection,
167 0 : credentials_provider: Arc<CredentialsProvider>,
168 0 : ) -> anyhow::Result<()> {
169 : loop {
170 : // The connection lives for 12h, for the sanity check we refresh it every hour.
171 0 : tokio::time::sleep(Duration::from_secs(60 * 60)).await;
172 0 : match Self::refresh_token(&mut con, credentials_provider.clone()).await {
173 : Ok(()) => {
174 0 : info!("Token refreshed");
175 : }
176 0 : Err(e) => {
177 0 : error!("Error during token refresh: {e:?}");
178 : }
179 : }
180 : }
181 : }
182 0 : async fn refresh_token(
183 0 : con: &mut MultiplexedConnection,
184 0 : credentials_provider: Arc<CredentialsProvider>,
185 0 : ) -> anyhow::Result<()> {
186 0 : let (user, password) = credentials_provider.provide_credentials().await?;
187 0 : let _: () = redis::cmd("AUTH")
188 0 : .arg(user)
189 0 : .arg(password)
190 0 : .query_async(con)
191 0 : .await?;
192 0 : Ok(())
193 0 : }
194 : /// Sends an already encoded (packed) command into the TCP socket and
195 : /// reads the single response from it.
196 0 : pub(crate) async fn send_packed_command(
197 0 : &mut self,
198 0 : cmd: &redis::Cmd,
199 0 : ) -> RedisResult<redis::Value> {
200 : // Clone connection to avoid having to lock the ArcSwap in write mode
201 0 : let con = self.con.as_mut().ok_or(redis::RedisError::from((
202 0 : redis::ErrorKind::IoError,
203 0 : "Connection not established",
204 0 : )))?;
205 0 : con.send_packed_command(cmd).await
206 0 : }
207 :
208 : /// Sends multiple already encoded (packed) command into the TCP socket
209 : /// and reads `count` responses from it. This is used to implement
210 : /// pipelining.
211 0 : pub(crate) async fn send_packed_commands(
212 0 : &mut self,
213 0 : cmd: &redis::Pipeline,
214 0 : offset: usize,
215 0 : count: usize,
216 0 : ) -> RedisResult<Vec<redis::Value>> {
217 : // Clone shared connection future to avoid having to lock the ArcSwap in write mode
218 0 : let con = self.con.as_mut().ok_or(redis::RedisError::from((
219 0 : redis::ErrorKind::IoError,
220 0 : "Connection not established",
221 0 : )))?;
222 0 : con.send_packed_commands(cmd, offset, count).await
223 0 : }
224 : }
225 :
226 : impl ConnectionLike for ConnectionWithCredentialsProvider {
227 0 : fn req_packed_command<'a>(
228 0 : &'a mut self,
229 0 : cmd: &'a redis::Cmd,
230 0 : ) -> redis::RedisFuture<'a, redis::Value> {
231 0 : (async move { self.send_packed_command(cmd).await }).boxed()
232 0 : }
233 :
234 0 : fn req_packed_commands<'a>(
235 0 : &'a mut self,
236 0 : cmd: &'a redis::Pipeline,
237 0 : offset: usize,
238 0 : count: usize,
239 0 : ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
240 0 : (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
241 0 : }
242 :
243 0 : fn get_db(&self) -> i64 {
244 0 : 0
245 0 : }
246 : }
|