Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use redis::aio::PubSub;
6 : use serde::Deserialize;
7 : use tokio_util::sync::CancellationToken;
8 :
9 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
10 : use crate::cache::project_info::ProjectInfoCache;
11 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
12 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
13 : use crate::util::deserialize_json_string;
14 :
15 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
16 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
17 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
18 :
19 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
20 0 : let mut conn = client.get_async_pubsub().await?;
21 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
22 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
23 0 : Ok(conn)
24 0 : }
25 :
26 0 : #[derive(Debug, Deserialize)]
27 : struct NotificationHeader<'a> {
28 : topic: &'a str,
29 : }
30 :
31 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
32 : #[serde(tag = "topic", content = "data")]
33 : enum Notification {
34 : #[serde(
35 : rename = "/account_settings_update",
36 : alias = "/allowed_vpc_endpoints_updated_for_org",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AccountSettingsUpdate(InvalidateAccount),
40 :
41 : #[serde(
42 : rename = "/endpoint_settings_update",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : EndpointSettingsUpdate(InvalidateEndpoint),
46 :
47 : #[serde(
48 : rename = "/project_settings_update",
49 : alias = "/allowed_ips_updated",
50 : alias = "/block_public_or_vpc_access_updated",
51 : alias = "/allowed_vpc_endpoints_updated_for_projects",
52 : deserialize_with = "deserialize_json_string"
53 : )]
54 : ProjectSettingsUpdate(InvalidateProject),
55 :
56 : #[serde(
57 : rename = "/role_setting_update",
58 : alias = "/password_updated",
59 : deserialize_with = "deserialize_json_string"
60 : )]
61 : RoleSettingUpdate(InvalidateRole),
62 :
63 : #[serde(
64 : other,
65 : deserialize_with = "deserialize_unknown_topic",
66 : skip_serializing
67 : )]
68 : UnknownTopic,
69 : }
70 :
71 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
72 : #[serde(rename_all = "snake_case")]
73 : enum InvalidateEndpoint {
74 : EndpointId(EndpointIdInt),
75 : EndpointIds(Vec<EndpointIdInt>),
76 : }
77 : impl std::ops::Deref for InvalidateEndpoint {
78 : type Target = [EndpointIdInt];
79 0 : fn deref(&self) -> &Self::Target {
80 0 : match self {
81 0 : Self::EndpointId(id) => std::slice::from_ref(id),
82 0 : Self::EndpointIds(ids) => ids,
83 : }
84 0 : }
85 : }
86 :
87 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
88 : #[serde(rename_all = "snake_case")]
89 : enum InvalidateProject {
90 : ProjectId(ProjectIdInt),
91 : ProjectIds(Vec<ProjectIdInt>),
92 : }
93 : impl std::ops::Deref for InvalidateProject {
94 : type Target = [ProjectIdInt];
95 0 : fn deref(&self) -> &Self::Target {
96 0 : match self {
97 0 : Self::ProjectId(id) => std::slice::from_ref(id),
98 0 : Self::ProjectIds(ids) => ids,
99 : }
100 0 : }
101 : }
102 :
103 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
104 : #[serde(rename_all = "snake_case")]
105 : enum InvalidateAccount {
106 : AccountId(AccountIdInt),
107 : AccountIds(Vec<AccountIdInt>),
108 : }
109 : impl std::ops::Deref for InvalidateAccount {
110 : type Target = [AccountIdInt];
111 0 : fn deref(&self) -> &Self::Target {
112 0 : match self {
113 0 : Self::AccountId(id) => std::slice::from_ref(id),
114 0 : Self::AccountIds(ids) => ids,
115 : }
116 0 : }
117 : }
118 :
119 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
120 : struct InvalidateRole {
121 : project_id: ProjectIdInt,
122 : role_name: RoleNameInt,
123 : }
124 :
125 : // https://github.com/serde-rs/serde/issues/1714
126 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
127 1 : where
128 1 : D: serde::Deserializer<'de>,
129 : {
130 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
131 1 : Ok(())
132 1 : }
133 :
134 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
135 : cache: Arc<C>,
136 : }
137 :
138 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
139 0 : fn clone(&self) -> Self {
140 0 : Self {
141 0 : cache: self.cache.clone(),
142 0 : }
143 0 : }
144 : }
145 :
146 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
147 0 : pub(crate) fn new(cache: Arc<C>) -> Self {
148 0 : Self { cache }
149 0 : }
150 :
151 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
152 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
153 : let payload: String = msg.get_payload()?;
154 : tracing::debug!(?payload, "received a message payload");
155 :
156 : let msg: Notification = match serde_json::from_str(&payload) {
157 : Ok(Notification::UnknownTopic) => {
158 : match serde_json::from_str::<NotificationHeader>(&payload) {
159 : // don't update the metric for redis errors if it's just a topic we don't know about.
160 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
161 : Err(e) => {
162 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
163 : channel: msg.get_channel_name(),
164 : });
165 : tracing::error!("broken message: {e}");
166 : }
167 : }
168 : return Ok(());
169 : }
170 : Ok(msg) => msg,
171 : Err(e) => {
172 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
173 : channel: msg.get_channel_name(),
174 : });
175 : match serde_json::from_str::<NotificationHeader>(&payload) {
176 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
177 : Err(_) => tracing::error!("broken message: {e}"),
178 : }
179 : return Ok(());
180 : }
181 : };
182 :
183 : tracing::debug!(?msg, "received a message");
184 : match msg {
185 : Notification::RoleSettingUpdate { .. }
186 : | Notification::EndpointSettingsUpdate { .. }
187 : | Notification::ProjectSettingsUpdate { .. }
188 : | Notification::AccountSettingsUpdate { .. } => {
189 : invalidate_cache(self.cache.clone(), msg.clone());
190 :
191 : let m = &Metrics::get().proxy.redis_events_count;
192 : match msg {
193 : Notification::RoleSettingUpdate { .. } => {
194 : m.inc(RedisEventsCount::InvalidateRole);
195 : }
196 : Notification::EndpointSettingsUpdate { .. } => {
197 : m.inc(RedisEventsCount::InvalidateEndpoint);
198 : }
199 : Notification::ProjectSettingsUpdate { .. } => {
200 : m.inc(RedisEventsCount::InvalidateProject);
201 : }
202 : Notification::AccountSettingsUpdate { .. } => {
203 : m.inc(RedisEventsCount::InvalidateOrg);
204 : }
205 : Notification::UnknownTopic => {}
206 : }
207 :
208 : // TODO: add additional metrics for the other event types.
209 :
210 : // It might happen that the invalid entry is on the way to be cached.
211 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
212 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
213 : let cache = self.cache.clone();
214 0 : tokio::spawn(async move {
215 0 : tokio::time::sleep(INVALIDATION_LAG).await;
216 0 : invalidate_cache(cache, msg);
217 0 : });
218 : }
219 :
220 : Notification::UnknownTopic => unreachable!(),
221 : }
222 :
223 : Ok(())
224 : }
225 : }
226 :
227 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
228 0 : match msg {
229 0 : Notification::EndpointSettingsUpdate(ids) => ids
230 0 : .iter()
231 0 : .for_each(|&id| cache.invalidate_endpoint_access(id)),
232 :
233 0 : Notification::AccountSettingsUpdate(ids) => ids
234 0 : .iter()
235 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
236 :
237 0 : Notification::ProjectSettingsUpdate(ids) => ids
238 0 : .iter()
239 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
240 :
241 : Notification::RoleSettingUpdate(InvalidateRole {
242 0 : project_id,
243 0 : role_name,
244 0 : }) => cache.invalidate_role_secret_for_project(project_id, role_name),
245 :
246 0 : Notification::UnknownTopic => unreachable!(),
247 : }
248 0 : }
249 :
250 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
251 0 : handler: MessageHandler<C>,
252 0 : redis: ConnectionWithCredentialsProvider,
253 0 : cancellation_token: CancellationToken,
254 0 : ) -> anyhow::Result<()> {
255 : loop {
256 0 : if cancellation_token.is_cancelled() {
257 0 : return Ok(());
258 0 : }
259 0 : let mut conn = match try_connect(&redis).await {
260 0 : Ok(conn) => conn,
261 0 : Err(e) => {
262 0 : tracing::error!(
263 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
264 : );
265 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
266 0 : continue;
267 : }
268 : };
269 0 : let mut stream = conn.on_message();
270 0 : while let Some(msg) = stream.next().await {
271 0 : match handler.handle_message(msg).await {
272 0 : Ok(()) => {}
273 0 : Err(e) => {
274 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
275 0 : break;
276 : }
277 : }
278 0 : if cancellation_token.is_cancelled() {
279 0 : return Ok(());
280 0 : }
281 : }
282 : }
283 0 : }
284 :
285 : /// Handle console's invalidation messages.
286 : #[tracing::instrument(name = "redis_notifications", skip_all)]
287 : pub async fn task_main<C>(
288 : redis: ConnectionWithCredentialsProvider,
289 : cache: Arc<C>,
290 : ) -> anyhow::Result<Infallible>
291 : where
292 : C: ProjectInfoCache + Send + Sync + 'static,
293 : {
294 : let handler = MessageHandler::new(cache);
295 : // 6h - 1m.
296 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
297 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
298 : loop {
299 : let cancellation_token = CancellationToken::new();
300 : interval.tick().await;
301 :
302 : tokio::spawn(handle_messages(
303 : handler.clone(),
304 : redis.clone(),
305 : cancellation_token.clone(),
306 : ));
307 0 : tokio::spawn(async move {
308 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
309 0 : cancellation_token.cancel();
310 0 : });
311 : }
312 : }
313 :
314 : #[cfg(test)]
315 : mod tests {
316 : use serde_json::json;
317 :
318 : use super::*;
319 : use crate::types::{ProjectId, RoleName};
320 :
321 : #[test]
322 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
323 1 : let project_id: ProjectId = "new_project".into();
324 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
325 1 : let text = json!({
326 1 : "type": "message",
327 1 : "topic": "/allowed_ips_updated",
328 1 : "data": data,
329 1 : "extre_fields": "something"
330 : })
331 1 : .to_string();
332 :
333 1 : let result: Notification = serde_json::from_str(&text)?;
334 1 : assert_eq!(
335 : result,
336 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
337 : );
338 :
339 1 : Ok(())
340 1 : }
341 :
342 : #[test]
343 1 : fn parse_multiple_projects() -> anyhow::Result<()> {
344 1 : let project_id1: ProjectId = "new_project1".into();
345 1 : let project_id2: ProjectId = "new_project2".into();
346 1 : let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
347 1 : let text = json!({
348 1 : "type": "message",
349 1 : "topic": "/allowed_vpc_endpoints_updated_for_projects",
350 1 : "data": data,
351 1 : "extre_fields": "something"
352 : })
353 1 : .to_string();
354 :
355 1 : let result: Notification = serde_json::from_str(&text)?;
356 1 : assert_eq!(
357 : result,
358 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
359 1 : (&project_id1).into(),
360 1 : (&project_id2).into()
361 1 : ]))
362 : );
363 :
364 1 : Ok(())
365 1 : }
366 :
367 : #[test]
368 1 : fn parse_password_updated() -> anyhow::Result<()> {
369 1 : let project_id: ProjectId = "new_project".into();
370 1 : let role_name: RoleName = "new_role".into();
371 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
372 1 : let text = json!({
373 1 : "type": "message",
374 1 : "topic": "/password_updated",
375 1 : "data": data,
376 1 : "extre_fields": "something"
377 : })
378 1 : .to_string();
379 :
380 1 : let result: Notification = serde_json::from_str(&text)?;
381 1 : assert_eq!(
382 : result,
383 1 : Notification::RoleSettingUpdate(InvalidateRole {
384 1 : project_id: (&project_id).into(),
385 1 : role_name: (&role_name).into(),
386 1 : })
387 : );
388 :
389 1 : Ok(())
390 1 : }
391 :
392 : #[test]
393 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
394 1 : let with_data = json!({
395 1 : "type": "message",
396 1 : "topic": "/doesnotexist",
397 1 : "data": {
398 1 : "payload": "ignored"
399 : },
400 1 : "extra_fields": "something"
401 : })
402 1 : .to_string();
403 1 : let result: Notification = serde_json::from_str(&with_data)?;
404 1 : assert_eq!(result, Notification::UnknownTopic);
405 :
406 1 : let without_data = json!({
407 1 : "type": "message",
408 1 : "topic": "/doesnotexist",
409 1 : "extra_fields": "something"
410 : })
411 1 : .to_string();
412 1 : let result: Notification = serde_json::from_str(&without_data)?;
413 1 : assert_eq!(result, Notification::UnknownTopic);
414 :
415 1 : Ok(())
416 1 : }
417 : }
|