Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use tracing::Instrument;
10 : use uuid::Uuid;
11 :
12 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
13 : use crate::cache::project_info::ProjectInfoCache;
14 : use crate::cancellation::{CancelMap, CancellationHandler};
15 : use crate::config::ProxyConfig;
16 : use crate::intern::{ProjectIdInt, RoleNameInt};
17 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
18 :
19 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
20 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
21 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
22 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
23 :
24 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
25 0 : let mut conn = client.get_async_pubsub().await?;
26 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
27 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
28 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
29 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
30 0 : Ok(conn)
31 0 : }
32 :
33 0 : #[derive(Debug, Deserialize)]
34 : struct NotificationHeader<'a> {
35 : topic: &'a str,
36 : }
37 :
38 8 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
39 : #[serde(tag = "topic", content = "data")]
40 : pub(crate) enum Notification {
41 : #[serde(
42 : rename = "/allowed_ips_updated",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : AllowedIpsUpdate {
46 : allowed_ips_update: AllowedIpsUpdate,
47 : },
48 : #[serde(
49 : rename = "/block_public_or_vpc_access_updated",
50 : deserialize_with = "deserialize_json_string"
51 : )]
52 : BlockPublicOrVpcAccessUpdated {
53 : block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
54 : },
55 : #[serde(
56 : rename = "/allowed_vpc_endpoints_updated_for_org",
57 : deserialize_with = "deserialize_json_string"
58 : )]
59 : AllowedVpcEndpointsUpdatedForOrg {
60 : allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
61 : },
62 : #[serde(
63 : rename = "/allowed_vpc_endpoints_updated_for_projects",
64 : deserialize_with = "deserialize_json_string"
65 : )]
66 : AllowedVpcEndpointsUpdatedForProjects {
67 : allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
68 : },
69 : #[serde(
70 : rename = "/password_updated",
71 : deserialize_with = "deserialize_json_string"
72 : )]
73 : PasswordUpdate { password_update: PasswordUpdate },
74 : #[serde(rename = "/cancel_session")]
75 : Cancel(CancelSession),
76 :
77 : #[serde(
78 : other,
79 : deserialize_with = "deserialize_unknown_topic",
80 : skip_serializing
81 : )]
82 : UnknownTopic,
83 : }
84 :
85 1 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
86 : pub(crate) struct AllowedIpsUpdate {
87 : project_id: ProjectIdInt,
88 : }
89 :
90 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
91 : pub(crate) struct BlockPublicOrVpcAccessUpdated {
92 : project_id: ProjectIdInt,
93 : }
94 :
95 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
96 : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
97 : // TODO: change type once the implementation is more fully fledged.
98 : // See e.g. https://github.com/neondatabase/neon/pull/10073.
99 : account_id: ProjectIdInt,
100 : }
101 :
102 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
103 : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
104 : project_ids: Vec<ProjectIdInt>,
105 : }
106 :
107 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
108 : pub(crate) struct PasswordUpdate {
109 : project_id: ProjectIdInt,
110 : role_name: RoleNameInt,
111 : }
112 :
113 8 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
114 : pub(crate) struct CancelSession {
115 : pub(crate) region_id: Option<String>,
116 : pub(crate) cancel_key_data: CancelKeyData,
117 : pub(crate) session_id: Uuid,
118 : pub(crate) peer_addr: Option<std::net::IpAddr>,
119 : }
120 :
121 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
122 2 : where
123 2 : T: for<'de2> serde::Deserialize<'de2>,
124 2 : D: serde::Deserializer<'de>,
125 2 : {
126 2 : let s = String::deserialize(deserializer)?;
127 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
128 2 : }
129 :
130 : // https://github.com/serde-rs/serde/issues/1714
131 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
132 1 : where
133 1 : D: serde::Deserializer<'de>,
134 1 : {
135 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
136 1 : Ok(())
137 1 : }
138 :
139 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
140 : cache: Arc<C>,
141 : cancellation_handler: Arc<CancellationHandler<()>>,
142 : region_id: String,
143 : }
144 :
145 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
146 0 : fn clone(&self) -> Self {
147 0 : Self {
148 0 : cache: self.cache.clone(),
149 0 : cancellation_handler: self.cancellation_handler.clone(),
150 0 : region_id: self.region_id.clone(),
151 0 : }
152 0 : }
153 : }
154 :
155 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
156 0 : pub(crate) fn new(
157 0 : cache: Arc<C>,
158 0 : cancellation_handler: Arc<CancellationHandler<()>>,
159 0 : region_id: String,
160 0 : ) -> Self {
161 0 : Self {
162 0 : cache,
163 0 : cancellation_handler,
164 0 : region_id,
165 0 : }
166 0 : }
167 :
168 0 : pub(crate) async fn increment_active_listeners(&self) {
169 0 : self.cache.increment_active_listeners().await;
170 0 : }
171 :
172 0 : pub(crate) async fn decrement_active_listeners(&self) {
173 0 : self.cache.decrement_active_listeners().await;
174 0 : }
175 :
176 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
177 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
178 : let payload: String = msg.get_payload()?;
179 : tracing::debug!(?payload, "received a message payload");
180 :
181 : let msg: Notification = match serde_json::from_str(&payload) {
182 : Ok(Notification::UnknownTopic) => {
183 : match serde_json::from_str::<NotificationHeader>(&payload) {
184 : // don't update the metric for redis errors if it's just a topic we don't know about.
185 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
186 : Err(e) => {
187 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
188 : channel: msg.get_channel_name(),
189 : });
190 : tracing::error!("broken message: {e}");
191 : }
192 : };
193 : return Ok(());
194 : }
195 : Ok(msg) => msg,
196 : Err(e) => {
197 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
198 : channel: msg.get_channel_name(),
199 : });
200 : match serde_json::from_str::<NotificationHeader>(&payload) {
201 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
202 : Err(_) => tracing::error!("broken message: {e}"),
203 : };
204 : return Ok(());
205 : }
206 : };
207 :
208 : tracing::debug!(?msg, "received a message");
209 : match msg {
210 : Notification::Cancel(cancel_session) => {
211 : tracing::Span::current().record(
212 : "session_id",
213 : tracing::field::display(cancel_session.session_id),
214 : );
215 : Metrics::get()
216 : .proxy
217 : .redis_events_count
218 : .inc(RedisEventsCount::CancelSession);
219 : if let Some(cancel_region) = cancel_session.region_id {
220 : // If the message is not for this region, ignore it.
221 : if cancel_region != self.region_id {
222 : return Ok(());
223 : }
224 : }
225 :
226 : // TODO: Remove unspecified peer_addr after the complete migration to the new format
227 : let peer_addr = cancel_session
228 : .peer_addr
229 : .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
230 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
231 : cancel_span.follows_from(tracing::Span::current());
232 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
233 : match self
234 : .cancellation_handler
235 : .cancel_session(
236 : cancel_session.cancel_key_data,
237 : uuid::Uuid::nil(),
238 : peer_addr,
239 : cancel_session.peer_addr.is_some(),
240 : )
241 : .instrument(cancel_span)
242 : .await
243 : {
244 : Ok(()) => {}
245 : Err(e) => {
246 : tracing::warn!("failed to cancel session: {e}");
247 : }
248 : }
249 : }
250 : Notification::AllowedIpsUpdate { .. }
251 : | Notification::PasswordUpdate { .. }
252 : | Notification::BlockPublicOrVpcAccessUpdated { .. }
253 : | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
254 : | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
255 : invalidate_cache(self.cache.clone(), msg.clone());
256 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
257 : Metrics::get()
258 : .proxy
259 : .redis_events_count
260 : .inc(RedisEventsCount::AllowedIpsUpdate);
261 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
262 : Metrics::get()
263 : .proxy
264 : .redis_events_count
265 : .inc(RedisEventsCount::PasswordUpdate);
266 : }
267 : // TODO: add additional metrics for the other event types.
268 :
269 : // It might happen that the invalid entry is on the way to be cached.
270 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
271 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
272 : let cache = self.cache.clone();
273 0 : tokio::spawn(async move {
274 0 : tokio::time::sleep(INVALIDATION_LAG).await;
275 0 : invalidate_cache(cache, msg);
276 0 : });
277 : }
278 :
279 : Notification::UnknownTopic => unreachable!(),
280 : }
281 :
282 : Ok(())
283 : }
284 : }
285 :
286 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
287 0 : match msg {
288 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
289 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
290 0 : }
291 0 : Notification::PasswordUpdate { password_update } => cache
292 0 : .invalidate_role_secret_for_project(
293 0 : password_update.project_id,
294 0 : password_update.role_name,
295 0 : ),
296 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
297 0 : Notification::BlockPublicOrVpcAccessUpdated { .. } => {
298 0 : // https://github.com/neondatabase/neon/pull/10073
299 0 : }
300 0 : Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
301 0 : // https://github.com/neondatabase/neon/pull/10073
302 0 : }
303 0 : Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
304 0 : // https://github.com/neondatabase/neon/pull/10073
305 0 : }
306 0 : Notification::UnknownTopic => unreachable!(),
307 : }
308 0 : }
309 :
310 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
311 0 : handler: MessageHandler<C>,
312 0 : redis: ConnectionWithCredentialsProvider,
313 0 : cancellation_token: CancellationToken,
314 0 : ) -> anyhow::Result<()> {
315 : loop {
316 0 : if cancellation_token.is_cancelled() {
317 0 : return Ok(());
318 0 : }
319 0 : let mut conn = match try_connect(&redis).await {
320 0 : Ok(conn) => {
321 0 : handler.increment_active_listeners().await;
322 0 : conn
323 : }
324 0 : Err(e) => {
325 0 : tracing::error!(
326 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
327 : );
328 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
329 0 : continue;
330 : }
331 : };
332 0 : let mut stream = conn.on_message();
333 0 : while let Some(msg) = stream.next().await {
334 0 : match handler.handle_message(msg).await {
335 0 : Ok(()) => {}
336 0 : Err(e) => {
337 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
338 0 : break;
339 : }
340 : }
341 0 : if cancellation_token.is_cancelled() {
342 0 : handler.decrement_active_listeners().await;
343 0 : return Ok(());
344 0 : }
345 : }
346 0 : handler.decrement_active_listeners().await;
347 : }
348 0 : }
349 :
350 : /// Handle console's invalidation messages.
351 : #[tracing::instrument(name = "redis_notifications", skip_all)]
352 : pub async fn task_main<C>(
353 : config: &'static ProxyConfig,
354 : redis: ConnectionWithCredentialsProvider,
355 : cache: Arc<C>,
356 : cancel_map: CancelMap,
357 : region_id: String,
358 : ) -> anyhow::Result<Infallible>
359 : where
360 : C: ProjectInfoCache + Send + Sync + 'static,
361 : {
362 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
363 : &config.connect_to_compute,
364 : cancel_map,
365 : crate::metrics::CancellationSource::FromRedis,
366 : ));
367 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
368 : // 6h - 1m.
369 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
370 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
371 : loop {
372 : let cancellation_token = CancellationToken::new();
373 : interval.tick().await;
374 :
375 : tokio::spawn(handle_messages(
376 : handler.clone(),
377 : redis.clone(),
378 : cancellation_token.clone(),
379 : ));
380 0 : tokio::spawn(async move {
381 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
382 0 : cancellation_token.cancel();
383 0 : });
384 : }
385 : }
386 :
387 : #[cfg(test)]
388 : mod tests {
389 : use serde_json::json;
390 :
391 : use super::*;
392 : use crate::types::{ProjectId, RoleName};
393 :
394 : #[test]
395 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
396 1 : let project_id: ProjectId = "new_project".into();
397 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
398 1 : let text = json!({
399 1 : "type": "message",
400 1 : "topic": "/allowed_ips_updated",
401 1 : "data": data,
402 1 : "extre_fields": "something"
403 1 : })
404 1 : .to_string();
405 :
406 1 : let result: Notification = serde_json::from_str(&text)?;
407 1 : assert_eq!(
408 1 : result,
409 1 : Notification::AllowedIpsUpdate {
410 1 : allowed_ips_update: AllowedIpsUpdate {
411 1 : project_id: (&project_id).into()
412 1 : }
413 1 : }
414 1 : );
415 :
416 1 : Ok(())
417 1 : }
418 :
419 : #[test]
420 1 : fn parse_password_updated() -> anyhow::Result<()> {
421 1 : let project_id: ProjectId = "new_project".into();
422 1 : let role_name: RoleName = "new_role".into();
423 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
424 1 : let text = json!({
425 1 : "type": "message",
426 1 : "topic": "/password_updated",
427 1 : "data": data,
428 1 : "extre_fields": "something"
429 1 : })
430 1 : .to_string();
431 :
432 1 : let result: Notification = serde_json::from_str(&text)?;
433 1 : assert_eq!(
434 1 : result,
435 1 : Notification::PasswordUpdate {
436 1 : password_update: PasswordUpdate {
437 1 : project_id: (&project_id).into(),
438 1 : role_name: (&role_name).into(),
439 1 : }
440 1 : }
441 1 : );
442 :
443 1 : Ok(())
444 1 : }
445 : #[test]
446 1 : fn parse_cancel_session() -> anyhow::Result<()> {
447 1 : let cancel_key_data = CancelKeyData {
448 1 : backend_pid: 42,
449 1 : cancel_key: 41,
450 1 : };
451 1 : let uuid = uuid::Uuid::new_v4();
452 1 : let msg = Notification::Cancel(CancelSession {
453 1 : cancel_key_data,
454 1 : region_id: None,
455 1 : session_id: uuid,
456 1 : peer_addr: None,
457 1 : });
458 1 : let text = serde_json::to_string(&msg)?;
459 1 : let result: Notification = serde_json::from_str(&text)?;
460 1 : assert_eq!(msg, result);
461 :
462 1 : let msg = Notification::Cancel(CancelSession {
463 1 : cancel_key_data,
464 1 : region_id: Some("region".to_string()),
465 1 : session_id: uuid,
466 1 : peer_addr: None,
467 1 : });
468 1 : let text = serde_json::to_string(&msg)?;
469 1 : let result: Notification = serde_json::from_str(&text)?;
470 1 : assert_eq!(msg, result,);
471 :
472 1 : Ok(())
473 1 : }
474 :
475 : #[test]
476 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
477 1 : let with_data = json!({
478 1 : "type": "message",
479 1 : "topic": "/doesnotexist",
480 1 : "data": {
481 1 : "payload": "ignored"
482 1 : },
483 1 : "extra_fields": "something"
484 1 : })
485 1 : .to_string();
486 1 : let result: Notification = serde_json::from_str(&with_data)?;
487 1 : assert_eq!(result, Notification::UnknownTopic);
488 :
489 1 : let without_data = json!({
490 1 : "type": "message",
491 1 : "topic": "/doesnotexist",
492 1 : "extra_fields": "something"
493 1 : })
494 1 : .to_string();
495 1 : let result: Notification = serde_json::from_str(&without_data)?;
496 1 : assert_eq!(result, Notification::UnknownTopic);
497 :
498 1 : Ok(())
499 1 : }
500 : }
|