LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 45.8 % 179 82
Test Date: 2025-01-07 20:58:07 Functions: 12.0 % 166 20

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use pq_proto::CancelKeyData;
       6              : use redis::aio::PubSub;
       7              : use serde::{Deserialize, Serialize};
       8              : use tokio_util::sync::CancellationToken;
       9              : use tracing::Instrument;
      10              : use uuid::Uuid;
      11              : 
      12              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      13              : use crate::cache::project_info::ProjectInfoCache;
      14              : use crate::cancellation::{CancelMap, CancellationHandler};
      15              : use crate::config::ProxyConfig;
      16              : use crate::intern::{ProjectIdInt, RoleNameInt};
      17              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      18              : 
      19              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      20              : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
      21              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      22              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      23              : 
      24            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      25            0 :     let mut conn = client.get_async_pubsub().await?;
      26            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      27            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      28            0 :     tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
      29            0 :     conn.subscribe(PROXY_CHANNEL_NAME).await?;
      30            0 :     Ok(conn)
      31            0 : }
      32              : 
      33            0 : #[derive(Debug, Deserialize)]
      34              : struct NotificationHeader<'a> {
      35              :     topic: &'a str,
      36              : }
      37              : 
      38           10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      39              : #[serde(tag = "topic", content = "data")]
      40              : pub(crate) enum Notification {
      41              :     #[serde(
      42              :         rename = "/allowed_ips_updated",
      43              :         deserialize_with = "deserialize_json_string"
      44              :     )]
      45              :     AllowedIpsUpdate {
      46              :         allowed_ips_update: AllowedIpsUpdate,
      47              :     },
      48              :     #[serde(
      49              :         rename = "/block_public_or_vpc_access_updated",
      50              :         deserialize_with = "deserialize_json_string"
      51              :     )]
      52              :     BlockPublicOrVpcAccessUpdated {
      53              :         block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
      54              :     },
      55              :     #[serde(
      56              :         rename = "/allowed_vpc_endpoints_updated_for_org",
      57              :         deserialize_with = "deserialize_json_string"
      58              :     )]
      59              :     AllowedVpcEndpointsUpdatedForOrg {
      60              :         allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
      61              :     },
      62              :     #[serde(
      63              :         rename = "/allowed_vpc_endpoints_updated_for_projects",
      64              :         deserialize_with = "deserialize_json_string"
      65              :     )]
      66              :     AllowedVpcEndpointsUpdatedForProjects {
      67              :         allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
      68              :     },
      69              :     #[serde(
      70              :         rename = "/password_updated",
      71              :         deserialize_with = "deserialize_json_string"
      72              :     )]
      73              :     PasswordUpdate { password_update: PasswordUpdate },
      74              :     #[serde(rename = "/cancel_session")]
      75              :     Cancel(CancelSession),
      76              : 
      77              :     #[serde(other, skip_serializing)]
      78              :     UnknownTopic,
      79              : }
      80              : 
      81            2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      82              : pub(crate) struct AllowedIpsUpdate {
      83              :     project_id: ProjectIdInt,
      84              : }
      85              : 
      86            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      87              : pub(crate) struct BlockPublicOrVpcAccessUpdated {
      88              :     project_id: ProjectIdInt,
      89              : }
      90              : 
      91            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      92              : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
      93              :     // TODO: change type once the implementation is more fully fledged.
      94              :     // See e.g. https://github.com/neondatabase/neon/pull/10073.
      95              :     account_id: ProjectIdInt,
      96              : }
      97              : 
      98            0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      99              : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
     100              :     project_ids: Vec<ProjectIdInt>,
     101              : }
     102              : 
     103            3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
     104              : pub(crate) struct PasswordUpdate {
     105              :     project_id: ProjectIdInt,
     106              :     role_name: RoleNameInt,
     107              : }
     108              : 
     109           10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
     110              : pub(crate) struct CancelSession {
     111              :     pub(crate) region_id: Option<String>,
     112              :     pub(crate) cancel_key_data: CancelKeyData,
     113              :     pub(crate) session_id: Uuid,
     114              :     pub(crate) peer_addr: Option<std::net::IpAddr>,
     115              : }
     116              : 
     117            2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
     118            2 : where
     119            2 :     T: for<'de2> serde::Deserialize<'de2>,
     120            2 :     D: serde::Deserializer<'de>,
     121            2 : {
     122            2 :     let s = String::deserialize(deserializer)?;
     123            2 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
     124            2 : }
     125              : 
     126              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
     127              :     cache: Arc<C>,
     128              :     cancellation_handler: Arc<CancellationHandler<()>>,
     129              :     region_id: String,
     130              : }
     131              : 
     132              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
     133            0 :     fn clone(&self) -> Self {
     134            0 :         Self {
     135            0 :             cache: self.cache.clone(),
     136            0 :             cancellation_handler: self.cancellation_handler.clone(),
     137            0 :             region_id: self.region_id.clone(),
     138            0 :         }
     139            0 :     }
     140              : }
     141              : 
     142              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
     143            0 :     pub(crate) fn new(
     144            0 :         cache: Arc<C>,
     145            0 :         cancellation_handler: Arc<CancellationHandler<()>>,
     146            0 :         region_id: String,
     147            0 :     ) -> Self {
     148            0 :         Self {
     149            0 :             cache,
     150            0 :             cancellation_handler,
     151            0 :             region_id,
     152            0 :         }
     153            0 :     }
     154              : 
     155            0 :     pub(crate) async fn increment_active_listeners(&self) {
     156            0 :         self.cache.increment_active_listeners().await;
     157            0 :     }
     158              : 
     159            0 :     pub(crate) async fn decrement_active_listeners(&self) {
     160            0 :         self.cache.decrement_active_listeners().await;
     161            0 :     }
     162              : 
     163            0 :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     164              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     165              :         let payload: String = msg.get_payload()?;
     166              :         tracing::debug!(?payload, "received a message payload");
     167              : 
     168              :         let msg: Notification = match serde_json::from_str(&payload) {
     169              :             Ok(Notification::UnknownTopic) => {
     170              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     171              :                     // don't update the metric for redis errors if it's just a topic we don't know about.
     172              :                     Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
     173              :                     Err(e) => {
     174              :                         Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     175              :                             channel: msg.get_channel_name(),
     176              :                         });
     177              :                         tracing::error!("broken message: {e}");
     178              :                     }
     179              :                 };
     180              :                 return Ok(());
     181              :             }
     182              :             Ok(msg) => msg,
     183              :             Err(e) => {
     184              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     185              :                     channel: msg.get_channel_name(),
     186              :                 });
     187              :                 match serde_json::from_str::<NotificationHeader>(&payload) {
     188              :                     Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
     189              :                     Err(_) => tracing::error!("broken message: {e}"),
     190              :                 };
     191              :                 return Ok(());
     192              :             }
     193              :         };
     194              : 
     195              :         tracing::debug!(?msg, "received a message");
     196              :         match msg {
     197              :             Notification::Cancel(cancel_session) => {
     198              :                 tracing::Span::current().record(
     199              :                     "session_id",
     200              :                     tracing::field::display(cancel_session.session_id),
     201              :                 );
     202              :                 Metrics::get()
     203              :                     .proxy
     204              :                     .redis_events_count
     205              :                     .inc(RedisEventsCount::CancelSession);
     206              :                 if let Some(cancel_region) = cancel_session.region_id {
     207              :                     // If the message is not for this region, ignore it.
     208              :                     if cancel_region != self.region_id {
     209              :                         return Ok(());
     210              :                     }
     211              :                 }
     212              : 
     213              :                 // TODO: Remove unspecified peer_addr after the complete migration to the new format
     214              :                 let peer_addr = cancel_session
     215              :                     .peer_addr
     216              :                     .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
     217              :                 let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
     218              :                 cancel_span.follows_from(tracing::Span::current());
     219              :                 // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
     220              :                 match self
     221              :                     .cancellation_handler
     222              :                     .cancel_session(
     223              :                         cancel_session.cancel_key_data,
     224              :                         uuid::Uuid::nil(),
     225              :                         peer_addr,
     226              :                         cancel_session.peer_addr.is_some(),
     227              :                     )
     228              :                     .instrument(cancel_span)
     229              :                     .await
     230              :                 {
     231              :                     Ok(()) => {}
     232              :                     Err(e) => {
     233              :                         tracing::warn!("failed to cancel session: {e}");
     234              :                     }
     235              :                 }
     236              :             }
     237              :             Notification::AllowedIpsUpdate { .. }
     238              :             | Notification::PasswordUpdate { .. }
     239              :             | Notification::BlockPublicOrVpcAccessUpdated { .. }
     240              :             | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
     241              :             | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
     242              :                 invalidate_cache(self.cache.clone(), msg.clone());
     243              :                 if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
     244              :                     Metrics::get()
     245              :                         .proxy
     246              :                         .redis_events_count
     247              :                         .inc(RedisEventsCount::AllowedIpsUpdate);
     248              :                 } else if matches!(msg, Notification::PasswordUpdate { .. }) {
     249              :                     Metrics::get()
     250              :                         .proxy
     251              :                         .redis_events_count
     252              :                         .inc(RedisEventsCount::PasswordUpdate);
     253              :                 }
     254              :                 // TODO: add additional metrics for the other event types.
     255              : 
     256              :                 // It might happen that the invalid entry is on the way to be cached.
     257              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     258              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     259              :                 let cache = self.cache.clone();
     260            0 :                 tokio::spawn(async move {
     261            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     262            0 :                     invalidate_cache(cache, msg);
     263            0 :                 });
     264              :             }
     265              : 
     266              :             Notification::UnknownTopic => unreachable!(),
     267              :         }
     268              : 
     269              :         Ok(())
     270              :     }
     271              : }
     272              : 
     273            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     274            0 :     match msg {
     275            0 :         Notification::AllowedIpsUpdate { allowed_ips_update } => {
     276            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
     277            0 :         }
     278            0 :         Notification::PasswordUpdate { password_update } => cache
     279            0 :             .invalidate_role_secret_for_project(
     280            0 :                 password_update.project_id,
     281            0 :                 password_update.role_name,
     282            0 :             ),
     283            0 :         Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
     284            0 :         Notification::BlockPublicOrVpcAccessUpdated { .. } => {
     285            0 :             // https://github.com/neondatabase/neon/pull/10073
     286            0 :         }
     287            0 :         Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
     288            0 :             // https://github.com/neondatabase/neon/pull/10073
     289            0 :         }
     290            0 :         Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
     291            0 :             // https://github.com/neondatabase/neon/pull/10073
     292            0 :         }
     293            0 :         Notification::UnknownTopic => unreachable!(),
     294              :     }
     295            0 : }
     296              : 
     297            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     298            0 :     handler: MessageHandler<C>,
     299            0 :     redis: ConnectionWithCredentialsProvider,
     300            0 :     cancellation_token: CancellationToken,
     301            0 : ) -> anyhow::Result<()> {
     302              :     loop {
     303            0 :         if cancellation_token.is_cancelled() {
     304            0 :             return Ok(());
     305            0 :         }
     306            0 :         let mut conn = match try_connect(&redis).await {
     307            0 :             Ok(conn) => {
     308            0 :                 handler.increment_active_listeners().await;
     309            0 :                 conn
     310              :             }
     311            0 :             Err(e) => {
     312            0 :                 tracing::error!(
     313            0 :             "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     314              :         );
     315            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     316            0 :                 continue;
     317              :             }
     318              :         };
     319            0 :         let mut stream = conn.on_message();
     320            0 :         while let Some(msg) = stream.next().await {
     321            0 :             match handler.handle_message(msg).await {
     322            0 :                 Ok(()) => {}
     323            0 :                 Err(e) => {
     324            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     325            0 :                     break;
     326              :                 }
     327              :             }
     328            0 :             if cancellation_token.is_cancelled() {
     329            0 :                 handler.decrement_active_listeners().await;
     330            0 :                 return Ok(());
     331            0 :             }
     332              :         }
     333            0 :         handler.decrement_active_listeners().await;
     334              :     }
     335            0 : }
     336              : 
     337              : /// Handle console's invalidation messages.
     338              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     339              : pub async fn task_main<C>(
     340              :     config: &'static ProxyConfig,
     341              :     redis: ConnectionWithCredentialsProvider,
     342              :     cache: Arc<C>,
     343              :     cancel_map: CancelMap,
     344              :     region_id: String,
     345              : ) -> anyhow::Result<Infallible>
     346              : where
     347              :     C: ProjectInfoCache + Send + Sync + 'static,
     348              : {
     349              :     let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
     350              :         &config.connect_to_compute,
     351              :         cancel_map,
     352              :         crate::metrics::CancellationSource::FromRedis,
     353              :     ));
     354              :     let handler = MessageHandler::new(cache, cancellation_handler, region_id);
     355              :     // 6h - 1m.
     356              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     357              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     358              :     loop {
     359              :         let cancellation_token = CancellationToken::new();
     360              :         interval.tick().await;
     361              : 
     362              :         tokio::spawn(handle_messages(
     363              :             handler.clone(),
     364              :             redis.clone(),
     365              :             cancellation_token.clone(),
     366              :         ));
     367            0 :         tokio::spawn(async move {
     368            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     369            0 :             cancellation_token.cancel();
     370            0 :         });
     371              :     }
     372              : }
     373              : 
     374              : #[cfg(test)]
     375              : mod tests {
     376              :     use serde_json::json;
     377              : 
     378              :     use super::*;
     379              :     use crate::types::{ProjectId, RoleName};
     380              : 
     381              :     #[test]
     382            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     383            1 :         let project_id: ProjectId = "new_project".into();
     384            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     385            1 :         let text = json!({
     386            1 :             "type": "message",
     387            1 :             "topic": "/allowed_ips_updated",
     388            1 :             "data": data,
     389            1 :             "extre_fields": "something"
     390            1 :         })
     391            1 :         .to_string();
     392              : 
     393            1 :         let result: Notification = serde_json::from_str(&text)?;
     394            1 :         assert_eq!(
     395            1 :             result,
     396            1 :             Notification::AllowedIpsUpdate {
     397            1 :                 allowed_ips_update: AllowedIpsUpdate {
     398            1 :                     project_id: (&project_id).into()
     399            1 :                 }
     400            1 :             }
     401            1 :         );
     402              : 
     403            1 :         Ok(())
     404            1 :     }
     405              : 
     406              :     #[test]
     407            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     408            1 :         let project_id: ProjectId = "new_project".into();
     409            1 :         let role_name: RoleName = "new_role".into();
     410            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     411            1 :         let text = json!({
     412            1 :             "type": "message",
     413            1 :             "topic": "/password_updated",
     414            1 :             "data": data,
     415            1 :             "extre_fields": "something"
     416            1 :         })
     417            1 :         .to_string();
     418              : 
     419            1 :         let result: Notification = serde_json::from_str(&text)?;
     420            1 :         assert_eq!(
     421            1 :             result,
     422            1 :             Notification::PasswordUpdate {
     423            1 :                 password_update: PasswordUpdate {
     424            1 :                     project_id: (&project_id).into(),
     425            1 :                     role_name: (&role_name).into(),
     426            1 :                 }
     427            1 :             }
     428            1 :         );
     429              : 
     430            1 :         Ok(())
     431            1 :     }
     432              :     #[test]
     433            1 :     fn parse_cancel_session() -> anyhow::Result<()> {
     434            1 :         let cancel_key_data = CancelKeyData {
     435            1 :             backend_pid: 42,
     436            1 :             cancel_key: 41,
     437            1 :         };
     438            1 :         let uuid = uuid::Uuid::new_v4();
     439            1 :         let msg = Notification::Cancel(CancelSession {
     440            1 :             cancel_key_data,
     441            1 :             region_id: None,
     442            1 :             session_id: uuid,
     443            1 :             peer_addr: None,
     444            1 :         });
     445            1 :         let text = serde_json::to_string(&msg)?;
     446            1 :         let result: Notification = serde_json::from_str(&text)?;
     447            1 :         assert_eq!(msg, result);
     448              : 
     449            1 :         let msg = Notification::Cancel(CancelSession {
     450            1 :             cancel_key_data,
     451            1 :             region_id: Some("region".to_string()),
     452            1 :             session_id: uuid,
     453            1 :             peer_addr: None,
     454            1 :         });
     455            1 :         let text = serde_json::to_string(&msg)?;
     456            1 :         let result: Notification = serde_json::from_str(&text)?;
     457            1 :         assert_eq!(msg, result,);
     458              : 
     459            1 :         Ok(())
     460            1 :     }
     461              : }
        

Generated by: LCOV version 2.1-beta