LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 62.4 % 133 83
Test Date: 2024-02-14 18:05:35 Functions: 21.5 % 149 32

            Line data    Source code
       1              : use std::{convert::Infallible, sync::Arc};
       2              : 
       3              : use futures::StreamExt;
       4              : use pq_proto::CancelKeyData;
       5              : use redis::aio::PubSub;
       6              : use serde::{Deserialize, Serialize};
       7              : use uuid::Uuid;
       8              : 
       9              : use crate::{
      10              :     cache::project_info::ProjectInfoCache,
      11              :     cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
      12              :     intern::{ProjectIdInt, RoleNameInt},
      13              : };
      14              : 
      15              : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      16              : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
      17              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      18              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      19              : 
      20              : struct RedisConsumerClient {
      21              :     client: redis::Client,
      22              : }
      23              : 
      24              : impl RedisConsumerClient {
      25            0 :     pub fn new(url: &str) -> anyhow::Result<Self> {
      26            0 :         let client = redis::Client::open(url)?;
      27            0 :         Ok(Self { client })
      28            0 :     }
      29            0 :     async fn try_connect(&self) -> anyhow::Result<PubSub> {
      30            0 :         let mut conn = self.client.get_async_connection().await?.into_pubsub();
      31            0 :         tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
      32            0 :         conn.subscribe(CPLANE_CHANNEL_NAME).await?;
      33            0 :         tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
      34            0 :         conn.subscribe(PROXY_CHANNEL_NAME).await?;
      35            0 :         Ok(conn)
      36            0 :     }
      37              : }
      38              : 
      39           20 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      40              : #[serde(tag = "topic", content = "data")]
      41              : pub(crate) enum Notification {
      42              :     #[serde(
      43              :         rename = "/allowed_ips_updated",
      44              :         deserialize_with = "deserialize_json_string"
      45              :     )]
      46              :     AllowedIpsUpdate {
      47              :         allowed_ips_update: AllowedIpsUpdate,
      48              :     },
      49              :     #[serde(
      50              :         rename = "/password_updated",
      51              :         deserialize_with = "deserialize_json_string"
      52              :     )]
      53              :     PasswordUpdate { password_update: PasswordUpdate },
      54              :     #[serde(rename = "/cancel_session")]
      55              :     Cancel(CancelSession),
      56              : }
      57            6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      58              : pub(crate) struct AllowedIpsUpdate {
      59              :     project_id: ProjectIdInt,
      60              : }
      61           10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      62              : pub(crate) struct PasswordUpdate {
      63              :     project_id: ProjectIdInt,
      64              :     role_name: RoleNameInt,
      65              : }
      66           28 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
      67              : pub(crate) struct CancelSession {
      68              :     pub region_id: Option<String>,
      69              :     pub cancel_key_data: CancelKeyData,
      70              :     pub session_id: Uuid,
      71              : }
      72              : 
      73            4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
      74            4 : where
      75            4 :     T: for<'de2> serde::Deserialize<'de2>,
      76            4 :     D: serde::Deserializer<'de>,
      77            4 : {
      78            4 :     let s = String::deserialize(deserializer)?;
      79            4 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
      80            4 : }
      81              : 
      82              : struct MessageHandler<
      83              :     C: ProjectInfoCache + Send + Sync + 'static,
      84              :     H: NotificationsCancellationHandler + Send + Sync + 'static,
      85              : > {
      86              :     cache: Arc<C>,
      87              :     cancellation_handler: Arc<H>,
      88              :     region_id: String,
      89              : }
      90              : 
      91              : impl<
      92              :         C: ProjectInfoCache + Send + Sync + 'static,
      93              :         H: NotificationsCancellationHandler + Send + Sync + 'static,
      94              :     > MessageHandler<C, H>
      95              : {
      96            0 :     pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
      97            0 :         Self {
      98            0 :             cache,
      99            0 :             cancellation_handler,
     100            0 :             region_id,
     101            0 :         }
     102            0 :     }
     103            0 :     pub fn disable_ttl(&self) {
     104            0 :         self.cache.disable_ttl();
     105            0 :     }
     106            0 :     pub fn enable_ttl(&self) {
     107            0 :         self.cache.enable_ttl();
     108            0 :     }
     109            0 :     #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
     110              :     async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
     111              :         use Notification::*;
     112              :         let payload: String = msg.get_payload()?;
     113            0 :         tracing::debug!(?payload, "received a message payload");
     114              : 
     115              :         let msg: Notification = match serde_json::from_str(&payload) {
     116              :             Ok(msg) => msg,
     117              :             Err(e) => {
     118            0 :                 tracing::error!("broken message: {e}");
     119              :                 return Ok(());
     120              :             }
     121              :         };
     122            0 :         tracing::debug!(?msg, "received a message");
     123              :         match msg {
     124              :             Cancel(cancel_session) => {
     125              :                 tracing::Span::current().record(
     126              :                     "session_id",
     127              :                     &tracing::field::display(cancel_session.session_id),
     128              :                 );
     129              :                 if let Some(cancel_region) = cancel_session.region_id {
     130              :                     // If the message is not for this region, ignore it.
     131              :                     if cancel_region != self.region_id {
     132              :                         return Ok(());
     133              :                     }
     134              :                 }
     135              :                 // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
     136              :                 match self
     137              :                     .cancellation_handler
     138              :                     .cancel_session_no_publish(cancel_session.cancel_key_data)
     139              :                     .await
     140              :                 {
     141              :                     Ok(()) => {}
     142              :                     Err(e) => {
     143            0 :                         tracing::error!("failed to cancel session: {e}");
     144              :                     }
     145              :                 }
     146              :             }
     147              :             _ => {
     148              :                 invalidate_cache(self.cache.clone(), msg.clone());
     149              :                 // It might happen that the invalid entry is on the way to be cached.
     150              :                 // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
     151              :                 // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     152              :                 let cache = self.cache.clone();
     153            0 :                 tokio::spawn(async move {
     154            0 :                     tokio::time::sleep(INVALIDATION_LAG).await;
     155            0 :                     invalidate_cache(cache, msg);
     156            0 :                 });
     157              :             }
     158              :         }
     159              : 
     160              :         Ok(())
     161              :     }
     162              : }
     163              : 
     164            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
     165            0 :     use Notification::*;
     166            0 :     match msg {
     167            0 :         AllowedIpsUpdate { allowed_ips_update } => {
     168            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
     169              :         }
     170            0 :         PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
     171            0 :             password_update.project_id,
     172            0 :             password_update.role_name,
     173            0 :         ),
     174            0 :         Cancel(_) => unreachable!("cancel message should be handled separately"),
     175              :     }
     176            0 : }
     177              : 
     178              : /// Handle console's invalidation messages.
     179            0 : #[tracing::instrument(name = "console_notifications", skip_all)]
     180              : pub async fn task_main<C>(
     181              :     url: String,
     182              :     cache: Arc<C>,
     183              :     cancel_map: CancelMap,
     184              :     region_id: String,
     185              : ) -> anyhow::Result<Infallible>
     186              : where
     187              :     C: ProjectInfoCache + Send + Sync + 'static,
     188              : {
     189              :     cache.enable_ttl();
     190              :     let handler = MessageHandler::new(
     191              :         cache,
     192              :         Arc::new(CancellationHandler::new(cancel_map, None)),
     193              :         region_id,
     194              :     );
     195              : 
     196              :     loop {
     197              :         let redis = RedisConsumerClient::new(&url)?;
     198              :         let conn = match redis.try_connect().await {
     199              :             Ok(conn) => {
     200              :                 handler.disable_ttl();
     201              :                 conn
     202              :             }
     203              :             Err(e) => {
     204            0 :                 tracing::error!(
     205            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     206            0 :                 );
     207              :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     208              :                 continue;
     209              :             }
     210              :         };
     211              :         let mut stream = conn.into_on_message();
     212              :         while let Some(msg) = stream.next().await {
     213              :             match handler.handle_message(msg).await {
     214              :                 Ok(()) => {}
     215              :                 Err(e) => {
     216            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     217              :                     break;
     218              :                 }
     219              :             }
     220              :         }
     221              :         handler.enable_ttl();
     222              :     }
     223              : }
     224              : 
     225              : #[cfg(test)]
     226              : mod tests {
     227              :     use crate::{ProjectId, RoleName};
     228              : 
     229              :     use super::*;
     230              :     use serde_json::json;
     231              : 
     232            2 :     #[test]
     233            2 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     234            2 :         let project_id: ProjectId = "new_project".into();
     235            2 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     236            2 :         let text = json!({
     237            2 :             "type": "message",
     238            2 :             "topic": "/allowed_ips_updated",
     239            2 :             "data": data,
     240            2 :             "extre_fields": "something"
     241            2 :         })
     242            2 :         .to_string();
     243              : 
     244            2 :         let result: Notification = serde_json::from_str(&text)?;
     245            2 :         assert_eq!(
     246            2 :             result,
     247            2 :             Notification::AllowedIpsUpdate {
     248            2 :                 allowed_ips_update: AllowedIpsUpdate {
     249            2 :                     project_id: (&project_id).into()
     250            2 :                 }
     251            2 :             }
     252            2 :         );
     253              : 
     254            2 :         Ok(())
     255            2 :     }
     256              : 
     257            2 :     #[test]
     258            2 :     fn parse_password_updated() -> anyhow::Result<()> {
     259            2 :         let project_id: ProjectId = "new_project".into();
     260            2 :         let role_name: RoleName = "new_role".into();
     261            2 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     262            2 :         let text = json!({
     263            2 :             "type": "message",
     264            2 :             "topic": "/password_updated",
     265            2 :             "data": data,
     266            2 :             "extre_fields": "something"
     267            2 :         })
     268            2 :         .to_string();
     269              : 
     270            2 :         let result: Notification = serde_json::from_str(&text)?;
     271            2 :         assert_eq!(
     272            2 :             result,
     273            2 :             Notification::PasswordUpdate {
     274            2 :                 password_update: PasswordUpdate {
     275            2 :                     project_id: (&project_id).into(),
     276            2 :                     role_name: (&role_name).into(),
     277            2 :                 }
     278            2 :             }
     279            2 :         );
     280              : 
     281            2 :         Ok(())
     282            2 :     }
     283            2 :     #[test]
     284            2 :     fn parse_cancel_session() -> anyhow::Result<()> {
     285            2 :         let cancel_key_data = CancelKeyData {
     286            2 :             backend_pid: 42,
     287            2 :             cancel_key: 41,
     288            2 :         };
     289            2 :         let uuid = uuid::Uuid::new_v4();
     290            2 :         let msg = Notification::Cancel(CancelSession {
     291            2 :             cancel_key_data,
     292            2 :             region_id: None,
     293            2 :             session_id: uuid,
     294            2 :         });
     295            2 :         let text = serde_json::to_string(&msg)?;
     296            2 :         let result: Notification = serde_json::from_str(&text)?;
     297            2 :         assert_eq!(msg, result);
     298              : 
     299            2 :         let msg = Notification::Cancel(CancelSession {
     300            2 :             cancel_key_data,
     301            2 :             region_id: Some("region".to_string()),
     302            2 :             session_id: uuid,
     303            2 :         });
     304            2 :         let text = serde_json::to_string(&msg)?;
     305            2 :         let result: Notification = serde_json::from_str(&text)?;
     306            2 :         assert_eq!(msg, result,);
     307              : 
     308            2 :         Ok(())
     309            2 :     }
     310              : }
        

Generated by: LCOV version 2.1-beta