LCOV - code coverage report
Current view: top level - proxy/src/cache - endpoints.rs (source / functions) Coverage Total Hit
Test: a2f0f8a80fbf1089336086fa360ce27fa555cb1a.info Lines: 9.9 % 192 19
Test Date: 2024-11-20 17:59:39 Functions: 15.9 % 44 7

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::future::pending;
       3              : use std::sync::atomic::{AtomicBool, Ordering};
       4              : use std::sync::{Arc, Mutex};
       5              : 
       6              : use dashmap::DashSet;
       7              : use redis::streams::{StreamReadOptions, StreamReadReply};
       8              : use redis::{AsyncCommands, FromRedisValue, Value};
       9              : use serde::Deserialize;
      10              : use tokio_util::sync::CancellationToken;
      11              : use tracing::info;
      12              : 
      13              : use crate::config::EndpointCacheConfig;
      14              : use crate::context::RequestContext;
      15              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
      16              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      17              : use crate::rate_limiter::GlobalRateLimiter;
      18              : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      19              : use crate::types::EndpointId;
      20              : 
      21              : // TODO: this could be an enum, but events in Redis need to be fixed first.
      22              : // ProjectCreated was sent with type:branch_created. So we ignore type.
      23            5 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      24              : struct ControlPlaneEvent {
      25              :     endpoint_created: Option<EndpointCreated>,
      26              :     branch_created: Option<BranchCreated>,
      27              :     project_created: Option<ProjectCreated>,
      28              :     #[serde(rename = "type")]
      29              :     _type: Option<String>,
      30              : }
      31              : 
      32            2 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      33              : struct EndpointCreated {
      34              :     endpoint_id: EndpointIdInt,
      35              : }
      36              : 
      37            0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      38              : struct BranchCreated {
      39              :     branch_id: BranchIdInt,
      40              : }
      41              : 
      42            0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      43              : struct ProjectCreated {
      44              :     project_id: ProjectIdInt,
      45              : }
      46              : 
      47              : impl TryFrom<&Value> for ControlPlaneEvent {
      48              :     type Error = anyhow::Error;
      49            0 :     fn try_from(value: &Value) -> Result<Self, Self::Error> {
      50            0 :         let json = String::from_redis_value(value)?;
      51            0 :         Ok(serde_json::from_str(&json)?)
      52            0 :     }
      53              : }
      54              : 
      55              : pub struct EndpointsCache {
      56              :     config: EndpointCacheConfig,
      57              :     endpoints: DashSet<EndpointIdInt>,
      58              :     branches: DashSet<BranchIdInt>,
      59              :     projects: DashSet<ProjectIdInt>,
      60              :     ready: AtomicBool,
      61              :     limiter: Arc<Mutex<GlobalRateLimiter>>,
      62              : }
      63              : 
      64              : impl EndpointsCache {
      65            0 :     pub(crate) fn new(config: EndpointCacheConfig) -> Self {
      66            0 :         Self {
      67            0 :             limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
      68            0 :                 config.limiter_info.clone(),
      69            0 :             ))),
      70            0 :             config,
      71            0 :             endpoints: DashSet::new(),
      72            0 :             branches: DashSet::new(),
      73            0 :             projects: DashSet::new(),
      74            0 :             ready: AtomicBool::new(false),
      75            0 :         }
      76            0 :     }
      77              : 
      78            0 :     pub(crate) fn is_valid(&self, ctx: &RequestContext, endpoint: &EndpointId) -> bool {
      79            0 :         if !self.ready.load(Ordering::Acquire) {
      80              :             // the endpoint cache is not yet fully initialised.
      81            0 :             return true;
      82            0 :         }
      83            0 : 
      84            0 :         if !self.should_reject(endpoint) {
      85            0 :             ctx.set_rejected(false);
      86            0 :             return true;
      87            0 :         }
      88            0 : 
      89            0 :         // report that we might want to reject this endpoint
      90            0 :         ctx.set_rejected(true);
      91            0 : 
      92            0 :         // If cache is disabled, just collect the metrics and return.
      93            0 :         if self.config.disable_cache {
      94            0 :             return true;
      95            0 :         }
      96            0 : 
      97            0 :         // If the limiter allows, we can pretend like it's valid
      98            0 :         // (incase it is, due to redis channel lag).
      99            0 :         if self.limiter.lock().unwrap().check() {
     100            0 :             return true;
     101            0 :         }
     102            0 : 
     103            0 :         // endpoint not found, and there's too much load.
     104            0 :         false
     105            0 :     }
     106              : 
     107            0 :     fn should_reject(&self, endpoint: &EndpointId) -> bool {
     108            0 :         if endpoint.is_endpoint() {
     109            0 :             let Some(endpoint) = EndpointIdInt::get(endpoint) else {
     110              :                 // if we haven't interned this endpoint, it's not in the cache.
     111            0 :                 return true;
     112              :             };
     113            0 :             !self.endpoints.contains(&endpoint)
     114            0 :         } else if endpoint.is_branch() {
     115            0 :             let Some(branch) = BranchIdInt::get(endpoint) else {
     116              :                 // if we haven't interned this branch, it's not in the cache.
     117            0 :                 return true;
     118              :             };
     119            0 :             !self.branches.contains(&branch)
     120              :         } else {
     121            0 :             let Some(project) = ProjectIdInt::get(endpoint) else {
     122              :                 // if we haven't interned this project, it's not in the cache.
     123            0 :                 return true;
     124              :             };
     125            0 :             !self.projects.contains(&project)
     126              :         }
     127            0 :     }
     128              : 
     129            0 :     fn insert_event(&self, event: ControlPlaneEvent) {
     130            0 :         if let Some(endpoint_created) = event.endpoint_created {
     131            0 :             self.endpoints.insert(endpoint_created.endpoint_id);
     132            0 :             Metrics::get()
     133            0 :                 .proxy
     134            0 :                 .redis_events_count
     135            0 :                 .inc(RedisEventsCount::EndpointCreated);
     136            0 :         } else if let Some(branch_created) = event.branch_created {
     137            0 :             self.branches.insert(branch_created.branch_id);
     138            0 :             Metrics::get()
     139            0 :                 .proxy
     140            0 :                 .redis_events_count
     141            0 :                 .inc(RedisEventsCount::BranchCreated);
     142            0 :         } else if let Some(project_created) = event.project_created {
     143            0 :             self.projects.insert(project_created.project_id);
     144            0 :             Metrics::get()
     145            0 :                 .proxy
     146            0 :                 .redis_events_count
     147            0 :                 .inc(RedisEventsCount::ProjectCreated);
     148            0 :         }
     149            0 :     }
     150              : 
     151            0 :     pub async fn do_read(
     152            0 :         &self,
     153            0 :         mut con: ConnectionWithCredentialsProvider,
     154            0 :         cancellation_token: CancellationToken,
     155            0 :     ) -> anyhow::Result<Infallible> {
     156            0 :         let mut last_id = "0-0".to_string();
     157              :         loop {
     158            0 :             if let Err(e) = con.connect().await {
     159            0 :                 tracing::error!("error connecting to redis: {:?}", e);
     160            0 :                 self.ready.store(false, Ordering::Release);
     161            0 :             }
     162            0 :             if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
     163            0 :                 tracing::error!("error reading from redis: {:?}", e);
     164            0 :                 self.ready.store(false, Ordering::Release);
     165            0 :             }
     166            0 :             if cancellation_token.is_cancelled() {
     167            0 :                 info!("cancellation token is cancelled, exiting");
     168              :                 // Maintenance tasks run forever. Sleep forever when canceled.
     169            0 :                 pending::<()>().await;
     170            0 :             }
     171            0 :             tokio::time::sleep(self.config.retry_interval).await;
     172              :         }
     173              :     }
     174              : 
     175            0 :     async fn read_from_stream(
     176            0 :         &self,
     177            0 :         con: &mut ConnectionWithCredentialsProvider,
     178            0 :         last_id: &mut String,
     179            0 :     ) -> anyhow::Result<()> {
     180            0 :         tracing::info!("reading endpoints/branches/projects from redis");
     181            0 :         self.batch_read(
     182            0 :             con,
     183            0 :             StreamReadOptions::default().count(self.config.initial_batch_size),
     184            0 :             last_id,
     185            0 :             true,
     186            0 :         )
     187            0 :         .await?;
     188            0 :         tracing::info!("ready to filter user requests");
     189            0 :         self.ready.store(true, Ordering::Release);
     190            0 :         self.batch_read(
     191            0 :             con,
     192            0 :             StreamReadOptions::default()
     193            0 :                 .count(self.config.default_batch_size)
     194            0 :                 .block(self.config.xread_timeout.as_millis() as usize),
     195            0 :             last_id,
     196            0 :             false,
     197            0 :         )
     198            0 :         .await
     199            0 :     }
     200              : 
     201            0 :     async fn batch_read(
     202            0 :         &self,
     203            0 :         conn: &mut ConnectionWithCredentialsProvider,
     204            0 :         opts: StreamReadOptions,
     205            0 :         last_id: &mut String,
     206            0 :         return_when_finish: bool,
     207            0 :     ) -> anyhow::Result<()> {
     208            0 :         let mut total: usize = 0;
     209              :         loop {
     210            0 :             let mut res: StreamReadReply = conn
     211            0 :                 .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
     212            0 :                 .await?;
     213              : 
     214            0 :             if res.keys.is_empty() {
     215            0 :                 if return_when_finish {
     216            0 :                     if total != 0 {
     217            0 :                         break;
     218            0 :                     }
     219            0 :                     anyhow::bail!(
     220            0 :                         "Redis stream {} is empty, cannot be used to filter endpoints",
     221            0 :                         self.config.stream_name
     222            0 :                     );
     223            0 :                 }
     224            0 :                 // If we are not returning when finish, we should wait for more data.
     225            0 :                 continue;
     226            0 :             }
     227            0 :             if res.keys.len() != 1 {
     228            0 :                 anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
     229            0 :             }
     230            0 : 
     231            0 :             let key = res.keys.pop().expect("Checked length above");
     232            0 :             let len = key.ids.len();
     233            0 :             for stream_id in key.ids {
     234            0 :                 total += 1;
     235            0 :                 for value in stream_id.map.values() {
     236            0 :                     match value.try_into() {
     237            0 :                         Ok(event) => self.insert_event(event),
     238            0 :                         Err(err) => {
     239            0 :                             Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     240            0 :                                 channel: &self.config.stream_name,
     241            0 :                             });
     242            0 :                             tracing::error!("error parsing value {value:?}: {err:?}");
     243              :                         }
     244              :                     };
     245              :                 }
     246            0 :                 if total.is_power_of_two() {
     247            0 :                     tracing::debug!("endpoints read {}", total);
     248            0 :                 }
     249            0 :                 *last_id = stream_id.id;
     250              :             }
     251            0 :             if return_when_finish && len <= self.config.default_batch_size {
     252            0 :                 break;
     253            0 :             }
     254              :         }
     255            0 :         tracing::info!("read {} endpoints/branches/projects from redis", total);
     256            0 :         Ok(())
     257            0 :     }
     258              : }
     259              : 
     260              : #[cfg(test)]
     261              : mod tests {
     262              :     use super::*;
     263              : 
     264              :     #[test]
     265            1 :     fn test_parse_control_plane_event() {
     266            1 :         let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
     267            1 : 
     268            1 :         let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
     269            1 : 
     270            1 :         assert_eq!(
     271            1 :             serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
     272            1 :             ControlPlaneEvent {
     273            1 :                 endpoint_created: Some(EndpointCreated {
     274            1 :                     endpoint_id: endpoint_id.into(),
     275            1 :                 }),
     276            1 :                 branch_created: None,
     277            1 :                 project_created: None,
     278            1 :                 _type: Some("endpoint_created".into()),
     279            1 :             }
     280            1 :         );
     281            1 :     }
     282              : }
        

Generated by: LCOV version 2.1-beta