LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 71d97c4519b9017c5903db4dfe4edf4a84645500.info Lines: 49.7 % 165 82
Test Date: 2024-12-19 16:48:20 Functions: 20.0 % 100 20

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::Arc;
       3              : 
       4              : use futures::StreamExt;
       5              : use pq_proto::CancelKeyData;
       6              : use redis::aio::PubSub;
       7              : use serde::{Deserialize, Serialize};
       8              : use tokio_util::sync::CancellationToken;
       9              : use tracing::Instrument;
      10              : use uuid::Uuid;
      11              : 
      12              : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      13              : use crate::cache::project_info::ProjectInfoCache;
      14              : use crate::cancellation::{CancelMap, CancellationHandler};
      15              : use crate::intern::{ProjectIdInt, RoleNameInt};
      16              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      17              : 
      18              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      19              : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
      20              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      21              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      22              : 
      23            0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
      24            0 :     let mut conn = client.get_async_pubsub().await?;
      25            0 :     tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      26            0 :     conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      27            0 :     tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
      28            0 :     conn.subscribe(PROXY_CHANNEL_NAME).await?;
      29            0 :     Ok(conn)
      30            0 : }
      31              : 
      32           10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      33              : #[serde(tag = "topic", content = "data")]
      34              : pub(crate) enum Notification {
      35              :     #[serde(
      36              :         rename = "/allowed_ips_updated",
      37              :         deserialize_with = "deserialize_json_string"
      38              :     )]
      39              :     AllowedIpsUpdate {
      40              :         allowed_ips_update: AllowedIpsUpdate,
      41              :     },
      42              :     #[serde(
      43              :         rename = "/password_updated",
      44              :         deserialize_with = "deserialize_json_string"
      45              :     )]
      46              :     PasswordUpdate { password_update: PasswordUpdate },
      47              :     #[serde(rename = "/cancel_session")]
      48              :     Cancel(CancelSession),
      49              : }
      50            2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      51              : pub(crate) struct AllowedIpsUpdate {
      52              :     project_id: ProjectIdInt,
      53              : }
      54            3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      55              : pub(crate) struct PasswordUpdate {
      56              :     project_id: ProjectIdInt,
      57              :     role_name: RoleNameInt,
      58              : }
      59           10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      60              : pub(crate) struct CancelSession {
      61              :     pub(crate) region_id: Option<String>,
      62              :     pub(crate) cancel_key_data: CancelKeyData,
      63              :     pub(crate) session_id: Uuid,
      64              :     pub(crate) peer_addr: Option<std::net::IpAddr>,
      65              : }
      66              : 
      67            2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
      68            2 : where
      69            2 :     T: for<'de2> serde::Deserialize<'de2>,
      70            2 :     D: serde::Deserializer<'de>,
      71            2 : {
      72            2 :     let s = String::deserialize(deserializer)?;
      73            2 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
      74            2 : }
      75              : 
      76              : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
      77              :     cache: Arc<C>,
      78              :     cancellation_handler: Arc<CancellationHandler<()>>,
      79              :     region_id: String,
      80              : }
      81              : 
      82              : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
      83            0 :     fn clone(&self) -> Self {
      84            0 :         Self {
      85            0 :             cache: self.cache.clone(),
      86            0 :             cancellation_handler: self.cancellation_handler.clone(),
      87            0 :             region_id: self.region_id.clone(),
      88            0 :         }
      89            0 :     }
      90              : }
      91              : 
      92              : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
      93            0 :     pub(crate) fn new(
      94            0 :         cache: Arc<C>,
      95            0 :         cancellation_handler: Arc<CancellationHandler<()>>,
      96            0 :         region_id: String,
      97            0 :     ) -> Self {
      98            0 :         Self {
      99            0 :             cache,
     100            0 :             cancellation_handler,
     101            0 :             region_id,
     102            0 :         }
     103            0 :     }
     104            0 :     pub(crate) async fn increment_active_listeners(&self) {
     105            0 :         self.cache.increment_active_listeners().await;
     106            0 :     }
     107            0 :     pub(crate) async fn decrement_active_listeners(&self) {
     108            0 :         self.cache.decrement_active_listeners().await;
     109            0 :     }
     110            0 :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     111              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     112              :         let payload: String = msg.get_payload()?;
     113              :         tracing::debug!(?payload, "received a message payload");
     114              : 
     115              :         let msg: Notification = match serde_json::from_str(&payload) {
     116              :             Ok(msg) => msg,
     117              :             Err(e) => {
     118              :                 Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     119              :                     channel: msg.get_channel_name(),
     120              :                 });
     121              :                 tracing::error!("broken message: {e}");
     122              :                 return Ok(());
     123              :             }
     124              :         };
     125              :         tracing::debug!(?msg, "received a message");
     126              :         match msg {
     127              :             Notification::Cancel(cancel_session) => {
     128              :                 tracing::Span::current().record(
     129              :                     "session_id",
     130              :                     tracing::field::display(cancel_session.session_id),
     131              :                 );
     132              :                 Metrics::get()
     133              :                     .proxy
     134              :                     .redis_events_count
     135              :                     .inc(RedisEventsCount::CancelSession);
     136              :                 if let Some(cancel_region) = cancel_session.region_id {
     137              :                     // If the message is not for this region, ignore it.
     138              :                     if cancel_region != self.region_id {
     139              :                         return Ok(());
     140              :                     }
     141              :                 }
     142              : 
     143              :                 // TODO: Remove unspecified peer_addr after the complete migration to the new format
     144              :                 let peer_addr = cancel_session
     145              :                     .peer_addr
     146              :                     .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
     147              :                 let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
     148              :                 cancel_span.follows_from(tracing::Span::current());
     149              :                 // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
     150              :                 match self
     151              :                     .cancellation_handler
     152              :                     .cancel_session(
     153              :                         cancel_session.cancel_key_data,
     154              :                         uuid::Uuid::nil(),
     155              :                         peer_addr,
     156              :                         cancel_session.peer_addr.is_some(),
     157              :                     )
     158              :                     .instrument(cancel_span)
     159              :                     .await
     160              :                 {
     161              :                     Ok(()) => {}
     162              :                     Err(e) => {
     163              :                         tracing::warn!("failed to cancel session: {e}");
     164              :                     }
     165              :                 }
     166              :             }
     167              :             Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
     168              :                 invalidate_cache(self.cache.clone(), msg.clone());
     169              :                 if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
     170              :                     Metrics::get()
     171              :                         .proxy
     172              :                         .redis_events_count
     173              :                         .inc(RedisEventsCount::AllowedIpsUpdate);
     174              :                 } else if matches!(msg, Notification::PasswordUpdate { .. }) {
     175              :                     Metrics::get()
     176              :                         .proxy
     177              :                         .redis_events_count
     178              :                         .inc(RedisEventsCount::PasswordUpdate);
     179              :                 }
     180              :                 // It might happen that the invalid entry is on the way to be cached.
     181              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     182              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     183              :                 let cache = self.cache.clone();
     184            0 :                 tokio::spawn(async move {
     185            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     186            0 :                     invalidate_cache(cache, msg);
     187            0 :                 });
     188              :             }
     189              :         }
     190              : 
     191              :         Ok(())
     192              :     }
     193              : }
     194              : 
     195            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     196            0 :     match msg {
     197            0 :         Notification::AllowedIpsUpdate { allowed_ips_update } => {
     198            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
     199            0 :         }
     200            0 :         Notification::PasswordUpdate { password_update } => cache
     201            0 :             .invalidate_role_secret_for_project(
     202            0 :                 password_update.project_id,
     203            0 :                 password_update.role_name,
     204            0 :             ),
     205            0 :         Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
     206              :     }
     207            0 : }
     208              : 
     209            0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
     210            0 :     handler: MessageHandler<C>,
     211            0 :     redis: ConnectionWithCredentialsProvider,
     212            0 :     cancellation_token: CancellationToken,
     213            0 : ) -> anyhow::Result<()> {
     214              :     loop {
     215            0 :         if cancellation_token.is_cancelled() {
     216            0 :             return Ok(());
     217            0 :         }
     218            0 :         let mut conn = match try_connect(&redis).await {
     219            0 :             Ok(conn) => {
     220            0 :                 handler.increment_active_listeners().await;
     221            0 :                 conn
     222              :             }
     223            0 :             Err(e) => {
     224            0 :                 tracing::error!(
     225            0 :             "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     226              :         );
     227            0 :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     228            0 :                 continue;
     229              :             }
     230              :         };
     231            0 :         let mut stream = conn.on_message();
     232            0 :         while let Some(msg) = stream.next().await {
     233            0 :             match handler.handle_message(msg).await {
     234            0 :                 Ok(()) => {}
     235            0 :                 Err(e) => {
     236            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     237            0 :                     break;
     238              :                 }
     239              :             }
     240            0 :             if cancellation_token.is_cancelled() {
     241            0 :                 handler.decrement_active_listeners().await;
     242            0 :                 return Ok(());
     243            0 :             }
     244              :         }
     245            0 :         handler.decrement_active_listeners().await;
     246              :     }
     247            0 : }
     248              : 
     249              : /// Handle console's invalidation messages.
     250              : #[tracing::instrument(name = "redis_notifications", skip_all)]
     251              : pub async fn task_main<C>(
     252              :     redis: ConnectionWithCredentialsProvider,
     253              :     cache: Arc<C>,
     254              :     cancel_map: CancelMap,
     255              :     region_id: String,
     256              : ) -> anyhow::Result<Infallible>
     257              : where
     258              :     C: ProjectInfoCache + Send + Sync + 'static,
     259              : {
     260              :     let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
     261              :         cancel_map,
     262              :         crate::metrics::CancellationSource::FromRedis,
     263              :     ));
     264              :     let handler = MessageHandler::new(cache, cancellation_handler, region_id);
     265              :     // 6h - 1m.
     266              :     // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
     267              :     let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
     268              :     loop {
     269              :         let cancellation_token = CancellationToken::new();
     270              :         interval.tick().await;
     271              : 
     272              :         tokio::spawn(handle_messages(
     273              :             handler.clone(),
     274              :             redis.clone(),
     275              :             cancellation_token.clone(),
     276              :         ));
     277            0 :         tokio::spawn(async move {
     278            0 :             tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
     279            0 :             cancellation_token.cancel();
     280            0 :         });
     281              :     }
     282              : }
     283              : 
     284              : #[cfg(test)]
     285              : mod tests {
     286              :     use serde_json::json;
     287              : 
     288              :     use super::*;
     289              :     use crate::types::{ProjectId, RoleName};
     290              : 
     291              :     #[test]
     292            1 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     293            1 :         let project_id: ProjectId = "new_project".into();
     294            1 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     295            1 :         let text = json!({
     296            1 :             "type": "message",
     297            1 :             "topic": "/allowed_ips_updated",
     298            1 :             "data": data,
     299            1 :             "extre_fields": "something"
     300            1 :         })
     301            1 :         .to_string();
     302              : 
     303            1 :         let result: Notification = serde_json::from_str(&text)?;
     304            1 :         assert_eq!(
     305            1 :             result,
     306            1 :             Notification::AllowedIpsUpdate {
     307            1 :                 allowed_ips_update: AllowedIpsUpdate {
     308            1 :                     project_id: (&project_id).into()
     309            1 :                 }
     310            1 :             }
     311            1 :         );
     312              : 
     313            1 :         Ok(())
     314            1 :     }
     315              : 
     316              :     #[test]
     317            1 :     fn parse_password_updated() -> anyhow::Result<()> {
     318            1 :         let project_id: ProjectId = "new_project".into();
     319            1 :         let role_name: RoleName = "new_role".into();
     320            1 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     321            1 :         let text = json!({
     322            1 :             "type": "message",
     323            1 :             "topic": "/password_updated",
     324            1 :             "data": data,
     325            1 :             "extre_fields": "something"
     326            1 :         })
     327            1 :         .to_string();
     328              : 
     329            1 :         let result: Notification = serde_json::from_str(&text)?;
     330            1 :         assert_eq!(
     331            1 :             result,
     332            1 :             Notification::PasswordUpdate {
     333            1 :                 password_update: PasswordUpdate {
     334            1 :                     project_id: (&project_id).into(),
     335            1 :                     role_name: (&role_name).into(),
     336            1 :                 }
     337            1 :             }
     338            1 :         );
     339              : 
     340            1 :         Ok(())
     341            1 :     }
     342              :     #[test]
     343            1 :     fn parse_cancel_session() -> anyhow::Result<()> {
     344            1 :         let cancel_key_data = CancelKeyData {
     345            1 :             backend_pid: 42,
     346            1 :             cancel_key: 41,
     347            1 :         };
     348            1 :         let uuid = uuid::Uuid::new_v4();
     349            1 :         let msg = Notification::Cancel(CancelSession {
     350            1 :             cancel_key_data,
     351            1 :             region_id: None,
     352            1 :             session_id: uuid,
     353            1 :             peer_addr: None,
     354            1 :         });
     355            1 :         let text = serde_json::to_string(&msg)?;
     356            1 :         let result: Notification = serde_json::from_str(&text)?;
     357            1 :         assert_eq!(msg, result);
     358              : 
     359            1 :         let msg = Notification::Cancel(CancelSession {
     360            1 :             cancel_key_data,
     361            1 :             region_id: Some("region".to_string()),
     362            1 :             session_id: uuid,
     363            1 :             peer_addr: None,
     364            1 :         });
     365            1 :         let text = serde_json::to_string(&msg)?;
     366            1 :         let result: Notification = serde_json::from_str(&text)?;
     367            1 :         assert_eq!(msg, result,);
     368              : 
     369            1 :         Ok(())
     370            1 :     }
     371              : }
        

Generated by: LCOV version 2.1-beta