LCOV - code coverage report
Current view: top level - proxy/src/cache - endpoints.rs (source / functions) Coverage Total Hit
Test: b4ae4c4857f9ef3e144e982a35ee23bc84c71983.info Lines: 3.6 % 169 6
Test Date: 2024-10-22 22:13:45 Functions: 15.6 % 45 7

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::sync::atomic::{AtomicBool, Ordering};
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use dashmap::DashSet;
       7              : use redis::streams::{StreamReadOptions, StreamReadReply};
       8              : use redis::{AsyncCommands, FromRedisValue, Value};
       9              : use serde::Deserialize;
      10              : use tokio::sync::Mutex;
      11              : use tokio_util::sync::CancellationToken;
      12              : use tracing::info;
      13              : 
      14              : use crate::config::EndpointCacheConfig;
      15              : use crate::context::RequestMonitoring;
      16              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
      17              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      18              : use crate::rate_limiter::GlobalRateLimiter;
      19              : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      20              : use crate::EndpointId;
      21              : 
      22            5 : #[derive(Deserialize, Debug, Clone)]
      23              : pub(crate) struct ControlPlaneEventKey {
      24              :     endpoint_created: Option<EndpointCreated>,
      25              :     branch_created: Option<BranchCreated>,
      26              :     project_created: Option<ProjectCreated>,
      27              : }
      28            2 : #[derive(Deserialize, Debug, Clone)]
      29              : struct EndpointCreated {
      30              :     endpoint_id: String,
      31              : }
      32            0 : #[derive(Deserialize, Debug, Clone)]
      33              : struct BranchCreated {
      34              :     branch_id: String,
      35              : }
      36            0 : #[derive(Deserialize, Debug, Clone)]
      37              : struct ProjectCreated {
      38              :     project_id: String,
      39              : }
      40              : 
      41              : pub struct EndpointsCache {
      42              :     config: EndpointCacheConfig,
      43              :     endpoints: DashSet<EndpointIdInt>,
      44              :     branches: DashSet<BranchIdInt>,
      45              :     projects: DashSet<ProjectIdInt>,
      46              :     ready: AtomicBool,
      47              :     limiter: Arc<Mutex<GlobalRateLimiter>>,
      48              : }
      49              : 
      50              : impl EndpointsCache {
      51            0 :     pub(crate) fn new(config: EndpointCacheConfig) -> Self {
      52            0 :         Self {
      53            0 :             limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
      54            0 :                 config.limiter_info.clone(),
      55            0 :             ))),
      56            0 :             config,
      57            0 :             endpoints: DashSet::new(),
      58            0 :             branches: DashSet::new(),
      59            0 :             projects: DashSet::new(),
      60            0 :             ready: AtomicBool::new(false),
      61            0 :         }
      62            0 :     }
      63            0 :     pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
      64            0 :         if !self.ready.load(Ordering::Acquire) {
      65            0 :             return true;
      66            0 :         }
      67            0 :         let rejected = self.should_reject(endpoint);
      68            0 :         ctx.set_rejected(rejected);
      69            0 :         info!(?rejected, "check endpoint is valid, disabled cache");
      70              :         // If cache is disabled, just collect the metrics and return or
      71              :         // If the limiter allows, we don't need to check the cache.
      72            0 :         if self.config.disable_cache || self.limiter.lock().await.check() {
      73            0 :             return true;
      74            0 :         }
      75            0 :         !rejected
      76            0 :     }
      77            0 :     fn should_reject(&self, endpoint: &EndpointId) -> bool {
      78            0 :         if endpoint.is_endpoint() {
      79            0 :             !self.endpoints.contains(&EndpointIdInt::from(endpoint))
      80            0 :         } else if endpoint.is_branch() {
      81            0 :             !self
      82            0 :                 .branches
      83            0 :                 .contains(&BranchIdInt::from(&endpoint.as_branch()))
      84              :         } else {
      85            0 :             !self
      86            0 :                 .projects
      87            0 :                 .contains(&ProjectIdInt::from(&endpoint.as_project()))
      88              :         }
      89            0 :     }
      90            0 :     fn insert_event(&self, key: ControlPlaneEventKey) {
      91              :         // Do not do normalization here, we expect the events to be normalized.
      92            0 :         if let Some(endpoint_created) = key.endpoint_created {
      93            0 :             self.endpoints
      94            0 :                 .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
      95            0 :             Metrics::get()
      96            0 :                 .proxy
      97            0 :                 .redis_events_count
      98            0 :                 .inc(RedisEventsCount::EndpointCreated);
      99            0 :         }
     100            0 :         if let Some(branch_created) = key.branch_created {
     101            0 :             self.branches
     102            0 :                 .insert(BranchIdInt::from(&branch_created.branch_id.into()));
     103            0 :             Metrics::get()
     104            0 :                 .proxy
     105            0 :                 .redis_events_count
     106            0 :                 .inc(RedisEventsCount::BranchCreated);
     107            0 :         }
     108            0 :         if let Some(project_created) = key.project_created {
     109            0 :             self.projects
     110            0 :                 .insert(ProjectIdInt::from(&project_created.project_id.into()));
     111            0 :             Metrics::get()
     112            0 :                 .proxy
     113            0 :                 .redis_events_count
     114            0 :                 .inc(RedisEventsCount::ProjectCreated);
     115            0 :         }
     116            0 :     }
     117            0 :     pub async fn do_read(
     118            0 :         &self,
     119            0 :         mut con: ConnectionWithCredentialsProvider,
     120            0 :         cancellation_token: CancellationToken,
     121            0 :     ) -> anyhow::Result<Infallible> {
     122            0 :         let mut last_id = "0-0".to_string();
     123              :         loop {
     124            0 :             if let Err(e) = con.connect().await {
     125            0 :                 tracing::error!("error connecting to redis: {:?}", e);
     126            0 :                 self.ready.store(false, Ordering::Release);
     127            0 :             }
     128            0 :             if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
     129            0 :                 tracing::error!("error reading from redis: {:?}", e);
     130            0 :                 self.ready.store(false, Ordering::Release);
     131            0 :             }
     132            0 :             if cancellation_token.is_cancelled() {
     133            0 :                 info!("cancellation token is cancelled, exiting");
     134            0 :                 tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await;
     135              :                 // 1 week.
     136            0 :             }
     137            0 :             tokio::time::sleep(self.config.retry_interval).await;
     138              :         }
     139              :     }
     140            0 :     async fn read_from_stream(
     141            0 :         &self,
     142            0 :         con: &mut ConnectionWithCredentialsProvider,
     143            0 :         last_id: &mut String,
     144            0 :     ) -> anyhow::Result<()> {
     145            0 :         tracing::info!("reading endpoints/branches/projects from redis");
     146            0 :         self.batch_read(
     147            0 :             con,
     148            0 :             StreamReadOptions::default().count(self.config.initial_batch_size),
     149            0 :             last_id,
     150            0 :             true,
     151            0 :         )
     152            0 :         .await?;
     153            0 :         tracing::info!("ready to filter user requests");
     154            0 :         self.ready.store(true, Ordering::Release);
     155            0 :         self.batch_read(
     156            0 :             con,
     157            0 :             StreamReadOptions::default()
     158            0 :                 .count(self.config.default_batch_size)
     159            0 :                 .block(self.config.xread_timeout.as_millis() as usize),
     160            0 :             last_id,
     161            0 :             false,
     162            0 :         )
     163            0 :         .await
     164            0 :     }
     165            0 :     fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
     166            0 :         let s: String = FromRedisValue::from_redis_value(value)?;
     167            0 :         Ok(serde_json::from_str(&s)?)
     168            0 :     }
     169            0 :     async fn batch_read(
     170            0 :         &self,
     171            0 :         conn: &mut ConnectionWithCredentialsProvider,
     172            0 :         opts: StreamReadOptions,
     173            0 :         last_id: &mut String,
     174            0 :         return_when_finish: bool,
     175            0 :     ) -> anyhow::Result<()> {
     176            0 :         let mut total: usize = 0;
     177              :         loop {
     178            0 :             let mut res: StreamReadReply = conn
     179            0 :                 .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
     180            0 :                 .await?;
     181              : 
     182            0 :             if res.keys.is_empty() {
     183            0 :                 if return_when_finish {
     184            0 :                     if total != 0 {
     185            0 :                         break;
     186            0 :                     }
     187            0 :                     anyhow::bail!(
     188            0 :                         "Redis stream {} is empty, cannot be used to filter endpoints",
     189            0 :                         self.config.stream_name
     190            0 :                     );
     191            0 :                 }
     192            0 :                 // If we are not returning when finish, we should wait for more data.
     193            0 :                 continue;
     194            0 :             }
     195            0 :             if res.keys.len() != 1 {
     196            0 :                 anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
     197            0 :             }
     198            0 : 
     199            0 :             let res = res.keys.pop().expect("Checked length above");
     200            0 :             let len = res.ids.len();
     201            0 :             for x in res.ids {
     202            0 :                 total += 1;
     203            0 :                 for (_, v) in x.map {
     204            0 :                     let key = match Self::parse_key_value(&v) {
     205            0 :                         Ok(x) => x,
     206            0 :                         Err(e) => {
     207            0 :                             Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     208            0 :                                 channel: &self.config.stream_name,
     209            0 :                             });
     210            0 :                             tracing::error!("error parsing value {v:?}: {e:?}");
     211            0 :                             continue;
     212              :                         }
     213              :                     };
     214            0 :                     self.insert_event(key);
     215              :                 }
     216            0 :                 if total.is_power_of_two() {
     217            0 :                     tracing::debug!("endpoints read {}", total);
     218            0 :                 }
     219            0 :                 *last_id = x.id;
     220              :             }
     221            0 :             if return_when_finish && len <= self.config.default_batch_size {
     222            0 :                 break;
     223            0 :             }
     224              :         }
     225            0 :         tracing::info!("read {} endpoints/branches/projects from redis", total);
     226            0 :         Ok(())
     227            0 :     }
     228              : }
     229              : 
     230              : #[cfg(test)]
     231              : mod tests {
     232              :     use super::ControlPlaneEventKey;
     233              : 
     234              :     #[test]
     235            1 :     fn test() {
     236            1 :         let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
     237            1 :         serde_json::from_str::<ControlPlaneEventKey>(s).unwrap();
     238            1 :     }
     239              : }
        

Generated by: LCOV version 2.1-beta