LCOV - code coverage report
Current view: top level - proxy/src/cache - endpoints.rs (source / functions) Coverage Total Hit
Test: f8d8f5b90fa487a9e82c42da223f012f5d4fece7.info Lines: 3.6 % 169 6
Test Date: 2024-09-19 20:36:02 Functions: 15.6 % 45 7

            Line data    Source code
       1              : use std::{
       2              :     convert::Infallible,
       3              :     sync::{
       4              :         atomic::{AtomicBool, Ordering},
       5              :         Arc,
       6              :     },
       7              :     time::Duration,
       8              : };
       9              : 
      10              : use dashmap::DashSet;
      11              : use redis::{
      12              :     streams::{StreamReadOptions, StreamReadReply},
      13              :     AsyncCommands, FromRedisValue, Value,
      14              : };
      15              : use serde::Deserialize;
      16              : use tokio::sync::Mutex;
      17              : use tokio_util::sync::CancellationToken;
      18              : use tracing::info;
      19              : 
      20              : use crate::{
      21              :     config::EndpointCacheConfig,
      22              :     context::RequestMonitoring,
      23              :     intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
      24              :     metrics::{Metrics, RedisErrors, RedisEventsCount},
      25              :     rate_limiter::GlobalRateLimiter,
      26              :     redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
      27              :     EndpointId,
      28              : };
      29              : 
      30            5 : #[derive(Deserialize, Debug, Clone)]
      31              : pub(crate) struct ControlPlaneEventKey {
      32              :     endpoint_created: Option<EndpointCreated>,
      33              :     branch_created: Option<BranchCreated>,
      34              :     project_created: Option<ProjectCreated>,
      35              : }
      36            2 : #[derive(Deserialize, Debug, Clone)]
      37              : struct EndpointCreated {
      38              :     endpoint_id: String,
      39              : }
      40            0 : #[derive(Deserialize, Debug, Clone)]
      41              : struct BranchCreated {
      42              :     branch_id: String,
      43              : }
      44            0 : #[derive(Deserialize, Debug, Clone)]
      45              : struct ProjectCreated {
      46              :     project_id: String,
      47              : }
      48              : 
      49              : pub struct EndpointsCache {
      50              :     config: EndpointCacheConfig,
      51              :     endpoints: DashSet<EndpointIdInt>,
      52              :     branches: DashSet<BranchIdInt>,
      53              :     projects: DashSet<ProjectIdInt>,
      54              :     ready: AtomicBool,
      55              :     limiter: Arc<Mutex<GlobalRateLimiter>>,
      56              : }
      57              : 
      58              : impl EndpointsCache {
      59            0 :     pub(crate) fn new(config: EndpointCacheConfig) -> Self {
      60            0 :         Self {
      61            0 :             limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
      62            0 :                 config.limiter_info.clone(),
      63            0 :             ))),
      64            0 :             config,
      65            0 :             endpoints: DashSet::new(),
      66            0 :             branches: DashSet::new(),
      67            0 :             projects: DashSet::new(),
      68            0 :             ready: AtomicBool::new(false),
      69            0 :         }
      70            0 :     }
      71            0 :     pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
      72            0 :         if !self.ready.load(Ordering::Acquire) {
      73            0 :             return true;
      74            0 :         }
      75            0 :         let rejected = self.should_reject(endpoint);
      76            0 :         ctx.set_rejected(rejected);
      77            0 :         info!(?rejected, "check endpoint is valid, disabled cache");
      78              :         // If cache is disabled, just collect the metrics and return or
      79              :         // If the limiter allows, we don't need to check the cache.
      80            0 :         if self.config.disable_cache || self.limiter.lock().await.check() {
      81            0 :             return true;
      82            0 :         }
      83            0 :         !rejected
      84            0 :     }
      85            0 :     fn should_reject(&self, endpoint: &EndpointId) -> bool {
      86            0 :         if endpoint.is_endpoint() {
      87            0 :             !self.endpoints.contains(&EndpointIdInt::from(endpoint))
      88            0 :         } else if endpoint.is_branch() {
      89            0 :             !self
      90            0 :                 .branches
      91            0 :                 .contains(&BranchIdInt::from(&endpoint.as_branch()))
      92              :         } else {
      93            0 :             !self
      94            0 :                 .projects
      95            0 :                 .contains(&ProjectIdInt::from(&endpoint.as_project()))
      96              :         }
      97            0 :     }
      98            0 :     fn insert_event(&self, key: ControlPlaneEventKey) {
      99              :         // Do not do normalization here, we expect the events to be normalized.
     100            0 :         if let Some(endpoint_created) = key.endpoint_created {
     101            0 :             self.endpoints
     102            0 :                 .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
     103            0 :             Metrics::get()
     104            0 :                 .proxy
     105            0 :                 .redis_events_count
     106            0 :                 .inc(RedisEventsCount::EndpointCreated);
     107            0 :         }
     108            0 :         if let Some(branch_created) = key.branch_created {
     109            0 :             self.branches
     110            0 :                 .insert(BranchIdInt::from(&branch_created.branch_id.into()));
     111            0 :             Metrics::get()
     112            0 :                 .proxy
     113            0 :                 .redis_events_count
     114            0 :                 .inc(RedisEventsCount::BranchCreated);
     115            0 :         }
     116            0 :         if let Some(project_created) = key.project_created {
     117            0 :             self.projects
     118            0 :                 .insert(ProjectIdInt::from(&project_created.project_id.into()));
     119            0 :             Metrics::get()
     120            0 :                 .proxy
     121            0 :                 .redis_events_count
     122            0 :                 .inc(RedisEventsCount::ProjectCreated);
     123            0 :         }
     124            0 :     }
     125            0 :     pub async fn do_read(
     126            0 :         &self,
     127            0 :         mut con: ConnectionWithCredentialsProvider,
     128            0 :         cancellation_token: CancellationToken,
     129            0 :     ) -> anyhow::Result<Infallible> {
     130            0 :         let mut last_id = "0-0".to_string();
     131              :         loop {
     132            0 :             if let Err(e) = con.connect().await {
     133            0 :                 tracing::error!("error connecting to redis: {:?}", e);
     134            0 :                 self.ready.store(false, Ordering::Release);
     135            0 :             }
     136            0 :             if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
     137            0 :                 tracing::error!("error reading from redis: {:?}", e);
     138            0 :                 self.ready.store(false, Ordering::Release);
     139            0 :             }
     140            0 :             if cancellation_token.is_cancelled() {
     141            0 :                 info!("cancellation token is cancelled, exiting");
     142            0 :                 tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await;
     143              :                 // 1 week.
     144            0 :             }
     145            0 :             tokio::time::sleep(self.config.retry_interval).await;
     146              :         }
     147              :     }
     148            0 :     async fn read_from_stream(
     149            0 :         &self,
     150            0 :         con: &mut ConnectionWithCredentialsProvider,
     151            0 :         last_id: &mut String,
     152            0 :     ) -> anyhow::Result<()> {
     153            0 :         tracing::info!("reading endpoints/branches/projects from redis");
     154            0 :         self.batch_read(
     155            0 :             con,
     156            0 :             StreamReadOptions::default().count(self.config.initial_batch_size),
     157            0 :             last_id,
     158            0 :             true,
     159            0 :         )
     160            0 :         .await?;
     161            0 :         tracing::info!("ready to filter user requests");
     162            0 :         self.ready.store(true, Ordering::Release);
     163            0 :         self.batch_read(
     164            0 :             con,
     165            0 :             StreamReadOptions::default()
     166            0 :                 .count(self.config.default_batch_size)
     167            0 :                 .block(self.config.xread_timeout.as_millis() as usize),
     168            0 :             last_id,
     169            0 :             false,
     170            0 :         )
     171            0 :         .await
     172            0 :     }
     173            0 :     fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
     174            0 :         let s: String = FromRedisValue::from_redis_value(value)?;
     175            0 :         Ok(serde_json::from_str(&s)?)
     176            0 :     }
     177            0 :     async fn batch_read(
     178            0 :         &self,
     179            0 :         conn: &mut ConnectionWithCredentialsProvider,
     180            0 :         opts: StreamReadOptions,
     181            0 :         last_id: &mut String,
     182            0 :         return_when_finish: bool,
     183            0 :     ) -> anyhow::Result<()> {
     184            0 :         let mut total: usize = 0;
     185              :         loop {
     186            0 :             let mut res: StreamReadReply = conn
     187            0 :                 .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
     188            0 :                 .await?;
     189              : 
     190            0 :             if res.keys.is_empty() {
     191            0 :                 if return_when_finish {
     192            0 :                     if total != 0 {
     193            0 :                         break;
     194            0 :                     }
     195            0 :                     anyhow::bail!(
     196            0 :                         "Redis stream {} is empty, cannot be used to filter endpoints",
     197            0 :                         self.config.stream_name
     198            0 :                     );
     199            0 :                 }
     200            0 :                 // If we are not returning when finish, we should wait for more data.
     201            0 :                 continue;
     202            0 :             }
     203            0 :             if res.keys.len() != 1 {
     204            0 :                 anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
     205            0 :             }
     206            0 : 
     207            0 :             let res = res.keys.pop().expect("Checked length above");
     208            0 :             let len = res.ids.len();
     209            0 :             for x in res.ids {
     210            0 :                 total += 1;
     211            0 :                 for (_, v) in x.map {
     212            0 :                     let key = match Self::parse_key_value(&v) {
     213            0 :                         Ok(x) => x,
     214            0 :                         Err(e) => {
     215            0 :                             Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     216            0 :                                 channel: &self.config.stream_name,
     217            0 :                             });
     218            0 :                             tracing::error!("error parsing value {v:?}: {e:?}");
     219            0 :                             continue;
     220              :                         }
     221              :                     };
     222            0 :                     self.insert_event(key);
     223              :                 }
     224            0 :                 if total.is_power_of_two() {
     225            0 :                     tracing::debug!("endpoints read {}", total);
     226            0 :                 }
     227            0 :                 *last_id = x.id;
     228              :             }
     229            0 :             if return_when_finish && len <= self.config.default_batch_size {
     230            0 :                 break;
     231            0 :             }
     232              :         }
     233            0 :         tracing::info!("read {} endpoints/branches/projects from redis", total);
     234            0 :         Ok(())
     235            0 :     }
     236              : }
     237              : 
     238              : #[cfg(test)]
     239              : mod tests {
     240              :     use super::ControlPlaneEventKey;
     241              : 
     242              :     #[test]
     243            1 :     fn test() {
     244            1 :         let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
     245            1 :         serde_json::from_str::<ControlPlaneEventKey>(s).unwrap();
     246            1 :     }
     247              : }
        

Generated by: LCOV version 2.1-beta