Line data Source code
1 : use std::{
2 : convert::Infallible,
3 : sync::{
4 : atomic::{AtomicBool, Ordering},
5 : Arc,
6 : },
7 : };
8 :
9 : use dashmap::DashSet;
10 : use redis::{
11 : streams::{StreamReadOptions, StreamReadReply},
12 : AsyncCommands, FromRedisValue, Value,
13 : };
14 : use serde::Deserialize;
15 : use tokio::sync::Mutex;
16 : use tracing::info;
17 :
18 : use crate::{
19 : config::EndpointCacheConfig,
20 : context::RequestMonitoring,
21 : intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
22 : metrics::{Metrics, RedisErrors},
23 : rate_limiter::GlobalRateLimiter,
24 : redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
25 : EndpointId,
26 : };
27 :
28 16 : #[derive(Deserialize, Debug, Clone)]
29 : pub struct ControlPlaneEventKey {
30 : endpoint_created: Option<EndpointCreated>,
31 : branch_created: Option<BranchCreated>,
32 : project_created: Option<ProjectCreated>,
33 : }
34 6 : #[derive(Deserialize, Debug, Clone)]
35 : struct EndpointCreated {
36 : endpoint_id: String,
37 : }
38 0 : #[derive(Deserialize, Debug, Clone)]
39 : struct BranchCreated {
40 : branch_id: String,
41 : }
42 0 : #[derive(Deserialize, Debug, Clone)]
43 : struct ProjectCreated {
44 : project_id: String,
45 : }
46 :
47 : pub struct EndpointsCache {
48 : config: EndpointCacheConfig,
49 : endpoints: DashSet<EndpointIdInt>,
50 : branches: DashSet<BranchIdInt>,
51 : projects: DashSet<ProjectIdInt>,
52 : ready: AtomicBool,
53 : limiter: Arc<Mutex<GlobalRateLimiter>>,
54 : }
55 :
56 : impl EndpointsCache {
57 0 : pub fn new(config: EndpointCacheConfig) -> Self {
58 0 : Self {
59 0 : limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
60 0 : config.limiter_info.clone(),
61 0 : ))),
62 0 : config,
63 0 : endpoints: DashSet::new(),
64 0 : branches: DashSet::new(),
65 0 : projects: DashSet::new(),
66 0 : ready: AtomicBool::new(false),
67 0 : }
68 0 : }
69 0 : pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
70 0 : if !self.ready.load(Ordering::Acquire) {
71 0 : return true;
72 0 : }
73 0 : let rejected = self.should_reject(endpoint);
74 0 : ctx.set_rejected(rejected);
75 0 : info!(?rejected, "check endpoint is valid, disabled cache");
76 : // If cache is disabled, just collect the metrics and return or
77 : // If the limiter allows, we don't need to check the cache.
78 0 : if self.config.disable_cache || self.limiter.lock().await.check() {
79 0 : return true;
80 0 : }
81 0 : !rejected
82 0 : }
83 0 : fn should_reject(&self, endpoint: &EndpointId) -> bool {
84 0 : if endpoint.is_endpoint() {
85 0 : !self.endpoints.contains(&EndpointIdInt::from(endpoint))
86 0 : } else if endpoint.is_branch() {
87 0 : !self
88 0 : .branches
89 0 : .contains(&BranchIdInt::from(&endpoint.as_branch()))
90 : } else {
91 0 : !self
92 0 : .projects
93 0 : .contains(&ProjectIdInt::from(&endpoint.as_project()))
94 : }
95 0 : }
96 0 : fn insert_event(&self, key: ControlPlaneEventKey) {
97 : // Do not do normalization here, we expect the events to be normalized.
98 0 : if let Some(endpoint_created) = key.endpoint_created {
99 0 : self.endpoints
100 0 : .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
101 0 : }
102 0 : if let Some(branch_created) = key.branch_created {
103 0 : self.branches
104 0 : .insert(BranchIdInt::from(&branch_created.branch_id.into()));
105 0 : }
106 0 : if let Some(project_created) = key.project_created {
107 0 : self.projects
108 0 : .insert(ProjectIdInt::from(&project_created.project_id.into()));
109 0 : }
110 0 : }
111 0 : pub async fn do_read(
112 0 : &self,
113 0 : mut con: ConnectionWithCredentialsProvider,
114 0 : ) -> anyhow::Result<Infallible> {
115 0 : let mut last_id = "0-0".to_string();
116 : loop {
117 0 : self.ready.store(false, Ordering::Release);
118 0 : if let Err(e) = con.connect().await {
119 0 : tracing::error!("error connecting to redis: {:?}", e);
120 0 : continue;
121 0 : }
122 0 : if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
123 0 : tracing::error!("error reading from redis: {:?}", e);
124 0 : }
125 0 : tokio::time::sleep(self.config.retry_interval).await;
126 : }
127 : }
128 0 : async fn read_from_stream(
129 0 : &self,
130 0 : con: &mut ConnectionWithCredentialsProvider,
131 0 : last_id: &mut String,
132 0 : ) -> anyhow::Result<()> {
133 0 : tracing::info!("reading endpoints/branches/projects from redis");
134 0 : self.batch_read(
135 0 : con,
136 0 : StreamReadOptions::default().count(self.config.initial_batch_size),
137 0 : last_id,
138 0 : true,
139 0 : )
140 0 : .await?;
141 0 : tracing::info!("ready to filter user requests");
142 0 : self.ready.store(true, Ordering::Release);
143 0 : self.batch_read(
144 0 : con,
145 0 : StreamReadOptions::default()
146 0 : .count(self.config.default_batch_size)
147 0 : .block(self.config.xread_timeout.as_millis() as usize),
148 0 : last_id,
149 0 : false,
150 0 : )
151 0 : .await
152 0 : }
153 0 : fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
154 0 : let s: String = FromRedisValue::from_redis_value(value)?;
155 0 : Ok(serde_json::from_str(&s)?)
156 0 : }
157 0 : async fn batch_read(
158 0 : &self,
159 0 : conn: &mut ConnectionWithCredentialsProvider,
160 0 : opts: StreamReadOptions,
161 0 : last_id: &mut String,
162 0 : return_when_finish: bool,
163 0 : ) -> anyhow::Result<()> {
164 0 : let mut total: usize = 0;
165 : loop {
166 0 : let mut res: StreamReadReply = conn
167 0 : .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
168 0 : .await?;
169 :
170 0 : if res.keys.is_empty() {
171 0 : if return_when_finish {
172 0 : if total != 0 {
173 0 : break;
174 0 : }
175 0 : anyhow::bail!(
176 0 : "Redis stream {} is empty, cannot be used to filter endpoints",
177 0 : self.config.stream_name
178 0 : );
179 0 : }
180 0 : // If we are not returning when finish, we should wait for more data.
181 0 : continue;
182 0 : }
183 0 : if res.keys.len() != 1 {
184 0 : anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
185 0 : }
186 0 :
187 0 : let res = res.keys.pop().expect("Checked length above");
188 0 : let len = res.ids.len();
189 0 : for x in res.ids {
190 0 : total += 1;
191 0 : for (_, v) in x.map {
192 0 : let key = match Self::parse_key_value(&v) {
193 0 : Ok(x) => x,
194 0 : Err(e) => {
195 0 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
196 0 : channel: &self.config.stream_name,
197 0 : });
198 0 : tracing::error!("error parsing value {v:?}: {e:?}");
199 0 : continue;
200 : }
201 : };
202 0 : self.insert_event(key);
203 : }
204 0 : if total.is_power_of_two() {
205 0 : tracing::debug!("endpoints read {}", total);
206 0 : }
207 0 : *last_id = x.id;
208 : }
209 0 : if return_when_finish && len <= self.config.default_batch_size {
210 0 : break;
211 0 : }
212 : }
213 0 : tracing::info!("read {} endpoints/branches/projects from redis", total);
214 0 : Ok(())
215 0 : }
216 : }
217 :
218 : #[cfg(test)]
219 : mod tests {
220 : use super::ControlPlaneEventKey;
221 :
222 : #[test]
223 2 : fn test() {
224 2 : let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
225 2 : let _: ControlPlaneEventKey = serde_json::from_str(s).unwrap();
226 2 : }
227 : }
|