Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use redis::aio::PubSub;
6 : use serde::Deserialize;
7 : use tokio_util::sync::CancellationToken;
8 :
9 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
10 : use crate::cache::project_info::ProjectInfoCache;
11 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
12 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
13 :
14 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
15 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
16 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
17 :
18 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
19 0 : let mut conn = client.get_async_pubsub().await?;
20 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
21 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
22 0 : Ok(conn)
23 0 : }
24 :
25 0 : #[derive(Debug, Deserialize)]
26 : struct NotificationHeader<'a> {
27 : topic: &'a str,
28 : }
29 :
30 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
31 : #[serde(tag = "topic", content = "data")]
32 : enum Notification {
33 : #[serde(
34 : rename = "/account_settings_update",
35 : alias = "/allowed_vpc_endpoints_updated_for_org",
36 : deserialize_with = "deserialize_json_string"
37 : )]
38 : AccountSettingsUpdate(InvalidateAccount),
39 :
40 : #[serde(
41 : rename = "/endpoint_settings_update",
42 : deserialize_with = "deserialize_json_string"
43 : )]
44 : EndpointSettingsUpdate(InvalidateEndpoint),
45 :
46 : #[serde(
47 : rename = "/project_settings_update",
48 : alias = "/allowed_ips_updated",
49 : alias = "/block_public_or_vpc_access_updated",
50 : alias = "/allowed_vpc_endpoints_updated_for_projects",
51 : deserialize_with = "deserialize_json_string"
52 : )]
53 : ProjectSettingsUpdate(InvalidateProject),
54 :
55 : #[serde(
56 : rename = "/role_setting_update",
57 : alias = "/password_updated",
58 : deserialize_with = "deserialize_json_string"
59 : )]
60 : RoleSettingUpdate(InvalidateRole),
61 :
62 : #[serde(
63 : other,
64 : deserialize_with = "deserialize_unknown_topic",
65 : skip_serializing
66 : )]
67 : UnknownTopic,
68 : }
69 :
70 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
71 : #[serde(rename_all = "snake_case")]
72 : enum InvalidateEndpoint {
73 : EndpointId(EndpointIdInt),
74 : EndpointIds(Vec<EndpointIdInt>),
75 : }
76 : impl std::ops::Deref for InvalidateEndpoint {
77 : type Target = [EndpointIdInt];
78 0 : fn deref(&self) -> &Self::Target {
79 0 : match self {
80 0 : Self::EndpointId(id) => std::slice::from_ref(id),
81 0 : Self::EndpointIds(ids) => ids,
82 : }
83 0 : }
84 : }
85 :
86 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
87 : #[serde(rename_all = "snake_case")]
88 : enum InvalidateProject {
89 : ProjectId(ProjectIdInt),
90 : ProjectIds(Vec<ProjectIdInt>),
91 : }
92 : impl std::ops::Deref for InvalidateProject {
93 : type Target = [ProjectIdInt];
94 0 : fn deref(&self) -> &Self::Target {
95 0 : match self {
96 0 : Self::ProjectId(id) => std::slice::from_ref(id),
97 0 : Self::ProjectIds(ids) => ids,
98 : }
99 0 : }
100 : }
101 :
102 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
103 : #[serde(rename_all = "snake_case")]
104 : enum InvalidateAccount {
105 : AccountId(AccountIdInt),
106 : AccountIds(Vec<AccountIdInt>),
107 : }
108 : impl std::ops::Deref for InvalidateAccount {
109 : type Target = [AccountIdInt];
110 0 : fn deref(&self) -> &Self::Target {
111 0 : match self {
112 0 : Self::AccountId(id) => std::slice::from_ref(id),
113 0 : Self::AccountIds(ids) => ids,
114 : }
115 0 : }
116 : }
117 :
118 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
119 : struct InvalidateRole {
120 : project_id: ProjectIdInt,
121 : role_name: RoleNameInt,
122 : }
123 :
124 3 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
125 3 : where
126 3 : T: for<'de2> serde::Deserialize<'de2>,
127 3 : D: serde::Deserializer<'de>,
128 : {
129 3 : let s = String::deserialize(deserializer)?;
130 3 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
131 3 : }
132 :
133 : // https://github.com/serde-rs/serde/issues/1714
134 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
135 1 : where
136 1 : D: serde::Deserializer<'de>,
137 : {
138 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
139 1 : Ok(())
140 1 : }
141 :
142 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
143 : cache: Arc<C>,
144 : }
145 :
146 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
147 0 : fn clone(&self) -> Self {
148 0 : Self {
149 0 : cache: self.cache.clone(),
150 0 : }
151 0 : }
152 : }
153 :
154 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
155 0 : pub(crate) fn new(cache: Arc<C>) -> Self {
156 0 : Self { cache }
157 0 : }
158 :
159 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
160 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
161 : let payload: String = msg.get_payload()?;
162 : tracing::debug!(?payload, "received a message payload");
163 :
164 : let msg: Notification = match serde_json::from_str(&payload) {
165 : Ok(Notification::UnknownTopic) => {
166 : match serde_json::from_str::<NotificationHeader>(&payload) {
167 : // don't update the metric for redis errors if it's just a topic we don't know about.
168 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
169 : Err(e) => {
170 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
171 : channel: msg.get_channel_name(),
172 : });
173 : tracing::error!("broken message: {e}");
174 : }
175 : }
176 : return Ok(());
177 : }
178 : Ok(msg) => msg,
179 : Err(e) => {
180 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
181 : channel: msg.get_channel_name(),
182 : });
183 : match serde_json::from_str::<NotificationHeader>(&payload) {
184 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
185 : Err(_) => tracing::error!("broken message: {e}"),
186 : }
187 : return Ok(());
188 : }
189 : };
190 :
191 : tracing::debug!(?msg, "received a message");
192 : match msg {
193 : Notification::RoleSettingUpdate { .. }
194 : | Notification::EndpointSettingsUpdate { .. }
195 : | Notification::ProjectSettingsUpdate { .. }
196 : | Notification::AccountSettingsUpdate { .. } => {
197 : invalidate_cache(self.cache.clone(), msg.clone());
198 :
199 : let m = &Metrics::get().proxy.redis_events_count;
200 : match msg {
201 : Notification::RoleSettingUpdate { .. } => {
202 : m.inc(RedisEventsCount::InvalidateRole);
203 : }
204 : Notification::EndpointSettingsUpdate { .. } => {
205 : m.inc(RedisEventsCount::InvalidateEndpoint);
206 : }
207 : Notification::ProjectSettingsUpdate { .. } => {
208 : m.inc(RedisEventsCount::InvalidateProject);
209 : }
210 : Notification::AccountSettingsUpdate { .. } => {
211 : m.inc(RedisEventsCount::InvalidateOrg);
212 : }
213 : Notification::UnknownTopic => {}
214 : }
215 :
216 : // TODO: add additional metrics for the other event types.
217 :
218 : // It might happen that the invalid entry is on the way to be cached.
219 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
220 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
221 : let cache = self.cache.clone();
222 0 : tokio::spawn(async move {
223 0 : tokio::time::sleep(INVALIDATION_LAG).await;
224 0 : invalidate_cache(cache, msg);
225 0 : });
226 : }
227 :
228 : Notification::UnknownTopic => unreachable!(),
229 : }
230 :
231 : Ok(())
232 : }
233 : }
234 :
235 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
236 0 : match msg {
237 0 : Notification::EndpointSettingsUpdate(ids) => ids
238 0 : .iter()
239 0 : .for_each(|&id| cache.invalidate_endpoint_access(id)),
240 :
241 0 : Notification::AccountSettingsUpdate(ids) => ids
242 0 : .iter()
243 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
244 :
245 0 : Notification::ProjectSettingsUpdate(ids) => ids
246 0 : .iter()
247 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
248 :
249 : Notification::RoleSettingUpdate(InvalidateRole {
250 0 : project_id,
251 0 : role_name,
252 0 : }) => cache.invalidate_role_secret_for_project(project_id, role_name),
253 :
254 0 : Notification::UnknownTopic => unreachable!(),
255 : }
256 0 : }
257 :
258 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
259 0 : handler: MessageHandler<C>,
260 0 : redis: ConnectionWithCredentialsProvider,
261 0 : cancellation_token: CancellationToken,
262 0 : ) -> anyhow::Result<()> {
263 : loop {
264 0 : if cancellation_token.is_cancelled() {
265 0 : return Ok(());
266 0 : }
267 0 : let mut conn = match try_connect(&redis).await {
268 0 : Ok(conn) => {
269 0 : handler.cache.increment_active_listeners().await;
270 0 : conn
271 : }
272 0 : Err(e) => {
273 0 : tracing::error!(
274 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
275 : );
276 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
277 0 : continue;
278 : }
279 : };
280 0 : let mut stream = conn.on_message();
281 0 : while let Some(msg) = stream.next().await {
282 0 : match handler.handle_message(msg).await {
283 0 : Ok(()) => {}
284 0 : Err(e) => {
285 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
286 0 : break;
287 : }
288 : }
289 0 : if cancellation_token.is_cancelled() {
290 0 : handler.cache.decrement_active_listeners().await;
291 0 : return Ok(());
292 0 : }
293 : }
294 0 : handler.cache.decrement_active_listeners().await;
295 : }
296 0 : }
297 :
298 : /// Handle console's invalidation messages.
299 : #[tracing::instrument(name = "redis_notifications", skip_all)]
300 : pub async fn task_main<C>(
301 : redis: ConnectionWithCredentialsProvider,
302 : cache: Arc<C>,
303 : ) -> anyhow::Result<Infallible>
304 : where
305 : C: ProjectInfoCache + Send + Sync + 'static,
306 : {
307 : let handler = MessageHandler::new(cache);
308 : // 6h - 1m.
309 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
310 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
311 : loop {
312 : let cancellation_token = CancellationToken::new();
313 : interval.tick().await;
314 :
315 : tokio::spawn(handle_messages(
316 : handler.clone(),
317 : redis.clone(),
318 : cancellation_token.clone(),
319 : ));
320 0 : tokio::spawn(async move {
321 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
322 0 : cancellation_token.cancel();
323 0 : });
324 : }
325 : }
326 :
327 : #[cfg(test)]
328 : mod tests {
329 : use serde_json::json;
330 :
331 : use super::*;
332 : use crate::types::{ProjectId, RoleName};
333 :
334 : #[test]
335 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
336 1 : let project_id: ProjectId = "new_project".into();
337 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
338 1 : let text = json!({
339 1 : "type": "message",
340 1 : "topic": "/allowed_ips_updated",
341 1 : "data": data,
342 1 : "extre_fields": "something"
343 : })
344 1 : .to_string();
345 :
346 1 : let result: Notification = serde_json::from_str(&text)?;
347 1 : assert_eq!(
348 : result,
349 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
350 : );
351 :
352 1 : Ok(())
353 1 : }
354 :
355 : #[test]
356 1 : fn parse_multiple_projects() -> anyhow::Result<()> {
357 1 : let project_id1: ProjectId = "new_project1".into();
358 1 : let project_id2: ProjectId = "new_project2".into();
359 1 : let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
360 1 : let text = json!({
361 1 : "type": "message",
362 1 : "topic": "/allowed_vpc_endpoints_updated_for_projects",
363 1 : "data": data,
364 1 : "extre_fields": "something"
365 : })
366 1 : .to_string();
367 :
368 1 : let result: Notification = serde_json::from_str(&text)?;
369 1 : assert_eq!(
370 : result,
371 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
372 1 : (&project_id1).into(),
373 1 : (&project_id2).into()
374 1 : ]))
375 : );
376 :
377 1 : Ok(())
378 1 : }
379 :
380 : #[test]
381 1 : fn parse_password_updated() -> anyhow::Result<()> {
382 1 : let project_id: ProjectId = "new_project".into();
383 1 : let role_name: RoleName = "new_role".into();
384 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
385 1 : let text = json!({
386 1 : "type": "message",
387 1 : "topic": "/password_updated",
388 1 : "data": data,
389 1 : "extre_fields": "something"
390 : })
391 1 : .to_string();
392 :
393 1 : let result: Notification = serde_json::from_str(&text)?;
394 1 : assert_eq!(
395 : result,
396 1 : Notification::RoleSettingUpdate(InvalidateRole {
397 1 : project_id: (&project_id).into(),
398 1 : role_name: (&role_name).into(),
399 1 : })
400 : );
401 :
402 1 : Ok(())
403 1 : }
404 :
405 : #[test]
406 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
407 1 : let with_data = json!({
408 1 : "type": "message",
409 1 : "topic": "/doesnotexist",
410 1 : "data": {
411 1 : "payload": "ignored"
412 : },
413 1 : "extra_fields": "something"
414 : })
415 1 : .to_string();
416 1 : let result: Notification = serde_json::from_str(&with_data)?;
417 1 : assert_eq!(result, Notification::UnknownTopic);
418 :
419 1 : let without_data = json!({
420 1 : "type": "message",
421 1 : "topic": "/doesnotexist",
422 1 : "extra_fields": "something"
423 : })
424 1 : .to_string();
425 1 : let result: Notification = serde_json::from_str(&without_data)?;
426 1 : assert_eq!(result, Notification::UnknownTopic);
427 :
428 1 : Ok(())
429 1 : }
430 : }
|