Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use tracing::Instrument;
10 : use uuid::Uuid;
11 :
12 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
13 : use crate::cache::project_info::ProjectInfoCache;
14 : use crate::cancellation::{CancelMap, CancellationHandler};
15 : use crate::config::ProxyConfig;
16 : use crate::intern::{ProjectIdInt, RoleNameInt};
17 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
18 :
19 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
20 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
21 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
22 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
23 :
24 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
25 0 : let mut conn = client.get_async_pubsub().await?;
26 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
27 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
28 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
29 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
30 0 : Ok(conn)
31 0 : }
32 :
33 0 : #[derive(Debug, Deserialize)]
34 : struct NotificationHeader<'a> {
35 : topic: &'a str,
36 : }
37 :
38 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
39 : #[serde(tag = "topic", content = "data")]
40 : pub(crate) enum Notification {
41 : #[serde(
42 : rename = "/allowed_ips_updated",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : AllowedIpsUpdate {
46 : allowed_ips_update: AllowedIpsUpdate,
47 : },
48 : #[serde(
49 : rename = "/block_public_or_vpc_access_updated",
50 : deserialize_with = "deserialize_json_string"
51 : )]
52 : BlockPublicOrVpcAccessUpdated {
53 : block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
54 : },
55 : #[serde(
56 : rename = "/allowed_vpc_endpoints_updated_for_org",
57 : deserialize_with = "deserialize_json_string"
58 : )]
59 : AllowedVpcEndpointsUpdatedForOrg {
60 : allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
61 : },
62 : #[serde(
63 : rename = "/allowed_vpc_endpoints_updated_for_projects",
64 : deserialize_with = "deserialize_json_string"
65 : )]
66 : AllowedVpcEndpointsUpdatedForProjects {
67 : allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
68 : },
69 : #[serde(
70 : rename = "/password_updated",
71 : deserialize_with = "deserialize_json_string"
72 : )]
73 : PasswordUpdate { password_update: PasswordUpdate },
74 : #[serde(rename = "/cancel_session")]
75 : Cancel(CancelSession),
76 :
77 : #[serde(other, skip_serializing)]
78 : UnknownTopic,
79 : }
80 :
81 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
82 : pub(crate) struct AllowedIpsUpdate {
83 : project_id: ProjectIdInt,
84 : }
85 :
86 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
87 : pub(crate) struct BlockPublicOrVpcAccessUpdated {
88 : project_id: ProjectIdInt,
89 : }
90 :
91 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
92 : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
93 : // TODO: change type once the implementation is more fully fledged.
94 : // See e.g. https://github.com/neondatabase/neon/pull/10073.
95 : account_id: ProjectIdInt,
96 : }
97 :
98 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
99 : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
100 : project_ids: Vec<ProjectIdInt>,
101 : }
102 :
103 3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
104 : pub(crate) struct PasswordUpdate {
105 : project_id: ProjectIdInt,
106 : role_name: RoleNameInt,
107 : }
108 :
109 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
110 : pub(crate) struct CancelSession {
111 : pub(crate) region_id: Option<String>,
112 : pub(crate) cancel_key_data: CancelKeyData,
113 : pub(crate) session_id: Uuid,
114 : pub(crate) peer_addr: Option<std::net::IpAddr>,
115 : }
116 :
117 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
118 2 : where
119 2 : T: for<'de2> serde::Deserialize<'de2>,
120 2 : D: serde::Deserializer<'de>,
121 2 : {
122 2 : let s = String::deserialize(deserializer)?;
123 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
124 2 : }
125 :
126 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
127 : cache: Arc<C>,
128 : cancellation_handler: Arc<CancellationHandler<()>>,
129 : region_id: String,
130 : }
131 :
132 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
133 0 : fn clone(&self) -> Self {
134 0 : Self {
135 0 : cache: self.cache.clone(),
136 0 : cancellation_handler: self.cancellation_handler.clone(),
137 0 : region_id: self.region_id.clone(),
138 0 : }
139 0 : }
140 : }
141 :
142 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
143 0 : pub(crate) fn new(
144 0 : cache: Arc<C>,
145 0 : cancellation_handler: Arc<CancellationHandler<()>>,
146 0 : region_id: String,
147 0 : ) -> Self {
148 0 : Self {
149 0 : cache,
150 0 : cancellation_handler,
151 0 : region_id,
152 0 : }
153 0 : }
154 :
155 0 : pub(crate) async fn increment_active_listeners(&self) {
156 0 : self.cache.increment_active_listeners().await;
157 0 : }
158 :
159 0 : pub(crate) async fn decrement_active_listeners(&self) {
160 0 : self.cache.decrement_active_listeners().await;
161 0 : }
162 :
163 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
164 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
165 : let payload: String = msg.get_payload()?;
166 : tracing::debug!(?payload, "received a message payload");
167 :
168 : let msg: Notification = match serde_json::from_str(&payload) {
169 : Ok(Notification::UnknownTopic) => {
170 : match serde_json::from_str::<NotificationHeader>(&payload) {
171 : // don't update the metric for redis errors if it's just a topic we don't know about.
172 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
173 : Err(e) => {
174 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
175 : channel: msg.get_channel_name(),
176 : });
177 : tracing::error!("broken message: {e}");
178 : }
179 : };
180 : return Ok(());
181 : }
182 : Ok(msg) => msg,
183 : Err(e) => {
184 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
185 : channel: msg.get_channel_name(),
186 : });
187 : match serde_json::from_str::<NotificationHeader>(&payload) {
188 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
189 : Err(_) => tracing::error!("broken message: {e}"),
190 : };
191 : return Ok(());
192 : }
193 : };
194 :
195 : tracing::debug!(?msg, "received a message");
196 : match msg {
197 : Notification::Cancel(cancel_session) => {
198 : tracing::Span::current().record(
199 : "session_id",
200 : tracing::field::display(cancel_session.session_id),
201 : );
202 : Metrics::get()
203 : .proxy
204 : .redis_events_count
205 : .inc(RedisEventsCount::CancelSession);
206 : if let Some(cancel_region) = cancel_session.region_id {
207 : // If the message is not for this region, ignore it.
208 : if cancel_region != self.region_id {
209 : return Ok(());
210 : }
211 : }
212 :
213 : // TODO: Remove unspecified peer_addr after the complete migration to the new format
214 : let peer_addr = cancel_session
215 : .peer_addr
216 : .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
217 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
218 : cancel_span.follows_from(tracing::Span::current());
219 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
220 : match self
221 : .cancellation_handler
222 : .cancel_session(
223 : cancel_session.cancel_key_data,
224 : uuid::Uuid::nil(),
225 : peer_addr,
226 : cancel_session.peer_addr.is_some(),
227 : )
228 : .instrument(cancel_span)
229 : .await
230 : {
231 : Ok(()) => {}
232 : Err(e) => {
233 : tracing::warn!("failed to cancel session: {e}");
234 : }
235 : }
236 : }
237 : Notification::AllowedIpsUpdate { .. }
238 : | Notification::PasswordUpdate { .. }
239 : | Notification::BlockPublicOrVpcAccessUpdated { .. }
240 : | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
241 : | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
242 : invalidate_cache(self.cache.clone(), msg.clone());
243 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
244 : Metrics::get()
245 : .proxy
246 : .redis_events_count
247 : .inc(RedisEventsCount::AllowedIpsUpdate);
248 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
249 : Metrics::get()
250 : .proxy
251 : .redis_events_count
252 : .inc(RedisEventsCount::PasswordUpdate);
253 : }
254 : // TODO: add additional metrics for the other event types.
255 :
256 : // It might happen that the invalid entry is on the way to be cached.
257 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
258 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
259 : let cache = self.cache.clone();
260 0 : tokio::spawn(async move {
261 0 : tokio::time::sleep(INVALIDATION_LAG).await;
262 0 : invalidate_cache(cache, msg);
263 0 : });
264 : }
265 :
266 : Notification::UnknownTopic => unreachable!(),
267 : }
268 :
269 : Ok(())
270 : }
271 : }
272 :
273 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
274 0 : match msg {
275 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
276 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
277 0 : }
278 0 : Notification::PasswordUpdate { password_update } => cache
279 0 : .invalidate_role_secret_for_project(
280 0 : password_update.project_id,
281 0 : password_update.role_name,
282 0 : ),
283 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
284 0 : Notification::BlockPublicOrVpcAccessUpdated { .. } => {
285 0 : // https://github.com/neondatabase/neon/pull/10073
286 0 : }
287 0 : Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
288 0 : // https://github.com/neondatabase/neon/pull/10073
289 0 : }
290 0 : Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
291 0 : // https://github.com/neondatabase/neon/pull/10073
292 0 : }
293 0 : Notification::UnknownTopic => unreachable!(),
294 : }
295 0 : }
296 :
297 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
298 0 : handler: MessageHandler<C>,
299 0 : redis: ConnectionWithCredentialsProvider,
300 0 : cancellation_token: CancellationToken,
301 0 : ) -> anyhow::Result<()> {
302 : loop {
303 0 : if cancellation_token.is_cancelled() {
304 0 : return Ok(());
305 0 : }
306 0 : let mut conn = match try_connect(&redis).await {
307 0 : Ok(conn) => {
308 0 : handler.increment_active_listeners().await;
309 0 : conn
310 : }
311 0 : Err(e) => {
312 0 : tracing::error!(
313 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
314 : );
315 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
316 0 : continue;
317 : }
318 : };
319 0 : let mut stream = conn.on_message();
320 0 : while let Some(msg) = stream.next().await {
321 0 : match handler.handle_message(msg).await {
322 0 : Ok(()) => {}
323 0 : Err(e) => {
324 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
325 0 : break;
326 : }
327 : }
328 0 : if cancellation_token.is_cancelled() {
329 0 : handler.decrement_active_listeners().await;
330 0 : return Ok(());
331 0 : }
332 : }
333 0 : handler.decrement_active_listeners().await;
334 : }
335 0 : }
336 :
337 : /// Handle console's invalidation messages.
338 : #[tracing::instrument(name = "redis_notifications", skip_all)]
339 : pub async fn task_main<C>(
340 : config: &'static ProxyConfig,
341 : redis: ConnectionWithCredentialsProvider,
342 : cache: Arc<C>,
343 : cancel_map: CancelMap,
344 : region_id: String,
345 : ) -> anyhow::Result<Infallible>
346 : where
347 : C: ProjectInfoCache + Send + Sync + 'static,
348 : {
349 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
350 : &config.connect_to_compute,
351 : cancel_map,
352 : crate::metrics::CancellationSource::FromRedis,
353 : ));
354 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
355 : // 6h - 1m.
356 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
357 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
358 : loop {
359 : let cancellation_token = CancellationToken::new();
360 : interval.tick().await;
361 :
362 : tokio::spawn(handle_messages(
363 : handler.clone(),
364 : redis.clone(),
365 : cancellation_token.clone(),
366 : ));
367 0 : tokio::spawn(async move {
368 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
369 0 : cancellation_token.cancel();
370 0 : });
371 : }
372 : }
373 :
374 : #[cfg(test)]
375 : mod tests {
376 : use serde_json::json;
377 :
378 : use super::*;
379 : use crate::types::{ProjectId, RoleName};
380 :
381 : #[test]
382 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
383 1 : let project_id: ProjectId = "new_project".into();
384 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
385 1 : let text = json!({
386 1 : "type": "message",
387 1 : "topic": "/allowed_ips_updated",
388 1 : "data": data,
389 1 : "extre_fields": "something"
390 1 : })
391 1 : .to_string();
392 :
393 1 : let result: Notification = serde_json::from_str(&text)?;
394 1 : assert_eq!(
395 1 : result,
396 1 : Notification::AllowedIpsUpdate {
397 1 : allowed_ips_update: AllowedIpsUpdate {
398 1 : project_id: (&project_id).into()
399 1 : }
400 1 : }
401 1 : );
402 :
403 1 : Ok(())
404 1 : }
405 :
406 : #[test]
407 1 : fn parse_password_updated() -> anyhow::Result<()> {
408 1 : let project_id: ProjectId = "new_project".into();
409 1 : let role_name: RoleName = "new_role".into();
410 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
411 1 : let text = json!({
412 1 : "type": "message",
413 1 : "topic": "/password_updated",
414 1 : "data": data,
415 1 : "extre_fields": "something"
416 1 : })
417 1 : .to_string();
418 :
419 1 : let result: Notification = serde_json::from_str(&text)?;
420 1 : assert_eq!(
421 1 : result,
422 1 : Notification::PasswordUpdate {
423 1 : password_update: PasswordUpdate {
424 1 : project_id: (&project_id).into(),
425 1 : role_name: (&role_name).into(),
426 1 : }
427 1 : }
428 1 : );
429 :
430 1 : Ok(())
431 1 : }
432 : #[test]
433 1 : fn parse_cancel_session() -> anyhow::Result<()> {
434 1 : let cancel_key_data = CancelKeyData {
435 1 : backend_pid: 42,
436 1 : cancel_key: 41,
437 1 : };
438 1 : let uuid = uuid::Uuid::new_v4();
439 1 : let msg = Notification::Cancel(CancelSession {
440 1 : cancel_key_data,
441 1 : region_id: None,
442 1 : session_id: uuid,
443 1 : peer_addr: None,
444 1 : });
445 1 : let text = serde_json::to_string(&msg)?;
446 1 : let result: Notification = serde_json::from_str(&text)?;
447 1 : assert_eq!(msg, result);
448 :
449 1 : let msg = Notification::Cancel(CancelSession {
450 1 : cancel_key_data,
451 1 : region_id: Some("region".to_string()),
452 1 : session_id: uuid,
453 1 : peer_addr: None,
454 1 : });
455 1 : let text = serde_json::to_string(&msg)?;
456 1 : let result: Notification = serde_json::from_str(&text)?;
457 1 : assert_eq!(msg, result,);
458 :
459 1 : Ok(())
460 1 : }
461 : }
|