LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 4be46b1c0003aa3bbac9ade362c676b419df4c20.info Lines: 46.9 % 160 75
Test Date: 2025-07-22 17:50:06 Functions: 15.6 % 32 5

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use redis::aio::PubSub;
       6              : use serde::Deserialize;
       7              : use tokio_util::sync::CancellationToken;
       8              : 
       9              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      10              : use crate::cache::project_info::ProjectInfoCache;
      11              : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
      12              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      13              : use crate::util::deserialize_json_string;
      14              : 
      15              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      16              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      17              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      18              : 
      19            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      20            0 :     let mut conn = client.get_async_pubsub().await?;
      21            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      22            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      23            0 :     Ok(conn)
      24            0 : }
      25              : 
      26            0 : #[derive(Debug, Deserialize)]
      27              : struct NotificationHeader<'a> {
      28              :     topic: &'a str,
      29              : }
      30              : 
      31            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      32              : #[serde(tag = "topic", content = "data")]
      33              : enum Notification {
      34              :     #[serde(
      35              :         rename = "/account_settings_update",
      36              :         alias = "/allowed_vpc_endpoints_updated_for_org",
      37              :         deserialize_with = "deserialize_json_string"
      38              :     )]
      39              :     AccountSettingsUpdate(InvalidateAccount),
      40              : 
      41              :     #[serde(
      42              :         rename = "/endpoint_settings_update",
      43              :         deserialize_with = "deserialize_json_string"
      44              :     )]
      45              :     EndpointSettingsUpdate(InvalidateEndpoint),
      46              : 
      47              :     #[serde(
      48              :         rename = "/project_settings_update",
      49              :         alias = "/allowed_ips_updated",
      50              :         alias = "/block_public_or_vpc_access_updated",
      51              :         alias = "/allowed_vpc_endpoints_updated_for_projects",
      52              :         deserialize_with = "deserialize_json_string"
      53              :     )]
      54              :     ProjectSettingsUpdate(InvalidateProject),
      55              : 
      56              :     #[serde(
      57              :         rename = "/role_setting_update",
      58              :         alias = "/password_updated",
      59              :         deserialize_with = "deserialize_json_string"
      60              :     )]
      61              :     RoleSettingUpdate(InvalidateRole),
      62              : 
      63              :     #[serde(
      64              :         other,
      65              :         deserialize_with = "deserialize_unknown_topic",
      66              :         skip_serializing
      67              :     )]
      68              :     UnknownTopic,
      69              : }
      70              : 
      71            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      72              : #[serde(rename_all = "snake_case")]
      73              : enum InvalidateEndpoint {
      74              :     EndpointId(EndpointIdInt),
      75              :     EndpointIds(Vec<EndpointIdInt>),
      76              : }
      77              : impl std::ops::Deref for InvalidateEndpoint {
      78              :     type Target = [EndpointIdInt];
      79            0 :     fn deref(&self) -> &Self::Target {
      80            0 :         match self {
      81            0 :             Self::EndpointId(id) => std::slice::from_ref(id),
      82            0 :             Self::EndpointIds(ids) => ids,
      83              :         }
      84            0 :     }
      85              : }
      86              : 
      87            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      88              : #[serde(rename_all = "snake_case")]
      89              : enum InvalidateProject {
      90              :     ProjectId(ProjectIdInt),
      91              :     ProjectIds(Vec<ProjectIdInt>),
      92              : }
      93              : impl std::ops::Deref for InvalidateProject {
      94              :     type Target = [ProjectIdInt];
      95            0 :     fn deref(&self) -> &Self::Target {
      96            0 :         match self {
      97            0 :             Self::ProjectId(id) => std::slice::from_ref(id),
      98            0 :             Self::ProjectIds(ids) => ids,
      99              :         }
     100            0 :     }
     101              : }
     102              : 
     103            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     104              : #[serde(rename_all = "snake_case")]
     105              : enum InvalidateAccount {
     106              :     AccountId(AccountIdInt),
     107              :     AccountIds(Vec<AccountIdInt>),
     108              : }
     109              : impl std::ops::Deref for InvalidateAccount {
     110              :     type Target = [AccountIdInt];
     111            0 :     fn deref(&self) -> &Self::Target {
     112            0 :         match self {
     113            0 :             Self::AccountId(id) => std::slice::from_ref(id),
     114            0 :             Self::AccountIds(ids) => ids,
     115              :         }
     116            0 :     }
     117              : }
     118              : 
     119            0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
     120              : struct InvalidateRole {
     121              :     project_id: ProjectIdInt,
     122              :     role_name: RoleNameInt,
     123              : }
     124              : 
     125              : // https://github.com/serde-rs/serde/issues/1714
     126            1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
     127            1 : where
     128            1 :     D: serde::Deserializer<'de>,
     129              : {
     130            1 :     deserializer.deserialize_any(serde::de::IgnoredAny)?;
     131            1 :     Ok(())
     132            1 : }
     133              : 
     134              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     135              :     cache: Arc<C>,
     136              : }
     137              : 
     138              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     139            0 :     fn clone(&self) -> Self {
     140            0 :         Self {
     141            0 :             cache: self.cache.clone(),
     142            0 :         }
     143            0 :     }
     144              : }
     145              : 
     146              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     147            0 :     pub(crate) fn new(cache: Arc<C>) -> Self {
     148            0 :         Self { cache }
     149            0 :     }
     150              : 
     151              :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     152              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     153              :         let payload: String = msg.get_payload()?;
     154              :         tracing::debug!(?payload, "received a message payload");
     155              : 
     156              :         let msg: Notification = match serde_json::from_str(&payload) {
     157              :             Ok(Notification::UnknownTopic) => {
     158              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     159              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     160              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     161              :                     Err(e) => {
     162              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     163              :                             channel: msg.get_channel_name(),
     164              :                         });
     165              :                         tracing::error!("broken message: {e}");
     166              :                     }
     167              :                 }
     168              :                 return Ok(());
     169              :             }
     170              :             Ok(msg) => msg,
     171              :             Err(e) => {
     172              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     173              :                     channel: msg.get_channel_name(),
     174              :                 });
     175              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     176              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     177              :                     Err(_) => tracing::error!("broken message: {e}"),
     178              :                 }
     179              :                 return Ok(());
     180              :             }
     181              :         };
     182              : 
     183              :         tracing::debug!(?msg, "received a message");
     184              :         match msg {
     185              :             Notification::RoleSettingUpdate { .. }
     186              :             | Notification::EndpointSettingsUpdate { .. }
     187              :             | Notification::ProjectSettingsUpdate { .. }
     188              :             | Notification::AccountSettingsUpdate { .. } => {
     189              :                 invalidate_cache(self.cache.clone(), msg.clone());
     190              : 
     191              :                 let m = &Metrics::get().proxy.redis_events_count;
     192              :                 match msg {
     193              :                     Notification::RoleSettingUpdate { .. } => {
     194              :                         m.inc(RedisEventsCount::InvalidateRole);
     195              :                     }
     196              :                     Notification::EndpointSettingsUpdate { .. } => {
     197              :                         m.inc(RedisEventsCount::InvalidateEndpoint);
     198              :                     }
     199              :                     Notification::ProjectSettingsUpdate { .. } => {
     200              :                         m.inc(RedisEventsCount::InvalidateProject);
     201              :                     }
     202              :                     Notification::AccountSettingsUpdate { .. } => {
     203              :                         m.inc(RedisEventsCount::InvalidateOrg);
     204              :                     }
     205              :                     Notification::UnknownTopic => {}
     206              :                 }
     207              : 
     208              :                 // TODO: add additional metrics for the other event types.
     209              : 
     210              :                 // It might happen that the invalid entry is on the way to be cached.
     211              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     212              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     213              :                 let cache = self.cache.clone();
     214            0 :                 tokio::spawn(async move {
     215            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     216            0 :                     invalidate_cache(cache, msg);
     217            0 :                 });
     218              :             }
     219              : 
     220              :             Notification::UnknownTopic => unreachable!(),
     221              :         }
     222              : 
     223              :         Ok(())
     224              :     }
     225              : }
     226              : 
     227            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     228            0 :     match msg {
     229            0 :         Notification::EndpointSettingsUpdate(ids) => ids
     230            0 :             .iter()
     231            0 :             .for_each(|&id| cache.invalidate_endpoint_access(id)),
     232              : 
     233            0 :         Notification::AccountSettingsUpdate(ids) => ids
     234            0 :             .iter()
     235            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
     236              : 
     237            0 :         Notification::ProjectSettingsUpdate(ids) => ids
     238            0 :             .iter()
     239            0 :             .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
     240              : 
     241              :         Notification::RoleSettingUpdate(InvalidateRole {
     242            0 :             project_id,
     243            0 :             role_name,
     244            0 :         }) => cache.invalidate_role_secret_for_project(project_id, role_name),
     245              : 
     246            0 :         Notification::UnknownTopic => unreachable!(),
     247              :     }
     248            0 : }
     249              : 
     250            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     251            0 :     handler: MessageHandler<C>,
     252            0 :     redis: ConnectionWithCredentialsProvider,
     253            0 :     cancellation_token: CancellationToken,
     254            0 : ) -> anyhow::Result<()> {
     255              :     loop {
     256            0 :         if cancellation_token.is_cancelled() {
     257            0 :             return Ok(());
     258            0 :         }
     259            0 :         let mut conn = match try_connect(&redis).await {
     260            0 :             Ok(conn) => conn,
     261            0 :             Err(e) => {
     262            0 :                 tracing::error!(
     263            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     264              :                 );
     265            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     266            0 :                 continue;
     267              :             }
     268              :         };
     269            0 :         let mut stream = conn.on_message();
     270            0 :         while let Some(msg) = stream.next().await {
     271            0 :             match handler.handle_message(msg).await {
     272            0 :                 Ok(()) => {}
     273            0 :                 Err(e) => {
     274            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     275            0 :                     break;
     276              :                 }
     277              :             }
     278            0 :             if cancellation_token.is_cancelled() {
     279            0 :                 return Ok(());
     280            0 :             }
     281              :         }
     282              :     }
     283            0 : }
     284              : 
     285              : /// Handle console's invalidation messages.
     286              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     287              : pub async fn task_main<C>(
     288              :     redis: ConnectionWithCredentialsProvider,
     289              :     cache: Arc<C>,
     290              : ) -> anyhow::Result<Infallible>
     291              : where
     292              :     C: ProjectInfoCache + Send + Sync + 'static,
     293              : {
     294              :     let handler = MessageHandler::new(cache);
     295              :     // 6h - 1m.
     296              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     297              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     298              :     loop {
     299              :         let cancellation_token = CancellationToken::new();
     300              :         interval.tick().await;
     301              : 
     302              :         tokio::spawn(handle_messages(
     303              :             handler.clone(),
     304              :             redis.clone(),
     305              :             cancellation_token.clone(),
     306              :         ));
     307            0 :         tokio::spawn(async move {
     308            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     309            0 :             cancellation_token.cancel();
     310            0 :         });
     311              :     }
     312              : }
     313              : 
     314              : #[cfg(test)]
     315              : mod tests {
     316              :     use serde_json::json;
     317              : 
     318              :     use super::*;
     319              :     use crate::types::{ProjectId, RoleName};
     320              : 
     321              :     #[test]
     322            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     323            1 :         let project_id: ProjectId = "new_project".into();
     324            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     325            1 :         let text = json!({
     326            1 :             "type": "message",
     327            1 :             "topic": "/allowed_ips_updated",
     328            1 :             "data": data,
     329            1 :             "extre_fields": "something"
     330              :         })
     331            1 :         .to_string();
     332              : 
     333            1 :         let result: Notification = serde_json::from_str(&text)?;
     334            1 :         assert_eq!(
     335              :             result,
     336            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
     337              :         );
     338              : 
     339            1 :         Ok(())
     340            1 :     }
     341              : 
     342              :     #[test]
     343            1 :     fn parse_multiple_projects() -> anyhow::Result<()> {
     344            1 :         let project_id1: ProjectId = "new_project1".into();
     345            1 :         let project_id2: ProjectId = "new_project2".into();
     346            1 :         let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
     347            1 :         let text = json!({
     348            1 :             "type": "message",
     349            1 :             "topic": "/allowed_vpc_endpoints_updated_for_projects",
     350            1 :             "data": data,
     351            1 :             "extre_fields": "something"
     352              :         })
     353            1 :         .to_string();
     354              : 
     355            1 :         let result: Notification = serde_json::from_str(&text)?;
     356            1 :         assert_eq!(
     357              :             result,
     358            1 :             Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
     359            1 :                 (&project_id1).into(),
     360            1 :                 (&project_id2).into()
     361            1 :             ]))
     362              :         );
     363              : 
     364            1 :         Ok(())
     365            1 :     }
     366              : 
     367              :     #[test]
     368            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     369            1 :         let project_id: ProjectId = "new_project".into();
     370            1 :         let role_name: RoleName = "new_role".into();
     371            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     372            1 :         let text = json!({
     373            1 :             "type": "message",
     374            1 :             "topic": "/password_updated",
     375            1 :             "data": data,
     376            1 :             "extre_fields": "something"
     377              :         })
     378            1 :         .to_string();
     379              : 
     380            1 :         let result: Notification = serde_json::from_str(&text)?;
     381            1 :         assert_eq!(
     382              :             result,
     383            1 :             Notification::RoleSettingUpdate(InvalidateRole {
     384            1 :                 project_id: (&project_id).into(),
     385            1 :                 role_name: (&role_name).into(),
     386            1 :             })
     387              :         );
     388              : 
     389            1 :         Ok(())
     390            1 :     }
     391              : 
     392              :     #[test]
     393            1 :     fn parse_unknown_topic() -> anyhow::Result<()> {
     394            1 :         let with_data = json!({
     395            1 :             "type": "message",
     396            1 :             "topic": "/doesnotexist",
     397            1 :             "data": {
     398            1 :                 "payload": "ignored"
     399              :             },
     400            1 :             "extra_fields": "something"
     401              :         })
     402            1 :         .to_string();
     403            1 :         let result: Notification = serde_json::from_str(&with_data)?;
     404            1 :         assert_eq!(result, Notification::UnknownTopic);
     405              : 
     406            1 :         let without_data = json!({
     407            1 :             "type": "message",
     408            1 :             "topic": "/doesnotexist",
     409            1 :             "extra_fields": "something"
     410              :         })
     411            1 :         .to_string();
     412            1 :         let result: Notification = serde_json::from_str(&without_data)?;
     413            1 :         assert_eq!(result, Notification::UnknownTopic);
     414              : 
     415            1 :         Ok(())
     416            1 :     }
     417              : }
        

Generated by: LCOV version 2.1-beta