Line data Source code
1 : use std::convert::Infallible;
2 : use std::sync::Arc;
3 :
4 : use futures::StreamExt;
5 : use pq_proto::CancelKeyData;
6 : use redis::aio::PubSub;
7 : use serde::{Deserialize, Serialize};
8 : use tokio_util::sync::CancellationToken;
9 : use uuid::Uuid;
10 :
11 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
12 : use crate::cache::project_info::ProjectInfoCache;
13 : use crate::cancellation::{CancelMap, CancellationHandler};
14 : use crate::intern::{ProjectIdInt, RoleNameInt};
15 : use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
16 :
17 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
18 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
19 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
20 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
21 :
22 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
23 0 : let mut conn = client.get_async_pubsub().await?;
24 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
25 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
26 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
27 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
28 0 : Ok(conn)
29 0 : }
30 :
31 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
32 : #[serde(tag = "topic", content = "data")]
33 : pub(crate) enum Notification {
34 : #[serde(
35 : rename = "/allowed_ips_updated",
36 : deserialize_with = "deserialize_json_string"
37 : )]
38 : AllowedIpsUpdate {
39 : allowed_ips_update: AllowedIpsUpdate,
40 : },
41 : #[serde(
42 : rename = "/password_updated",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : PasswordUpdate { password_update: PasswordUpdate },
46 : #[serde(rename = "/cancel_session")]
47 : Cancel(CancelSession),
48 : }
49 2 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
50 : pub(crate) struct AllowedIpsUpdate {
51 : project_id: ProjectIdInt,
52 : }
53 3 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
54 : pub(crate) struct PasswordUpdate {
55 : project_id: ProjectIdInt,
56 : role_name: RoleNameInt,
57 : }
58 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
59 : pub(crate) struct CancelSession {
60 : pub(crate) region_id: Option<String>,
61 : pub(crate) cancel_key_data: CancelKeyData,
62 : pub(crate) session_id: Uuid,
63 : pub(crate) peer_addr: Option<std::net::IpAddr>,
64 : }
65 :
66 2 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
67 2 : where
68 2 : T: for<'de2> serde::Deserialize<'de2>,
69 2 : D: serde::Deserializer<'de>,
70 2 : {
71 2 : let s = String::deserialize(deserializer)?;
72 2 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
73 2 : }
74 :
75 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
76 : cache: Arc<C>,
77 : cancellation_handler: Arc<CancellationHandler<()>>,
78 : region_id: String,
79 : }
80 :
81 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
82 0 : fn clone(&self) -> Self {
83 0 : Self {
84 0 : cache: self.cache.clone(),
85 0 : cancellation_handler: self.cancellation_handler.clone(),
86 0 : region_id: self.region_id.clone(),
87 0 : }
88 0 : }
89 : }
90 :
91 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
92 0 : pub(crate) fn new(
93 0 : cache: Arc<C>,
94 0 : cancellation_handler: Arc<CancellationHandler<()>>,
95 0 : region_id: String,
96 0 : ) -> Self {
97 0 : Self {
98 0 : cache,
99 0 : cancellation_handler,
100 0 : region_id,
101 0 : }
102 0 : }
103 0 : pub(crate) async fn increment_active_listeners(&self) {
104 0 : self.cache.increment_active_listeners().await;
105 0 : }
106 0 : pub(crate) async fn decrement_active_listeners(&self) {
107 0 : self.cache.decrement_active_listeners().await;
108 0 : }
109 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
110 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
111 : let payload: String = msg.get_payload()?;
112 : tracing::debug!(?payload, "received a message payload");
113 :
114 : let msg: Notification = match serde_json::from_str(&payload) {
115 : Ok(msg) => msg,
116 : Err(e) => {
117 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
118 : channel: msg.get_channel_name(),
119 : });
120 : tracing::error!("broken message: {e}");
121 : return Ok(());
122 : }
123 : };
124 : tracing::debug!(?msg, "received a message");
125 : match msg {
126 : Notification::Cancel(cancel_session) => {
127 : tracing::Span::current().record(
128 : "session_id",
129 : tracing::field::display(cancel_session.session_id),
130 : );
131 : Metrics::get()
132 : .proxy
133 : .redis_events_count
134 : .inc(RedisEventsCount::CancelSession);
135 : if let Some(cancel_region) = cancel_session.region_id {
136 : // If the message is not for this region, ignore it.
137 : if cancel_region != self.region_id {
138 : return Ok(());
139 : }
140 : }
141 :
142 : // TODO: Remove unspecified peer_addr after the complete migration to the new format
143 : let peer_addr = cancel_session
144 : .peer_addr
145 : .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
146 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
147 : match self
148 : .cancellation_handler
149 : .cancel_session(
150 : cancel_session.cancel_key_data,
151 : uuid::Uuid::nil(),
152 : &peer_addr,
153 : cancel_session.peer_addr.is_some(),
154 : )
155 : .await
156 : {
157 : Ok(()) => {}
158 : Err(e) => {
159 : tracing::warn!("failed to cancel session: {e}");
160 : }
161 : }
162 : }
163 : Notification::AllowedIpsUpdate { .. } | Notification::PasswordUpdate { .. } => {
164 : invalidate_cache(self.cache.clone(), msg.clone());
165 : if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
166 : Metrics::get()
167 : .proxy
168 : .redis_events_count
169 : .inc(RedisEventsCount::AllowedIpsUpdate);
170 : } else if matches!(msg, Notification::PasswordUpdate { .. }) {
171 : Metrics::get()
172 : .proxy
173 : .redis_events_count
174 : .inc(RedisEventsCount::PasswordUpdate);
175 : }
176 : // It might happen that the invalid entry is on the way to be cached.
177 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
178 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
179 : let cache = self.cache.clone();
180 0 : tokio::spawn(async move {
181 0 : tokio::time::sleep(INVALIDATION_LAG).await;
182 0 : invalidate_cache(cache, msg);
183 0 : });
184 : }
185 : }
186 :
187 : Ok(())
188 : }
189 : }
190 :
191 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
192 0 : match msg {
193 0 : Notification::AllowedIpsUpdate { allowed_ips_update } => {
194 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
195 0 : }
196 0 : Notification::PasswordUpdate { password_update } => cache
197 0 : .invalidate_role_secret_for_project(
198 0 : password_update.project_id,
199 0 : password_update.role_name,
200 0 : ),
201 0 : Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
202 : }
203 0 : }
204 :
205 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
206 0 : handler: MessageHandler<C>,
207 0 : redis: ConnectionWithCredentialsProvider,
208 0 : cancellation_token: CancellationToken,
209 0 : ) -> anyhow::Result<()> {
210 : loop {
211 0 : if cancellation_token.is_cancelled() {
212 0 : return Ok(());
213 0 : }
214 0 : let mut conn = match try_connect(&redis).await {
215 0 : Ok(conn) => {
216 0 : handler.increment_active_listeners().await;
217 0 : conn
218 : }
219 0 : Err(e) => {
220 0 : tracing::error!(
221 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
222 : );
223 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
224 0 : continue;
225 : }
226 : };
227 0 : let mut stream = conn.on_message();
228 0 : while let Some(msg) = stream.next().await {
229 0 : match handler.handle_message(msg).await {
230 0 : Ok(()) => {}
231 0 : Err(e) => {
232 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
233 0 : break;
234 : }
235 : }
236 0 : if cancellation_token.is_cancelled() {
237 0 : handler.decrement_active_listeners().await;
238 0 : return Ok(());
239 0 : }
240 : }
241 0 : handler.decrement_active_listeners().await;
242 : }
243 0 : }
244 :
245 : /// Handle console's invalidation messages.
246 : #[tracing::instrument(name = "redis_notifications", skip_all)]
247 : pub async fn task_main<C>(
248 : redis: ConnectionWithCredentialsProvider,
249 : cache: Arc<C>,
250 : cancel_map: CancelMap,
251 : region_id: String,
252 : ) -> anyhow::Result<Infallible>
253 : where
254 : C: ProjectInfoCache + Send + Sync + 'static,
255 : {
256 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
257 : cancel_map,
258 : crate::metrics::CancellationSource::FromRedis,
259 : ));
260 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
261 : // 6h - 1m.
262 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
263 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
264 : loop {
265 : let cancellation_token = CancellationToken::new();
266 : interval.tick().await;
267 :
268 : tokio::spawn(handle_messages(
269 : handler.clone(),
270 : redis.clone(),
271 : cancellation_token.clone(),
272 : ));
273 0 : tokio::spawn(async move {
274 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
275 0 : cancellation_token.cancel();
276 0 : });
277 : }
278 : }
279 :
280 : #[cfg(test)]
281 : mod tests {
282 : use serde_json::json;
283 :
284 : use super::*;
285 : use crate::types::{ProjectId, RoleName};
286 :
287 : #[test]
288 1 : fn parse_allowed_ips() -> anyhow::Result<()> {
289 1 : let project_id: ProjectId = "new_project".into();
290 1 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
291 1 : let text = json!({
292 1 : "type": "message",
293 1 : "topic": "/allowed_ips_updated",
294 1 : "data": data,
295 1 : "extre_fields": "something"
296 1 : })
297 1 : .to_string();
298 :
299 1 : let result: Notification = serde_json::from_str(&text)?;
300 1 : assert_eq!(
301 1 : result,
302 1 : Notification::AllowedIpsUpdate {
303 1 : allowed_ips_update: AllowedIpsUpdate {
304 1 : project_id: (&project_id).into()
305 1 : }
306 1 : }
307 1 : );
308 :
309 1 : Ok(())
310 1 : }
311 :
312 : #[test]
313 1 : fn parse_password_updated() -> anyhow::Result<()> {
314 1 : let project_id: ProjectId = "new_project".into();
315 1 : let role_name: RoleName = "new_role".into();
316 1 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
317 1 : let text = json!({
318 1 : "type": "message",
319 1 : "topic": "/password_updated",
320 1 : "data": data,
321 1 : "extre_fields": "something"
322 1 : })
323 1 : .to_string();
324 :
325 1 : let result: Notification = serde_json::from_str(&text)?;
326 1 : assert_eq!(
327 1 : result,
328 1 : Notification::PasswordUpdate {
329 1 : password_update: PasswordUpdate {
330 1 : project_id: (&project_id).into(),
331 1 : role_name: (&role_name).into(),
332 1 : }
333 1 : }
334 1 : );
335 :
336 1 : Ok(())
337 1 : }
338 : #[test]
339 1 : fn parse_cancel_session() -> anyhow::Result<()> {
340 1 : let cancel_key_data = CancelKeyData {
341 1 : backend_pid: 42,
342 1 : cancel_key: 41,
343 1 : };
344 1 : let uuid = uuid::Uuid::new_v4();
345 1 : let msg = Notification::Cancel(CancelSession {
346 1 : cancel_key_data,
347 1 : region_id: None,
348 1 : session_id: uuid,
349 1 : peer_addr: None,
350 1 : });
351 1 : let text = serde_json::to_string(&msg)?;
352 1 : let result: Notification = serde_json::from_str(&text)?;
353 1 : assert_eq!(msg, result);
354 :
355 1 : let msg = Notification::Cancel(CancelSession {
356 1 : cancel_key_data,
357 1 : region_id: Some("region".to_string()),
358 1 : session_id: uuid,
359 1 : peer_addr: None,
360 1 : });
361 1 : let text = serde_json::to_string(&msg)?;
362 1 : let result: Notification = serde_json::from_str(&text)?;
363 1 : assert_eq!(msg, result,);
364 :
365 1 : Ok(())
366 1 : }
367 : }
|