Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::atomic::{AtomicBool, Ordering};
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use dashmap::DashSet;
7 : use redis::streams::{StreamReadOptions, StreamReadReply};
8 : use redis::{AsyncCommands, FromRedisValue, Value};
9 : use serde::Deserialize;
10 : use tokio::sync::Mutex;
11 : use tokio_util::sync::CancellationToken;
12 : use tracing::info;
13 :
14 : use crate::config::EndpointCacheConfig;
15 : use crate::context::RequestMonitoring;
16 : use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
17 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
18 : use crate::rate_limiter::GlobalRateLimiter;
19 : use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
20 : use crate::EndpointId;
21 :
22 5 : #[derive(Deserialize, Debug, Clone)]
23 : pub(crate) struct ControlPlaneEventKey {
24 : endpoint_created: Option<EndpointCreated>,
25 : branch_created: Option<BranchCreated>,
26 : project_created: Option<ProjectCreated>,
27 : }
28 2 : #[derive(Deserialize, Debug, Clone)]
29 : struct EndpointCreated {
30 : endpoint_id: String,
31 : }
32 0 : #[derive(Deserialize, Debug, Clone)]
33 : struct BranchCreated {
34 : branch_id: String,
35 : }
36 0 : #[derive(Deserialize, Debug, Clone)]
37 : struct ProjectCreated {
38 : project_id: String,
39 : }
40 :
41 : pub struct EndpointsCache {
42 : config: EndpointCacheConfig,
43 : endpoints: DashSet<EndpointIdInt>,
44 : branches: DashSet<BranchIdInt>,
45 : projects: DashSet<ProjectIdInt>,
46 : ready: AtomicBool,
47 : limiter: Arc<Mutex<GlobalRateLimiter>>,
48 : }
49 :
50 : impl EndpointsCache {
51 0 : pub(crate) fn new(config: EndpointCacheConfig) -> Self {
52 0 : Self {
53 0 : limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
54 0 : config.limiter_info.clone(),
55 0 : ))),
56 0 : config,
57 0 : endpoints: DashSet::new(),
58 0 : branches: DashSet::new(),
59 0 : projects: DashSet::new(),
60 0 : ready: AtomicBool::new(false),
61 0 : }
62 0 : }
63 0 : pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
64 0 : if !self.ready.load(Ordering::Acquire) {
65 0 : return true;
66 0 : }
67 0 : let rejected = self.should_reject(endpoint);
68 0 : ctx.set_rejected(rejected);
69 0 : info!(?rejected, "check endpoint is valid, disabled cache");
70 : // If cache is disabled, just collect the metrics and return or
71 : // If the limiter allows, we don't need to check the cache.
72 0 : if self.config.disable_cache || self.limiter.lock().await.check() {
73 0 : return true;
74 0 : }
75 0 : !rejected
76 0 : }
77 0 : fn should_reject(&self, endpoint: &EndpointId) -> bool {
78 0 : if endpoint.is_endpoint() {
79 0 : !self.endpoints.contains(&EndpointIdInt::from(endpoint))
80 0 : } else if endpoint.is_branch() {
81 0 : !self
82 0 : .branches
83 0 : .contains(&BranchIdInt::from(&endpoint.as_branch()))
84 : } else {
85 0 : !self
86 0 : .projects
87 0 : .contains(&ProjectIdInt::from(&endpoint.as_project()))
88 : }
89 0 : }
90 0 : fn insert_event(&self, key: ControlPlaneEventKey) {
91 : // Do not do normalization here, we expect the events to be normalized.
92 0 : if let Some(endpoint_created) = key.endpoint_created {
93 0 : self.endpoints
94 0 : .insert(EndpointIdInt::from(&endpoint_created.endpoint_id.into()));
95 0 : Metrics::get()
96 0 : .proxy
97 0 : .redis_events_count
98 0 : .inc(RedisEventsCount::EndpointCreated);
99 0 : }
100 0 : if let Some(branch_created) = key.branch_created {
101 0 : self.branches
102 0 : .insert(BranchIdInt::from(&branch_created.branch_id.into()));
103 0 : Metrics::get()
104 0 : .proxy
105 0 : .redis_events_count
106 0 : .inc(RedisEventsCount::BranchCreated);
107 0 : }
108 0 : if let Some(project_created) = key.project_created {
109 0 : self.projects
110 0 : .insert(ProjectIdInt::from(&project_created.project_id.into()));
111 0 : Metrics::get()
112 0 : .proxy
113 0 : .redis_events_count
114 0 : .inc(RedisEventsCount::ProjectCreated);
115 0 : }
116 0 : }
117 0 : pub async fn do_read(
118 0 : &self,
119 0 : mut con: ConnectionWithCredentialsProvider,
120 0 : cancellation_token: CancellationToken,
121 0 : ) -> anyhow::Result<Infallible> {
122 0 : let mut last_id = "0-0".to_string();
123 : loop {
124 0 : if let Err(e) = con.connect().await {
125 0 : tracing::error!("error connecting to redis: {:?}", e);
126 0 : self.ready.store(false, Ordering::Release);
127 0 : }
128 0 : if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
129 0 : tracing::error!("error reading from redis: {:?}", e);
130 0 : self.ready.store(false, Ordering::Release);
131 0 : }
132 0 : if cancellation_token.is_cancelled() {
133 0 : info!("cancellation token is cancelled, exiting");
134 0 : tokio::time::sleep(Duration::from_secs(60 * 60 * 24 * 7)).await;
135 : // 1 week.
136 0 : }
137 0 : tokio::time::sleep(self.config.retry_interval).await;
138 : }
139 : }
140 0 : async fn read_from_stream(
141 0 : &self,
142 0 : con: &mut ConnectionWithCredentialsProvider,
143 0 : last_id: &mut String,
144 0 : ) -> anyhow::Result<()> {
145 0 : tracing::info!("reading endpoints/branches/projects from redis");
146 0 : self.batch_read(
147 0 : con,
148 0 : StreamReadOptions::default().count(self.config.initial_batch_size),
149 0 : last_id,
150 0 : true,
151 0 : )
152 0 : .await?;
153 0 : tracing::info!("ready to filter user requests");
154 0 : self.ready.store(true, Ordering::Release);
155 0 : self.batch_read(
156 0 : con,
157 0 : StreamReadOptions::default()
158 0 : .count(self.config.default_batch_size)
159 0 : .block(self.config.xread_timeout.as_millis() as usize),
160 0 : last_id,
161 0 : false,
162 0 : )
163 0 : .await
164 0 : }
165 0 : fn parse_key_value(value: &Value) -> anyhow::Result<ControlPlaneEventKey> {
166 0 : let s: String = FromRedisValue::from_redis_value(value)?;
167 0 : Ok(serde_json::from_str(&s)?)
168 0 : }
169 0 : async fn batch_read(
170 0 : &self,
171 0 : conn: &mut ConnectionWithCredentialsProvider,
172 0 : opts: StreamReadOptions,
173 0 : last_id: &mut String,
174 0 : return_when_finish: bool,
175 0 : ) -> anyhow::Result<()> {
176 0 : let mut total: usize = 0;
177 : loop {
178 0 : let mut res: StreamReadReply = conn
179 0 : .xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
180 0 : .await?;
181 :
182 0 : if res.keys.is_empty() {
183 0 : if return_when_finish {
184 0 : if total != 0 {
185 0 : break;
186 0 : }
187 0 : anyhow::bail!(
188 0 : "Redis stream {} is empty, cannot be used to filter endpoints",
189 0 : self.config.stream_name
190 0 : );
191 0 : }
192 0 : // If we are not returning when finish, we should wait for more data.
193 0 : continue;
194 0 : }
195 0 : if res.keys.len() != 1 {
196 0 : anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
197 0 : }
198 0 :
199 0 : let res = res.keys.pop().expect("Checked length above");
200 0 : let len = res.ids.len();
201 0 : for x in res.ids {
202 0 : total += 1;
203 0 : for (_, v) in x.map {
204 0 : let key = match Self::parse_key_value(&v) {
205 0 : Ok(x) => x,
206 0 : Err(e) => {
207 0 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
208 0 : channel: &self.config.stream_name,
209 0 : });
210 0 : tracing::error!("error parsing value {v:?}: {e:?}");
211 0 : continue;
212 : }
213 : };
214 0 : self.insert_event(key);
215 : }
216 0 : if total.is_power_of_two() {
217 0 : tracing::debug!("endpoints read {}", total);
218 0 : }
219 0 : *last_id = x.id;
220 : }
221 0 : if return_when_finish && len <= self.config.default_batch_size {
222 0 : break;
223 0 : }
224 : }
225 0 : tracing::info!("read {} endpoints/branches/projects from redis", total);
226 0 : Ok(())
227 0 : }
228 : }
229 :
230 : #[cfg(test)]
231 : mod tests {
232 : use super::ControlPlaneEventKey;
233 :
234 : #[test]
235 1 : fn test() {
236 1 : let s = "{\"branch_created\":null,\"endpoint_created\":{\"endpoint_id\":\"ep-rapid-thunder-w0qqw2q9\"},\"project_created\":null,\"type\":\"endpoint_created\"}";
237 1 : serde_json::from_str::<ControlPlaneEventKey>(s).unwrap();
238 1 : }
239 : }
|