Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use redis::aio::PubSub;
6 : use serde::Deserialize;
7 : use tokio_util::sync::CancellationToken;
8 :
9 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
10 : use crate::cache::project_info::ProjectInfoCache;
11 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
12 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
13 :
14 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
15 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
16 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
17 :
18 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
19 0 : let mut conn = client.get_async_pubsub().await?;
20 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
21 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
22 0 : Ok(conn)
23 0 : }
24 :
25 0 : #[derive(Debug, Deserialize)]
26 : struct NotificationHeader<'a> {
27 : topic: &'a str,
28 : }
29 :
30 8 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
31 : #[serde(tag = "topic", content = "data")]
32 : enum Notification {
33 : #[serde(
34 : rename = "/account_settings_update",
35 : alias = "/allowed_vpc_endpoints_updated_for_org",
36 : deserialize_with = "deserialize_json_string"
37 : )]
38 : AccountSettingsUpdate(InvalidateAccount),
39 :
40 : #[serde(
41 : rename = "/endpoint_settings_update",
42 : deserialize_with = "deserialize_json_string"
43 : )]
44 : EndpointSettingsUpdate(InvalidateEndpoint),
45 :
46 : #[serde(
47 : rename = "/project_settings_update",
48 : alias = "/allowed_ips_updated",
49 : alias = "/block_public_or_vpc_access_updated",
50 : alias = "/allowed_vpc_endpoints_updated_for_projects",
51 : deserialize_with = "deserialize_json_string"
52 : )]
53 : ProjectSettingsUpdate(InvalidateProject),
54 :
55 : #[serde(
56 : rename = "/role_setting_update",
57 : alias = "/password_updated",
58 : deserialize_with = "deserialize_json_string"
59 : )]
60 : RoleSettingUpdate(InvalidateRole),
61 :
62 : #[serde(
63 : other,
64 : deserialize_with = "deserialize_unknown_topic",
65 : skip_serializing
66 : )]
67 : UnknownTopic,
68 : }
69 :
70 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
71 : #[serde(rename_all = "snake_case")]
72 : enum InvalidateEndpoint {
73 : EndpointId(EndpointIdInt),
74 : EndpointIds(Vec<EndpointIdInt>),
75 : }
76 : impl std::ops::Deref for InvalidateEndpoint {
77 : type Target = [EndpointIdInt];
78 0 : fn deref(&self) -> &Self::Target {
79 0 : match self {
80 0 : Self::EndpointId(id) => std::slice::from_ref(id),
81 0 : Self::EndpointIds(ids) => ids,
82 : }
83 0 : }
84 : }
85 :
86 2 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
87 : #[serde(rename_all = "snake_case")]
88 : enum InvalidateProject {
89 : ProjectId(ProjectIdInt),
90 : ProjectIds(Vec<ProjectIdInt>),
91 : }
92 : impl std::ops::Deref for InvalidateProject {
93 : type Target = [ProjectIdInt];
94 0 : fn deref(&self) -> &Self::Target {
95 0 : match self {
96 0 : Self::ProjectId(id) => std::slice::from_ref(id),
97 0 : Self::ProjectIds(ids) => ids,
98 : }
99 0 : }
100 : }
101 :
102 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
103 : #[serde(rename_all = "snake_case")]
104 : enum InvalidateAccount {
105 : AccountId(AccountIdInt),
106 : AccountIds(Vec<AccountIdInt>),
107 : }
108 : impl std::ops::Deref for InvalidateAccount {
109 : type Target = [AccountIdInt];
110 0 : fn deref(&self) -> &Self::Target {
111 0 : match self {
112 0 : Self::AccountId(id) => std::slice::from_ref(id),
113 0 : Self::AccountIds(ids) => ids,
114 : }
115 0 : }
116 : }
117 :
118 2 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
119 : struct InvalidateRole {
120 : project_id: ProjectIdInt,
121 : role_name: RoleNameInt,
122 : }
123 :
124 3 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
125 3 : where
126 3 : T: for<'de2> serde::Deserialize<'de2>,
127 3 : D: serde::Deserializer<'de>,
128 3 : {
129 3 : let s = String::deserialize(deserializer)?;
130 3 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
131 3 : }
132 :
133 : // https://github.com/serde-rs/serde/issues/1714
134 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
135 1 : where
136 1 : D: serde::Deserializer<'de>,
137 1 : {
138 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
139 1 : Ok(())
140 1 : }
141 :
142 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
143 : cache: Arc<C>,
144 : region_id: String,
145 : }
146 :
147 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
148 0 : fn clone(&self) -> Self {
149 0 : Self {
150 0 : cache: self.cache.clone(),
151 0 : region_id: self.region_id.clone(),
152 0 : }
153 0 : }
154 : }
155 :
156 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
157 0 : pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
158 0 : Self { cache, region_id }
159 0 : }
160 :
161 0 : pub(crate) async fn increment_active_listeners(&self) {
162 0 : self.cache.increment_active_listeners().await;
163 0 : }
164 :
165 0 : pub(crate) async fn decrement_active_listeners(&self) {
166 0 : self.cache.decrement_active_listeners().await;
167 0 : }
168 :
169 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
170 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
171 : let payload: String = msg.get_payload()?;
172 : tracing::debug!(?payload, "received a message payload");
173 :
174 : let msg: Notification = match serde_json::from_str(&payload) {
175 : Ok(Notification::UnknownTopic) => {
176 : match serde_json::from_str::<NotificationHeader>(&payload) {
177 : // don't update the metric for redis errors if it's just a topic we don't know about.
178 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
179 : Err(e) => {
180 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
181 : channel: msg.get_channel_name(),
182 : });
183 : tracing::error!("broken message: {e}");
184 : }
185 : }
186 : return Ok(());
187 : }
188 : Ok(msg) => msg,
189 : Err(e) => {
190 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
191 : channel: msg.get_channel_name(),
192 : });
193 : match serde_json::from_str::<NotificationHeader>(&payload) {
194 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
195 : Err(_) => tracing::error!("broken message: {e}"),
196 : }
197 : return Ok(());
198 : }
199 : };
200 :
201 : tracing::debug!(?msg, "received a message");
202 : match msg {
203 : Notification::RoleSettingUpdate { .. }
204 : | Notification::EndpointSettingsUpdate { .. }
205 : | Notification::ProjectSettingsUpdate { .. }
206 : | Notification::AccountSettingsUpdate { .. } => {
207 : invalidate_cache(self.cache.clone(), msg.clone());
208 :
209 : let m = &Metrics::get().proxy.redis_events_count;
210 : match msg {
211 : Notification::RoleSettingUpdate { .. } => {
212 : m.inc(RedisEventsCount::InvalidateRole);
213 : }
214 : Notification::EndpointSettingsUpdate { .. } => {
215 : m.inc(RedisEventsCount::InvalidateEndpoint);
216 : }
217 : Notification::ProjectSettingsUpdate { .. } => {
218 : m.inc(RedisEventsCount::InvalidateProject);
219 : }
220 : Notification::AccountSettingsUpdate { .. } => {
221 : m.inc(RedisEventsCount::InvalidateOrg);
222 : }
223 : Notification::UnknownTopic => {}
224 : }
225 :
226 : // TODO: add additional metrics for the other event types.
227 :
228 : // It might happen that the invalid entry is on the way to be cached.
229 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
230 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
231 : let cache = self.cache.clone();
232 0 : tokio::spawn(async move {
233 0 : tokio::time::sleep(INVALIDATION_LAG).await;
234 0 : invalidate_cache(cache, msg);
235 0 : });
236 : }
237 :
238 : Notification::UnknownTopic => unreachable!(),
239 : }
240 :
241 : Ok(())
242 : }
243 : }
244 :
245 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
246 0 : match msg {
247 0 : Notification::EndpointSettingsUpdate(ids) => ids
248 0 : .iter()
249 0 : .for_each(|&id| cache.invalidate_endpoint_access(id)),
250 :
251 0 : Notification::AccountSettingsUpdate(ids) => ids
252 0 : .iter()
253 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
254 :
255 0 : Notification::ProjectSettingsUpdate(ids) => ids
256 0 : .iter()
257 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
258 :
259 : Notification::RoleSettingUpdate(InvalidateRole {
260 0 : project_id,
261 0 : role_name,
262 0 : }) => cache.invalidate_role_secret_for_project(project_id, role_name),
263 :
264 0 : Notification::UnknownTopic => unreachable!(),
265 : }
266 0 : }
267 :
268 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
269 0 : handler: MessageHandler<C>,
270 0 : redis: ConnectionWithCredentialsProvider,
271 0 : cancellation_token: CancellationToken,
272 0 : ) -> anyhow::Result<()> {
273 : loop {
274 0 : if cancellation_token.is_cancelled() {
275 0 : return Ok(());
276 0 : }
277 0 : let mut conn = match try_connect(&redis).await {
278 0 : Ok(conn) => {
279 0 : handler.increment_active_listeners().await;
280 0 : conn
281 : }
282 0 : Err(e) => {
283 0 : tracing::error!(
284 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
285 : );
286 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
287 0 : continue;
288 : }
289 : };
290 0 : let mut stream = conn.on_message();
291 0 : while let Some(msg) = stream.next().await {
292 0 : match handler.handle_message(msg).await {
293 0 : Ok(()) => {}
294 0 : Err(e) => {
295 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
296 0 : break;
297 : }
298 : }
299 0 : if cancellation_token.is_cancelled() {
300 0 : handler.decrement_active_listeners().await;
301 0 : return Ok(());
302 0 : }
303 : }
304 0 : handler.decrement_active_listeners().await;
305 : }
306 0 : }
307 :
308 : /// Handle console's invalidation messages.
309 : #[tracing::instrument(name = "redis_notifications", skip_all)]
310 : pub async fn task_main<C>(
311 : redis: ConnectionWithCredentialsProvider,
312 : cache: Arc<C>,
313 : region_id: String,
314 : ) -> anyhow::Result<Infallible>
315 : where
316 : C: ProjectInfoCache + Send + Sync + 'static,
317 : {
318 : let handler = MessageHandler::new(cache, region_id);
319 : // 6h - 1m.
320 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
321 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
322 : loop {
323 : let cancellation_token = CancellationToken::new();
324 : interval.tick().await;
325 :
326 : tokio::spawn(handle_messages(
327 : handler.clone(),
328 : redis.clone(),
329 : cancellation_token.clone(),
330 : ));
331 0 : tokio::spawn(async move {
332 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
333 0 : cancellation_token.cancel();
334 0 : });
335 : }
336 : }
337 :
338 : #[cfg(test)]
339 : mod tests {
340 : use serde_json::json;
341 :
342 : use super::*;
343 : use crate::types::{ProjectId, RoleName};
344 :
345 : #[test]
346 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
347 1 : let project_id: ProjectId = "new_project".into();
348 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
349 1 : let text = json!({
350 1 : "type": "message",
351 1 : "topic": "/allowed_ips_updated",
352 1 : "data": data,
353 1 : "extre_fields": "something"
354 1 : })
355 1 : .to_string();
356 :
357 1 : let result: Notification = serde_json::from_str(&text)?;
358 1 : assert_eq!(
359 1 : result,
360 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
361 1 : );
362 :
363 1 : Ok(())
364 1 : }
365 :
366 : #[test]
367 1 : fn parse_multiple_projects() -> anyhow::Result<()> {
368 1 : let project_id1: ProjectId = "new_project1".into();
369 1 : let project_id2: ProjectId = "new_project2".into();
370 1 : let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
371 1 : let text = json!({
372 1 : "type": "message",
373 1 : "topic": "/allowed_vpc_endpoints_updated_for_projects",
374 1 : "data": data,
375 1 : "extre_fields": "something"
376 1 : })
377 1 : .to_string();
378 :
379 1 : let result: Notification = serde_json::from_str(&text)?;
380 1 : assert_eq!(
381 1 : result,
382 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
383 1 : (&project_id1).into(),
384 1 : (&project_id2).into()
385 1 : ]))
386 1 : );
387 :
388 1 : Ok(())
389 1 : }
390 :
391 : #[test]
392 1 : fn parse_password_updated() -> anyhow::Result<()> {
393 1 : let project_id: ProjectId = "new_project".into();
394 1 : let role_name: RoleName = "new_role".into();
395 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
396 1 : let text = json!({
397 1 : "type": "message",
398 1 : "topic": "/password_updated",
399 1 : "data": data,
400 1 : "extre_fields": "something"
401 1 : })
402 1 : .to_string();
403 :
404 1 : let result: Notification = serde_json::from_str(&text)?;
405 1 : assert_eq!(
406 1 : result,
407 1 : Notification::RoleSettingUpdate(InvalidateRole {
408 1 : project_id: (&project_id).into(),
409 1 : role_name: (&role_name).into(),
410 1 : })
411 1 : );
412 :
413 1 : Ok(())
414 1 : }
415 :
416 : #[test]
417 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
418 1 : let with_data = json!({
419 1 : "type": "message",
420 1 : "topic": "/doesnotexist",
421 1 : "data": {
422 1 : "payload": "ignored"
423 1 : },
424 1 : "extra_fields": "something"
425 1 : })
426 1 : .to_string();
427 1 : let result: Notification = serde_json::from_str(&with_data)?;
428 1 : assert_eq!(result, Notification::UnknownTopic);
429 :
430 1 : let without_data = json!({
431 1 : "type": "message",
432 1 : "topic": "/doesnotexist",
433 1 : "extra_fields": "something"
434 1 : })
435 1 : .to_string();
436 1 : let result: Notification = serde_json::from_str(&without_data)?;
437 1 : assert_eq!(result, Notification::UnknownTopic);
438 :
439 1 : Ok(())
440 1 : }
441 : }
|