LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 727bdccc1d7d53837da843959afb612f56da4e79.info Lines: 49.7 % 169 84
Test Date: 2025-01-30 15:18:43 Functions: 9.1 % 121 11

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use pq_proto::CancelKeyData;
       6              : use redis::aio::PubSub;
       7              : use serde::{Deserialize, Serialize};
       8              : use tokio_util::sync::CancellationToken;
       9              : use uuid::Uuid;
      10              : 
      11              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      12              : use crate::cache::project_info::ProjectInfoCache;
      13              : use crate::intern::{ProjectIdInt, RoleNameInt};
      14              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      15              : 
      16              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      17              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      18              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      19              : 
      20            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      21            0 :     let mut conn = client.get_async_pubsub().await?;
      22            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      23            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      24            0 :     Ok(conn)
      25            0 : }
      26              : 
      27            0 : #[derive(Debug, Deserialize)]
      28              : struct NotificationHeader<'a> {
      29              :     topic: &'a str,
      30              : }
      31              : 
      32            6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      33              : #[serde(tag = "topic", content = "data")]
      34              : pub(crate) enum Notification {
      35              :     #[serde(
      36              :         rename = "/allowed_ips_updated",
      37              :         deserialize_with = "deserialize_json_string"
      38              :     )]
      39              :     AllowedIpsUpdate {
      40              :         allowed_ips_update: AllowedIpsUpdate,
      41              :     },
      42              :     #[serde(
      43              :         rename = "/block_public_or_vpc_access_updated",
      44              :         deserialize_with = "deserialize_json_string"
      45              :     )]
      46              :     BlockPublicOrVpcAccessUpdated {
      47              :         block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
      48              :     },
      49              :     #[serde(
      50              :         rename = "/allowed_vpc_endpoints_updated_for_org",
      51              :         deserialize_with = "deserialize_json_string"
      52              :     )]
      53              :     AllowedVpcEndpointsUpdatedForOrg {
      54              :         allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
      55              :     },
      56              :     #[serde(
      57              :         rename = "/allowed_vpc_endpoints_updated_for_projects",
      58              :         deserialize_with = "deserialize_json_string"
      59              :     )]
      60              :     AllowedVpcEndpointsUpdatedForProjects {
      61              :         allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
      62              :     },
      63              :     #[serde(
      64              :         rename = "/password_updated",
      65              :         deserialize_with = "deserialize_json_string"
      66              :     )]
      67              :     PasswordUpdate { password_update: PasswordUpdate },
      68              : 
      69              :     #[serde(
      70              :         other,
      71              :         deserialize_with = "deserialize_unknown_topic",
      72              :         skip_serializing
      73              :     )]
      74              :     UnknownTopic,
      75              : }
      76              : 
      77            1 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      78              : pub(crate) struct AllowedIpsUpdate {
      79              :     project_id: ProjectIdInt,
      80              : }
      81              : 
      82            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      83              : pub(crate) struct BlockPublicOrVpcAccessUpdated {
      84              :     project_id: ProjectIdInt,
      85              : }
      86              : 
      87            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      88              : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
      89              :     // TODO: change type once the implementation is more fully fledged.
      90              :     // See e.g. https://github.com/neondatabase/neon/pull/10073.
      91              :     account_id: ProjectIdInt,
      92              : }
      93              : 
      94            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      95              : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
      96              :     project_ids: Vec<ProjectIdInt>,
      97              : }
      98              : 
      99            2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
     100              : pub(crate) struct PasswordUpdate {
     101              :     project_id: ProjectIdInt,
     102              :     role_name: RoleNameInt,
     103              : }
     104              : 
     105            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
     106              : pub(crate) struct CancelSession {
     107              :     pub(crate) region_id: Option<String>,
     108              :     pub(crate) cancel_key_data: CancelKeyData,
     109              :     pub(crate) session_id: Uuid,
     110              :     pub(crate) peer_addr: Option<std::net::IpAddr>,
     111              : }
     112              : 
     113            2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
     114            2 : where
     115            2 :     T: for<'de2> serde::Deserialize<'de2>,
     116            2 :     D: serde::Deserializer<'de>,
     117            2 : {
     118            2 :     let s = String::deserialize(deserializer)?;
     119            2 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
     120            2 : }
     121              : 
     122              : // https://github.com/serde-rs/serde/issues/1714
     123            1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
     124            1 : where
     125            1 :     D: serde::Deserializer<'de>,
     126            1 : {
     127            1 :     deserializer.deserialize_any(serde::de::IgnoredAny)?;
     128            1 :     Ok(())
     129            1 : }
     130              : 
     131              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     132              :     cache: Arc<C>,
     133              :     region_id: String,
     134              : }
     135              : 
     136              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     137            0 :     fn clone(&self) -> Self {
     138            0 :         Self {
     139            0 :             cache: self.cache.clone(),
     140            0 :             region_id: self.region_id.clone(),
     141            0 :         }
     142            0 :     }
     143              : }
     144              : 
     145              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     146            0 :     pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
     147            0 :         Self { cache, region_id }
     148            0 :     }
     149              : 
     150            0 :     pub(crate) async fn increment_active_listeners(&self) {
     151            0 :         self.cache.increment_active_listeners().await;
     152            0 :     }
     153              : 
     154            0 :     pub(crate) async fn decrement_active_listeners(&self) {
     155            0 :         self.cache.decrement_active_listeners().await;
     156            0 :     }
     157              : 
     158              :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     159              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     160              :         let payload: String = msg.get_payload()?;
     161              :         tracing::debug!(?payload, "received a message payload");
     162              : 
     163              :         let msg: Notification = match serde_json::from_str(&payload) {
     164              :             Ok(Notification::UnknownTopic) => {
     165              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     166              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     167              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     168              :                     Err(e) => {
     169              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     170              :                             channel: msg.get_channel_name(),
     171              :                         });
     172              :                         tracing::error!("broken message: {e}");
     173              :                     }
     174              :                 };
     175              :                 return Ok(());
     176              :             }
     177              :             Ok(msg) => msg,
     178              :             Err(e) => {
     179              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     180              :                     channel: msg.get_channel_name(),
     181              :                 });
     182              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     183              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     184              :                     Err(_) => tracing::error!("broken message: {e}"),
     185              :                 };
     186              :                 return Ok(());
     187              :             }
     188              :         };
     189              : 
     190              :         tracing::debug!(?msg, "received a message");
     191              :         match msg {
     192              :             Notification::AllowedIpsUpdate { .. }
     193              :             | Notification::PasswordUpdate { .. }
     194              :             | Notification::BlockPublicOrVpcAccessUpdated { .. }
     195              :             | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
     196              :             | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
     197              :                 invalidate_cache(self.cache.clone(), msg.clone());
     198              :                 if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
     199              :                     Metrics::get()
     200              :                         .proxy
     201              :                         .redis_events_count
     202              :                         .inc(RedisEventsCount::AllowedIpsUpdate);
     203              :                 } else if matches!(msg, Notification::PasswordUpdate { .. }) {
     204              :                     Metrics::get()
     205              :                         .proxy
     206              :                         .redis_events_count
     207              :                         .inc(RedisEventsCount::PasswordUpdate);
     208              :                 }
     209              :                 // TODO: add additional metrics for the other event types.
     210              : 
     211              :                 // It might happen that the invalid entry is on the way to be cached.
     212              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     213              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     214              :                 let cache = self.cache.clone();
     215            0 :                 tokio::spawn(async move {
     216            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     217            0 :                     invalidate_cache(cache, msg);
     218            0 :                 });
     219              :             }
     220              : 
     221              :             Notification::UnknownTopic => unreachable!(),
     222              :         }
     223              : 
     224              :         Ok(())
     225              :     }
     226              : }
     227              : 
     228            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     229            0 :     match msg {
     230            0 :         Notification::AllowedIpsUpdate { allowed_ips_update } => {
     231            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
     232            0 :         }
     233            0 :         Notification::PasswordUpdate { password_update } => cache
     234            0 :             .invalidate_role_secret_for_project(
     235            0 :                 password_update.project_id,
     236            0 :                 password_update.role_name,
     237            0 :             ),
     238            0 :         Notification::BlockPublicOrVpcAccessUpdated { .. } => {
     239            0 :             // https://github.com/neondatabase/neon/pull/10073
     240            0 :         }
     241            0 :         Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
     242            0 :             // https://github.com/neondatabase/neon/pull/10073
     243            0 :         }
     244            0 :         Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
     245            0 :             // https://github.com/neondatabase/neon/pull/10073
     246            0 :         }
     247            0 :         Notification::UnknownTopic => unreachable!(),
     248              :     }
     249            0 : }
     250              : 
     251            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     252            0 :     handler: MessageHandler<C>,
     253            0 :     redis: ConnectionWithCredentialsProvider,
     254            0 :     cancellation_token: CancellationToken,
     255            0 : ) -> anyhow::Result<()> {
     256              :     loop {
     257            0 :         if cancellation_token.is_cancelled() {
     258            0 :             return Ok(());
     259            0 :         }
     260            0 :         let mut conn = match try_connect(&redis).await {
     261            0 :             Ok(conn) => {
     262            0 :                 handler.increment_active_listeners().await;
     263            0 :                 conn
     264              :             }
     265            0 :             Err(e) => {
     266            0 :                 tracing::error!(
     267            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     268              :                 );
     269            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     270            0 :                 continue;
     271              :             }
     272              :         };
     273            0 :         let mut stream = conn.on_message();
     274            0 :         while let Some(msg) = stream.next().await {
     275            0 :             match handler.handle_message(msg).await {
     276            0 :                 Ok(()) => {}
     277            0 :                 Err(e) => {
     278            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     279            0 :                     break;
     280              :                 }
     281              :             }
     282            0 :             if cancellation_token.is_cancelled() {
     283            0 :                 handler.decrement_active_listeners().await;
     284            0 :                 return Ok(());
     285            0 :             }
     286              :         }
     287            0 :         handler.decrement_active_listeners().await;
     288              :     }
     289            0 : }
     290              : 
     291              : /// Handle console's invalidation messages.
     292              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     293              : pub async fn task_main<C>(
     294              :     redis: ConnectionWithCredentialsProvider,
     295              :     cache: Arc<C>,
     296              :     region_id: String,
     297              : ) -> anyhow::Result<Infallible>
     298              : where
     299              :     C: ProjectInfoCache + Send + Sync + 'static,
     300              : {
     301              :     let handler = MessageHandler::new(cache, region_id);
     302              :     // 6h - 1m.
     303              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     304              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     305              :     loop {
     306              :         let cancellation_token = CancellationToken::new();
     307              :         interval.tick().await;
     308              : 
     309              :         tokio::spawn(handle_messages(
     310              :             handler.clone(),
     311              :             redis.clone(),
     312              :             cancellation_token.clone(),
     313              :         ));
     314            0 :         tokio::spawn(async move {
     315            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     316            0 :             cancellation_token.cancel();
     317            0 :         });
     318              :     }
     319              : }
     320              : 
     321              : #[cfg(test)]
     322              : mod tests {
     323              :     use serde_json::json;
     324              : 
     325              :     use super::*;
     326              :     use crate::types::{ProjectId, RoleName};
     327              : 
     328              :     #[test]
     329            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     330            1 :         let project_id: ProjectId = "new_project".into();
     331            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     332            1 :         let text = json!({
     333            1 :             "type": "message",
     334            1 :             "topic": "/allowed_ips_updated",
     335            1 :             "data": data,
     336            1 :             "extre_fields": "something"
     337            1 :         })
     338            1 :         .to_string();
     339              : 
     340            1 :         let result: Notification = serde_json::from_str(&text)?;
     341            1 :         assert_eq!(
     342            1 :             result,
     343            1 :             Notification::AllowedIpsUpdate {
     344            1 :                 allowed_ips_update: AllowedIpsUpdate {
     345            1 :                     project_id: (&project_id).into()
     346            1 :                 }
     347            1 :             }
     348            1 :         );
     349              : 
     350            1 :         Ok(())
     351            1 :     }
     352              : 
     353              :     #[test]
     354            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     355            1 :         let project_id: ProjectId = "new_project".into();
     356            1 :         let role_name: RoleName = "new_role".into();
     357            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     358            1 :         let text = json!({
     359            1 :             "type": "message",
     360            1 :             "topic": "/password_updated",
     361            1 :             "data": data,
     362            1 :             "extre_fields": "something"
     363            1 :         })
     364            1 :         .to_string();
     365              : 
     366            1 :         let result: Notification = serde_json::from_str(&text)?;
     367            1 :         assert_eq!(
     368            1 :             result,
     369            1 :             Notification::PasswordUpdate {
     370            1 :                 password_update: PasswordUpdate {
     371            1 :                     project_id: (&project_id).into(),
     372            1 :                     role_name: (&role_name).into(),
     373            1 :                 }
     374            1 :             }
     375            1 :         );
     376              : 
     377            1 :         Ok(())
     378            1 :     }
     379              : 
     380              :     #[test]
     381            1 :     fn parse_unknown_topic() -> anyhow::Result<()> {
     382            1 :         let with_data = json!({
     383            1 :             "type": "message",
     384            1 :             "topic": "/doesnotexist",
     385            1 :             "data": {
     386            1 :                 "payload": "ignored"
     387            1 :             },
     388            1 :             "extra_fields": "something"
     389            1 :         })
     390            1 :         .to_string();
     391            1 :         let result: Notification = serde_json::from_str(&with_data)?;
     392            1 :         assert_eq!(result, Notification::UnknownTopic);
     393              : 
     394            1 :         let without_data = json!({
     395            1 :             "type": "message",
     396            1 :             "topic": "/doesnotexist",
     397            1 :             "extra_fields": "something"
     398            1 :         })
     399            1 :         .to_string();
     400            1 :         let result: Notification = serde_json::from_str(&without_data)?;
     401            1 :         assert_eq!(result, Notification::UnknownTopic);
     402              : 
     403            1 :         Ok(())
     404            1 :     }
     405              : }
        

Generated by: LCOV version 2.1-beta