LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 48.8 % 172 84
Test Date: 2025-02-20 13:11:02 Functions: 9.1 % 121 11

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use pq_proto::CancelKeyData;
       6              : use redis::aio::PubSub;
       7              : use serde::{Deserialize, Serialize};
       8              : use tokio_util::sync::CancellationToken;
       9              : use uuid::Uuid;
      10              : 
      11              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      12              : use crate::cache::project_info::ProjectInfoCache;
      13              : use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
      14              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      15              : 
      16              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      17              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      18              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      19              : 
      20            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      21            0 :     let mut conn = client.get_async_pubsub().await?;
      22            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      23            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      24            0 :     Ok(conn)
      25            0 : }
      26              : 
      27            0 : #[derive(Debug, Deserialize)]
      28              : struct NotificationHeader<'a> {
      29              :     topic: &'a str,
      30              : }
      31              : 
      32            6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      33              : #[serde(tag = "topic", content = "data")]
      34              : pub(crate) enum Notification {
      35              :     #[serde(
      36              :         rename = "/allowed_ips_updated",
      37              :         deserialize_with = "deserialize_json_string"
      38              :     )]
      39              :     AllowedIpsUpdate {
      40              :         allowed_ips_update: AllowedIpsUpdate,
      41              :     },
      42              :     #[serde(
      43              :         rename = "/block_public_or_vpc_access_updated",
      44              :         deserialize_with = "deserialize_json_string"
      45              :     )]
      46              :     BlockPublicOrVpcAccessUpdated {
      47              :         block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
      48              :     },
      49              :     #[serde(
      50              :         rename = "/allowed_vpc_endpoints_updated_for_org",
      51              :         deserialize_with = "deserialize_json_string"
      52              :     )]
      53              :     AllowedVpcEndpointsUpdatedForOrg {
      54              :         allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
      55              :     },
      56              :     #[serde(
      57              :         rename = "/allowed_vpc_endpoints_updated_for_projects",
      58              :         deserialize_with = "deserialize_json_string"
      59              :     )]
      60              :     AllowedVpcEndpointsUpdatedForProjects {
      61              :         allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
      62              :     },
      63              :     #[serde(
      64              :         rename = "/password_updated",
      65              :         deserialize_with = "deserialize_json_string"
      66              :     )]
      67              :     PasswordUpdate { password_update: PasswordUpdate },
      68              : 
      69              :     #[serde(
      70              :         other,
      71              :         deserialize_with = "deserialize_unknown_topic",
      72              :         skip_serializing
      73              :     )]
      74              :     UnknownTopic,
      75              : }
      76              : 
      77            1 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      78              : pub(crate) struct AllowedIpsUpdate {
      79              :     project_id: ProjectIdInt,
      80              : }
      81              : 
      82            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      83              : pub(crate) struct BlockPublicOrVpcAccessUpdated {
      84              :     project_id: ProjectIdInt,
      85              : }
      86              : 
      87            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      88              : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
      89              :     account_id: AccountIdInt,
      90              : }
      91              : 
      92            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      93              : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
      94              :     project_ids: Vec<ProjectIdInt>,
      95              : }
      96              : 
      97            2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      98              : pub(crate) struct PasswordUpdate {
      99              :     project_id: ProjectIdInt,
     100              :     role_name: RoleNameInt,
     101              : }
     102              : 
     103            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
     104              : pub(crate) struct CancelSession {
     105              :     pub(crate) region_id: Option<String>,
     106              :     pub(crate) cancel_key_data: CancelKeyData,
     107              :     pub(crate) session_id: Uuid,
     108              :     pub(crate) peer_addr: Option<std::net::IpAddr>,
     109              : }
     110              : 
     111            2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
     112            2 : where
     113            2 :     T: for<'de2> serde::Deserialize<'de2>,
     114            2 :     D: serde::Deserializer<'de>,
     115            2 : {
     116            2 :     let s = String::deserialize(deserializer)?;
     117            2 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
     118            2 : }
     119              : 
     120              : // https://github.com/serde-rs/serde/issues/1714
     121            1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
     122            1 : where
     123            1 :     D: serde::Deserializer<'de>,
     124            1 : {
     125            1 :     deserializer.deserialize_any(serde::de::IgnoredAny)?;
     126            1 :     Ok(())
     127            1 : }
     128              : 
     129              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     130              :     cache: Arc<C>,
     131              :     region_id: String,
     132              : }
     133              : 
     134              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     135            0 :     fn clone(&self) -> Self {
     136            0 :         Self {
     137            0 :             cache: self.cache.clone(),
     138            0 :             region_id: self.region_id.clone(),
     139            0 :         }
     140            0 :     }
     141              : }
     142              : 
     143              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     144            0 :     pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
     145            0 :         Self { cache, region_id }
     146            0 :     }
     147              : 
     148            0 :     pub(crate) async fn increment_active_listeners(&self) {
     149            0 :         self.cache.increment_active_listeners().await;
     150            0 :     }
     151              : 
     152            0 :     pub(crate) async fn decrement_active_listeners(&self) {
     153            0 :         self.cache.decrement_active_listeners().await;
     154            0 :     }
     155              : 
     156              :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     157              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     158              :         let payload: String = msg.get_payload()?;
     159              :         tracing::debug!(?payload, "received a message payload");
     160              : 
     161              :         let msg: Notification = match serde_json::from_str(&payload) {
     162              :             Ok(Notification::UnknownTopic) => {
     163              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     164              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     165              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     166              :                     Err(e) => {
     167              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     168              :                             channel: msg.get_channel_name(),
     169              :                         });
     170              :                         tracing::error!("broken message: {e}");
     171              :                     }
     172              :                 }
     173              :                 return Ok(());
     174              :             }
     175              :             Ok(msg) => msg,
     176              :             Err(e) => {
     177              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     178              :                     channel: msg.get_channel_name(),
     179              :                 });
     180              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     181              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     182              :                     Err(_) => tracing::error!("broken message: {e}"),
     183              :                 }
     184              :                 return Ok(());
     185              :             }
     186              :         };
     187              : 
     188              :         tracing::debug!(?msg, "received a message");
     189              :         match msg {
     190              :             Notification::AllowedIpsUpdate { .. }
     191              :             | Notification::PasswordUpdate { .. }
     192              :             | Notification::BlockPublicOrVpcAccessUpdated { .. }
     193              :             | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
     194              :             | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
     195              :                 invalidate_cache(self.cache.clone(), msg.clone());
     196              :                 if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
     197              :                     Metrics::get()
     198              :                         .proxy
     199              :                         .redis_events_count
     200              :                         .inc(RedisEventsCount::AllowedIpsUpdate);
     201              :                 } else if matches!(msg, Notification::PasswordUpdate { .. }) {
     202              :                     Metrics::get()
     203              :                         .proxy
     204              :                         .redis_events_count
     205              :                         .inc(RedisEventsCount::PasswordUpdate);
     206              :                 } else if matches!(
     207              :                     msg,
     208              :                     Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
     209              :                 ) {
     210              :                     Metrics::get()
     211              :                         .proxy
     212              :                         .redis_events_count
     213              :                         .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
     214              :                 } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
     215              :                     Metrics::get()
     216              :                         .proxy
     217              :                         .redis_events_count
     218              :                         .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
     219              :                 } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
     220              :                     Metrics::get()
     221              :                         .proxy
     222              :                         .redis_events_count
     223              :                         .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
     224              :                 }
     225              :                 // TODO: add additional metrics for the other event types.
     226              : 
     227              :                 // It might happen that the invalid entry is on the way to be cached.
     228              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     229              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     230              :                 let cache = self.cache.clone();
     231            0 :                 tokio::spawn(async move {
     232            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     233            0 :                     invalidate_cache(cache, msg);
     234            0 :                 });
     235              :             }
     236              : 
     237              :             Notification::UnknownTopic => unreachable!(),
     238              :         }
     239              : 
     240              :         Ok(())
     241              :     }
     242              : }
     243              : 
     244            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     245            0 :     match msg {
     246            0 :         Notification::AllowedIpsUpdate { allowed_ips_update } => {
     247            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
     248            0 :         }
     249              :         Notification::BlockPublicOrVpcAccessUpdated {
     250            0 :             block_public_or_vpc_access_updated,
     251            0 :         } => cache.invalidate_block_public_or_vpc_access_for_project(
     252            0 :             block_public_or_vpc_access_updated.project_id,
     253            0 :         ),
     254              :         Notification::AllowedVpcEndpointsUpdatedForOrg {
     255            0 :             allowed_vpc_endpoints_updated_for_org,
     256            0 :         } => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
     257            0 :             allowed_vpc_endpoints_updated_for_org.account_id,
     258            0 :         ),
     259              :         Notification::AllowedVpcEndpointsUpdatedForProjects {
     260            0 :             allowed_vpc_endpoints_updated_for_projects,
     261            0 :         } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
     262            0 :             allowed_vpc_endpoints_updated_for_projects.project_ids,
     263            0 :         ),
     264            0 :         Notification::PasswordUpdate { password_update } => cache
     265            0 :             .invalidate_role_secret_for_project(
     266            0 :                 password_update.project_id,
     267            0 :                 password_update.role_name,
     268            0 :             ),
     269            0 :         Notification::UnknownTopic => unreachable!(),
     270              :     }
     271            0 : }
     272              : 
     273            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     274            0 :     handler: MessageHandler<C>,
     275            0 :     redis: ConnectionWithCredentialsProvider,
     276            0 :     cancellation_token: CancellationToken,
     277            0 : ) -> anyhow::Result<()> {
     278              :     loop {
     279            0 :         if cancellation_token.is_cancelled() {
     280            0 :             return Ok(());
     281            0 :         }
     282            0 :         let mut conn = match try_connect(&redis).await {
     283            0 :             Ok(conn) => {
     284            0 :                 handler.increment_active_listeners().await;
     285            0 :                 conn
     286              :             }
     287            0 :             Err(e) => {
     288            0 :                 tracing::error!(
     289            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     290              :                 );
     291            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     292            0 :                 continue;
     293              :             }
     294              :         };
     295            0 :         let mut stream = conn.on_message();
     296            0 :         while let Some(msg) = stream.next().await {
     297            0 :             match handler.handle_message(msg).await {
     298            0 :                 Ok(()) => {}
     299            0 :                 Err(e) => {
     300            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     301            0 :                     break;
     302              :                 }
     303              :             }
     304            0 :             if cancellation_token.is_cancelled() {
     305            0 :                 handler.decrement_active_listeners().await;
     306            0 :                 return Ok(());
     307            0 :             }
     308              :         }
     309            0 :         handler.decrement_active_listeners().await;
     310              :     }
     311            0 : }
     312              : 
     313              : /// Handle console's invalidation messages.
     314              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     315              : pub async fn task_main<C>(
     316              :     redis: ConnectionWithCredentialsProvider,
     317              :     cache: Arc<C>,
     318              :     region_id: String,
     319              : ) -> anyhow::Result<Infallible>
     320              : where
     321              :     C: ProjectInfoCache + Send + Sync + 'static,
     322              : {
     323              :     let handler = MessageHandler::new(cache, region_id);
     324              :     // 6h - 1m.
     325              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     326              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     327              :     loop {
     328              :         let cancellation_token = CancellationToken::new();
     329              :         interval.tick().await;
     330              : 
     331              :         tokio::spawn(handle_messages(
     332              :             handler.clone(),
     333              :             redis.clone(),
     334              :             cancellation_token.clone(),
     335              :         ));
     336            0 :         tokio::spawn(async move {
     337            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     338            0 :             cancellation_token.cancel();
     339            0 :         });
     340              :     }
     341              : }
     342              : 
     343              : #[cfg(test)]
     344              : mod tests {
     345              :     use serde_json::json;
     346              : 
     347              :     use super::*;
     348              :     use crate::types::{ProjectId, RoleName};
     349              : 
     350              :     #[test]
     351            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     352            1 :         let project_id: ProjectId = "new_project".into();
     353            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     354            1 :         let text = json!({
     355            1 :             "type": "message",
     356            1 :             "topic": "/allowed_ips_updated",
     357            1 :             "data": data,
     358            1 :             "extre_fields": "something"
     359            1 :         })
     360            1 :         .to_string();
     361              : 
     362            1 :         let result: Notification = serde_json::from_str(&text)?;
     363            1 :         assert_eq!(
     364            1 :             result,
     365            1 :             Notification::AllowedIpsUpdate {
     366            1 :                 allowed_ips_update: AllowedIpsUpdate {
     367            1 :                     project_id: (&project_id).into()
     368            1 :                 }
     369            1 :             }
     370            1 :         );
     371              : 
     372            1 :         Ok(())
     373            1 :     }
     374              : 
     375              :     #[test]
     376            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     377            1 :         let project_id: ProjectId = "new_project".into();
     378            1 :         let role_name: RoleName = "new_role".into();
     379            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     380            1 :         let text = json!({
     381            1 :             "type": "message",
     382            1 :             "topic": "/password_updated",
     383            1 :             "data": data,
     384            1 :             "extre_fields": "something"
     385            1 :         })
     386            1 :         .to_string();
     387              : 
     388            1 :         let result: Notification = serde_json::from_str(&text)?;
     389            1 :         assert_eq!(
     390            1 :             result,
     391            1 :             Notification::PasswordUpdate {
     392            1 :                 password_update: PasswordUpdate {
     393            1 :                     project_id: (&project_id).into(),
     394            1 :                     role_name: (&role_name).into(),
     395            1 :                 }
     396            1 :             }
     397            1 :         );
     398              : 
     399            1 :         Ok(())
     400            1 :     }
     401              : 
     402              :     #[test]
     403            1 :     fn parse_unknown_topic() -> anyhow::Result<()> {
     404            1 :         let with_data = json!({
     405            1 :             "type": "message",
     406            1 :             "topic": "/doesnotexist",
     407            1 :             "data": {
     408            1 :                 "payload": "ignored"
     409            1 :             },
     410            1 :             "extra_fields": "something"
     411            1 :         })
     412            1 :         .to_string();
     413            1 :         let result: Notification = serde_json::from_str(&with_data)?;
     414            1 :         assert_eq!(result, Notification::UnknownTopic);
     415              : 
     416            1 :         let without_data = json!({
     417            1 :             "type": "message",
     418            1 :             "topic": "/doesnotexist",
     419            1 :             "extra_fields": "something"
     420            1 :         })
     421            1 :         .to_string();
     422            1 :         let result: Notification = serde_json::from_str(&without_data)?;
     423            1 :         assert_eq!(result, Notification::UnknownTopic);
     424              : 
     425            1 :         Ok(())
     426            1 :     }
     427              : }
        

Generated by: LCOV version 2.1-beta