Line data Source code
1 : use std::convert::Infallible;
2 : use std::future::pending;
3 : use std::sync::atomic::{AtomicBool, Ordering};
4 : use std::sync::{Arc, Mutex};
5 :
6 : use clashmap::ClashSet;
7 : use redis::streams::{StreamReadOptions, StreamReadReply};
8 : use redis::{AsyncCommands, FromRedisValue, Value};
9 : use serde::Deserialize;
10 : use tokio_util::sync::CancellationToken;
11 : use tracing::info;
12 :
13 : use crate::config::EndpointCacheConfig;
14 : use crate::context::RequestContext;
15 : use crate::ext::LockExt;
16 : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
17 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
18 : use crate::rate_limiter::GlobalRateLimiter;
19 : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
20 : use crate::types::EndpointId;
21 :
22 : // TODO: this could be an enum, but events in Redis need to be fixed first.
23 : // ProjectCreated was sent with type:branch_created. So we ignore type.
24 4 : #[derive(Deserialize, Debug, Clone, PartialEq)]
25 : struct ControlPlaneEvent {
26 : endpoint_created: Option<EndpointCreated>,
27 : branch_created: Option<BranchCreated>,
28 : project_created: Option<ProjectCreated>,
29 : #[serde(rename = "type")]
30 : _type: Option<String>,
31 : }
32 :
33 1 : #[derive(Deserialize, Debug, Clone, PartialEq)]
34 : struct EndpointCreated {
35 : endpoint_id: EndpointIdInt,
36 : }
37 :
38 0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
39 : struct BranchCreated {
40 : branch_id: BranchIdInt,
41 : }
42 :
43 0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
44 : struct ProjectCreated {
45 : project_id: ProjectIdInt,
46 : }
47 :
48 : impl TryFrom<&Value> for ControlPlaneEvent {
49 : type Error = anyhow::Error;
50 0 : fn try_from(value: &Value) -> Result<Self, Self::Error> {
51 0 : let json = String::from_redis_value(value)?;
52 0 : Ok(serde_json::from_str(&json)?)
53 0 : }
54 : }
55 :
56 : pub struct EndpointsCache {
57 : config: EndpointCacheConfig,
58 : endpoints: ClashSet<EndpointIdInt>,
59 : branches: ClashSet<BranchIdInt>,
60 : projects: ClashSet<ProjectIdInt>,
61 : ready: AtomicBool,
62 : limiter: Arc<Mutex<GlobalRateLimiter>>,
63 : }
64 :
65 : impl EndpointsCache {
66 0 : pub(crate) fn new(config: EndpointCacheConfig) -> Self {
67 0 : Self {
68 0 : limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
69 0 : config.limiter_info.clone(),
70 0 : ))),
71 0 : config,
72 0 : endpoints: ClashSet::new(),
73 0 : branches: ClashSet::new(),
74 0 : projects: ClashSet::new(),
75 0 : ready: AtomicBool::new(false),
76 0 : }
77 0 : }
78 :
79 0 : pub(crate) fn is_valid(&self, ctx: &RequestContext, endpoint: &EndpointId) -> bool {
80 0 : if !self.ready.load(Ordering::Acquire) {
81 : // the endpoint cache is not yet fully initialised.
82 0 : return true;
83 0 : }
84 0 :
85 0 : if !self.should_reject(endpoint) {
86 0 : ctx.set_rejected(false);
87 0 : return true;
88 0 : }
89 0 :
90 0 : // report that we might want to reject this endpoint
91 0 : ctx.set_rejected(true);
92 0 :
93 0 : // If cache is disabled, just collect the metrics and return.
94 0 : if self.config.disable_cache {
95 0 : return true;
96 0 : }
97 0 :
98 0 : // If the limiter allows, we can pretend like it's valid
99 0 : // (incase it is, due to redis channel lag).
100 0 : if self.limiter.lock_propagate_poison().check() {
101 0 : return true;
102 0 : }
103 0 :
104 0 : // endpoint not found, and there's too much load.
105 0 : false
106 0 : }
107 :
108 0 : fn should_reject(&self, endpoint: &EndpointId) -> bool {
109 0 : if endpoint.is_endpoint() {
110 0 : let Some(endpoint) = EndpointIdInt::get(endpoint) else {
111 : // if we haven't interned this endpoint, it's not in the cache.
112 0 : return true;
113 : };
114 0 : !self.endpoints.contains(&endpoint)
115 0 : } else if endpoint.is_branch() {
116 0 : let Some(branch) = BranchIdInt::get(endpoint) else {
117 : // if we haven't interned this branch, it's not in the cache.
118 0 : return true;
119 : };
120 0 : !self.branches.contains(&branch)
121 : } else {
122 0 : let Some(project) = ProjectIdInt::get(endpoint) else {
123 : // if we haven't interned this project, it's not in the cache.
124 0 : return true;
125 : };
126 0 : !self.projects.contains(&project)
127 : }
128 0 : }
129 :
130 0 : fn insert_event(&self, event: ControlPlaneEvent) {
131 0 : if let Some(endpoint_created) = event.endpoint_created {
132 0 : self.endpoints.insert(endpoint_created.endpoint_id);
133 0 : Metrics::get()
134 0 : .proxy
135 0 : .redis_events_count
136 0 : .inc(RedisEventsCount::EndpointCreated);
137 0 : } else if let Some(branch_created) = event.branch_created {
138 0 : self.branches.insert(branch_created.branch_id);
139 0 : Metrics::get()
140 0 : .proxy
141 0 : .redis_events_count
142 0 : .inc(RedisEventsCount::BranchCreated);
143 0 : } else if let Some(project_created) = event.project_created {
144 0 : self.projects.insert(project_created.project_id);
145 0 : Metrics::get()
146 0 : .proxy
147 0 : .redis_events_count
148 0 : .inc(RedisEventsCount::ProjectCreated);
149 0 : }
150 0 : }
151 :
152 0 : pub async fn do_read(
153 0 : &self,
154 0 : mut con: ConnectionWithCredentialsProvider,
155 0 : cancellation_token: CancellationToken,
156 0 : ) -> anyhow::Result<Infallible> {
157 0 : let mut last_id = "0-0".to_string();
158 : loop {
159 0 : if let Err(e) = con.connect().await {
160 0 : tracing::error!("error connecting to redis: {:?}", e);
161 0 : self.ready.store(false, Ordering::Release);
162 0 : }
163 0 : if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
164 0 : tracing::error!("error reading from redis: {:?}", e);
165 0 : self.ready.store(false, Ordering::Release);
166 0 : }
167 0 : if cancellation_token.is_cancelled() {
168 0 : info!("cancellation token is cancelled, exiting");
169 : // Maintenance tasks run forever. Sleep forever when canceled.
170 0 : pending::<()>().await;
171 0 : }
172 0 : tokio::time::sleep(self.config.retry_interval).await;
173 : }
174 : }
175 :
176 0 : async fn read_from_stream(
177 0 : &self,
178 0 : con: &mut ConnectionWithCredentialsProvider,
179 0 : last_id: &mut String,
180 0 : ) -> anyhow::Result<()> {
181 0 : tracing::info!("reading endpoints/branches/projects from redis");
182 0 : self.batch_read(
183 0 : con,
184 0 : StreamReadOptions::default().count(self.config.initial_batch_size),
185 0 : last_id,
186 0 : true,
187 0 : )
188 0 : .await?;
189 0 : tracing::info!("ready to filter user requests");
190 0 : self.ready.store(true, Ordering::Release);
191 0 : self.batch_read(
192 0 : con,
193 0 : StreamReadOptions::default()
194 0 : .count(self.config.default_batch_size)
195 0 : .block(self.config.xread_timeout.as_millis() as usize),
196 0 : last_id,
197 0 : false,
198 0 : )
199 0 : .await
200 0 : }
201 :
202 0 : async fn batch_read(
203 0 : &self,
204 0 : conn: &mut ConnectionWithCredentialsProvider,
205 0 : opts: StreamReadOptions,
206 0 : last_id: &mut String,
207 0 : return_when_finish: bool,
208 0 : ) -> anyhow::Result<()> {
209 0 : let mut total: usize = 0;
210 : loop {
211 0 : let mut res: StreamReadReply = conn
212 0 : .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
213 0 : .await?;
214 :
215 0 : if res.keys.is_empty() {
216 0 : if return_when_finish {
217 0 : if total != 0 {
218 0 : break;
219 0 : }
220 0 : anyhow::bail!(
221 0 : "Redis stream {} is empty, cannot be used to filter endpoints",
222 0 : self.config.stream_name
223 0 : );
224 0 : }
225 0 : // If we are not returning when finish, we should wait for more data.
226 0 : continue;
227 0 : }
228 0 : if res.keys.len() != 1 {
229 0 : anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
230 0 : }
231 0 :
232 0 : let key = res.keys.pop().expect("Checked length above");
233 0 : let len = key.ids.len();
234 0 : for stream_id in key.ids {
235 0 : total += 1;
236 0 : for value in stream_id.map.values() {
237 0 : match value.try_into() {
238 0 : Ok(event) => self.insert_event(event),
239 0 : Err(err) => {
240 0 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
241 0 : channel: &self.config.stream_name,
242 0 : });
243 0 : tracing::error!("error parsing value {value:?}: {err:?}");
244 : }
245 : }
246 : }
247 0 : if total.is_power_of_two() {
248 0 : tracing::debug!("endpoints read {}", total);
249 0 : }
250 0 : *last_id = stream_id.id;
251 : }
252 0 : if return_when_finish && len <= self.config.default_batch_size {
253 0 : break;
254 0 : }
255 : }
256 0 : tracing::info!("read {} endpoints/branches/projects from redis", total);
257 0 : Ok(())
258 0 : }
259 : }
260 :
261 : #[cfg(test)]
262 : #[expect(clippy::unwrap_used)]
263 : mod tests {
264 : use super::*;
265 :
266 : #[test]
267 1 : fn test_parse_control_plane_event() {
268 1 : let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
269 1 :
270 1 : let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
271 1 :
272 1 : assert_eq!(
273 1 : serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
274 1 : ControlPlaneEvent {
275 1 : endpoint_created: Some(EndpointCreated {
276 1 : endpoint_id: endpoint_id.into(),
277 1 : }),
278 1 : branch_created: None,
279 1 : project_created: None,
280 1 : _type: Some("endpoint_created".into()),
281 1 : }
282 1 : );
283 1 : }
284 : }
|