Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use tracing::Instrument;
10 : use uuid::Uuid;
11 :
12 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
13 : use crate::cache::project_info::ProjectInfoCache;
14 : use crate::cancellation::{CancelMap, CancellationHandler};
15 : use crate::intern::{ProjectIdInt, RoleNameInt};
16 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
17 :
18 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
19 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
20 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
21 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
22 :
23 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
24 0 : let mut conn = client.get_async_pubsub().await?;
25 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
26 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
27 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
28 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
29 0 : Ok(conn)
30 0 : }
31 :
32 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
33 : #[serde(tag = "topic", content = "data")]
34 : pub(crate) enum Notification {
35 : #[serde(
36 : rename = "/allowed_ips_updated",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AllowedIpsUpdate {
40 : allowed_ips_update: AllowedIpsUpdate,
41 : },
42 : #[serde(
43 : rename = "/password_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : PasswordUpdate { password_update: PasswordUpdate },
47 : #[serde(rename = "/cancel_session")]
48 : Cancel(CancelSession),
49 : }
50 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
51 : pub(crate) struct AllowedIpsUpdate {
52 : project_id: ProjectIdInt,
53 : }
54 3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
55 : pub(crate) struct PasswordUpdate {
56 : project_id: ProjectIdInt,
57 : role_name: RoleNameInt,
58 : }
59 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
60 : pub(crate) struct CancelSession {
61 : pub(crate) region_id: Option<String>,
62 : pub(crate) cancel_key_data: CancelKeyData,
63 : pub(crate) session_id: Uuid,
64 : pub(crate) peer_addr: Option<std::net::IpAddr>,
65 : }
66 :
67 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
68 2 : where
69 2 : T: for<'de2> serde::Deserialize<'de2>,
70 2 : D: serde::Deserializer<'de>,
71 2 : {
72 2 : let s = String::deserialize(deserializer)?;
73 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
74 2 : }
75 :
76 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
77 : cache: Arc<C>,
78 : cancellation_handler: Arc<CancellationHandler<()>>,
79 : region_id: String,
80 : }
81 :
82 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
83 0 : fn clone(&self) -> Self {
84 0 : Self {
85 0 : cache: self.cache.clone(),
86 0 : cancellation_handler: self.cancellation_handler.clone(),
87 0 : region_id: self.region_id.clone(),
88 0 : }
89 0 : }
90 : }
91 :
92 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
93 0 : pub(crate) fn new(
94 0 : cache: Arc<C>,
95 0 : cancellation_handler: Arc<CancellationHandler<()>>,
96 0 : region_id: String,
97 0 : ) -> Self {
98 0 : Self {
99 0 : cache,
100 0 : cancellation_handler,
101 0 : region_id,
102 0 : }
103 0 : }
104 0 : pub(crate) async fn increment_active_listeners(&self) {
105 0 : self.cache.increment_active_listeners().await;
106 0 : }
107 0 : pub(crate) async fn decrement_active_listeners(&self) {
108 0 : self.cache.decrement_active_listeners().await;
109 0 : }
110 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
111 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
112 : let payload: String = msg.get_payload()?;
113 : tracing::debug!(?payload, "received a message payload");
114 :
115 : let msg: Notification = match serde_json::from_str(&payload) {
116 : Ok(msg) => msg,
117 : Err(e) => {
118 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
119 : channel: msg.get_channel_name(),
120 : });
121 : tracing::error!("broken message: {e}");
122 : return Ok(());
123 : }
124 : };
125 : tracing::debug!(?msg, "received a message");
126 : match msg {
127 : Notification::Cancel(cancel_session) => {
128 : tracing::Span::current().record(
129 : "session_id",
130 : tracing::field::display(cancel_session.session_id),
131 : );
132 : Metrics::get()
133 : .proxy
134 : .redis_events_count
135 : .inc(RedisEventsCount::CancelSession);
136 : if let Some(cancel_region) = cancel_session.region_id {
137 : // If the message is not for this region, ignore it.
138 : if cancel_region != self.region_id {
139 : return Ok(());
140 : }
141 : }
142 :
143 : // TODO: Remove unspecified peer_addr after the complete migration to the new format
144 : let peer_addr = cancel_session
145 : .peer_addr
146 : .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
147 : let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
148 : cancel_span.follows_from(tracing::Span::current());
149 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
150 : match self
151 : .cancellation_handler
152 : .cancel_session(
153 : cancel_session.cancel_key_data,
154 : uuid::Uuid::nil(),
155 : peer_addr,
156 : cancel_session.peer_addr.is_some(),
157 : )
158 : .instrument(cancel_span)
159 : .await
160 : {
161 : Ok(()) => {}
162 : Err(e) => {
163 : tracing::warn!("failed to cancel session: {e}");
164 : }
165 : }
166 : }
167 : Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
168 : invalidate_cache(self.cache.clone(), msg.clone());
169 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
170 : Metrics::get()
171 : .proxy
172 : .redis_events_count
173 : .inc(RedisEventsCount::AllowedIpsUpdate);
174 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
175 : Metrics::get()
176 : .proxy
177 : .redis_events_count
178 : .inc(RedisEventsCount::PasswordUpdate);
179 : }
180 : // It might happen that the invalid entry is on the way to be cached.
181 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
182 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
183 : let cache = self.cache.clone();
184 0 : tokio::spawn(async move {
185 0 : tokio::time::sleep(INVALIDATION_LAG).await;
186 0 : invalidate_cache(cache, msg);
187 0 : });
188 : }
189 : }
190 :
191 : Ok(())
192 : }
193 : }
194 :
195 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
196 0 : match msg {
197 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
198 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
199 0 : }
200 0 : Notification::PasswordUpdate { password_update } => cache
201 0 : .invalidate_role_secret_for_project(
202 0 : password_update.project_id,
203 0 : password_update.role_name,
204 0 : ),
205 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
206 : }
207 0 : }
208 :
209 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
210 0 : handler: MessageHandler<C>,
211 0 : redis: ConnectionWithCredentialsProvider,
212 0 : cancellation_token: CancellationToken,
213 0 : ) -> anyhow::Result<()> {
214 : loop {
215 0 : if cancellation_token.is_cancelled() {
216 0 : return Ok(());
217 0 : }
218 0 : let mut conn = match try_connect(&redis).await {
219 0 : Ok(conn) => {
220 0 : handler.increment_active_listeners().await;
221 0 : conn
222 : }
223 0 : Err(e) => {
224 0 : tracing::error!(
225 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
226 : );
227 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
228 0 : continue;
229 : }
230 : };
231 0 : let mut stream = conn.on_message();
232 0 : while let Some(msg) = stream.next().await {
233 0 : match handler.handle_message(msg).await {
234 0 : Ok(()) => {}
235 0 : Err(e) => {
236 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
237 0 : break;
238 : }
239 : }
240 0 : if cancellation_token.is_cancelled() {
241 0 : handler.decrement_active_listeners().await;
242 0 : return Ok(());
243 0 : }
244 : }
245 0 : handler.decrement_active_listeners().await;
246 : }
247 0 : }
248 :
249 : /// Handle console's invalidation messages.
250 : #[tracing::instrument(name = "redis_notifications", skip_all)]
251 : pub async fn task_main<C>(
252 : redis: ConnectionWithCredentialsProvider,
253 : cache: Arc<C>,
254 : cancel_map: CancelMap,
255 : region_id: String,
256 : ) -> anyhow::Result<Infallible>
257 : where
258 : C: ProjectInfoCache + Send + Sync + 'static,
259 : {
260 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
261 : cancel_map,
262 : crate::metrics::CancellationSource::FromRedis,
263 : ));
264 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
265 : // 6h - 1m.
266 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
267 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
268 : loop {
269 : let cancellation_token = CancellationToken::new();
270 : interval.tick().await;
271 :
272 : tokio::spawn(handle_messages(
273 : handler.clone(),
274 : redis.clone(),
275 : cancellation_token.clone(),
276 : ));
277 0 : tokio::spawn(async move {
278 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
279 0 : cancellation_token.cancel();
280 0 : });
281 : }
282 : }
283 :
284 : #[cfg(test)]
285 : mod tests {
286 : use serde_json::json;
287 :
288 : use super::*;
289 : use crate::types::{ProjectId, RoleName};
290 :
291 : #[test]
292 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
293 1 : let project_id: ProjectId = "new_project".into();
294 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
295 1 : let text = json!({
296 1 : "type": "message",
297 1 : "topic": "/allowed_ips_updated",
298 1 : "data": data,
299 1 : "extre_fields": "something"
300 1 : })
301 1 : .to_string();
302 :
303 1 : let result: Notification = serde_json::from_str(&text)?;
304 1 : assert_eq!(
305 1 : result,
306 1 : Notification::AllowedIpsUpdate {
307 1 : allowed_ips_update: AllowedIpsUpdate {
308 1 : project_id: (&project_id).into()
309 1 : }
310 1 : }
311 1 : );
312 :
313 1 : Ok(())
314 1 : }
315 :
316 : #[test]
317 1 : fn parse_password_updated() -> anyhow::Result<()> {
318 1 : let project_id: ProjectId = "new_project".into();
319 1 : let role_name: RoleName = "new_role".into();
320 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
321 1 : let text = json!({
322 1 : "type": "message",
323 1 : "topic": "/password_updated",
324 1 : "data": data,
325 1 : "extre_fields": "something"
326 1 : })
327 1 : .to_string();
328 :
329 1 : let result: Notification = serde_json::from_str(&text)?;
330 1 : assert_eq!(
331 1 : result,
332 1 : Notification::PasswordUpdate {
333 1 : password_update: PasswordUpdate {
334 1 : project_id: (&project_id).into(),
335 1 : role_name: (&role_name).into(),
336 1 : }
337 1 : }
338 1 : );
339 :
340 1 : Ok(())
341 1 : }
342 : #[test]
343 1 : fn parse_cancel_session() -> anyhow::Result<()> {
344 1 : let cancel_key_data = CancelKeyData {
345 1 : backend_pid: 42,
346 1 : cancel_key: 41,
347 1 : };
348 1 : let uuid = uuid::Uuid::new_v4();
349 1 : let msg = Notification::Cancel(CancelSession {
350 1 : cancel_key_data,
351 1 : region_id: None,
352 1 : session_id: uuid,
353 1 : peer_addr: None,
354 1 : });
355 1 : let text = serde_json::to_string(&msg)?;
356 1 : let result: Notification = serde_json::from_str(&text)?;
357 1 : assert_eq!(msg, result);
358 :
359 1 : let msg = Notification::Cancel(CancelSession {
360 1 : cancel_key_data,
361 1 : region_id: Some("region".to_string()),
362 1 : session_id: uuid,
363 1 : peer_addr: None,
364 1 : });
365 1 : let text = serde_json::to_string(&msg)?;
366 1 : let result: Notification = serde_json::from_str(&text)?;
367 1 : assert_eq!(msg, result,);
368 :
369 1 : Ok(())
370 1 : }
371 : }
|