Line data Source code
1 : use std::{
2 : convert::Infallible,
3 : sync::{
4 : atomic::{AtomicBool, Ordering},
5 : Arc,
6 : },
7 : time::Duration,
8 : };
9 :
10 : use dashmap::DashSet;
11 : use redis::{
12 : streams::{StreamReadOptions, StreamReadReply},
13 : AsyncCommands, FromRedisValue, Value,
14 : };
15 : use serde::Deserialize;
16 : use tokio::sync::Mutex;
17 : use tokio_util::sync::CancellationToken;
18 : use tracing::info;
19 :
20 : use crate::{
21 : config::EndpointCacheConfig,
22 : context::RequestMonitoring,
23 : intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
24 : metrics::{Metrics, RedisErrors, RedisEventsCount},
25 : rate_limiter::GlobalRateLimiter,
26 : redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
27 : EndpointId,
28 : };
29 :
30 5 : #[derive(Deserialize, Debug, Clone)]
31 : pub(crate) struct ControlPlaneEventKey {
32 : endpoint_created: Option<EndpointCreated>,
33 : branch_created: Option<BranchCreated>,
34 : project_created: Option<ProjectCreated>,
35 : }
36 2 : #[derive(Deserialize, Debug, Clone)]
37 : struct EndpointCreated {
38 : endpoint_id: String,
39 : }
40 0 : #[derive(Deserialize, Debug, Clone)]
41 : struct BranchCreated {
42 : branch_id: String,
43 : }
44 0 : #[derive(Deserialize, Debug, Clone)]
45 : struct ProjectCreated {
46 : project_id: String,
47 : }
48 :
49 : pub struct EndpointsCache {
50 : config: EndpointCacheConfig,
51 : endpoints: DashSet<EndpointIdInt>,
52 : branches: DashSet<BranchIdInt>,
53 : projects: DashSet<ProjectIdInt>,
54 : ready: AtomicBool,
55 : limiter: Arc<Mutex<GlobalRateLimiter>>,
56 : }
57 :
58 : impl EndpointsCache {
59 0 : pub(crate) fn new(config: EndpointCacheConfig) -> Self {
60 0 : Self {
61 0 : limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
62 0 : config.limiter_info.clone(),
63 0 : ))),
64 0 : config,
65 0 : endpoints: DashSet::new(),
66 0 : branches: DashSet::new(),
67 0 : projects: DashSet::new(),
68 0 : ready: AtomicBool::new(false),
69 0 : }
70 0 : }
71 0 : pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
72 0 : if !self.ready.load(Ordering::Acquire) {
73 0 : return true;
74 0 : }
75 0 : let rejected = self.should_reject(endpoint);
76 0 : ctx.set_rejected(rejected);
77 0 : info!(?rejected, "check endpoint is valid, disabled cache");
78 : // If cache is disabled, just collect the metrics and return or
79 : // If the limiter allows, we don't need to check the cache.
80 0 : if self.config.disable_cache || self.limiter.lock().await.check() {
81 0 : return true;
82 0 : }
83 0 : !rejected
84 0 : }
85 0 : fn should_reject(&self, endpoint: &EndpointId) -> bool {
86 0 : if endpoint.is_endpoint() {
87 0 : !self.endpoints.contains(&EndpointIdInt::from(endpoint))
88 0 : } else if endpoint.is_branch() {
89 0 : !self
90 0 : .branches
91 0 : .contains(&BranchIdInt::from(&endpoint.as_branch()))
92 : } else {
93 0 : !self
94 0 : .projects
95 0 : .contains(&ProjectIdInt::from(&endpoint.as_project()))
96 : }
97 0 : }
98 0 : fn insert_event(&self, key: ControlPlaneEventKey) {
99 : // Do not do normalization here, we expect the events to be normalized.
100 0 : if let Some(endpoint_created) = key.endpoint_created {
101 0 : self.endpoints
102 0 : .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
103 0 : Metrics::get()
104 0 : .proxy
105 0 : .redis_events_count
106 0 : .inc(RedisEventsCount::EndpointCreated);
107 0 : }
108 0 : if let Some(branch_created) = key.branch_created {
109 0 : self.branches
110 0 : .insert(BranchIdInt::from(&branch_created.branch_id.into()));
111 0 : Metrics::get()
112 0 : .proxy
113 0 : .redis_events_count
114 0 : .inc(RedisEventsCount::BranchCreated);
115 0 : }
116 0 : if let Some(project_created) = key.project_created {
117 0 : self.projects
118 0 : .insert(ProjectIdInt::from(&project_created.project_id.into()));
119 0 : Metrics::get()
120 0 : .proxy
121 0 : .redis_events_count
122 0 : .inc(RedisEventsCount::ProjectCreated);
123 0 : }
124 0 : }
125 0 : pub async fn do_read(
126 0 : &self,
127 0 : mut con: ConnectionWithCredentialsProvider,
128 0 : cancellation_token: CancellationToken,
129 0 : ) -> anyhow::Result<Infallible> {
130 0 : let mut last_id = "0-0".to_string();
131 : loop {
132 0 : if let Err(e) = con.connect().await {
133 0 : tracing::error!("error connecting to redis: {:?}", e);
134 0 : self.ready.store(false, Ordering::Release);
135 0 : }
136 0 : if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
137 0 : tracing::error!("error reading from redis: {:?}", e);
138 0 : self.ready.store(false, Ordering::Release);
139 0 : }
140 0 : if cancellation_token.is_cancelled() {
141 0 : info!("cancellation token is cancelled, exiting");
142 0 : tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await;
143 : // 1 week.
144 0 : }
145 0 : tokio::time::sleep(self.config.retry_interval).await;
146 : }
147 : }
148 0 : async fn read_from_stream(
149 0 : &self,
150 0 : con: &mut ConnectionWithCredentialsProvider,
151 0 : last_id: &mut String,
152 0 : ) -> anyhow::Result<()> {
153 0 : tracing::info!("reading endpoints/branches/projects from redis");
154 0 : self.batch_read(
155 0 : con,
156 0 : StreamReadOptions::default().count(self.config.initial_batch_size),
157 0 : last_id,
158 0 : true,
159 0 : )
160 0 : .await?;
161 0 : tracing::info!("ready to filter user requests");
162 0 : self.ready.store(true, Ordering::Release);
163 0 : self.batch_read(
164 0 : con,
165 0 : StreamReadOptions::default()
166 0 : .count(self.config.default_batch_size)
167 0 : .block(self.config.xread_timeout.as_millis() as usize),
168 0 : last_id,
169 0 : false,
170 0 : )
171 0 : .await
172 0 : }
173 0 : fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
174 0 : let s: String = FromRedisValue::from_redis_value(value)?;
175 0 : Ok(serde_json::from_str(&s)?)
176 0 : }
177 0 : async fn batch_read(
178 0 : &self,
179 0 : conn: &mut ConnectionWithCredentialsProvider,
180 0 : opts: StreamReadOptions,
181 0 : last_id: &mut String,
182 0 : return_when_finish: bool,
183 0 : ) -> anyhow::Result<()> {
184 0 : let mut total: usize = 0;
185 : loop {
186 0 : let mut res: StreamReadReply = conn
187 0 : .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
188 0 : .await?;
189 :
190 0 : if res.keys.is_empty() {
191 0 : if return_when_finish {
192 0 : if total != 0 {
193 0 : break;
194 0 : }
195 0 : anyhow::bail!(
196 0 : "Redis stream {} is empty, cannot be used to filter endpoints",
197 0 : self.config.stream_name
198 0 : );
199 0 : }
200 0 : // If we are not returning when finish, we should wait for more data.
201 0 : continue;
202 0 : }
203 0 : if res.keys.len() != 1 {
204 0 : anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
205 0 : }
206 0 :
207 0 : let res = res.keys.pop().expect("Checked length above");
208 0 : let len = res.ids.len();
209 0 : for x in res.ids {
210 0 : total += 1;
211 0 : for (_, v) in x.map {
212 0 : let key = match Self::parse_key_value(&v) {
213 0 : Ok(x) => x,
214 0 : Err(e) => {
215 0 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
216 0 : channel: &self.config.stream_name,
217 0 : });
218 0 : tracing::error!("error parsing value {v:?}: {e:?}");
219 0 : continue;
220 : }
221 : };
222 0 : self.insert_event(key);
223 : }
224 0 : if total.is_power_of_two() {
225 0 : tracing::debug!("endpoints read {}", total);
226 0 : }
227 0 : *last_id = x.id;
228 : }
229 0 : if return_when_finish && len <= self.config.default_batch_size {
230 0 : break;
231 0 : }
232 : }
233 0 : tracing::info!("read {} endpoints/branches/projects from redis", total);
234 0 : Ok(())
235 0 : }
236 : }
237 :
238 : #[cfg(test)]
239 : mod tests {
240 : use super::ControlPlaneEventKey;
241 :
242 : #[test]
243 1 : fn test() {
244 1 : let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
245 1 : serde_json::from_str::<ControlPlaneEventKey>(s).unwrap();
246 1 : }
247 : }
|