LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 48.0 % 171 82
Test Date: 2025-07-16 12:29:03 Functions: 17.5 % 40 7

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use redis::aio::PubSub;
       6              : use serde::Deserialize;
       7              : use tokio_util::sync::CancellationToken;
       8              : 
       9              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      10              : use crate::cache::project_info::ProjectInfoCache;
      11              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      12              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      13              : 
      14              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      15              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      16              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      17              : 
      18            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      19            0 :     let mut conn = client.get_async_pubsub().await?;
      20            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      21            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      22            0 :     Ok(conn)
      23            0 : }
      24              : 
      25            0 : #[derive(Debug, Deserialize)]
      26              : struct NotificationHeader<'a> {
      27              :     topic: &'a str,
      28              : }
      29              : 
      30            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      31              : #[serde(tag = "topic", content = "data")]
      32              : enum Notification {
      33              :     #[serde(
      34              :         rename = "/account_settings_update",
      35              :         alias = "/allowed_vpc_endpoints_updated_for_org",
      36              :         deserialize_with = "deserialize_json_string"
      37              :     )]
      38              :     AccountSettingsUpdate(InvalidateAccount),
      39              : 
      40              :     #[serde(
      41              :         rename = "/endpoint_settings_update",
      42              :         deserialize_with = "deserialize_json_string"
      43              :     )]
      44              :     EndpointSettingsUpdate(InvalidateEndpoint),
      45              : 
      46              :     #[serde(
      47              :         rename = "/project_settings_update",
      48              :         alias = "/allowed_ips_updated",
      49              :         alias = "/block_public_or_vpc_access_updated",
      50              :         alias = "/allowed_vpc_endpoints_updated_for_projects",
      51              :         deserialize_with = "deserialize_json_string"
      52              :     )]
      53              :     ProjectSettingsUpdate(InvalidateProject),
      54              : 
      55              :     #[serde(
      56              :         rename = "/role_setting_update",
      57              :         alias = "/password_updated",
      58              :         deserialize_with = "deserialize_json_string"
      59              :     )]
      60              :     RoleSettingUpdate(InvalidateRole),
      61              : 
      62              :     #[serde(
      63              :         other,
      64              :         deserialize_with = "deserialize_unknown_topic",
      65              :         skip_serializing
      66              :     )]
      67              :     UnknownTopic,
      68              : }
      69              : 
      70            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      71              : #[serde(rename_all = "snake_case")]
      72              : enum InvalidateEndpoint {
      73              :     EndpointId(EndpointIdInt),
      74              :     EndpointIds(Vec<EndpointIdInt>),
      75              : }
      76              : impl std::ops::Deref for InvalidateEndpoint {
      77              :     type Target = [EndpointIdInt];
      78            0 :     fn deref(&self) -> &Self::Target {
      79            0 :         match self {
      80            0 :             Self::EndpointId(id) => std::slice::from_ref(id),
      81            0 :             Self::EndpointIds(ids) => ids,
      82              :         }
      83            0 :     }
      84              : }
      85              : 
      86            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      87              : #[serde(rename_all = "snake_case")]
      88              : enum InvalidateProject {
      89              :     ProjectId(ProjectIdInt),
      90              :     ProjectIds(Vec<ProjectIdInt>),
      91              : }
      92              : impl std::ops::Deref for InvalidateProject {
      93              :     type Target = [ProjectIdInt];
      94            0 :     fn deref(&self) -> &Self::Target {
      95            0 :         match self {
      96            0 :             Self::ProjectId(id) => std::slice::from_ref(id),
      97            0 :             Self::ProjectIds(ids) => ids,
      98              :         }
      99            0 :     }
     100              : }
     101              : 
     102            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     103              : #[serde(rename_all = "snake_case")]
     104              : enum InvalidateAccount {
     105              :     AccountId(AccountIdInt),
     106              :     AccountIds(Vec<AccountIdInt>),
     107              : }
     108              : impl std::ops::Deref for InvalidateAccount {
     109              :     type Target = [AccountIdInt];
     110            0 :     fn deref(&self) -> &Self::Target {
     111            0 :         match self {
     112            0 :             Self::AccountId(id) => std::slice::from_ref(id),
     113            0 :             Self::AccountIds(ids) => ids,
     114              :         }
     115            0 :     }
     116              : }
     117              : 
     118            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     119              : struct InvalidateRole {
     120              :     project_id: ProjectIdInt,
     121              :     role_name: RoleNameInt,
     122              : }
     123              : 
     124            3 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
     125            3 : where
     126            3 :     T: for<'de2> serde::Deserialize<'de2>,
     127            3 :     D: serde::Deserializer<'de>,
     128              : {
     129            3 :     let s = String::deserialize(deserializer)?;
     130            3 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
     131            3 : }
     132              : 
     133              : // https://github.com/serde-rs/serde/issues/1714
     134            1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
     135            1 : where
     136            1 :     D: serde::Deserializer<'de>,
     137              : {
     138            1 :     deserializer.deserialize_any(serde::de::IgnoredAny)?;
     139            1 :     Ok(())
     140            1 : }
     141              : 
     142              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     143              :     cache: Arc<C>,
     144              : }
     145              : 
     146              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     147            0 :     fn clone(&self) -> Self {
     148            0 :         Self {
     149            0 :             cache: self.cache.clone(),
     150            0 :         }
     151            0 :     }
     152              : }
     153              : 
     154              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     155            0 :     pub(crate) fn new(cache: Arc<C>) -> Self {
     156            0 :         Self { cache }
     157            0 :     }
     158              : 
     159              :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     160              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     161              :         let payload: String = msg.get_payload()?;
     162              :         tracing::debug!(?payload, "received a message payload");
     163              : 
     164              :         let msg: Notification = match serde_json::from_str(&payload) {
     165              :             Ok(Notification::UnknownTopic) => {
     166              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     167              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     168              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     169              :                     Err(e) => {
     170              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     171              :                             channel: msg.get_channel_name(),
     172              :                         });
     173              :                         tracing::error!("broken message: {e}");
     174              :                     }
     175              :                 }
     176              :                 return Ok(());
     177              :             }
     178              :             Ok(msg) => msg,
     179              :             Err(e) => {
     180              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     181              :                     channel: msg.get_channel_name(),
     182              :                 });
     183              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     184              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     185              :                     Err(_) => tracing::error!("broken message: {e}"),
     186              :                 }
     187              :                 return Ok(());
     188              :             }
     189              :         };
     190              : 
     191              :         tracing::debug!(?msg, "received a message");
     192              :         match msg {
     193              :             Notification::RoleSettingUpdate { .. }
     194              :             | Notification::EndpointSettingsUpdate { .. }
     195              :             | Notification::ProjectSettingsUpdate { .. }
     196              :             | Notification::AccountSettingsUpdate { .. } => {
     197              :                 invalidate_cache(self.cache.clone(), msg.clone());
     198              : 
     199              :                 let m = &Metrics::get().proxy.redis_events_count;
     200              :                 match msg {
     201              :                     Notification::RoleSettingUpdate { .. } => {
     202              :                         m.inc(RedisEventsCount::InvalidateRole);
     203              :                     }
     204              :                     Notification::EndpointSettingsUpdate { .. } => {
     205              :                         m.inc(RedisEventsCount::InvalidateEndpoint);
     206              :                     }
     207              :                     Notification::ProjectSettingsUpdate { .. } => {
     208              :                         m.inc(RedisEventsCount::InvalidateProject);
     209              :                     }
     210              :                     Notification::AccountSettingsUpdate { .. } => {
     211              :                         m.inc(RedisEventsCount::InvalidateOrg);
     212              :                     }
     213              :                     Notification::UnknownTopic => {}
     214              :                 }
     215              : 
     216              :                 // TODO: add additional metrics for the other event types.
     217              : 
     218              :                 // It might happen that the invalid entry is on the way to be cached.
     219              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     220              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     221              :                 let cache = self.cache.clone();
     222            0 :                 tokio::spawn(async move {
     223            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     224            0 :                     invalidate_cache(cache, msg);
     225            0 :                 });
     226              :             }
     227              : 
     228              :             Notification::UnknownTopic => unreachable!(),
     229              :         }
     230              : 
     231              :         Ok(())
     232              :     }
     233              : }
     234              : 
     235            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     236            0 :     match msg {
     237            0 :         Notification::EndpointSettingsUpdate(ids) => ids
     238            0 :             .iter()
     239            0 :             .for_each(|&id| cache.invalidate_endpoint_access(id)),
     240              : 
     241            0 :         Notification::AccountSettingsUpdate(ids) => ids
     242            0 :             .iter()
     243            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
     244              : 
     245            0 :         Notification::ProjectSettingsUpdate(ids) => ids
     246            0 :             .iter()
     247            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
     248              : 
     249              :         Notification::RoleSettingUpdate(InvalidateRole {
     250            0 :             project_id,
     251            0 :             role_name,
     252            0 :         }) => cache.invalidate_role_secret_for_project(project_id, role_name),
     253              : 
     254            0 :         Notification::UnknownTopic => unreachable!(),
     255              :     }
     256            0 : }
     257              : 
     258            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     259            0 :     handler: MessageHandler<C>,
     260            0 :     redis: ConnectionWithCredentialsProvider,
     261            0 :     cancellation_token: CancellationToken,
     262            0 : ) -> anyhow::Result<()> {
     263              :     loop {
     264            0 :         if cancellation_token.is_cancelled() {
     265            0 :             return Ok(());
     266            0 :         }
     267            0 :         let mut conn = match try_connect(&redis).await {
     268            0 :             Ok(conn) => {
     269            0 :                 handler.cache.increment_active_listeners().await;
     270            0 :                 conn
     271              :             }
     272            0 :             Err(e) => {
     273            0 :                 tracing::error!(
     274            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     275              :                 );
     276            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     277            0 :                 continue;
     278              :             }
     279              :         };
     280            0 :         let mut stream = conn.on_message();
     281            0 :         while let Some(msg) = stream.next().await {
     282            0 :             match handler.handle_message(msg).await {
     283            0 :                 Ok(()) => {}
     284            0 :                 Err(e) => {
     285            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     286            0 :                     break;
     287              :                 }
     288              :             }
     289            0 :             if cancellation_token.is_cancelled() {
     290            0 :                 handler.cache.decrement_active_listeners().await;
     291            0 :                 return Ok(());
     292            0 :             }
     293              :         }
     294            0 :         handler.cache.decrement_active_listeners().await;
     295              :     }
     296            0 : }
     297              : 
     298              : /// Handle console's invalidation messages.
     299              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     300              : pub async fn task_main<C>(
     301              :     redis: ConnectionWithCredentialsProvider,
     302              :     cache: Arc<C>,
     303              : ) -> anyhow::Result<Infallible>
     304              : where
     305              :     C: ProjectInfoCache + Send + Sync + 'static,
     306              : {
     307              :     let handler = MessageHandler::new(cache);
     308              :     // 6h - 1m.
     309              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     310              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     311              :     loop {
     312              :         let cancellation_token = CancellationToken::new();
     313              :         interval.tick().await;
     314              : 
     315              :         tokio::spawn(handle_messages(
     316              :             handler.clone(),
     317              :             redis.clone(),
     318              :             cancellation_token.clone(),
     319              :         ));
     320            0 :         tokio::spawn(async move {
     321            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     322            0 :             cancellation_token.cancel();
     323            0 :         });
     324              :     }
     325              : }
     326              : 
     327              : #[cfg(test)]
     328              : mod tests {
     329              :     use serde_json::json;
     330              : 
     331              :     use super::*;
     332              :     use crate::types::{ProjectId, RoleName};
     333              : 
     334              :     #[test]
     335            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     336            1 :         let project_id: ProjectId = "new_project".into();
     337            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     338            1 :         let text = json!({
     339            1 :             "type": "message",
     340            1 :             "topic": "/allowed_ips_updated",
     341            1 :             "data": data,
     342            1 :             "extre_fields": "something"
     343              :         })
     344            1 :         .to_string();
     345              : 
     346            1 :         let result: Notification = serde_json::from_str(&text)?;
     347            1 :         assert_eq!(
     348              :             result,
     349            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
     350              :         );
     351              : 
     352            1 :         Ok(())
     353            1 :     }
     354              : 
     355              :     #[test]
     356            1 :     fn parse_multiple_projects() -> anyhow::Result<()> {
     357            1 :         let project_id1: ProjectId = "new_project1".into();
     358            1 :         let project_id2: ProjectId = "new_project2".into();
     359            1 :         let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
     360            1 :         let text = json!({
     361            1 :             "type": "message",
     362            1 :             "topic": "/allowed_vpc_endpoints_updated_for_projects",
     363            1 :             "data": data,
     364            1 :             "extre_fields": "something"
     365              :         })
     366            1 :         .to_string();
     367              : 
     368            1 :         let result: Notification = serde_json::from_str(&text)?;
     369            1 :         assert_eq!(
     370              :             result,
     371            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
     372            1 :                 (&project_id1).into(),
     373            1 :                 (&project_id2).into()
     374            1 :             ]))
     375              :         );
     376              : 
     377            1 :         Ok(())
     378            1 :     }
     379              : 
     380              :     #[test]
     381            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     382            1 :         let project_id: ProjectId = "new_project".into();
     383            1 :         let role_name: RoleName = "new_role".into();
     384            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     385            1 :         let text = json!({
     386            1 :             "type": "message",
     387            1 :             "topic": "/password_updated",
     388            1 :             "data": data,
     389            1 :             "extre_fields": "something"
     390              :         })
     391            1 :         .to_string();
     392              : 
     393            1 :         let result: Notification = serde_json::from_str(&text)?;
     394            1 :         assert_eq!(
     395              :             result,
     396            1 :             Notification::RoleSettingUpdate(InvalidateRole {
     397            1 :                 project_id: (&project_id).into(),
     398            1 :                 role_name: (&role_name).into(),
     399            1 :             })
     400              :         );
     401              : 
     402            1 :         Ok(())
     403            1 :     }
     404              : 
     405              :     #[test]
     406            1 :     fn parse_unknown_topic() -> anyhow::Result<()> {
     407            1 :         let with_data = json!({
     408            1 :             "type": "message",
     409            1 :             "topic": "/doesnotexist",
     410            1 :             "data": {
     411            1 :                 "payload": "ignored"
     412              :             },
     413            1 :             "extra_fields": "something"
     414              :         })
     415            1 :         .to_string();
     416            1 :         let result: Notification = serde_json::from_str(&with_data)?;
     417            1 :         assert_eq!(result, Notification::UnknownTopic);
     418              : 
     419            1 :         let without_data = json!({
     420            1 :             "type": "message",
     421            1 :             "topic": "/doesnotexist",
     422            1 :             "extra_fields": "something"
     423              :         })
     424            1 :         .to_string();
     425            1 :         let result: Notification = serde_json::from_str(&without_data)?;
     426            1 :         assert_eq!(result, Notification::UnknownTopic);
     427              : 
     428            1 :         Ok(())
     429            1 :     }
     430              : }
        

Generated by: LCOV version 2.1-beta