Line data Source code
1 : use std::{convert::Infallible, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use pq_proto::CancelKeyData;
5 : use redis::aio::PubSub;
6 : use serde::{Deserialize, Serialize};
7 : use uuid::Uuid;
8 :
9 : use crate::{
10 : cache::project_info::ProjectInfoCache,
11 : cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
12 : intern::{ProjectIdInt, RoleNameInt},
13 : };
14 :
15 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
16 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
17 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
18 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
19 :
20 : struct RedisConsumerClient {
21 : client: redis::Client,
22 : }
23 :
24 : impl RedisConsumerClient {
25 0 : pub fn new(url: &str) -> anyhow::Result<Self> {
26 0 : let client = redis::Client::open(url)?;
27 0 : Ok(Self { client })
28 0 : }
29 0 : async fn try_connect(&self) -> anyhow::Result<PubSub> {
30 0 : let mut conn = self.client.get_async_connection().await?.into_pubsub();
31 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
32 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
33 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
34 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
35 0 : Ok(conn)
36 0 : }
37 : }
38 :
39 20 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
40 : #[serde(tag = "topic", content = "data")]
41 : pub(crate) enum Notification {
42 : #[serde(
43 : rename = "/allowed_ips_updated",
44 : deserialize_with = "deserialize_json_string"
45 : )]
46 : AllowedIpsUpdate {
47 : allowed_ips_update: AllowedIpsUpdate,
48 : },
49 : #[serde(
50 : rename = "/password_updated",
51 : deserialize_with = "deserialize_json_string"
52 : )]
53 : PasswordUpdate { password_update: PasswordUpdate },
54 : #[serde(rename = "/cancel_session")]
55 : Cancel(CancelSession),
56 : }
57 6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
58 : pub(crate) struct AllowedIpsUpdate {
59 : project_id: ProjectIdInt,
60 : }
61 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
62 : pub(crate) struct PasswordUpdate {
63 : project_id: ProjectIdInt,
64 : role_name: RoleNameInt,
65 : }
66 28 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
67 : pub(crate) struct CancelSession {
68 : pub region_id: Option<String>,
69 : pub cancel_key_data: CancelKeyData,
70 : pub session_id: Uuid,
71 : }
72 :
73 4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
74 4 : where
75 4 : T: for<'de2> serde::Deserialize<'de2>,
76 4 : D: serde::Deserializer<'de>,
77 4 : {
78 4 : let s = String::deserialize(deserializer)?;
79 4 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
80 4 : }
81 :
82 : struct MessageHandler<
83 : C: ProjectInfoCache + Send + Sync + 'static,
84 : H: NotificationsCancellationHandler + Send + Sync + 'static,
85 : > {
86 : cache: Arc<C>,
87 : cancellation_handler: Arc<H>,
88 : region_id: String,
89 : }
90 :
91 : impl<
92 : C: ProjectInfoCache + Send + Sync + 'static,
93 : H: NotificationsCancellationHandler + Send + Sync + 'static,
94 : > MessageHandler<C, H>
95 : {
96 0 : pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
97 0 : Self {
98 0 : cache,
99 0 : cancellation_handler,
100 0 : region_id,
101 0 : }
102 0 : }
103 0 : pub fn disable_ttl(&self) {
104 0 : self.cache.disable_ttl();
105 0 : }
106 0 : pub fn enable_ttl(&self) {
107 0 : self.cache.enable_ttl();
108 0 : }
109 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
110 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
111 : use Notification::*;
112 : let payload: String = msg.get_payload()?;
113 0 : tracing::debug!(?payload, "received a message payload");
114 :
115 : let msg: Notification = match serde_json::from_str(&payload) {
116 : Ok(msg) => msg,
117 : Err(e) => {
118 0 : tracing::error!("broken message: {e}");
119 : return Ok(());
120 : }
121 : };
122 0 : tracing::debug!(?msg, "received a message");
123 : match msg {
124 : Cancel(cancel_session) => {
125 : tracing::Span::current().record(
126 : "session_id",
127 : &tracing::field::display(cancel_session.session_id),
128 : );
129 : if let Some(cancel_region) = cancel_session.region_id {
130 : // If the message is not for this region, ignore it.
131 : if cancel_region != self.region_id {
132 : return Ok(());
133 : }
134 : }
135 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
136 : match self
137 : .cancellation_handler
138 : .cancel_session_no_publish(cancel_session.cancel_key_data)
139 : .await
140 : {
141 : Ok(()) => {}
142 : Err(e) => {
143 0 : tracing::error!("failed to cancel session: {e}");
144 : }
145 : }
146 : }
147 : _ => {
148 : invalidate_cache(self.cache.clone(), msg.clone());
149 : // It might happen that the invalid entry is on the way to be cached.
150 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
151 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
152 : let cache = self.cache.clone();
153 0 : tokio::spawn(async move {
154 0 : tokio::time::sleep(INVALIDATION_LAG).await;
155 0 : invalidate_cache(cache, msg);
156 0 : });
157 : }
158 : }
159 :
160 : Ok(())
161 : }
162 : }
163 :
164 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
165 0 : use Notification::*;
166 0 : match msg {
167 0 : AllowedIpsUpdate { allowed_ips_update } => {
168 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
169 : }
170 0 : PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
171 0 : password_update.project_id,
172 0 : password_update.role_name,
173 0 : ),
174 0 : Cancel(_) => unreachable!("cancel message should be handled separately"),
175 : }
176 0 : }
177 :
178 : /// Handle console's invalidation messages.
179 0 : #[tracing::instrument(name = "console_notifications", skip_all)]
180 : pub async fn task_main<C>(
181 : url: String,
182 : cache: Arc<C>,
183 : cancel_map: CancelMap,
184 : region_id: String,
185 : ) -> anyhow::Result<Infallible>
186 : where
187 : C: ProjectInfoCache + Send + Sync + 'static,
188 : {
189 : cache.enable_ttl();
190 : let handler = MessageHandler::new(
191 : cache,
192 : Arc::new(CancellationHandler::new(cancel_map, None)),
193 : region_id,
194 : );
195 :
196 : loop {
197 : let redis = RedisConsumerClient::new(&url)?;
198 : let conn = match redis.try_connect().await {
199 : Ok(conn) => {
200 : handler.disable_ttl();
201 : conn
202 : }
203 : Err(e) => {
204 0 : tracing::error!(
205 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
206 0 : );
207 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
208 : continue;
209 : }
210 : };
211 : let mut stream = conn.into_on_message();
212 : while let Some(msg) = stream.next().await {
213 : match handler.handle_message(msg).await {
214 : Ok(()) => {}
215 : Err(e) => {
216 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
217 : break;
218 : }
219 : }
220 : }
221 : handler.enable_ttl();
222 : }
223 : }
224 :
225 : #[cfg(test)]
226 : mod tests {
227 : use crate::{ProjectId, RoleName};
228 :
229 : use super::*;
230 : use serde_json::json;
231 :
232 2 : #[test]
233 2 : fn parse_allowed_ips() -> anyhow::Result<()> {
234 2 : let project_id: ProjectId = "new_project".into();
235 2 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
236 2 : let text = json!({
237 2 : "type": "message",
238 2 : "topic": "/allowed_ips_updated",
239 2 : "data": data,
240 2 : "extre_fields": "something"
241 2 : })
242 2 : .to_string();
243 :
244 2 : let result: Notification = serde_json::from_str(&text)?;
245 2 : assert_eq!(
246 2 : result,
247 2 : Notification::AllowedIpsUpdate {
248 2 : allowed_ips_update: AllowedIpsUpdate {
249 2 : project_id: (&project_id).into()
250 2 : }
251 2 : }
252 2 : );
253 :
254 2 : Ok(())
255 2 : }
256 :
257 2 : #[test]
258 2 : fn parse_password_updated() -> anyhow::Result<()> {
259 2 : let project_id: ProjectId = "new_project".into();
260 2 : let role_name: RoleName = "new_role".into();
261 2 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
262 2 : let text = json!({
263 2 : "type": "message",
264 2 : "topic": "/password_updated",
265 2 : "data": data,
266 2 : "extre_fields": "something"
267 2 : })
268 2 : .to_string();
269 :
270 2 : let result: Notification = serde_json::from_str(&text)?;
271 2 : assert_eq!(
272 2 : result,
273 2 : Notification::PasswordUpdate {
274 2 : password_update: PasswordUpdate {
275 2 : project_id: (&project_id).into(),
276 2 : role_name: (&role_name).into(),
277 2 : }
278 2 : }
279 2 : );
280 :
281 2 : Ok(())
282 2 : }
283 2 : #[test]
284 2 : fn parse_cancel_session() -> anyhow::Result<()> {
285 2 : let cancel_key_data = CancelKeyData {
286 2 : backend_pid: 42,
287 2 : cancel_key: 41,
288 2 : };
289 2 : let uuid = uuid::Uuid::new_v4();
290 2 : let msg = Notification::Cancel(CancelSession {
291 2 : cancel_key_data,
292 2 : region_id: None,
293 2 : session_id: uuid,
294 2 : });
295 2 : let text = serde_json::to_string(&msg)?;
296 2 : let result: Notification = serde_json::from_str(&text)?;
297 2 : assert_eq!(msg, result);
298 :
299 2 : let msg = Notification::Cancel(CancelSession {
300 2 : cancel_key_data,
301 2 : region_id: Some("region".to_string()),
302 2 : session_id: uuid,
303 2 : });
304 2 : let text = serde_json::to_string(&msg)?;
305 2 : let result: Notification = serde_json::from_str(&text)?;
306 2 : assert_eq!(msg, result,);
307 :
308 2 : Ok(())
309 2 : }
310 : }
|