Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use uuid::Uuid;
10 :
11 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
12 : use crate::cache::project_info::ProjectInfoCache;
13 : use crate::intern::{ProjectIdInt, RoleNameInt};
14 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
15 :
16 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
17 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
18 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
19 :
20 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
21 0 : let mut conn = client.get_async_pubsub().await?;
22 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
23 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
24 0 : Ok(conn)
25 0 : }
26 :
27 0 : #[derive(Debug, Deserialize)]
28 : struct NotificationHeader<'a> {
29 : topic: &'a str,
30 : }
31 :
32 6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
33 : #[serde(tag = "topic", content = "data")]
34 : pub(crate) enum Notification {
35 : #[serde(
36 : rename = "/allowed_ips_updated",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AllowedIpsUpdate {
40 : allowed_ips_update: AllowedIpsUpdate,
41 : },
42 : #[serde(
43 : rename = "/block_public_or_vpc_access_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : BlockPublicOrVpcAccessUpdated {
47 : block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated,
48 : },
49 : #[serde(
50 : rename = "/allowed_vpc_endpoints_updated_for_org",
51 : deserialize_with = "deserialize_json_string"
52 : )]
53 : AllowedVpcEndpointsUpdatedForOrg {
54 : allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg,
55 : },
56 : #[serde(
57 : rename = "/allowed_vpc_endpoints_updated_for_projects",
58 : deserialize_with = "deserialize_json_string"
59 : )]
60 : AllowedVpcEndpointsUpdatedForProjects {
61 : allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects,
62 : },
63 : #[serde(
64 : rename = "/password_updated",
65 : deserialize_with = "deserialize_json_string"
66 : )]
67 : PasswordUpdate { password_update: PasswordUpdate },
68 :
69 : #[serde(
70 : other,
71 : deserialize_with = "deserialize_unknown_topic",
72 : skip_serializing
73 : )]
74 : UnknownTopic,
75 : }
76 :
77 1 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
78 : pub(crate) struct AllowedIpsUpdate {
79 : project_id: ProjectIdInt,
80 : }
81 :
82 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
83 : pub(crate) struct BlockPublicOrVpcAccessUpdated {
84 : project_id: ProjectIdInt,
85 : }
86 :
87 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
88 : pub(crate) struct AllowedVpcEndpointsUpdatedForOrg {
89 : // TODO: change type once the implementation is more fully fledged.
90 : // See e.g. https://github.com/neondatabase/neon/pull/10073.
91 : account_id: ProjectIdInt,
92 : }
93 :
94 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
95 : pub(crate) struct AllowedVpcEndpointsUpdatedForProjects {
96 : project_ids: Vec<ProjectIdInt>,
97 : }
98 :
99 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
100 : pub(crate) struct PasswordUpdate {
101 : project_id: ProjectIdInt,
102 : role_name: RoleNameInt,
103 : }
104 :
105 0 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
106 : pub(crate) struct CancelSession {
107 : pub(crate) region_id: Option<String>,
108 : pub(crate) cancel_key_data: CancelKeyData,
109 : pub(crate) session_id: Uuid,
110 : pub(crate) peer_addr: Option<std::net::IpAddr>,
111 : }
112 :
113 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
114 2 : where
115 2 : T: for<'de2> serde::Deserialize<'de2>,
116 2 : D: serde::Deserializer<'de>,
117 2 : {
118 2 : let s = String::deserialize(deserializer)?;
119 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
120 2 : }
121 :
122 : // https://github.com/serde-rs/serde/issues/1714
123 1 : fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error>
124 1 : where
125 1 : D: serde::Deserializer<'de>,
126 1 : {
127 1 : deserializer.deserialize_any(serde::de::IgnoredAny)?;
128 1 : Ok(())
129 1 : }
130 :
131 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
132 : cache: Arc<C>,
133 : region_id: String,
134 : }
135 :
136 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
137 0 : fn clone(&self) -> Self {
138 0 : Self {
139 0 : cache: self.cache.clone(),
140 0 : region_id: self.region_id.clone(),
141 0 : }
142 0 : }
143 : }
144 :
145 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
146 0 : pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
147 0 : Self { cache, region_id }
148 0 : }
149 :
150 0 : pub(crate) async fn increment_active_listeners(&self) {
151 0 : self.cache.increment_active_listeners().await;
152 0 : }
153 :
154 0 : pub(crate) async fn decrement_active_listeners(&self) {
155 0 : self.cache.decrement_active_listeners().await;
156 0 : }
157 :
158 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
159 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
160 : let payload: String = msg.get_payload()?;
161 : tracing::debug!(?payload, "received a message payload");
162 :
163 : let msg: Notification = match serde_json::from_str(&payload) {
164 : Ok(Notification::UnknownTopic) => {
165 : match serde_json::from_str::<NotificationHeader>(&payload) {
166 : // don't update the metric for redis errors if it's just a topic we don't know about.
167 : Ok(header) => tracing::warn!(topic = header.topic, "unknown topic"),
168 : Err(e) => {
169 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
170 : channel: msg.get_channel_name(),
171 : });
172 : tracing::error!("broken message: {e}");
173 : }
174 : };
175 : return Ok(());
176 : }
177 : Ok(msg) => msg,
178 : Err(e) => {
179 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
180 : channel: msg.get_channel_name(),
181 : });
182 : match serde_json::from_str::<NotificationHeader>(&payload) {
183 : Ok(header) => tracing::error!(topic = header.topic, "broken message: {e}"),
184 : Err(_) => tracing::error!("broken message: {e}"),
185 : };
186 : return Ok(());
187 : }
188 : };
189 :
190 : tracing::debug!(?msg, "received a message");
191 : match msg {
192 : Notification::AllowedIpsUpdate { .. }
193 : | Notification::PasswordUpdate { .. }
194 : | Notification::BlockPublicOrVpcAccessUpdated { .. }
195 : | Notification::AllowedVpcEndpointsUpdatedForOrg { .. }
196 : | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
197 : invalidate_cache(self.cache.clone(), msg.clone());
198 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
199 : Metrics::get()
200 : .proxy
201 : .redis_events_count
202 : .inc(RedisEventsCount::AllowedIpsUpdate);
203 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
204 : Metrics::get()
205 : .proxy
206 : .redis_events_count
207 : .inc(RedisEventsCount::PasswordUpdate);
208 : }
209 : // TODO: add additional metrics for the other event types.
210 :
211 : // It might happen that the invalid entry is on the way to be cached.
212 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
213 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
214 : let cache = self.cache.clone();
215 0 : tokio::spawn(async move {
216 0 : tokio::time::sleep(INVALIDATION_LAG).await;
217 0 : invalidate_cache(cache, msg);
218 0 : });
219 : }
220 :
221 : Notification::UnknownTopic => unreachable!(),
222 : }
223 :
224 : Ok(())
225 : }
226 : }
227 :
228 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
229 0 : match msg {
230 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
231 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
232 0 : }
233 0 : Notification::PasswordUpdate { password_update } => cache
234 0 : .invalidate_role_secret_for_project(
235 0 : password_update.project_id,
236 0 : password_update.role_name,
237 0 : ),
238 0 : Notification::BlockPublicOrVpcAccessUpdated { .. } => {
239 0 : // https://github.com/neondatabase/neon/pull/10073
240 0 : }
241 0 : Notification::AllowedVpcEndpointsUpdatedForOrg { .. } => {
242 0 : // https://github.com/neondatabase/neon/pull/10073
243 0 : }
244 0 : Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => {
245 0 : // https://github.com/neondatabase/neon/pull/10073
246 0 : }
247 0 : Notification::UnknownTopic => unreachable!(),
248 : }
249 0 : }
250 :
251 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
252 0 : handler: MessageHandler<C>,
253 0 : redis: ConnectionWithCredentialsProvider,
254 0 : cancellation_token: CancellationToken,
255 0 : ) -> anyhow::Result<()> {
256 : loop {
257 0 : if cancellation_token.is_cancelled() {
258 0 : return Ok(());
259 0 : }
260 0 : let mut conn = match try_connect(&redis).await {
261 0 : Ok(conn) => {
262 0 : handler.increment_active_listeners().await;
263 0 : conn
264 : }
265 0 : Err(e) => {
266 0 : tracing::error!(
267 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
268 : );
269 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
270 0 : continue;
271 : }
272 : };
273 0 : let mut stream = conn.on_message();
274 0 : while let Some(msg) = stream.next().await {
275 0 : match handler.handle_message(msg).await {
276 0 : Ok(()) => {}
277 0 : Err(e) => {
278 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
279 0 : break;
280 : }
281 : }
282 0 : if cancellation_token.is_cancelled() {
283 0 : handler.decrement_active_listeners().await;
284 0 : return Ok(());
285 0 : }
286 : }
287 0 : handler.decrement_active_listeners().await;
288 : }
289 0 : }
290 :
291 : /// Handle console's invalidation messages.
292 : #[tracing::instrument(name = "redis_notifications", skip_all)]
293 : pub async fn task_main<C>(
294 : redis: ConnectionWithCredentialsProvider,
295 : cache: Arc<C>,
296 : region_id: String,
297 : ) -> anyhow::Result<Infallible>
298 : where
299 : C: ProjectInfoCache + Send + Sync + 'static,
300 : {
301 : let handler = MessageHandler::new(cache, region_id);
302 : // 6h - 1m.
303 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
304 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
305 : loop {
306 : let cancellation_token = CancellationToken::new();
307 : interval.tick().await;
308 :
309 : tokio::spawn(handle_messages(
310 : handler.clone(),
311 : redis.clone(),
312 : cancellation_token.clone(),
313 : ));
314 0 : tokio::spawn(async move {
315 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
316 0 : cancellation_token.cancel();
317 0 : });
318 : }
319 : }
320 :
321 : #[cfg(test)]
322 : mod tests {
323 : use serde_json::json;
324 :
325 : use super::*;
326 : use crate::types::{ProjectId, RoleName};
327 :
328 : #[test]
329 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
330 1 : let project_id: ProjectId = "new_project".into();
331 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
332 1 : let text = json!({
333 1 : "type": "message",
334 1 : "topic": "/allowed_ips_updated",
335 1 : "data": data,
336 1 : "extre_fields": "something"
337 1 : })
338 1 : .to_string();
339 :
340 1 : let result: Notification = serde_json::from_str(&text)?;
341 1 : assert_eq!(
342 1 : result,
343 1 : Notification::AllowedIpsUpdate {
344 1 : allowed_ips_update: AllowedIpsUpdate {
345 1 : project_id: (&project_id).into()
346 1 : }
347 1 : }
348 1 : );
349 :
350 1 : Ok(())
351 1 : }
352 :
353 : #[test]
354 1 : fn parse_password_updated() -> anyhow::Result<()> {
355 1 : let project_id: ProjectId = "new_project".into();
356 1 : let role_name: RoleName = "new_role".into();
357 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
358 1 : let text = json!({
359 1 : "type": "message",
360 1 : "topic": "/password_updated",
361 1 : "data": data,
362 1 : "extre_fields": "something"
363 1 : })
364 1 : .to_string();
365 :
366 1 : let result: Notification = serde_json::from_str(&text)?;
367 1 : assert_eq!(
368 1 : result,
369 1 : Notification::PasswordUpdate {
370 1 : password_update: PasswordUpdate {
371 1 : project_id: (&project_id).into(),
372 1 : role_name: (&role_name).into(),
373 1 : }
374 1 : }
375 1 : );
376 :
377 1 : Ok(())
378 1 : }
379 :
380 : #[test]
381 1 : fn parse_unknown_topic() -> anyhow::Result<()> {
382 1 : let with_data = json!({
383 1 : "type": "message",
384 1 : "topic": "/doesnotexist",
385 1 : "data": {
386 1 : "payload": "ignored"
387 1 : },
388 1 : "extra_fields": "something"
389 1 : })
390 1 : .to_string();
391 1 : let result: Notification = serde_json::from_str(&with_data)?;
392 1 : assert_eq!(result, Notification::UnknownTopic);
393 :
394 1 : let without_data = json!({
395 1 : "type": "message",
396 1 : "topic": "/doesnotexist",
397 1 : "extra_fields": "something"
398 1 : })
399 1 : .to_string();
400 1 : let result: Notification = serde_json::from_str(&without_data)?;
401 1 : assert_eq!(result, Notification::UnknownTopic);
402 :
403 1 : Ok(())
404 1 : }
405 : }
|