LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 553e39c2773e5840c720c90d86e56f89a4330d43.info Lines: 51.6 % 192 99
Test Date: 2025-06-13 20:01:21 Functions: 17.1 % 70 12

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use redis::aio::PubSub;
       6              : use serde::Deserialize;
       7              : use tokio_util::sync::CancellationToken;
       8              : 
       9              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      10              : use crate::cache::project_info::ProjectInfoCache;
      11              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      12              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      13              : 
      14              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      15              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      16              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      17              : 
      18            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      19            0 :     let mut conn = client.get_async_pubsub().await?;
      20            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      21            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      22            0 :     Ok(conn)
      23            0 : }
      24              : 
      25            0 : #[derive(Debug, Deserialize)]
      26              : struct NotificationHeader<'a> {
      27              :     topic: &'a str,
      28              : }
      29              : 
      30            8 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      31              : #[serde(tag = "topic", content = "data")]
      32              : enum Notification {
      33              :     #[serde(
      34              :         rename = "/account_settings_update",
      35              :         alias = "/allowed_vpc_endpoints_updated_for_org",
      36              :         deserialize_with = "deserialize_json_string"
      37              :     )]
      38              :     AccountSettingsUpdate(InvalidateAccount),
      39              : 
      40              :     #[serde(
      41              :         rename = "/endpoint_settings_update",
      42              :         deserialize_with = "deserialize_json_string"
      43              :     )]
      44              :     EndpointSettingsUpdate(InvalidateEndpoint),
      45              : 
      46              :     #[serde(
      47              :         rename = "/project_settings_update",
      48              :         alias = "/allowed_ips_updated",
      49              :         alias = "/block_public_or_vpc_access_updated",
      50              :         alias = "/allowed_vpc_endpoints_updated_for_projects",
      51              :         deserialize_with = "deserialize_json_string"
      52              :     )]
      53              :     ProjectSettingsUpdate(InvalidateProject),
      54              : 
      55              :     #[serde(
      56              :         rename = "/role_setting_update",
      57              :         alias = "/password_updated",
      58              :         deserialize_with = "deserialize_json_string"
      59              :     )]
      60              :     RoleSettingUpdate(InvalidateRole),
      61              : 
      62              :     #[serde(
      63              :         other,
      64              :         deserialize_with = "deserialize_unknown_topic",
      65              :         skip_serializing
      66              :     )]
      67              :     UnknownTopic,
      68              : }
      69              : 
      70            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      71              : #[serde(rename_all = "snake_case")]
      72              : enum InvalidateEndpoint {
      73              :     EndpointId(EndpointIdInt),
      74              :     EndpointIds(Vec<EndpointIdInt>),
      75              : }
      76              : impl std::ops::Deref for InvalidateEndpoint {
      77              :     type Target = [EndpointIdInt];
      78            0 :     fn deref(&self) -> &Self::Target {
      79            0 :         match self {
      80            0 :             Self::EndpointId(id) => std::slice::from_ref(id),
      81            0 :             Self::EndpointIds(ids) => ids,
      82              :         }
      83            0 :     }
      84              : }
      85              : 
      86            2 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      87              : #[serde(rename_all = "snake_case")]
      88              : enum InvalidateProject {
      89              :     ProjectId(ProjectIdInt),
      90              :     ProjectIds(Vec<ProjectIdInt>),
      91              : }
      92              : impl std::ops::Deref for InvalidateProject {
      93              :     type Target = [ProjectIdInt];
      94            0 :     fn deref(&self) -> &Self::Target {
      95            0 :         match self {
      96            0 :             Self::ProjectId(id) => std::slice::from_ref(id),
      97            0 :             Self::ProjectIds(ids) => ids,
      98              :         }
      99            0 :     }
     100              : }
     101              : 
     102            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     103              : #[serde(rename_all = "snake_case")]
     104              : enum InvalidateAccount {
     105              :     AccountId(AccountIdInt),
     106              :     AccountIds(Vec<AccountIdInt>),
     107              : }
     108              : impl std::ops::Deref for InvalidateAccount {
     109              :     type Target = [AccountIdInt];
     110            0 :     fn deref(&self) -> &Self::Target {
     111            0 :         match self {
     112            0 :             Self::AccountId(id) => std::slice::from_ref(id),
     113            0 :             Self::AccountIds(ids) => ids,
     114              :         }
     115            0 :     }
     116              : }
     117              : 
     118            2 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     119              : struct InvalidateRole {
     120              :     project_id: ProjectIdInt,
     121              :     role_name: RoleNameInt,
     122              : }
     123              : 
     124            3 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
     125            3 : where
     126            3 :     T: for<'de2> serde::Deserialize<'de2>,
     127            3 :     D: serde::Deserializer<'de>,
     128            3 : {
     129            3 :     let s = String::deserialize(deserializer)?;
     130            3 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
     131            3 : }
     132              : 
     133              : // https://github.com/serde-rs/serde/issues/1714
     134            1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
     135            1 : where
     136            1 :     D: serde::Deserializer<'de>,
     137            1 : {
     138            1 :     deserializer.deserialize_any(serde::de::IgnoredAny)?;
     139            1 :     Ok(())
     140            1 : }
     141              : 
     142              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     143              :     cache: Arc<C>,
     144              :     region_id: String,
     145              : }
     146              : 
     147              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     148            0 :     fn clone(&self) -> Self {
     149            0 :         Self {
     150            0 :             cache: self.cache.clone(),
     151            0 :             region_id: self.region_id.clone(),
     152            0 :         }
     153            0 :     }
     154              : }
     155              : 
     156              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     157            0 :     pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
     158            0 :         Self { cache, region_id }
     159            0 :     }
     160              : 
     161            0 :     pub(crate) async fn increment_active_listeners(&self) {
     162            0 :         self.cache.increment_active_listeners().await;
     163            0 :     }
     164              : 
     165            0 :     pub(crate) async fn decrement_active_listeners(&self) {
     166            0 :         self.cache.decrement_active_listeners().await;
     167            0 :     }
     168              : 
     169              :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     170              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     171              :         let payload: String = msg.get_payload()?;
     172              :         tracing::debug!(?payload, "received a message payload");
     173              : 
     174              :         let msg: Notification = match serde_json::from_str(&payload) {
     175              :             Ok(Notification::UnknownTopic) => {
     176              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     177              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     178              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     179              :                     Err(e) => {
     180              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     181              :                             channel: msg.get_channel_name(),
     182              :                         });
     183              :                         tracing::error!("broken message: {e}");
     184              :                     }
     185              :                 }
     186              :                 return Ok(());
     187              :             }
     188              :             Ok(msg) => msg,
     189              :             Err(e) => {
     190              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     191              :                     channel: msg.get_channel_name(),
     192              :                 });
     193              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     194              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     195              :                     Err(_) => tracing::error!("broken message: {e}"),
     196              :                 }
     197              :                 return Ok(());
     198              :             }
     199              :         };
     200              : 
     201              :         tracing::debug!(?msg, "received a message");
     202              :         match msg {
     203              :             Notification::RoleSettingUpdate { .. }
     204              :             | Notification::EndpointSettingsUpdate { .. }
     205              :             | Notification::ProjectSettingsUpdate { .. }
     206              :             | Notification::AccountSettingsUpdate { .. } => {
     207              :                 invalidate_cache(self.cache.clone(), msg.clone());
     208              : 
     209              :                 let m = &Metrics::get().proxy.redis_events_count;
     210              :                 match msg {
     211              :                     Notification::RoleSettingUpdate { .. } => {
     212              :                         m.inc(RedisEventsCount::InvalidateRole);
     213              :                     }
     214              :                     Notification::EndpointSettingsUpdate { .. } => {
     215              :                         m.inc(RedisEventsCount::InvalidateEndpoint);
     216              :                     }
     217              :                     Notification::ProjectSettingsUpdate { .. } => {
     218              :                         m.inc(RedisEventsCount::InvalidateProject);
     219              :                     }
     220              :                     Notification::AccountSettingsUpdate { .. } => {
     221              :                         m.inc(RedisEventsCount::InvalidateOrg);
     222              :                     }
     223              :                     Notification::UnknownTopic => {}
     224              :                 }
     225              : 
     226              :                 // TODO: add additional metrics for the other event types.
     227              : 
     228              :                 // It might happen that the invalid entry is on the way to be cached.
     229              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     230              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     231              :                 let cache = self.cache.clone();
     232            0 :                 tokio::spawn(async move {
     233            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     234            0 :                     invalidate_cache(cache, msg);
     235            0 :                 });
     236              :             }
     237              : 
     238              :             Notification::UnknownTopic => unreachable!(),
     239              :         }
     240              : 
     241              :         Ok(())
     242              :     }
     243              : }
     244              : 
     245            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     246            0 :     match msg {
     247            0 :         Notification::EndpointSettingsUpdate(ids) => ids
     248            0 :             .iter()
     249            0 :             .for_each(|&id| cache.invalidate_endpoint_access(id)),
     250              : 
     251            0 :         Notification::AccountSettingsUpdate(ids) => ids
     252            0 :             .iter()
     253            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
     254              : 
     255            0 :         Notification::ProjectSettingsUpdate(ids) => ids
     256            0 :             .iter()
     257            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
     258              : 
     259              :         Notification::RoleSettingUpdate(InvalidateRole {
     260            0 :             project_id,
     261            0 :             role_name,
     262            0 :         }) => cache.invalidate_role_secret_for_project(project_id, role_name),
     263              : 
     264            0 :         Notification::UnknownTopic => unreachable!(),
     265              :     }
     266            0 : }
     267              : 
     268            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     269            0 :     handler: MessageHandler<C>,
     270            0 :     redis: ConnectionWithCredentialsProvider,
     271            0 :     cancellation_token: CancellationToken,
     272            0 : ) -> anyhow::Result<()> {
     273              :     loop {
     274            0 :         if cancellation_token.is_cancelled() {
     275            0 :             return Ok(());
     276            0 :         }
     277            0 :         let mut conn = match try_connect(&redis).await {
     278            0 :             Ok(conn) => {
     279            0 :                 handler.increment_active_listeners().await;
     280            0 :                 conn
     281              :             }
     282            0 :             Err(e) => {
     283            0 :                 tracing::error!(
     284            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     285              :                 );
     286            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     287            0 :                 continue;
     288              :             }
     289              :         };
     290            0 :         let mut stream = conn.on_message();
     291            0 :         while let Some(msg) = stream.next().await {
     292            0 :             match handler.handle_message(msg).await {
     293            0 :                 Ok(()) => {}
     294            0 :                 Err(e) => {
     295            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     296            0 :                     break;
     297              :                 }
     298              :             }
     299            0 :             if cancellation_token.is_cancelled() {
     300            0 :                 handler.decrement_active_listeners().await;
     301            0 :                 return Ok(());
     302            0 :             }
     303              :         }
     304            0 :         handler.decrement_active_listeners().await;
     305              :     }
     306            0 : }
     307              : 
     308              : /// Handle console's invalidation messages.
     309              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     310              : pub async fn task_main<C>(
     311              :     redis: ConnectionWithCredentialsProvider,
     312              :     cache: Arc<C>,
     313              :     region_id: String,
     314              : ) -> anyhow::Result<Infallible>
     315              : where
     316              :     C: ProjectInfoCache + Send + Sync + 'static,
     317              : {
     318              :     let handler = MessageHandler::new(cache, region_id);
     319              :     // 6h - 1m.
     320              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     321              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     322              :     loop {
     323              :         let cancellation_token = CancellationToken::new();
     324              :         interval.tick().await;
     325              : 
     326              :         tokio::spawn(handle_messages(
     327              :             handler.clone(),
     328              :             redis.clone(),
     329              :             cancellation_token.clone(),
     330              :         ));
     331            0 :         tokio::spawn(async move {
     332            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     333            0 :             cancellation_token.cancel();
     334            0 :         });
     335              :     }
     336              : }
     337              : 
     338              : #[cfg(test)]
     339              : mod tests {
     340              :     use serde_json::json;
     341              : 
     342              :     use super::*;
     343              :     use crate::types::{ProjectId, RoleName};
     344              : 
     345              :     #[test]
     346            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     347            1 :         let project_id: ProjectId = "new_project".into();
     348            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     349            1 :         let text = json!({
     350            1 :             "type": "message",
     351            1 :             "topic": "/allowed_ips_updated",
     352            1 :             "data": data,
     353            1 :             "extre_fields": "something"
     354            1 :         })
     355            1 :         .to_string();
     356              : 
     357            1 :         let result: Notification = serde_json::from_str(&text)?;
     358            1 :         assert_eq!(
     359            1 :             result,
     360            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
     361            1 :         );
     362              : 
     363            1 :         Ok(())
     364            1 :     }
     365              : 
     366              :     #[test]
     367            1 :     fn parse_multiple_projects() -> anyhow::Result<()> {
     368            1 :         let project_id1: ProjectId = "new_project1".into();
     369            1 :         let project_id2: ProjectId = "new_project2".into();
     370            1 :         let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
     371            1 :         let text = json!({
     372            1 :             "type": "message",
     373            1 :             "topic": "/allowed_vpc_endpoints_updated_for_projects",
     374            1 :             "data": data,
     375            1 :             "extre_fields": "something"
     376            1 :         })
     377            1 :         .to_string();
     378              : 
     379            1 :         let result: Notification = serde_json::from_str(&text)?;
     380            1 :         assert_eq!(
     381            1 :             result,
     382            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
     383            1 :                 (&project_id1).into(),
     384            1 :                 (&project_id2).into()
     385            1 :             ]))
     386            1 :         );
     387              : 
     388            1 :         Ok(())
     389            1 :     }
     390              : 
     391              :     #[test]
     392            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     393            1 :         let project_id: ProjectId = "new_project".into();
     394            1 :         let role_name: RoleName = "new_role".into();
     395            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     396            1 :         let text = json!({
     397            1 :             "type": "message",
     398            1 :             "topic": "/password_updated",
     399            1 :             "data": data,
     400            1 :             "extre_fields": "something"
     401            1 :         })
     402            1 :         .to_string();
     403              : 
     404            1 :         let result: Notification = serde_json::from_str(&text)?;
     405            1 :         assert_eq!(
     406            1 :             result,
     407            1 :             Notification::RoleSettingUpdate(InvalidateRole {
     408            1 :                 project_id: (&project_id).into(),
     409            1 :                 role_name: (&role_name).into(),
     410            1 :             })
     411            1 :         );
     412              : 
     413            1 :         Ok(())
     414            1 :     }
     415              : 
     416              :     #[test]
     417            1 :     fn parse_unknown_topic() -> anyhow::Result<()> {
     418            1 :         let with_data = json!({
     419            1 :             "type": "message",
     420            1 :             "topic": "/doesnotexist",
     421            1 :             "data": {
     422            1 :                 "payload": "ignored"
     423            1 :             },
     424            1 :             "extra_fields": "something"
     425            1 :         })
     426            1 :         .to_string();
     427            1 :         let result: Notification = serde_json::from_str(&with_data)?;
     428            1 :         assert_eq!(result, Notification::UnknownTopic);
     429              : 
     430            1 :         let without_data = json!({
     431            1 :             "type": "message",
     432            1 :             "topic": "/doesnotexist",
     433            1 :             "extra_fields": "something"
     434            1 :         })
     435            1 :         .to_string();
     436            1 :         let result: Notification = serde_json::from_str(&without_data)?;
     437            1 :         assert_eq!(result, Notification::UnknownTopic);
     438              : 
     439            1 :         Ok(())
     440            1 :     }
     441              : }
        

Generated by: LCOV version 2.1-beta