Line data Source code
1 : use std::convert::Infallible;
2 : use std::future::pending;
3 : use std::sync::atomic::{AtomicBool, Ordering};
4 : use std::sync::{Arc, Mutex};
5 :
6 : use dashmap::DashSet;
7 : use redis::streams::{StreamReadOptions, StreamReadReply};
8 : use redis::{AsyncCommands, FromRedisValue, Value};
9 : use serde::Deserialize;
10 : use tokio_util::sync::CancellationToken;
11 : use tracing::info;
12 :
13 : use crate::config::EndpointCacheConfig;
14 : use crate::context::RequestMonitoring;
15 : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
16 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
17 : use crate::rate_limiter::GlobalRateLimiter;
18 : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
19 : use crate::types::EndpointId;
20 :
21 : // TODO: this could be an enum, but events in Redis need to be fixed first.
22 : // ProjectCreated was sent with type:branch_created. So we ignore type.
23 5 : #[derive(Deserialize, Debug, Clone, PartialEq)]
24 : struct ControlPlaneEvent {
25 : endpoint_created: Option<EndpointCreated>,
26 : branch_created: Option<BranchCreated>,
27 : project_created: Option<ProjectCreated>,
28 : #[serde(rename = "type")]
29 : _type: Option<String>,
30 : }
31 :
32 2 : #[derive(Deserialize, Debug, Clone, PartialEq)]
33 : struct EndpointCreated {
34 : endpoint_id: EndpointIdInt,
35 : }
36 :
37 0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
38 : struct BranchCreated {
39 : branch_id: BranchIdInt,
40 : }
41 :
42 0 : #[derive(Deserialize, Debug, Clone, PartialEq)]
43 : struct ProjectCreated {
44 : project_id: ProjectIdInt,
45 : }
46 :
47 : impl TryFrom<&Value> for ControlPlaneEvent {
48 : type Error = anyhow::Error;
49 0 : fn try_from(value: &Value) -> Result<Self, Self::Error> {
50 0 : let json = String::from_redis_value(value)?;
51 0 : Ok(serde_json::from_str(&json)?)
52 0 : }
53 : }
54 :
55 : pub struct EndpointsCache {
56 : config: EndpointCacheConfig,
57 : endpoints: DashSet<EndpointIdInt>,
58 : branches: DashSet<BranchIdInt>,
59 : projects: DashSet<ProjectIdInt>,
60 : ready: AtomicBool,
61 : limiter: Arc<Mutex<GlobalRateLimiter>>,
62 : }
63 :
64 : impl EndpointsCache {
65 0 : pub(crate) fn new(config: EndpointCacheConfig) -> Self {
66 0 : Self {
67 0 : limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
68 0 : config.limiter_info.clone(),
69 0 : ))),
70 0 : config,
71 0 : endpoints: DashSet::new(),
72 0 : branches: DashSet::new(),
73 0 : projects: DashSet::new(),
74 0 : ready: AtomicBool::new(false),
75 0 : }
76 0 : }
77 :
78 0 : pub(crate) fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
79 0 : if !self.ready.load(Ordering::Acquire) {
80 : // the endpoint cache is not yet fully initialised.
81 0 : return true;
82 0 : }
83 0 :
84 0 : if !self.should_reject(endpoint) {
85 0 : ctx.set_rejected(false);
86 0 : return true;
87 0 : }
88 0 :
89 0 : // report that we might want to reject this endpoint
90 0 : ctx.set_rejected(true);
91 0 :
92 0 : // If cache is disabled, just collect the metrics and return.
93 0 : if self.config.disable_cache {
94 0 : return true;
95 0 : }
96 0 :
97 0 : // If the limiter allows, we can pretend like it's valid
98 0 : // (incase it is, due to redis channel lag).
99 0 : if self.limiter.lock().unwrap().check() {
100 0 : return true;
101 0 : }
102 0 :
103 0 : // endpoint not found, and there's too much load.
104 0 : false
105 0 : }
106 :
107 0 : fn should_reject(&self, endpoint: &EndpointId) -> bool {
108 0 : if endpoint.is_endpoint() {
109 0 : let Some(endpoint) = EndpointIdInt::get(endpoint) else {
110 : // if we haven't interned this endpoint, it's not in the cache.
111 0 : return true;
112 : };
113 0 : !self.endpoints.contains(&endpoint)
114 0 : } else if endpoint.is_branch() {
115 0 : let Some(branch) = BranchIdInt::get(endpoint) else {
116 : // if we haven't interned this branch, it's not in the cache.
117 0 : return true;
118 : };
119 0 : !self.branches.contains(&branch)
120 : } else {
121 0 : let Some(project) = ProjectIdInt::get(endpoint) else {
122 : // if we haven't interned this project, it's not in the cache.
123 0 : return true;
124 : };
125 0 : !self.projects.contains(&project)
126 : }
127 0 : }
128 :
129 0 : fn insert_event(&self, event: ControlPlaneEvent) {
130 0 : if let Some(endpoint_created) = event.endpoint_created {
131 0 : self.endpoints.insert(endpoint_created.endpoint_id);
132 0 : Metrics::get()
133 0 : .proxy
134 0 : .redis_events_count
135 0 : .inc(RedisEventsCount::EndpointCreated);
136 0 : } else if let Some(branch_created) = event.branch_created {
137 0 : self.branches.insert(branch_created.branch_id);
138 0 : Metrics::get()
139 0 : .proxy
140 0 : .redis_events_count
141 0 : .inc(RedisEventsCount::BranchCreated);
142 0 : } else if let Some(project_created) = event.project_created {
143 0 : self.projects.insert(project_created.project_id);
144 0 : Metrics::get()
145 0 : .proxy
146 0 : .redis_events_count
147 0 : .inc(RedisEventsCount::ProjectCreated);
148 0 : }
149 0 : }
150 :
151 0 : pub async fn do_read(
152 0 : &self,
153 0 : mut con: ConnectionWithCredentialsProvider,
154 0 : cancellation_token: CancellationToken,
155 0 : ) -> anyhow::Result<Infallible> {
156 0 : let mut last_id = "0-0".to_string();
157 : loop {
158 0 : if let Err(e) = con.connect().await {
159 0 : tracing::error!("error connecting to redis: {:?}", e);
160 0 : self.ready.store(false, Ordering::Release);
161 0 : }
162 0 : if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
163 0 : tracing::error!("error reading from redis: {:?}", e);
164 0 : self.ready.store(false, Ordering::Release);
165 0 : }
166 0 : if cancellation_token.is_cancelled() {
167 0 : info!("cancellation token is cancelled, exiting");
168 : // Maintenance tasks run forever. Sleep forever when canceled.
169 0 : pending::<()>().await;
170 0 : }
171 0 : tokio::time::sleep(self.config.retry_interval).await;
172 : }
173 : }
174 :
175 0 : async fn read_from_stream(
176 0 : &self,
177 0 : con: &mut ConnectionWithCredentialsProvider,
178 0 : last_id: &mut String,
179 0 : ) -> anyhow::Result<()> {
180 0 : tracing::info!("reading endpoints/branches/projects from redis");
181 0 : self.batch_read(
182 0 : con,
183 0 : StreamReadOptions::default().count(self.config.initial_batch_size),
184 0 : last_id,
185 0 : true,
186 0 : )
187 0 : .await?;
188 0 : tracing::info!("ready to filter user requests");
189 0 : self.ready.store(true, Ordering::Release);
190 0 : self.batch_read(
191 0 : con,
192 0 : StreamReadOptions::default()
193 0 : .count(self.config.default_batch_size)
194 0 : .block(self.config.xread_timeout.as_millis() as usize),
195 0 : last_id,
196 0 : false,
197 0 : )
198 0 : .await
199 0 : }
200 :
201 0 : async fn batch_read(
202 0 : &self,
203 0 : conn: &mut ConnectionWithCredentialsProvider,
204 0 : opts: StreamReadOptions,
205 0 : last_id: &mut String,
206 0 : return_when_finish: bool,
207 0 : ) -> anyhow::Result<()> {
208 0 : let mut total: usize = 0;
209 : loop {
210 0 : let mut res: StreamReadReply = conn
211 0 : .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
212 0 : .await?;
213 :
214 0 : if res.keys.is_empty() {
215 0 : if return_when_finish {
216 0 : if total != 0 {
217 0 : break;
218 0 : }
219 0 : anyhow::bail!(
220 0 : "Redis stream {} is empty, cannot be used to filter endpoints",
221 0 : self.config.stream_name
222 0 : );
223 0 : }
224 0 : // If we are not returning when finish, we should wait for more data.
225 0 : continue;
226 0 : }
227 0 : if res.keys.len() != 1 {
228 0 : anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
229 0 : }
230 0 :
231 0 : let key = res.keys.pop().expect("Checked length above");
232 0 : let len = key.ids.len();
233 0 : for stream_id in key.ids {
234 0 : total += 1;
235 0 : for value in stream_id.map.values() {
236 0 : match value.try_into() {
237 0 : Ok(event) => self.insert_event(event),
238 0 : Err(err) => {
239 0 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
240 0 : channel: &self.config.stream_name,
241 0 : });
242 0 : tracing::error!("error parsing value {value:?}: {err:?}");
243 : }
244 : };
245 : }
246 0 : if total.is_power_of_two() {
247 0 : tracing::debug!("endpoints read {}", total);
248 0 : }
249 0 : *last_id = stream_id.id;
250 : }
251 0 : if return_when_finish && len <= self.config.default_batch_size {
252 0 : break;
253 0 : }
254 : }
255 0 : tracing::info!("read {} endpoints/branches/projects from redis", total);
256 0 : Ok(())
257 0 : }
258 : }
259 :
260 : #[cfg(test)]
261 : mod tests {
262 : use super::*;
263 :
264 : #[test]
265 1 : fn test_parse_control_plane_event() {
266 1 : let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
267 1 :
268 1 : let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
269 1 :
270 1 : assert_eq!(
271 1 : serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
272 1 : ControlPlaneEvent {
273 1 : endpoint_created: Some(EndpointCreated {
274 1 : endpoint_id: endpoint_id.into(),
275 1 : }),
276 1 : branch_created: None,
277 1 : project_created: None,
278 1 : _type: Some("endpoint_created".into()),
279 1 : }
280 1 : );
281 1 : }
282 : }
|