Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use uuid::Uuid;
10 :
11 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
12 : use crate::cache::project_info::ProjectInfoCache;
13 : use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt};
14 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
15 :
16 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
17 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
18 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
19 :
20 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
21 0 : let mut conn = client.get_async_pubsub().await?;
22 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
23 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
24 0 : Ok(conn)
25 0 : }
26 :
27 0 : #[derive(Debug, Deserialize)]
28 : struct NotificationHeader<'a> {
29 : topic: &'a str,
30 : }
31 :
32 6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
33 : #[serde(tag = "topic", content = "data")]
34 : pub(crate) enum Notification {
35 : #[serde(
36 : rename = "/allowed_ips_updated",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AllowedIpsUpdate {
40 : allowed_ips_update: AllowedIpsUpdate,
41 : },
42 : #[serde(
43 : rename = "/block_public_or_vpc_access_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : BlockPublicOrVpcAccessUpdated {
47 : block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
48 : },
49 : #[serde(
50 : rename = "/allowed_vpc_endpoints_updated_for_org",
51 : deserialize_with = "deserialize_json_string"
52 : )]
53 : AllowedVpcEndpointsUpdatedForOrg {
54 : allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
55 : },
56 : #[serde(
57 : rename = "/allowed_vpc_endpoints_updated_for_projects",
58 : deserialize_with = "deserialize_json_string"
59 : )]
60 : AllowedVpcEndpointsUpdatedForProjects {
61 : allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
62 : },
63 : #[serde(
64 : rename = "/password_updated",
65 : deserialize_with = "deserialize_json_string"
66 : )]
67 : PasswordUpdate { password_update: PasswordUpdate },
68 :
69 : #[serde(
70 : other,
71 : deserialize_with = "deserialize_unknown_topic",
72 : skip_serializing
73 : )]
74 : UnknownTopic,
75 : }
76 :
77 1 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
78 : pub(crate) struct AllowedIpsUpdate {
79 : project_id: ProjectIdInt,
80 : }
81 :
82 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
83 : pub(crate) struct BlockPublicOrVpcAccessUpdated {
84 : project_id: ProjectIdInt,
85 : }
86 :
87 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
88 : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
89 : account_id: AccountIdInt,
90 : }
91 :
92 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
93 : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
94 : project_ids: Vec<ProjectIdInt>,
95 : }
96 :
97 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
98 : pub(crate) struct PasswordUpdate {
99 : project_id: ProjectIdInt,
100 : role_name: RoleNameInt,
101 : }
102 :
103 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
104 : pub(crate) struct CancelSession {
105 : pub(crate) region_id: Option<String>,
106 : pub(crate) cancel_key_data: CancelKeyData,
107 : pub(crate) session_id: Uuid,
108 : pub(crate) peer_addr: Option<std::net::IpAddr>,
109 : }
110 :
111 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
112 2 : where
113 2 : T: for<'de2> serde::Deserialize<'de2>,
114 2 : D: serde::Deserializer<'de>,
115 2 : {
116 2 : let s = String::deserialize(deserializer)?;
117 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
118 2 : }
119 :
120 : // https://github.com/serde-rs/serde/issues/1714
121 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
122 1 : where
123 1 : D: serde::Deserializer<'de>,
124 1 : {
125 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
126 1 : Ok(())
127 1 : }
128 :
129 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
130 : cache: Arc<C>,
131 : region_id: String,
132 : }
133 :
134 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
135 0 : fn clone(&self) -> Self {
136 0 : Self {
137 0 : cache: self.cache.clone(),
138 0 : region_id: self.region_id.clone(),
139 0 : }
140 0 : }
141 : }
142 :
143 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
144 0 : pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
145 0 : Self { cache, region_id }
146 0 : }
147 :
148 0 : pub(crate) async fn increment_active_listeners(&self) {
149 0 : self.cache.increment_active_listeners().await;
150 0 : }
151 :
152 0 : pub(crate) async fn decrement_active_listeners(&self) {
153 0 : self.cache.decrement_active_listeners().await;
154 0 : }
155 :
156 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
157 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
158 : let payload: String = msg.get_payload()?;
159 : tracing::debug!(?payload, "received a message payload");
160 :
161 : let msg: Notification = match serde_json::from_str(&payload) {
162 : Ok(Notification::UnknownTopic) => {
163 : match serde_json::from_str::<NotificationHeader>(&payload) {
164 : // don't update the metric for redis errors if it's just a topic we don't know about.
165 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
166 : Err(e) => {
167 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
168 : channel: msg.get_channel_name(),
169 : });
170 : tracing::error!("broken message: {e}");
171 : }
172 : }
173 : return Ok(());
174 : }
175 : Ok(msg) => msg,
176 : Err(e) => {
177 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
178 : channel: msg.get_channel_name(),
179 : });
180 : match serde_json::from_str::<NotificationHeader>(&payload) {
181 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
182 : Err(_) => tracing::error!("broken message: {e}"),
183 : }
184 : return Ok(());
185 : }
186 : };
187 :
188 : tracing::debug!(?msg, "received a message");
189 : match msg {
190 : Notification::AllowedIpsUpdate { .. }
191 : | Notification::PasswordUpdate { .. }
192 : | Notification::BlockPublicOrVpcAccessUpdated { .. }
193 : | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
194 : | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
195 : invalidate_cache(self.cache.clone(), msg.clone());
196 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
197 : Metrics::get()
198 : .proxy
199 : .redis_events_count
200 : .inc(RedisEventsCount::AllowedIpsUpdate);
201 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
202 : Metrics::get()
203 : .proxy
204 : .redis_events_count
205 : .inc(RedisEventsCount::PasswordUpdate);
206 : } else if matches!(
207 : msg,
208 : Notification::AllowedVpcEndpointsUpdatedForProjects { .. }
209 : ) {
210 : Metrics::get()
211 : .proxy
212 : .redis_events_count
213 : .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects);
214 : } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) {
215 : Metrics::get()
216 : .proxy
217 : .redis_events_count
218 : .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg);
219 : } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) {
220 : Metrics::get()
221 : .proxy
222 : .redis_events_count
223 : .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate);
224 : }
225 : // TODO: add additional metrics for the other event types.
226 :
227 : // It might happen that the invalid entry is on the way to be cached.
228 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
229 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
230 : let cache = self.cache.clone();
231 0 : tokio::spawn(async move {
232 0 : tokio::time::sleep(INVALIDATION_LAG).await;
233 0 : invalidate_cache(cache, msg);
234 0 : });
235 : }
236 :
237 : Notification::UnknownTopic => unreachable!(),
238 : }
239 :
240 : Ok(())
241 : }
242 : }
243 :
244 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
245 0 : match msg {
246 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
247 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
248 0 : }
249 : Notification::BlockPublicOrVpcAccessUpdated {
250 0 : block_public_or_vpc_access_updated,
251 0 : } => cache.invalidate_block_public_or_vpc_access_for_project(
252 0 : block_public_or_vpc_access_updated.project_id,
253 0 : ),
254 : Notification::AllowedVpcEndpointsUpdatedForOrg {
255 0 : allowed_vpc_endpoints_updated_for_org,
256 0 : } => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
257 0 : allowed_vpc_endpoints_updated_for_org.account_id,
258 0 : ),
259 : Notification::AllowedVpcEndpointsUpdatedForProjects {
260 0 : allowed_vpc_endpoints_updated_for_projects,
261 0 : } => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
262 0 : allowed_vpc_endpoints_updated_for_projects.project_ids,
263 0 : ),
264 0 : Notification::PasswordUpdate { password_update } => cache
265 0 : .invalidate_role_secret_for_project(
266 0 : password_update.project_id,
267 0 : password_update.role_name,
268 0 : ),
269 0 : Notification::UnknownTopic => unreachable!(),
270 : }
271 0 : }
272 :
273 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
274 0 : handler: MessageHandler<C>,
275 0 : redis: ConnectionWithCredentialsProvider,
276 0 : cancellation_token: CancellationToken,
277 0 : ) -> anyhow::Result<()> {
278 : loop {
279 0 : if cancellation_token.is_cancelled() {
280 0 : return Ok(());
281 0 : }
282 0 : let mut conn = match try_connect(&redis).await {
283 0 : Ok(conn) => {
284 0 : handler.increment_active_listeners().await;
285 0 : conn
286 : }
287 0 : Err(e) => {
288 0 : tracing::error!(
289 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
290 : );
291 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
292 0 : continue;
293 : }
294 : };
295 0 : let mut stream = conn.on_message();
296 0 : while let Some(msg) = stream.next().await {
297 0 : match handler.handle_message(msg).await {
298 0 : Ok(()) => {}
299 0 : Err(e) => {
300 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
301 0 : break;
302 : }
303 : }
304 0 : if cancellation_token.is_cancelled() {
305 0 : handler.decrement_active_listeners().await;
306 0 : return Ok(());
307 0 : }
308 : }
309 0 : handler.decrement_active_listeners().await;
310 : }
311 0 : }
312 :
313 : /// Handle console's invalidation messages.
314 : #[tracing::instrument(name = "redis_notifications", skip_all)]
315 : pub async fn task_main<C>(
316 : redis: ConnectionWithCredentialsProvider,
317 : cache: Arc<C>,
318 : region_id: String,
319 : ) -> anyhow::Result<Infallible>
320 : where
321 : C: ProjectInfoCache + Send + Sync + 'static,
322 : {
323 : let handler = MessageHandler::new(cache, region_id);
324 : // 6h - 1m.
325 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
326 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
327 : loop {
328 : let cancellation_token = CancellationToken::new();
329 : interval.tick().await;
330 :
331 : tokio::spawn(handle_messages(
332 : handler.clone(),
333 : redis.clone(),
334 : cancellation_token.clone(),
335 : ));
336 0 : tokio::spawn(async move {
337 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
338 0 : cancellation_token.cancel();
339 0 : });
340 : }
341 : }
342 :
343 : #[cfg(test)]
344 : mod tests {
345 : use serde_json::json;
346 :
347 : use super::*;
348 : use crate::types::{ProjectId, RoleName};
349 :
350 : #[test]
351 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
352 1 : let project_id: ProjectId = "new_project".into();
353 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
354 1 : let text = json!({
355 1 : "type": "message",
356 1 : "topic": "/allowed_ips_updated",
357 1 : "data": data,
358 1 : "extre_fields": "something"
359 1 : })
360 1 : .to_string();
361 :
362 1 : let result: Notification = serde_json::from_str(&text)?;
363 1 : assert_eq!(
364 1 : result,
365 1 : Notification::AllowedIpsUpdate {
366 1 : allowed_ips_update: AllowedIpsUpdate {
367 1 : project_id: (&project_id).into()
368 1 : }
369 1 : }
370 1 : );
371 :
372 1 : Ok(())
373 1 : }
374 :
375 : #[test]
376 1 : fn parse_password_updated() -> anyhow::Result<()> {
377 1 : let project_id: ProjectId = "new_project".into();
378 1 : let role_name: RoleName = "new_role".into();
379 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
380 1 : let text = json!({
381 1 : "type": "message",
382 1 : "topic": "/password_updated",
383 1 : "data": data,
384 1 : "extre_fields": "something"
385 1 : })
386 1 : .to_string();
387 :
388 1 : let result: Notification = serde_json::from_str(&text)?;
389 1 : assert_eq!(
390 1 : result,
391 1 : Notification::PasswordUpdate {
392 1 : password_update: PasswordUpdate {
393 1 : project_id: (&project_id).into(),
394 1 : role_name: (&role_name).into(),
395 1 : }
396 1 : }
397 1 : );
398 :
399 1 : Ok(())
400 1 : }
401 :
402 : #[test]
403 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
404 1 : let with_data = json!({
405 1 : "type": "message",
406 1 : "topic": "/doesnotexist",
407 1 : "data": {
408 1 : "payload": "ignored"
409 1 : },
410 1 : "extra_fields": "something"
411 1 : })
412 1 : .to_string();
413 1 : let result: Notification = serde_json::from_str(&with_data)?;
414 1 : assert_eq!(result, Notification::UnknownTopic);
415 :
416 1 : let without_data = json!({
417 1 : "type": "message",
418 1 : "topic": "/doesnotexist",
419 1 : "extra_fields": "something"
420 1 : })
421 1 : .to_string();
422 1 : let result: Notification = serde_json::from_str(&without_data)?;
423 1 : assert_eq!(result, Notification::UnknownTopic);
424 :
425 1 : Ok(())
426 1 : }
427 : }
|