Line data Source code
1 : use std::{convert::Infallible, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use pq_proto::CancelKeyData;
5 : use redis::aio::PubSub;
6 : use serde::{Deserialize, Serialize};
7 : use tokio_util::sync::CancellationToken;
8 : use uuid::Uuid;
9 :
10 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
11 : use crate::{
12 : cache::project_info::ProjectInfoCache,
13 : cancellation::{CancelMap, CancellationHandler},
14 : intern::{ProjectIdInt, RoleNameInt},
15 : metrics::{Metrics, RedisErrors, RedisEventsCount},
16 : };
17 :
18 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
19 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
20 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
21 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
22 :
23 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
24 0 : let mut conn = client.get_async_pubsub().await?;
25 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
26 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
27 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
28 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
29 0 : Ok(conn)
30 0 : }
31 :
32 20 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
33 : #[serde(tag = "topic", content = "data")]
34 : pub(crate) enum Notification {
35 : #[serde(
36 : rename = "/allowed_ips_updated",
37 : deserialize_with = "deserialize_json_string"
38 : )]
39 : AllowedIpsUpdate {
40 : allowed_ips_update: AllowedIpsUpdate,
41 : },
42 : #[serde(
43 : rename = "/password_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : PasswordUpdate { password_update: PasswordUpdate },
47 : #[serde(rename = "/cancel_session")]
48 : Cancel(CancelSession),
49 : }
50 4 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
51 : pub(crate) struct AllowedIpsUpdate {
52 : project_id: ProjectIdInt,
53 : }
54 6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
55 : pub(crate) struct PasswordUpdate {
56 : project_id: ProjectIdInt,
57 : role_name: RoleNameInt,
58 : }
59 16 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
60 : pub(crate) struct CancelSession {
61 : pub region_id: Option<String>,
62 : pub cancel_key_data: CancelKeyData,
63 : pub session_id: Uuid,
64 : }
65 :
66 4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
67 4 : where
68 4 : T: for<'de2> serde::Deserialize<'de2>,
69 4 : D: serde::Deserializer<'de>,
70 4 : {
71 4 : let s = String::deserialize(deserializer)?;
72 4 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
73 4 : }
74 :
75 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
76 : cache: Arc<C>,
77 : cancellation_handler: Arc<CancellationHandler<()>>,
78 : region_id: String,
79 : }
80 :
81 : impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
82 0 : fn clone(&self) -> Self {
83 0 : Self {
84 0 : cache: self.cache.clone(),
85 0 : cancellation_handler: self.cancellation_handler.clone(),
86 0 : region_id: self.region_id.clone(),
87 0 : }
88 0 : }
89 : }
90 :
91 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
92 0 : pub fn new(
93 0 : cache: Arc<C>,
94 0 : cancellation_handler: Arc<CancellationHandler<()>>,
95 0 : region_id: String,
96 0 : ) -> Self {
97 0 : Self {
98 0 : cache,
99 0 : cancellation_handler,
100 0 : region_id,
101 0 : }
102 0 : }
103 0 : pub async fn increment_active_listeners(&self) {
104 0 : self.cache.increment_active_listeners().await;
105 0 : }
106 0 : pub async fn decrement_active_listeners(&self) {
107 0 : self.cache.decrement_active_listeners().await;
108 0 : }
109 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
110 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
111 : use Notification::*;
112 : let payload: String = msg.get_payload()?;
113 : tracing::debug!(?payload, "received a message payload");
114 :
115 : let msg: Notification = match serde_json::from_str(&payload) {
116 : Ok(msg) => msg,
117 : Err(e) => {
118 : Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
119 : channel: msg.get_channel_name(),
120 : });
121 : tracing::error!("broken message: {e}");
122 : return Ok(());
123 : }
124 : };
125 : tracing::debug!(?msg, "received a message");
126 : match msg {
127 : Cancel(cancel_session) => {
128 : tracing::Span::current().record(
129 : "session_id",
130 : &tracing::field::display(cancel_session.session_id),
131 : );
132 : Metrics::get()
133 : .proxy
134 : .redis_events_count
135 : .inc(RedisEventsCount::CancelSession);
136 : if let Some(cancel_region) = cancel_session.region_id {
137 : // If the message is not for this region, ignore it.
138 : if cancel_region != self.region_id {
139 : return Ok(());
140 : }
141 : }
142 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
143 : match self
144 : .cancellation_handler
145 : .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
146 : .await
147 : {
148 : Ok(()) => {}
149 : Err(e) => {
150 : tracing::error!("failed to cancel session: {e}");
151 : }
152 : }
153 : }
154 : _ => {
155 : invalidate_cache(self.cache.clone(), msg.clone());
156 : if matches!(msg, AllowedIpsUpdate { .. }) {
157 : Metrics::get()
158 : .proxy
159 : .redis_events_count
160 : .inc(RedisEventsCount::AllowedIpsUpdate);
161 : } else if matches!(msg, PasswordUpdate { .. }) {
162 : Metrics::get()
163 : .proxy
164 : .redis_events_count
165 : .inc(RedisEventsCount::PasswordUpdate);
166 : }
167 : // It might happen that the invalid entry is on the way to be cached.
168 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
169 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
170 : let cache = self.cache.clone();
171 0 : tokio::spawn(async move {
172 0 : tokio::time::sleep(INVALIDATION_LAG).await;
173 0 : invalidate_cache(cache, msg);
174 0 : });
175 : }
176 : }
177 :
178 : Ok(())
179 : }
180 : }
181 :
182 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
183 0 : use Notification::*;
184 0 : match msg {
185 0 : AllowedIpsUpdate { allowed_ips_update } => {
186 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
187 : }
188 0 : PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
189 0 : password_update.project_id,
190 0 : password_update.role_name,
191 0 : ),
192 0 : Cancel(_) => unreachable!("cancel message should be handled separately"),
193 : }
194 0 : }
195 :
196 0 : async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
197 0 : handler: MessageHandler<C>,
198 0 : redis: ConnectionWithCredentialsProvider,
199 0 : cancellation_token: CancellationToken,
200 0 : ) -> anyhow::Result<()> {
201 : loop {
202 0 : if cancellation_token.is_cancelled() {
203 0 : return Ok(());
204 0 : }
205 0 : let mut conn = match try_connect(&redis).await {
206 0 : Ok(conn) => {
207 0 : handler.increment_active_listeners().await;
208 0 : conn
209 : }
210 0 : Err(e) => {
211 0 : tracing::error!(
212 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
213 : );
214 0 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
215 0 : continue;
216 : }
217 : };
218 0 : let mut stream = conn.on_message();
219 0 : while let Some(msg) = stream.next().await {
220 0 : match handler.handle_message(msg).await {
221 0 : Ok(()) => {}
222 0 : Err(e) => {
223 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
224 0 : break;
225 : }
226 : }
227 0 : if cancellation_token.is_cancelled() {
228 0 : handler.decrement_active_listeners().await;
229 0 : return Ok(());
230 0 : }
231 : }
232 0 : handler.decrement_active_listeners().await;
233 : }
234 0 : }
235 :
236 : /// Handle console's invalidation messages.
237 0 : #[tracing::instrument(name = "redis_notifications", skip_all)]
238 : pub async fn task_main<C>(
239 : redis: ConnectionWithCredentialsProvider,
240 : cache: Arc<C>,
241 : cancel_map: CancelMap,
242 : region_id: String,
243 : ) -> anyhow::Result<Infallible>
244 : where
245 : C: ProjectInfoCache + Send + Sync + 'static,
246 : {
247 : let cancellation_handler = Arc::new(CancellationHandler::<()>::new(
248 : cancel_map,
249 : crate::metrics::CancellationSource::FromRedis,
250 : ));
251 : let handler = MessageHandler::new(cache, cancellation_handler, region_id);
252 : // 6h - 1m.
253 : // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
254 : let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
255 : loop {
256 : let cancellation_token = CancellationToken::new();
257 : interval.tick().await;
258 :
259 : tokio::spawn(handle_messages(
260 : handler.clone(),
261 : redis.clone(),
262 : cancellation_token.clone(),
263 : ));
264 0 : tokio::spawn(async move {
265 0 : tokio::time::sleep(std::time::Duration::from_secs(6 * 60 * 60)).await; // 6h.
266 0 : cancellation_token.cancel();
267 0 : });
268 : }
269 : }
270 :
271 : #[cfg(test)]
272 : mod tests {
273 : use crate::{ProjectId, RoleName};
274 :
275 : use super::*;
276 : use serde_json::json;
277 :
278 : #[test]
279 2 : fn parse_allowed_ips() -> anyhow::Result<()> {
280 2 : let project_id: ProjectId = "new_project".into();
281 2 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
282 2 : let text = json!({
283 2 : "type": "message",
284 2 : "topic": "/allowed_ips_updated",
285 2 : "data": data,
286 2 : "extre_fields": "something"
287 2 : })
288 2 : .to_string();
289 :
290 2 : let result: Notification = serde_json::from_str(&text)?;
291 2 : assert_eq!(
292 2 : result,
293 2 : Notification::AllowedIpsUpdate {
294 2 : allowed_ips_update: AllowedIpsUpdate {
295 2 : project_id: (&project_id).into()
296 2 : }
297 2 : }
298 2 : );
299 :
300 2 : Ok(())
301 2 : }
302 :
303 : #[test]
304 2 : fn parse_password_updated() -> anyhow::Result<()> {
305 2 : let project_id: ProjectId = "new_project".into();
306 2 : let role_name: RoleName = "new_role".into();
307 2 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
308 2 : let text = json!({
309 2 : "type": "message",
310 2 : "topic": "/password_updated",
311 2 : "data": data,
312 2 : "extre_fields": "something"
313 2 : })
314 2 : .to_string();
315 :
316 2 : let result: Notification = serde_json::from_str(&text)?;
317 2 : assert_eq!(
318 2 : result,
319 2 : Notification::PasswordUpdate {
320 2 : password_update: PasswordUpdate {
321 2 : project_id: (&project_id).into(),
322 2 : role_name: (&role_name).into(),
323 2 : }
324 2 : }
325 2 : );
326 :
327 2 : Ok(())
328 2 : }
329 : #[test]
330 2 : fn parse_cancel_session() -> anyhow::Result<()> {
331 2 : let cancel_key_data = CancelKeyData {
332 2 : backend_pid: 42,
333 2 : cancel_key: 41,
334 2 : };
335 2 : let uuid = uuid::Uuid::new_v4();
336 2 : let msg = Notification::Cancel(CancelSession {
337 2 : cancel_key_data,
338 2 : region_id: None,
339 2 : session_id: uuid,
340 2 : });
341 2 : let text = serde_json::to_string(&msg)?;
342 2 : let result: Notification = serde_json::from_str(&text)?;
343 2 : assert_eq!(msg, result);
344 :
345 2 : let msg = Notification::Cancel(CancelSession {
346 2 : cancel_key_data,
347 2 : region_id: Some("region".to_string()),
348 2 : session_id: uuid,
349 2 : });
350 2 : let text = serde_json::to_string(&msg)?;
351 2 : let result: Notification = serde_json::from_str(&text)?;
352 2 : assert_eq!(msg, result,);
353 :
354 2 : Ok(())
355 2 : }
356 : }
|