LCOV - code coverage report
Current view: top level - proxy/src/redis - notifications.rs (source / functions) Coverage Total Hit
Test: 32f4a56327bc9da697706839ed4836b2a00a408f.info Lines: 63.3 % 90 57
Test Date: 2024-02-07 07:37:29 Functions: 24.5 % 94 23

            Line data    Source code
       1              : use std::{convert::Infallible, sync::Arc};
       2              : 
       3              : use futures::StreamExt;
       4              : use redis::aio::PubSub;
       5              : use serde::Deserialize;
       6              : 
       7              : use crate::{
       8              :     cache::project_info::ProjectInfoCache,
       9              :     intern::{ProjectIdInt, RoleNameInt},
      10              : };
      11              : 
      12              : const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
      13              : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
      14              : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
      15              : 
      16              : struct ConsoleRedisClient {
      17              :     client: redis::Client,
      18              : }
      19              : 
      20              : impl ConsoleRedisClient {
      21            0 :     pub fn new(url: &str) -> anyhow::Result<Self> {
      22            0 :         let client = redis::Client::open(url)?;
      23            0 :         Ok(Self { client })
      24            0 :     }
      25            0 :     async fn try_connect(&self) -> anyhow::Result<PubSub> {
      26            0 :         let mut conn = self.client.get_async_connection().await?.into_pubsub();
      27            0 :         tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
      28            0 :         conn.subscribe(CHANNEL_NAME).await?;
      29            0 :         Ok(conn)
      30            0 :     }
      31              : }
      32              : 
      33            8 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      34              : #[serde(tag = "topic", content = "data")]
      35              : enum Notification {
      36              :     #[serde(
      37              :         rename = "/allowed_ips_updated",
      38              :         deserialize_with = "deserialize_json_string"
      39              :     )]
      40              :     AllowedIpsUpdate {
      41              :         allowed_ips_update: AllowedIpsUpdate,
      42              :     },
      43              :     #[serde(
      44              :         rename = "/password_updated",
      45              :         deserialize_with = "deserialize_json_string"
      46              :     )]
      47              :     PasswordUpdate { password_update: PasswordUpdate },
      48              : }
      49            6 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      50              : struct AllowedIpsUpdate {
      51              :     project_id: ProjectIdInt,
      52              : }
      53           10 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
      54              : struct PasswordUpdate {
      55              :     project_id: ProjectIdInt,
      56              :     role_name: RoleNameInt,
      57              : }
      58            4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
      59            4 : where
      60            4 :     T: for<'de2> serde::Deserialize<'de2>,
      61            4 :     D: serde::Deserializer<'de>,
      62            4 : {
      63            4 :     let s = String::deserialize(deserializer)?;
      64            4 :     serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
      65            4 : }
      66              : 
      67            0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
      68            0 :     use Notification::*;
      69            0 :     match msg {
      70            0 :         AllowedIpsUpdate { allowed_ips_update } => {
      71            0 :             cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
      72              :         }
      73            0 :         PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
      74            0 :             password_update.project_id,
      75            0 :             password_update.role_name,
      76            0 :         ),
      77              :     }
      78            0 : }
      79              : 
      80            0 : #[tracing::instrument(skip(cache))]
      81              : fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
      82              : where
      83              :     C: ProjectInfoCache + Send + Sync + 'static,
      84              : {
      85              :     let payload: String = msg.get_payload()?;
      86            0 :     tracing::debug!(?payload, "received a message payload");
      87              : 
      88              :     let msg: Notification = match serde_json::from_str(&payload) {
      89              :         Ok(msg) => msg,
      90              :         Err(e) => {
      91            0 :             tracing::error!("broken message: {e}");
      92              :             return Ok(());
      93              :         }
      94              :     };
      95            0 :     tracing::debug!(?msg, "received a message");
      96              :     invalidate_cache(cache.clone(), msg.clone());
      97              :     // It might happen that the invalid entry is on the way to be cached.
      98              :     // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
      99              :     // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
     100            0 :     tokio::spawn(async move {
     101            0 :         tokio::time::sleep(INVALIDATION_LAG).await;
     102            0 :         invalidate_cache(cache, msg.clone());
     103            0 :     });
     104              : 
     105              :     Ok(())
     106              : }
     107              : 
     108              : /// Handle console's invalidation messages.
     109            0 : #[tracing::instrument(name = "console_notifications", skip_all)]
     110              : pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
     111              : where
     112              :     C: ProjectInfoCache + Send + Sync + 'static,
     113              : {
     114              :     cache.enable_ttl();
     115              : 
     116              :     loop {
     117              :         let redis = ConsoleRedisClient::new(&url)?;
     118              :         let conn = match redis.try_connect().await {
     119              :             Ok(conn) => {
     120              :                 cache.disable_ttl();
     121              :                 conn
     122              :             }
     123              :             Err(e) => {
     124            0 :                 tracing::error!(
     125            0 :                     "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
     126            0 :                 );
     127              :                 tokio::time::sleep(RECONNECT_TIMEOUT).await;
     128              :                 continue;
     129              :             }
     130              :         };
     131              :         let mut stream = conn.into_on_message();
     132              :         while let Some(msg) = stream.next().await {
     133              :             match handle_message(msg, cache.clone()) {
     134              :                 Ok(()) => {}
     135              :                 Err(e) => {
     136            0 :                     tracing::error!("failed to handle message: {e}, will try to reconnect");
     137              :                     break;
     138              :                 }
     139              :             }
     140              :         }
     141              :         cache.enable_ttl();
     142              :     }
     143              : }
     144              : 
     145              : #[cfg(test)]
     146              : mod tests {
     147              :     use crate::{ProjectId, RoleName};
     148              : 
     149              :     use super::*;
     150              :     use serde_json::json;
     151              : 
     152            2 :     #[test]
     153            2 :     fn parse_allowed_ips() -> anyhow::Result<()> {
     154            2 :         let project_id: ProjectId = "new_project".into();
     155            2 :         let data = format!("{{\"project_id\": \"{project_id}\"}}");
     156            2 :         let text = json!({
     157            2 :             "type": "message",
     158            2 :             "topic": "/allowed_ips_updated",
     159            2 :             "data": data,
     160            2 :             "extre_fields": "something"
     161            2 :         })
     162            2 :         .to_string();
     163              : 
     164            2 :         let result: Notification = serde_json::from_str(&text)?;
     165            2 :         assert_eq!(
     166            2 :             result,
     167            2 :             Notification::AllowedIpsUpdate {
     168            2 :                 allowed_ips_update: AllowedIpsUpdate {
     169            2 :                     project_id: (&project_id).into()
     170            2 :                 }
     171            2 :             }
     172            2 :         );
     173              : 
     174            2 :         Ok(())
     175            2 :     }
     176              : 
     177            2 :     #[test]
     178            2 :     fn parse_password_updated() -> anyhow::Result<()> {
     179            2 :         let project_id: ProjectId = "new_project".into();
     180            2 :         let role_name: RoleName = "new_role".into();
     181            2 :         let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
     182            2 :         let text = json!({
     183            2 :             "type": "message",
     184            2 :             "topic": "/password_updated",
     185            2 :             "data": data,
     186            2 :             "extre_fields": "something"
     187            2 :         })
     188            2 :         .to_string();
     189              : 
     190            2 :         let result: Notification = serde_json::from_str(&text)?;
     191            2 :         assert_eq!(
     192            2 :             result,
     193            2 :             Notification::PasswordUpdate {
     194            2 :                 password_update: PasswordUpdate {
     195            2 :                     project_id: (&project_id).into(),
     196            2 :                     role_name: (&role_name).into(),
     197            2 :                 }
     198            2 :             }
     199            2 :         );
     200              : 
     201            2 :         Ok(())
     202            2 :     }
     203              : }
        

Generated by: LCOV version 2.1-beta