Line data Source code
1 : use std::{convert::Infallible, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use pq_proto::CancelKeyData;
5 : use redis::aio::PubSub;
6 : use serde::{Deserialize, Serialize};
7 : use tokio_util::sync::CancellationToken;
8 : use uuid::Uuid;
9 :
10 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
11 : use crate::{
12 : cache::project_info::ProjectInfoCache,
13 : cancellation::{CancelMap, CancellationHandler},
14 : intern::{ProjectIdInt, RoleNameInt},
15 : metrics::{Metrics, RedisErrors, RedisEventsCount},
16 : };
17 :
18 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
19 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
20 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
21 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
22 :
23 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
24 0 : let mut conn = client.get_async_pubsub().await?;
25 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
26 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
27 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
28 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
29 0 : Ok(conn)
30 0 : }
31 :
32 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
33 : #[serde(tag = "topic", content = "data")]
34 : pub(crate) enum Notification {
35 : #[serde(
36 : rename = "/allowed_ips_updated",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AllowedIpsUpdate {
40 : allowed_ips_update: AllowedIpsUpdate,
41 : },
42 : #[serde(
43 : rename = "/password_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : PasswordUpdate { password_update: PasswordUpdate },
47 : #[serde(rename = "/cancel_session")]
48 : Cancel(CancelSession),
49 : }
50 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
51 : pub(crate) struct AllowedIpsUpdate {
52 : project_id: ProjectIdInt,
53 : }
54 3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
55 : pub(crate) struct PasswordUpdate {
56 : project_id: ProjectIdInt,
57 : role_name: RoleNameInt,
58 : }
59 8 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
60 : pub(crate) struct CancelSession {
61 : pub(crate) region_id: Option<String>,
62 : pub(crate) cancel_key_data: CancelKeyData,
63 : pub(crate) session_id: Uuid,
64 : }
65 :
66 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
67 2 : where
68 2 : T: for<'de2> serde::Deserialize<'de2>,
69 2 : D: serde::Deserializer<'de>,
70 2 : {
71 2 : let s = String::deserialize(deserializer)?;
72 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
73 2 : }
74 :
75 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
76 : cache: Arc<C>,
77 : cancellation_handler: Arc<CancellationHandler<()>>,
78 : region_id: String,
79 : }
80 :
81 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
82 0 : fn clone(&self) -> Self {
83 0 : Self {
84 0 : cache: self.cache.clone(),
85 0 : cancellation_handler: self.cancellation_handler.clone(),
86 0 : region_id: self.region_id.clone(),
87 0 : }
88 0 : }
89 : }
90 :
91 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
92 0 : pub(crate) fn new(
93 0 : cache: Arc<C>,
94 0 : cancellation_handler: Arc<CancellationHandler<()>>,
95 0 : region_id: String,
96 0 : ) -> Self {
97 0 : Self {
98 0 : cache,
99 0 : cancellation_handler,
100 0 : region_id,
101 0 : }
102 0 : }
103 0 : pub(crate) async fn increment_active_listeners(&self) {
104 0 : self.cache.increment_active_listeners().await;
105 0 : }
106 0 : pub(crate) async fn decrement_active_listeners(&self) {
107 0 : self.cache.decrement_active_listeners().await;
108 0 : }
109 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
110 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
111 : let payload: String = msg.get_payload()?;
112 : tracing::debug!(?payload, "received a message payload");
113 :
114 : let msg: Notification = match serde_json::from_str(&payload) {
115 : Ok(msg) => msg,
116 : Err(e) => {
117 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
118 : channel: msg.get_channel_name(),
119 : });
120 : tracing::error!("broken message: {e}");
121 : return Ok(());
122 : }
123 : };
124 : tracing::debug!(?msg, "received a message");
125 : match msg {
126 : Notification::Cancel(cancel_session) => {
127 : tracing::Span::current().record(
128 : "session_id",
129 : tracing::field::display(cancel_session.session_id),
130 : );
131 : Metrics::get()
132 : .proxy
133 : .redis_events_count
134 : .inc(RedisEventsCount::CancelSession);
135 : if let Some(cancel_region) = cancel_session.region_id {
136 : // If the message is not for this region, ignore it.
137 : if cancel_region != self.region_id {
138 : return Ok(());
139 : }
140 : }
141 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
142 : match self
143 : .cancellation_handler
144 : .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
145 : .await
146 : {
147 : Ok(()) => {}
148 : Err(e) => {
149 : tracing::error!("failed to cancel session: {e}");
150 : }
151 : }
152 : }
153 : Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
154 : invalidate_cache(self.cache.clone(), msg.clone());
155 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
156 : Metrics::get()
157 : .proxy
158 : .redis_events_count
159 : .inc(RedisEventsCount::AllowedIpsUpdate);
160 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
161 : Metrics::get()
162 : .proxy
163 : .redis_events_count
164 : .inc(RedisEventsCount::PasswordUpdate);
165 : }
166 : // It might happen that the invalid entry is on the way to be cached.
167 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
168 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
169 : let cache = self.cache.clone();
170 0 : tokio::spawn(async move {
171 0 : tokio::time::sleep(INVALIDATION_LAG).await;
172 0 : invalidate_cache(cache, msg);
173 0 : });
174 : }
175 : }
176 :
177 : Ok(())
178 : }
179 : }
180 :
181 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
182 0 : match msg {
183 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
184 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
185 0 : }
186 0 : Notification::PasswordUpdate { password_update } => cache
187 0 : .invalidate_role_secret_for_project(
188 0 : password_update.project_id,
189 0 : password_update.role_name,
190 0 : ),
191 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
192 : }
193 0 : }
194 :
195 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
196 0 : handler: MessageHandler<C>,
197 0 : redis: ConnectionWithCredentialsProvider,
198 0 : cancellation_token: CancellationToken,
199 0 : ) -> anyhow::Result<()> {
200 : loop {
201 0 : if cancellation_token.is_cancelled() {
202 0 : return Ok(());
203 0 : }
204 0 : let mut conn = match try_connect(&redis).await {
205 0 : Ok(conn) => {
206 0 : handler.increment_active_listeners().await;
207 0 : conn
208 : }
209 0 : Err(e) => {
210 0 : tracing::error!(
211 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
212 : );
213 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
214 0 : continue;
215 : }
216 : };
217 0 : let mut stream = conn.on_message();
218 0 : while let Some(msg) = stream.next().await {
219 0 : match handler.handle_message(msg).await {
220 0 : Ok(()) => {}
221 0 : Err(e) => {
222 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
223 0 : break;
224 : }
225 : }
226 0 : if cancellation_token.is_cancelled() {
227 0 : handler.decrement_active_listeners().await;
228 0 : return Ok(());
229 0 : }
230 : }
231 0 : handler.decrement_active_listeners().await;
232 : }
233 0 : }
234 :
235 : /// Handle console's invalidation messages.
236 : #[tracing::instrument(name = "redis_notifications", skip_all)]
237 : pub async fn task_main<C>(
238 : redis: ConnectionWithCredentialsProvider,
239 : cache: Arc<C>,
240 : cancel_map: CancelMap,
241 : region_id: String,
242 : ) -> anyhow::Result<Infallible>
243 : where
244 : C: ProjectInfoCache + Send + Sync + 'static,
245 : {
246 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
247 : cancel_map,
248 : crate::metrics::CancellationSource::FromRedis,
249 : ));
250 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
251 : // 6h - 1m.
252 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
253 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
254 : loop {
255 : let cancellation_token = CancellationToken::new();
256 : interval.tick().await;
257 :
258 : tokio::spawn(handle_messages(
259 : handler.clone(),
260 : redis.clone(),
261 : cancellation_token.clone(),
262 : ));
263 0 : tokio::spawn(async move {
264 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
265 0 : cancellation_token.cancel();
266 0 : });
267 : }
268 : }
269 :
270 : #[cfg(test)]
271 : mod tests {
272 : use crate::{ProjectId, RoleName};
273 :
274 : use super::*;
275 : use serde_json::json;
276 :
277 : #[test]
278 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
279 1 : let project_id: ProjectId = "new_project".into();
280 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
281 1 : let text = json!({
282 1 : "type": "message",
283 1 : "topic": "/allowed_ips_updated",
284 1 : "data": data,
285 1 : "extre_fields": "something"
286 1 : })
287 1 : .to_string();
288 :
289 1 : let result: Notification = serde_json::from_str(&text)?;
290 1 : assert_eq!(
291 1 : result,
292 1 : Notification::AllowedIpsUpdate {
293 1 : allowed_ips_update: AllowedIpsUpdate {
294 1 : project_id: (&project_id).into()
295 1 : }
296 1 : }
297 1 : );
298 :
299 1 : Ok(())
300 1 : }
301 :
302 : #[test]
303 1 : fn parse_password_updated() -> anyhow::Result<()> {
304 1 : let project_id: ProjectId = "new_project".into();
305 1 : let role_name: RoleName = "new_role".into();
306 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
307 1 : let text = json!({
308 1 : "type": "message",
309 1 : "topic": "/password_updated",
310 1 : "data": data,
311 1 : "extre_fields": "something"
312 1 : })
313 1 : .to_string();
314 :
315 1 : let result: Notification = serde_json::from_str(&text)?;
316 1 : assert_eq!(
317 1 : result,
318 1 : Notification::PasswordUpdate {
319 1 : password_update: PasswordUpdate {
320 1 : project_id: (&project_id).into(),
321 1 : role_name: (&role_name).into(),
322 1 : }
323 1 : }
324 1 : );
325 :
326 1 : Ok(())
327 1 : }
328 : #[test]
329 1 : fn parse_cancel_session() -> anyhow::Result<()> {
330 1 : let cancel_key_data = CancelKeyData {
331 1 : backend_pid: 42,
332 1 : cancel_key: 41,
333 1 : };
334 1 : let uuid = uuid::Uuid::new_v4();
335 1 : let msg = Notification::Cancel(CancelSession {
336 1 : cancel_key_data,
337 1 : region_id: None,
338 1 : session_id: uuid,
339 1 : });
340 1 : let text = serde_json::to_string(&msg)?;
341 1 : let result: Notification = serde_json::from_str(&text)?;
342 1 : assert_eq!(msg, result);
343 :
344 1 : let msg = Notification::Cancel(CancelSession {
345 1 : cancel_key_data,
346 1 : region_id: Some("region".to_string()),
347 1 : session_id: uuid,
348 1 : });
349 1 : let text = serde_json::to_string(&msg)?;
350 1 : let result: Notification = serde_json::from_str(&text)?;
351 1 : assert_eq!(msg, result,);
352 :
353 1 : Ok(())
354 1 : }
355 : }
|