Line data Source code
1 : use std::{convert::Infallible, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use pq_proto::CancelKeyData;
5 : use redis::aio::PubSub;
6 : use serde::{Deserialize, Serialize};
7 : use uuid::Uuid;
8 :
9 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
10 : use crate::{
11 : cache::project_info::ProjectInfoCache,
12 : cancellation::{CancelMap, CancellationHandler},
13 : intern::{ProjectIdInt, RoleNameInt},
14 : metrics::{NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS, REDIS_BROKEN_MESSAGES},
15 : };
16 :
17 : const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
18 : pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
19 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
20 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
21 :
22 0 : async fn try_connect(client: &ConnectionWithCredentialsProvider) -> anyhow::Result<PubSub> {
23 0 : let mut conn = client.get_async_pubsub().await?;
24 0 : tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
25 0 : conn.subscribe(CPLANE_CHANNEL_NAME).await?;
26 0 : tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
27 0 : conn.subscribe(PROXY_CHANNEL_NAME).await?;
28 0 : Ok(conn)
29 0 : }
30 :
31 20 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
32 : #[serde(tag = "topic", content = "data")]
33 : pub(crate) enum Notification {
34 : #[serde(
35 : rename = "/allowed_ips_updated",
36 : deserialize_with = "deserialize_json_string"
37 : )]
38 : AllowedIpsUpdate {
39 : allowed_ips_update: AllowedIpsUpdate,
40 : },
41 : #[serde(
42 : rename = "/password_updated",
43 : deserialize_with = "deserialize_json_string"
44 : )]
45 : PasswordUpdate { password_update: PasswordUpdate },
46 : #[serde(rename = "/cancel_session")]
47 : Cancel(CancelSession),
48 : }
49 6 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
50 : pub(crate) struct AllowedIpsUpdate {
51 : project_id: ProjectIdInt,
52 : }
53 10 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
54 : pub(crate) struct PasswordUpdate {
55 : project_id: ProjectIdInt,
56 : role_name: RoleNameInt,
57 : }
58 28 : #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
59 : pub(crate) struct CancelSession {
60 : pub region_id: Option<String>,
61 : pub cancel_key_data: CancelKeyData,
62 : pub session_id: Uuid,
63 : }
64 :
65 4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
66 4 : where
67 4 : T: for<'de2> serde::Deserialize<'de2>,
68 4 : D: serde::Deserializer<'de>,
69 4 : {
70 4 : let s = String::deserialize(deserializer)?;
71 4 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
72 4 : }
73 :
74 : struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
75 : cache: Arc<C>,
76 : cancellation_handler: Arc<CancellationHandler<()>>,
77 : region_id: String,
78 : }
79 :
80 : impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
81 0 : pub fn new(
82 0 : cache: Arc<C>,
83 0 : cancellation_handler: Arc<CancellationHandler<()>>,
84 0 : region_id: String,
85 0 : ) -> Self {
86 0 : Self {
87 0 : cache,
88 0 : cancellation_handler,
89 0 : region_id,
90 0 : }
91 0 : }
92 0 : pub fn disable_ttl(&self) {
93 0 : self.cache.disable_ttl();
94 0 : }
95 0 : pub fn enable_ttl(&self) {
96 0 : self.cache.enable_ttl();
97 0 : }
98 0 : #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
99 : async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
100 : use Notification::*;
101 : let payload: String = msg.get_payload()?;
102 0 : tracing::debug!(?payload, "received a message payload");
103 :
104 : let msg: Notification = match serde_json::from_str(&payload) {
105 : Ok(msg) => msg,
106 : Err(e) => {
107 : REDIS_BROKEN_MESSAGES
108 : .with_label_values(&[msg.get_channel_name()])
109 : .inc();
110 0 : tracing::error!("broken message: {e}");
111 : return Ok(());
112 : }
113 : };
114 0 : tracing::debug!(?msg, "received a message");
115 : match msg {
116 : Cancel(cancel_session) => {
117 : tracing::Span::current().record(
118 : "session_id",
119 : &tracing::field::display(cancel_session.session_id),
120 : );
121 : if let Some(cancel_region) = cancel_session.region_id {
122 : // If the message is not for this region, ignore it.
123 : if cancel_region != self.region_id {
124 : return Ok(());
125 : }
126 : }
127 : // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
128 : match self
129 : .cancellation_handler
130 : .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil())
131 : .await
132 : {
133 : Ok(()) => {}
134 : Err(e) => {
135 0 : tracing::error!("failed to cancel session: {e}");
136 : }
137 : }
138 : }
139 : _ => {
140 : invalidate_cache(self.cache.clone(), msg.clone());
141 : // It might happen that the invalid entry is on the way to be cached.
142 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
143 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
144 : let cache = self.cache.clone();
145 0 : tokio::spawn(async move {
146 0 : tokio::time::sleep(INVALIDATION_LAG).await;
147 0 : invalidate_cache(cache, msg);
148 0 : });
149 : }
150 : }
151 :
152 : Ok(())
153 : }
154 : }
155 :
156 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
157 0 : use Notification::*;
158 0 : match msg {
159 0 : AllowedIpsUpdate { allowed_ips_update } => {
160 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
161 : }
162 0 : PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
163 0 : password_update.project_id,
164 0 : password_update.role_name,
165 0 : ),
166 0 : Cancel(_) => unreachable!("cancel message should be handled separately"),
167 : }
168 0 : }
169 :
170 : /// Handle console's invalidation messages.
171 0 : #[tracing::instrument(name = "console_notifications", skip_all)]
172 : pub async fn task_main<C>(
173 : redis: ConnectionWithCredentialsProvider,
174 : cache: Arc<C>,
175 : cancel_map: CancelMap,
176 : region_id: String,
177 : ) -> anyhow::Result<Infallible>
178 : where
179 : C: ProjectInfoCache + Send + Sync + 'static,
180 : {
181 : cache.enable_ttl();
182 : let handler = MessageHandler::new(
183 : cache,
184 : Arc::new(CancellationHandler::<()>::new(
185 : cancel_map,
186 : NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS,
187 : )),
188 : region_id,
189 : );
190 :
191 : loop {
192 : let mut conn = match try_connect(&redis).await {
193 : Ok(conn) => {
194 : handler.disable_ttl();
195 : conn
196 : }
197 : Err(e) => {
198 0 : tracing::error!(
199 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
200 0 : );
201 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
202 : continue;
203 : }
204 : };
205 : let mut stream = conn.on_message();
206 : while let Some(msg) = stream.next().await {
207 : match handler.handle_message(msg).await {
208 : Ok(()) => {}
209 : Err(e) => {
210 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
211 : break;
212 : }
213 : }
214 : }
215 : handler.enable_ttl();
216 : }
217 : }
218 :
219 : #[cfg(test)]
220 : mod tests {
221 : use crate::{ProjectId, RoleName};
222 :
223 : use super::*;
224 : use serde_json::json;
225 :
226 : #[test]
227 2 : fn parse_allowed_ips() -> anyhow::Result<()> {
228 2 : let project_id: ProjectId = "new_project".into();
229 2 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
230 2 : let text = json!({
231 2 : "type": "message",
232 2 : "topic": "/allowed_ips_updated",
233 2 : "data": data,
234 2 : "extre_fields": "something"
235 2 : })
236 2 : .to_string();
237 :
238 2 : let result: Notification = serde_json::from_str(&text)?;
239 2 : assert_eq!(
240 2 : result,
241 2 : Notification::AllowedIpsUpdate {
242 2 : allowed_ips_update: AllowedIpsUpdate {
243 2 : project_id: (&project_id).into()
244 2 : }
245 2 : }
246 2 : );
247 :
248 2 : Ok(())
249 2 : }
250 :
251 : #[test]
252 2 : fn parse_password_updated() -> anyhow::Result<()> {
253 2 : let project_id: ProjectId = "new_project".into();
254 2 : let role_name: RoleName = "new_role".into();
255 2 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
256 2 : let text = json!({
257 2 : "type": "message",
258 2 : "topic": "/password_updated",
259 2 : "data": data,
260 2 : "extre_fields": "something"
261 2 : })
262 2 : .to_string();
263 :
264 2 : let result: Notification = serde_json::from_str(&text)?;
265 2 : assert_eq!(
266 2 : result,
267 2 : Notification::PasswordUpdate {
268 2 : password_update: PasswordUpdate {
269 2 : project_id: (&project_id).into(),
270 2 : role_name: (&role_name).into(),
271 2 : }
272 2 : }
273 2 : );
274 :
275 2 : Ok(())
276 2 : }
277 : #[test]
278 2 : fn parse_cancel_session() -> anyhow::Result<()> {
279 2 : let cancel_key_data = CancelKeyData {
280 2 : backend_pid: 42,
281 2 : cancel_key: 41,
282 2 : };
283 2 : let uuid = uuid::Uuid::new_v4();
284 2 : let msg = Notification::Cancel(CancelSession {
285 2 : cancel_key_data,
286 2 : region_id: None,
287 2 : session_id: uuid,
288 2 : });
289 2 : let text = serde_json::to_string(&msg)?;
290 2 : let result: Notification = serde_json::from_str(&text)?;
291 2 : assert_eq!(msg, result);
292 :
293 2 : let msg = Notification::Cancel(CancelSession {
294 2 : cancel_key_data,
295 2 : region_id: Some("region".to_string()),
296 2 : session_id: uuid,
297 2 : });
298 2 : let text = serde_json::to_string(&msg)?;
299 2 : let result: Notification = serde_json::from_str(&text)?;
300 2 : assert_eq!(msg, result,);
301 :
302 2 : Ok(())
303 2 : }
304 : }
|