Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use uuid::Uuid;
10 :
11 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
12 : use crate::cache::project_info::ProjectInfoCache;
13 : use crate::cancellation::{CancelMap, CancellationHandler};
14 : use crate::intern::{ProjectIdInt, RoleNameInt};
15 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
16 :
17 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
18 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
19 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
20 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
21 :
22 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
23 0 : let mut conn = client.get_async_pubsub().await?;
24 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
25 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
26 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
27 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
28 0 : Ok(conn)
29 0 : }
30 :
31 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
32 : #[serde(tag = "topic", content = "data")]
33 : pub(crate) enum Notification {
34 : #[serde(
35 : rename = "/allowed_ips_updated",
36 : deserialize_with = "deserialize_json_string"
37 : )]
38 : AllowedIpsUpdate {
39 : allowed_ips_update: AllowedIpsUpdate,
40 : },
41 : #[serde(
42 : rename = "/password_updated",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : PasswordUpdate { password_update: PasswordUpdate },
46 : #[serde(rename = "/cancel_session")]
47 : Cancel(CancelSession),
48 : }
49 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
50 : pub(crate) struct AllowedIpsUpdate {
51 : project_id: ProjectIdInt,
52 : }
53 3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
54 : pub(crate) struct PasswordUpdate {
55 : project_id: ProjectIdInt,
56 : role_name: RoleNameInt,
57 : }
58 8 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
59 : pub(crate) struct CancelSession {
60 : pub(crate) region_id: Option<String>,
61 : pub(crate) cancel_key_data: CancelKeyData,
62 : pub(crate) session_id: Uuid,
63 : }
64 :
65 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
66 2 : where
67 2 : T: for<'de2> serde::Deserialize<'de2>,
68 2 : D: serde::Deserializer<'de>,
69 2 : {
70 2 : let s = String::deserialize(deserializer)?;
71 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
72 2 : }
73 :
74 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
75 : cache: Arc<C>,
76 : cancellation_handler: Arc<CancellationHandler<()>>,
77 : region_id: String,
78 : }
79 :
80 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
81 0 : fn clone(&self) -> Self {
82 0 : Self {
83 0 : cache: self.cache.clone(),
84 0 : cancellation_handler: self.cancellation_handler.clone(),
85 0 : region_id: self.region_id.clone(),
86 0 : }
87 0 : }
88 : }
89 :
90 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
91 0 : pub(crate) fn new(
92 0 : cache: Arc<C>,
93 0 : cancellation_handler: Arc<CancellationHandler<()>>,
94 0 : region_id: String,
95 0 : ) -> Self {
96 0 : Self {
97 0 : cache,
98 0 : cancellation_handler,
99 0 : region_id,
100 0 : }
101 0 : }
102 0 : pub(crate) async fn increment_active_listeners(&self) {
103 0 : self.cache.increment_active_listeners().await;
104 0 : }
105 0 : pub(crate) async fn decrement_active_listeners(&self) {
106 0 : self.cache.decrement_active_listeners().await;
107 0 : }
108 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
109 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
110 : let payload: String = msg.get_payload()?;
111 : tracing::debug!(?payload, "received a message payload");
112 :
113 : let msg: Notification = match serde_json::from_str(&payload) {
114 : Ok(msg) => msg,
115 : Err(e) => {
116 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
117 : channel: msg.get_channel_name(),
118 : });
119 : tracing::error!("broken message: {e}");
120 : return Ok(());
121 : }
122 : };
123 : tracing::debug!(?msg, "received a message");
124 : match msg {
125 : Notification::Cancel(cancel_session) => {
126 : tracing::Span::current().record(
127 : "session_id",
128 : tracing::field::display(cancel_session.session_id),
129 : );
130 : Metrics::get()
131 : .proxy
132 : .redis_events_count
133 : .inc(RedisEventsCount::CancelSession);
134 : if let Some(cancel_region) = cancel_session.region_id {
135 : // If the message is not for this region, ignore it.
136 : if cancel_region != self.region_id {
137 : return Ok(());
138 : }
139 : }
140 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
141 : match self
142 : .cancellation_handler
143 : .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
144 : .await
145 : {
146 : Ok(()) => {}
147 : Err(e) => {
148 : tracing::warn!("failed to cancel session: {e}");
149 : }
150 : }
151 : }
152 : Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
153 : invalidate_cache(self.cache.clone(), msg.clone());
154 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
155 : Metrics::get()
156 : .proxy
157 : .redis_events_count
158 : .inc(RedisEventsCount::AllowedIpsUpdate);
159 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
160 : Metrics::get()
161 : .proxy
162 : .redis_events_count
163 : .inc(RedisEventsCount::PasswordUpdate);
164 : }
165 : // It might happen that the invalid entry is on the way to be cached.
166 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
167 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
168 : let cache = self.cache.clone();
169 0 : tokio::spawn(async move {
170 0 : tokio::time::sleep(INVALIDATION_LAG).await;
171 0 : invalidate_cache(cache, msg);
172 0 : });
173 : }
174 : }
175 :
176 : Ok(())
177 : }
178 : }
179 :
180 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
181 0 : match msg {
182 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
183 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
184 0 : }
185 0 : Notification::PasswordUpdate { password_update } => cache
186 0 : .invalidate_role_secret_for_project(
187 0 : password_update.project_id,
188 0 : password_update.role_name,
189 0 : ),
190 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
191 : }
192 0 : }
193 :
194 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
195 0 : handler: MessageHandler<C>,
196 0 : redis: ConnectionWithCredentialsProvider,
197 0 : cancellation_token: CancellationToken,
198 0 : ) -> anyhow::Result<()> {
199 : loop {
200 0 : if cancellation_token.is_cancelled() {
201 0 : return Ok(());
202 0 : }
203 0 : let mut conn = match try_connect(&redis).await {
204 0 : Ok(conn) => {
205 0 : handler.increment_active_listeners().await;
206 0 : conn
207 : }
208 0 : Err(e) => {
209 0 : tracing::error!(
210 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
211 : );
212 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
213 0 : continue;
214 : }
215 : };
216 0 : let mut stream = conn.on_message();
217 0 : while let Some(msg) = stream.next().await {
218 0 : match handler.handle_message(msg).await {
219 0 : Ok(()) => {}
220 0 : Err(e) => {
221 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
222 0 : break;
223 : }
224 : }
225 0 : if cancellation_token.is_cancelled() {
226 0 : handler.decrement_active_listeners().await;
227 0 : return Ok(());
228 0 : }
229 : }
230 0 : handler.decrement_active_listeners().await;
231 : }
232 0 : }
233 :
234 : /// Handle console's invalidation messages.
235 : #[tracing::instrument(name = "redis_notifications", skip_all)]
236 : pub async fn task_main<C>(
237 : redis: ConnectionWithCredentialsProvider,
238 : cache: Arc<C>,
239 : cancel_map: CancelMap,
240 : region_id: String,
241 : ) -> anyhow::Result<Infallible>
242 : where
243 : C: ProjectInfoCache + Send + Sync + 'static,
244 : {
245 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
246 : cancel_map,
247 : crate::metrics::CancellationSource::FromRedis,
248 : ));
249 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
250 : // 6h - 1m.
251 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
252 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
253 : loop {
254 : let cancellation_token = CancellationToken::new();
255 : interval.tick().await;
256 :
257 : tokio::spawn(handle_messages(
258 : handler.clone(),
259 : redis.clone(),
260 : cancellation_token.clone(),
261 : ));
262 0 : tokio::spawn(async move {
263 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
264 0 : cancellation_token.cancel();
265 0 : });
266 : }
267 : }
268 :
269 : #[cfg(test)]
270 : mod tests {
271 : use serde_json::json;
272 :
273 : use super::*;
274 : use crate::types::{ProjectId, RoleName};
275 :
276 : #[test]
277 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
278 1 : let project_id: ProjectId = "new_project".into();
279 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
280 1 : let text = json!({
281 1 : "type": "message",
282 1 : "topic": "/allowed_ips_updated",
283 1 : "data": data,
284 1 : "extre_fields": "something"
285 1 : })
286 1 : .to_string();
287 :
288 1 : let result: Notification = serde_json::from_str(&text)?;
289 1 : assert_eq!(
290 1 : result,
291 1 : Notification::AllowedIpsUpdate {
292 1 : allowed_ips_update: AllowedIpsUpdate {
293 1 : project_id: (&project_id).into()
294 1 : }
295 1 : }
296 1 : );
297 :
298 1 : Ok(())
299 1 : }
300 :
301 : #[test]
302 1 : fn parse_password_updated() -> anyhow::Result<()> {
303 1 : let project_id: ProjectId = "new_project".into();
304 1 : let role_name: RoleName = "new_role".into();
305 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
306 1 : let text = json!({
307 1 : "type": "message",
308 1 : "topic": "/password_updated",
309 1 : "data": data,
310 1 : "extre_fields": "something"
311 1 : })
312 1 : .to_string();
313 :
314 1 : let result: Notification = serde_json::from_str(&text)?;
315 1 : assert_eq!(
316 1 : result,
317 1 : Notification::PasswordUpdate {
318 1 : password_update: PasswordUpdate {
319 1 : project_id: (&project_id).into(),
320 1 : role_name: (&role_name).into(),
321 1 : }
322 1 : }
323 1 : );
324 :
325 1 : Ok(())
326 1 : }
327 : #[test]
328 1 : fn parse_cancel_session() -> anyhow::Result<()> {
329 1 : let cancel_key_data = CancelKeyData {
330 1 : backend_pid: 42,
331 1 : cancel_key: 41,
332 1 : };
333 1 : let uuid = uuid::Uuid::new_v4();
334 1 : let msg = Notification::Cancel(CancelSession {
335 1 : cancel_key_data,
336 1 : region_id: None,
337 1 : session_id: uuid,
338 1 : });
339 1 : let text = serde_json::to_string(&msg)?;
340 1 : let result: Notification = serde_json::from_str(&text)?;
341 1 : assert_eq!(msg, result);
342 :
343 1 : let msg = Notification::Cancel(CancelSession {
344 1 : cancel_key_data,
345 1 : region_id: Some("region".to_string()),
346 1 : session_id: uuid,
347 1 : });
348 1 : let text = serde_json::to_string(&msg)?;
349 1 : let result: Notification = serde_json::from_str(&text)?;
350 1 : assert_eq!(msg, result,);
351 :
352 1 : Ok(())
353 1 : }
354 : }
|