Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use redis::aio::PubSub;
6 : use serde::Deserialize;
7 : use tokio_util::sync::CancellationToken;
8 :
9 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
10 : use crate::cache::project_info::ProjectInfoCache;
11 : use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
12 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
13 : use crate::util::deserialize_json_string;
14 :
15 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
16 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
17 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
18 :
19 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
20 0 : let mut conn = client.get_async_pubsub().await?;
21 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
22 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
23 0 : Ok(conn)
24 0 : }
25 :
26 0 : #[derive(Debug, Deserialize)]
27 : struct NotificationHeader<'a> {
28 : topic: &'a str,
29 : }
30 :
31 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
32 : #[serde(tag = "topic", content = "data")]
33 : enum Notification {
34 : #[serde(
35 : rename = "/account_settings_update",
36 : alias = "/allowed_vpc_endpoints_updated_for_org",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AccountSettingsUpdate(InvalidateAccount),
40 :
41 : #[serde(
42 : rename = "/endpoint_settings_update",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : EndpointSettingsUpdate(InvalidateEndpoint),
46 :
47 : #[serde(
48 : rename = "/project_settings_update",
49 : alias = "/allowed_ips_updated",
50 : alias = "/block_public_or_vpc_access_updated",
51 : alias = "/allowed_vpc_endpoints_updated_for_projects",
52 : deserialize_with = "deserialize_json_string"
53 : )]
54 : ProjectSettingsUpdate(InvalidateProject),
55 :
56 : #[serde(
57 : rename = "/role_setting_update",
58 : alias = "/password_updated",
59 : deserialize_with = "deserialize_json_string"
60 : )]
61 : RoleSettingUpdate(InvalidateRole),
62 :
63 : #[serde(
64 : other,
65 : deserialize_with = "deserialize_unknown_topic",
66 : skip_serializing
67 : )]
68 : UnknownTopic,
69 : }
70 :
71 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
72 : #[serde(rename_all = "snake_case")]
73 : enum InvalidateEndpoint {
74 : EndpointId(EndpointIdInt),
75 : EndpointIds(Vec<EndpointIdInt>),
76 : }
77 : impl std::ops::Deref for InvalidateEndpoint {
78 : type Target = [EndpointIdInt];
79 0 : fn deref(&self) -> &Self::Target {
80 0 : match self {
81 0 : Self::EndpointId(id) => std::slice::from_ref(id),
82 0 : Self::EndpointIds(ids) => ids,
83 : }
84 0 : }
85 : }
86 :
87 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
88 : #[serde(rename_all = "snake_case")]
89 : enum InvalidateProject {
90 : ProjectId(ProjectIdInt),
91 : ProjectIds(Vec<ProjectIdInt>),
92 : }
93 : impl std::ops::Deref for InvalidateProject {
94 : type Target = [ProjectIdInt];
95 0 : fn deref(&self) -> &Self::Target {
96 0 : match self {
97 0 : Self::ProjectId(id) => std::slice::from_ref(id),
98 0 : Self::ProjectIds(ids) => ids,
99 : }
100 0 : }
101 : }
102 :
103 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
104 : #[serde(rename_all = "snake_case")]
105 : enum InvalidateAccount {
106 : AccountId(AccountIdInt),
107 : AccountIds(Vec<AccountIdInt>),
108 : }
109 : impl std::ops::Deref for InvalidateAccount {
110 : type Target = [AccountIdInt];
111 0 : fn deref(&self) -> &Self::Target {
112 0 : match self {
113 0 : Self::AccountId(id) => std::slice::from_ref(id),
114 0 : Self::AccountIds(ids) => ids,
115 : }
116 0 : }
117 : }
118 :
119 0 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
120 : struct InvalidateRole {
121 : project_id: ProjectIdInt,
122 : role_name: RoleNameInt,
123 : }
124 :
125 : // https://github.com/serde-rs/serde/issues/1714
126 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
127 1 : where
128 1 : D: serde::Deserializer<'de>,
129 : {
130 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
131 1 : Ok(())
132 1 : }
133 :
134 : struct MessageHandler<C: Send + Sync + 'static> {
135 : cache: Arc<C>,
136 : }
137 :
138 : impl<C: Send + Sync + 'static> Clone for MessageHandler<C> {
139 0 : fn clone(&self) -> Self {
140 0 : Self {
141 0 : cache: self.cache.clone(),
142 0 : }
143 0 : }
144 : }
145 :
146 : impl MessageHandler<ProjectInfoCache> {
147 0 : pub(crate) fn new(cache: Arc<ProjectInfoCache>) -> Self {
148 0 : Self { cache }
149 0 : }
150 :
151 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
152 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
153 : let payload: String = msg.get_payload()?;
154 : tracing::debug!(?payload, "received a message payload");
155 :
156 : let msg: Notification = match serde_json::from_str(&payload) {
157 : Ok(Notification::UnknownTopic) => {
158 : match serde_json::from_str::<NotificationHeader>(&payload) {
159 : // don't update the metric for redis errors if it's just a topic we don't know about.
160 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
161 : Err(e) => {
162 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
163 : channel: msg.get_channel_name(),
164 : });
165 : tracing::error!("broken message: {e}");
166 : }
167 : }
168 : return Ok(());
169 : }
170 : Ok(msg) => msg,
171 : Err(e) => {
172 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
173 : channel: msg.get_channel_name(),
174 : });
175 : match serde_json::from_str::<NotificationHeader>(&payload) {
176 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
177 : Err(_) => tracing::error!("broken message: {e}"),
178 : }
179 : return Ok(());
180 : }
181 : };
182 :
183 : tracing::debug!(?msg, "received a message");
184 : match msg {
185 : Notification::RoleSettingUpdate { .. }
186 : | Notification::EndpointSettingsUpdate { .. }
187 : | Notification::ProjectSettingsUpdate { .. }
188 : | Notification::AccountSettingsUpdate { .. } => {
189 : invalidate_cache(self.cache.clone(), msg.clone());
190 :
191 : let m = &Metrics::get().proxy.redis_events_count;
192 : match msg {
193 : Notification::RoleSettingUpdate { .. } => {
194 : m.inc(RedisEventsCount::InvalidateRole);
195 : }
196 : Notification::EndpointSettingsUpdate { .. } => {
197 : m.inc(RedisEventsCount::InvalidateEndpoint);
198 : }
199 : Notification::ProjectSettingsUpdate { .. } => {
200 : m.inc(RedisEventsCount::InvalidateProject);
201 : }
202 : Notification::AccountSettingsUpdate { .. } => {
203 : m.inc(RedisEventsCount::InvalidateOrg);
204 : }
205 : Notification::UnknownTopic => {}
206 : }
207 :
208 : // TODO: add additional metrics for the other event types.
209 :
210 : // It might happen that the invalid entry is on the way to be cached.
211 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
212 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
213 : let cache = self.cache.clone();
214 0 : tokio::spawn(async move {
215 0 : tokio::time::sleep(INVALIDATION_LAG).await;
216 0 : invalidate_cache(cache, msg);
217 0 : });
218 : }
219 :
220 : Notification::UnknownTopic => unreachable!(),
221 : }
222 :
223 : Ok(())
224 : }
225 : }
226 :
227 0 : fn invalidate_cache(cache: Arc<ProjectInfoCache>, msg: Notification) {
228 0 : match msg {
229 0 : Notification::EndpointSettingsUpdate(ids) => ids
230 0 : .iter()
231 0 : .for_each(|&id| cache.invalidate_endpoint_access(id)),
232 :
233 0 : Notification::AccountSettingsUpdate(ids) => ids
234 0 : .iter()
235 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)),
236 :
237 0 : Notification::ProjectSettingsUpdate(ids) => ids
238 0 : .iter()
239 0 : .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)),
240 :
241 : Notification::RoleSettingUpdate(InvalidateRole {
242 0 : project_id,
243 0 : role_name,
244 0 : }) => cache.invalidate_role_secret_for_project(project_id, role_name),
245 :
246 0 : Notification::UnknownTopic => unreachable!(),
247 : }
248 0 : }
249 :
250 0 : async fn handle_messages(
251 0 : handler: MessageHandler<ProjectInfoCache>,
252 0 : redis: ConnectionWithCredentialsProvider,
253 0 : cancellation_token: CancellationToken,
254 0 : ) -> anyhow::Result<()> {
255 : loop {
256 0 : if cancellation_token.is_cancelled() {
257 0 : return Ok(());
258 0 : }
259 0 : let mut conn = match try_connect(&redis).await {
260 0 : Ok(conn) => conn,
261 0 : Err(e) => {
262 0 : tracing::error!(
263 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
264 : );
265 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
266 0 : continue;
267 : }
268 : };
269 0 : let mut stream = conn.on_message();
270 0 : while let Some(msg) = stream.next().await {
271 0 : match handler.handle_message(msg).await {
272 0 : Ok(()) => {}
273 0 : Err(e) => {
274 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
275 0 : break;
276 : }
277 : }
278 0 : if cancellation_token.is_cancelled() {
279 0 : return Ok(());
280 0 : }
281 : }
282 : }
283 0 : }
284 :
285 : /// Handle console's invalidation messages.
286 : #[tracing::instrument(name = "redis_notifications", skip_all)]
287 : pub async fn task_main(
288 : redis: ConnectionWithCredentialsProvider,
289 : cache: Arc<ProjectInfoCache>,
290 : ) -> anyhow::Result<Infallible> {
291 : let handler = MessageHandler::new(cache);
292 : // 6h - 1m.
293 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
294 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
295 : loop {
296 : let cancellation_token = CancellationToken::new();
297 : interval.tick().await;
298 :
299 : tokio::spawn(handle_messages(
300 : handler.clone(),
301 : redis.clone(),
302 : cancellation_token.clone(),
303 : ));
304 0 : tokio::spawn(async move {
305 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
306 0 : cancellation_token.cancel();
307 0 : });
308 : }
309 : }
310 :
311 : #[cfg(test)]
312 : mod tests {
313 : use serde_json::json;
314 :
315 : use super::*;
316 : use crate::types::{ProjectId, RoleName};
317 :
318 : #[test]
319 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
320 1 : let project_id: ProjectId = "new_project".into();
321 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
322 1 : let text = json!({
323 1 : "type": "message",
324 1 : "topic": "/allowed_ips_updated",
325 1 : "data": data,
326 1 : "extre_fields": "something"
327 : })
328 1 : .to_string();
329 :
330 1 : let result: Notification = serde_json::from_str(&text)?;
331 1 : assert_eq!(
332 : result,
333 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into()))
334 : );
335 :
336 1 : Ok(())
337 1 : }
338 :
339 : #[test]
340 1 : fn parse_multiple_projects() -> anyhow::Result<()> {
341 1 : let project_id1: ProjectId = "new_project1".into();
342 1 : let project_id2: ProjectId = "new_project2".into();
343 1 : let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}");
344 1 : let text = json!({
345 1 : "type": "message",
346 1 : "topic": "/allowed_vpc_endpoints_updated_for_projects",
347 1 : "data": data,
348 1 : "extre_fields": "something"
349 : })
350 1 : .to_string();
351 :
352 1 : let result: Notification = serde_json::from_str(&text)?;
353 1 : assert_eq!(
354 : result,
355 1 : Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![
356 1 : (&project_id1).into(),
357 1 : (&project_id2).into()
358 1 : ]))
359 : );
360 :
361 1 : Ok(())
362 1 : }
363 :
364 : #[test]
365 1 : fn parse_password_updated() -> anyhow::Result<()> {
366 1 : let project_id: ProjectId = "new_project".into();
367 1 : let role_name: RoleName = "new_role".into();
368 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
369 1 : let text = json!({
370 1 : "type": "message",
371 1 : "topic": "/password_updated",
372 1 : "data": data,
373 1 : "extre_fields": "something"
374 : })
375 1 : .to_string();
376 :
377 1 : let result: Notification = serde_json::from_str(&text)?;
378 1 : assert_eq!(
379 : result,
380 1 : Notification::RoleSettingUpdate(InvalidateRole {
381 1 : project_id: (&project_id).into(),
382 1 : role_name: (&role_name).into(),
383 1 : })
384 : );
385 :
386 1 : Ok(())
387 1 : }
388 :
389 : #[test]
390 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
391 1 : let with_data = json!({
392 1 : "type": "message",
393 1 : "topic": "/doesnotexist",
394 1 : "data": {
395 1 : "payload": "ignored"
396 : },
397 1 : "extra_fields": "something"
398 : })
399 1 : .to_string();
400 1 : let result: Notification = serde_json::from_str(&with_data)?;
401 1 : assert_eq!(result, Notification::UnknownTopic);
402 :
403 1 : let without_data = json!({
404 1 : "type": "message",
405 1 : "topic": "/doesnotexist",
406 1 : "extra_fields": "something"
407 : })
408 1 : .to_string();
409 1 : let result: Notification = serde_json::from_str(&without_data)?;
410 1 : assert_eq!(result, Notification::UnknownTopic);
411 :
412 1 : Ok(())
413 1 : }
414 : }
|