LCOV - code coverage report
Current view: top level - proxy/src/cache - endpoints.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 9.9 % 192 19
Test Date: 2025-02-20 13:11:02 Functions: 9.4 % 32 3

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : use std::future::pending;
       3              : use std::sync::atomic::{AtomicBool, Ordering};
       4              : use std::sync::{Arc, Mutex};
       5              : 
       6              : use clashmap::ClashSet;
       7              : use redis::streams::{StreamReadOptions, StreamReadReply};
       8              : use redis::{AsyncCommands, FromRedisValue, Value};
       9              : use serde::Deserialize;
      10              : use tokio_util::sync::CancellationToken;
      11              : use tracing::info;
      12              : 
      13              : use crate::config::EndpointCacheConfig;
      14              : use crate::context::RequestContext;
      15              : use crate::ext::LockExt;
      16              : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
      17              : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
      18              : use crate::rate_limiter::GlobalRateLimiter;
      19              : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
      20              : use crate::types::EndpointId;
      21              : 
      22              : // TODO: this could be an enum, but events in Redis need to be fixed first.
      23              : // ProjectCreated was sent with type:branch_created. So we ignore type.
      24            4 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      25              : struct ControlPlaneEvent {
      26              :     endpoint_created: Option<EndpointCreated>,
      27              :     branch_created: Option<BranchCreated>,
      28              :     project_created: Option<ProjectCreated>,
      29              :     #[serde(rename = "type")]
      30              :     _type: Option<String>,
      31              : }
      32              : 
      33            1 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      34              : struct EndpointCreated {
      35              :     endpoint_id: EndpointIdInt,
      36              : }
      37              : 
      38            0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      39              : struct BranchCreated {
      40              :     branch_id: BranchIdInt,
      41              : }
      42              : 
      43            0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
      44              : struct ProjectCreated {
      45              :     project_id: ProjectIdInt,
      46              : }
      47              : 
      48              : impl TryFrom<&Value> for ControlPlaneEvent {
      49              :     type Error = anyhow::Error;
      50            0 :     fn try_from(value: &Value) -> Result<Self, Self::Error> {
      51            0 :         let json = String::from_redis_value(value)?;
      52            0 :         Ok(serde_json::from_str(&json)?)
      53            0 :     }
      54              : }
      55              : 
      56              : pub struct EndpointsCache {
      57              :     config: EndpointCacheConfig,
      58              :     endpoints: ClashSet<EndpointIdInt>,
      59              :     branches: ClashSet<BranchIdInt>,
      60              :     projects: ClashSet<ProjectIdInt>,
      61              :     ready: AtomicBool,
      62              :     limiter: Arc<Mutex<GlobalRateLimiter>>,
      63              : }
      64              : 
      65              : impl EndpointsCache {
      66            0 :     pub(crate) fn new(config: EndpointCacheConfig) -> Self {
      67            0 :         Self {
      68            0 :             limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
      69            0 :                 config.limiter_info.clone(),
      70            0 :             ))),
      71            0 :             config,
      72            0 :             endpoints: ClashSet::new(),
      73            0 :             branches: ClashSet::new(),
      74            0 :             projects: ClashSet::new(),
      75            0 :             ready: AtomicBool::new(false),
      76            0 :         }
      77            0 :     }
      78              : 
      79            0 :     pub(crate) fn is_valid(&self, ctx: &RequestContext, endpoint: &EndpointId) -> bool {
      80            0 :         if !self.ready.load(Ordering::Acquire) {
      81              :             // the endpoint cache is not yet fully initialised.
      82            0 :             return true;
      83            0 :         }
      84            0 : 
      85            0 :         if !self.should_reject(endpoint) {
      86            0 :             ctx.set_rejected(false);
      87            0 :             return true;
      88            0 :         }
      89            0 : 
      90            0 :         // report that we might want to reject this endpoint
      91            0 :         ctx.set_rejected(true);
      92            0 : 
      93            0 :         // If cache is disabled, just collect the metrics and return.
      94            0 :         if self.config.disable_cache {
      95            0 :             return true;
      96            0 :         }
      97            0 : 
      98            0 :         // If the limiter allows, we can pretend like it's valid
      99            0 :         // (incase it is, due to redis channel lag).
     100            0 :         if self.limiter.lock_propagate_poison().check() {
     101            0 :             return true;
     102            0 :         }
     103            0 : 
     104            0 :         // endpoint not found, and there's too much load.
     105            0 :         false
     106            0 :     }
     107              : 
     108            0 :     fn should_reject(&self, endpoint: &EndpointId) -> bool {
     109            0 :         if endpoint.is_endpoint() {
     110            0 :             let Some(endpoint) = EndpointIdInt::get(endpoint) else {
     111              :                 // if we haven't interned this endpoint, it's not in the cache.
     112            0 :                 return true;
     113              :             };
     114            0 :             !self.endpoints.contains(&endpoint)
     115            0 :         } else if endpoint.is_branch() {
     116            0 :             let Some(branch) = BranchIdInt::get(endpoint) else {
     117              :                 // if we haven't interned this branch, it's not in the cache.
     118            0 :                 return true;
     119              :             };
     120            0 :             !self.branches.contains(&branch)
     121              :         } else {
     122            0 :             let Some(project) = ProjectIdInt::get(endpoint) else {
     123              :                 // if we haven't interned this project, it's not in the cache.
     124            0 :                 return true;
     125              :             };
     126            0 :             !self.projects.contains(&project)
     127              :         }
     128            0 :     }
     129              : 
     130            0 :     fn insert_event(&self, event: ControlPlaneEvent) {
     131            0 :         if let Some(endpoint_created) = event.endpoint_created {
     132            0 :             self.endpoints.insert(endpoint_created.endpoint_id);
     133            0 :             Metrics::get()
     134            0 :                 .proxy
     135            0 :                 .redis_events_count
     136            0 :                 .inc(RedisEventsCount::EndpointCreated);
     137            0 :         } else if let Some(branch_created) = event.branch_created {
     138            0 :             self.branches.insert(branch_created.branch_id);
     139            0 :             Metrics::get()
     140            0 :                 .proxy
     141            0 :                 .redis_events_count
     142            0 :                 .inc(RedisEventsCount::BranchCreated);
     143            0 :         } else if let Some(project_created) = event.project_created {
     144            0 :             self.projects.insert(project_created.project_id);
     145            0 :             Metrics::get()
     146            0 :                 .proxy
     147            0 :                 .redis_events_count
     148            0 :                 .inc(RedisEventsCount::ProjectCreated);
     149            0 :         }
     150            0 :     }
     151              : 
     152            0 :     pub async fn do_read(
     153            0 :         &self,
     154            0 :         mut con: ConnectionWithCredentialsProvider,
     155            0 :         cancellation_token: CancellationToken,
     156            0 :     ) -> anyhow::Result<Infallible> {
     157            0 :         let mut last_id = "0-0".to_string();
     158              :         loop {
     159            0 :             if let Err(e) = con.connect().await {
     160            0 :                 tracing::error!("error connecting to redis: {:?}", e);
     161            0 :                 self.ready.store(false, Ordering::Release);
     162            0 :             }
     163            0 :             if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
     164            0 :                 tracing::error!("error reading from redis: {:?}", e);
     165            0 :                 self.ready.store(false, Ordering::Release);
     166            0 :             }
     167            0 :             if cancellation_token.is_cancelled() {
     168            0 :                 info!("cancellation token is cancelled, exiting");
     169              :                 // Maintenance tasks run forever. Sleep forever when canceled.
     170            0 :                 pending::<()>().await;
     171            0 :             }
     172            0 :             tokio::time::sleep(self.config.retry_interval).await;
     173              :         }
     174              :     }
     175              : 
     176            0 :     async fn read_from_stream(
     177            0 :         &self,
     178            0 :         con: &mut ConnectionWithCredentialsProvider,
     179            0 :         last_id: &mut String,
     180            0 :     ) -> anyhow::Result<()> {
     181            0 :         tracing::info!("reading endpoints/branches/projects from redis");
     182            0 :         self.batch_read(
     183            0 :             con,
     184            0 :             StreamReadOptions::default().count(self.config.initial_batch_size),
     185            0 :             last_id,
     186            0 :             true,
     187            0 :         )
     188            0 :         .await?;
     189            0 :         tracing::info!("ready to filter user requests");
     190            0 :         self.ready.store(true, Ordering::Release);
     191            0 :         self.batch_read(
     192            0 :             con,
     193            0 :             StreamReadOptions::default()
     194            0 :                 .count(self.config.default_batch_size)
     195            0 :                 .block(self.config.xread_timeout.as_millis() as usize),
     196            0 :             last_id,
     197            0 :             false,
     198            0 :         )
     199            0 :         .await
     200            0 :     }
     201              : 
     202            0 :     async fn batch_read(
     203            0 :         &self,
     204            0 :         conn: &mut ConnectionWithCredentialsProvider,
     205            0 :         opts: StreamReadOptions,
     206            0 :         last_id: &mut String,
     207            0 :         return_when_finish: bool,
     208            0 :     ) -> anyhow::Result<()> {
     209            0 :         let mut total: usize = 0;
     210              :         loop {
     211            0 :             let mut res: StreamReadReply = conn
     212            0 :                 .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
     213            0 :                 .await?;
     214              : 
     215            0 :             if res.keys.is_empty() {
     216            0 :                 if return_when_finish {
     217            0 :                     if total != 0 {
     218            0 :                         break;
     219            0 :                     }
     220            0 :                     anyhow::bail!(
     221            0 :                         "Redis stream {} is empty, cannot be used to filter endpoints",
     222            0 :                         self.config.stream_name
     223            0 :                     );
     224            0 :                 }
     225            0 :                 // If we are not returning when finish, we should wait for more data.
     226            0 :                 continue;
     227            0 :             }
     228            0 :             if res.keys.len() != 1 {
     229            0 :                 anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
     230            0 :             }
     231            0 : 
     232            0 :             let key = res.keys.pop().expect("Checked length above");
     233            0 :             let len = key.ids.len();
     234            0 :             for stream_id in key.ids {
     235            0 :                 total += 1;
     236            0 :                 for value in stream_id.map.values() {
     237            0 :                     match value.try_into() {
     238            0 :                         Ok(event) => self.insert_event(event),
     239            0 :                         Err(err) => {
     240            0 :                             Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
     241            0 :                                 channel: &self.config.stream_name,
     242            0 :                             });
     243            0 :                             tracing::error!("error parsing value {value:?}: {err:?}");
     244              :                         }
     245              :                     }
     246              :                 }
     247            0 :                 if total.is_power_of_two() {
     248            0 :                     tracing::debug!("endpoints read {}", total);
     249            0 :                 }
     250            0 :                 *last_id = stream_id.id;
     251              :             }
     252            0 :             if return_when_finish && len <= self.config.default_batch_size {
     253            0 :                 break;
     254            0 :             }
     255              :         }
     256            0 :         tracing::info!("read {} endpoints/branches/projects from redis", total);
     257            0 :         Ok(())
     258            0 :     }
     259              : }
     260              : 
     261              : #[cfg(test)]
     262              : #[expect(clippy::unwrap_used)]
     263              : mod tests {
     264              :     use super::*;
     265              : 
     266              :     #[test]
     267            1 :     fn test_parse_control_plane_event() {
     268            1 :         let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
     269            1 : 
     270            1 :         let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
     271            1 : 
     272            1 :         assert_eq!(
     273            1 :             serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
     274            1 :             ControlPlaneEvent {
     275            1 :                 endpoint_created: Some(EndpointCreated {
     276            1 :                     endpoint_id: endpoint_id.into(),
     277            1 :                 }),
     278            1 :                 branch_created: None,
     279            1 :                 project_created: None,
     280            1 :                 _type: Some("endpoint_created".into()),
     281            1 :             }
     282            1 :         );
     283            1 :     }
     284              : }
        

Generated by: LCOV version 2.1-beta