Line data Source code
1 : use std::{convert::Infallible, sync::Arc};
2 :
3 : use futures::StreamExt;
4 : use redis::aio::PubSub;
5 : use serde::Deserialize;
6 :
7 : use crate::{
8 : cache::project_info::ProjectInfoCache,
9 : intern::{ProjectIdInt, RoleNameInt},
10 : };
11 :
12 : const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
13 : const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
14 : const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
15 :
16 : struct ConsoleRedisClient {
17 : client: redis::Client,
18 : }
19 :
20 : impl ConsoleRedisClient {
21 0 : pub fn new(url: &str) -> anyhow::Result<Self> {
22 0 : let client = redis::Client::open(url)?;
23 0 : Ok(Self { client })
24 0 : }
25 0 : async fn try_connect(&self) -> anyhow::Result<PubSub> {
26 0 : let mut conn = self.client.get_async_connection().await?.into_pubsub();
27 0 : tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
28 0 : conn.subscribe(CHANNEL_NAME).await?;
29 0 : Ok(conn)
30 0 : }
31 : }
32 :
33 8 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
34 : #[serde(tag = "topic", content = "data")]
35 : enum Notification {
36 : #[serde(
37 : rename = "/allowed_ips_updated",
38 : deserialize_with = "deserialize_json_string"
39 : )]
40 : AllowedIpsUpdate {
41 : allowed_ips_update: AllowedIpsUpdate,
42 : },
43 : #[serde(
44 : rename = "/password_updated",
45 : deserialize_with = "deserialize_json_string"
46 : )]
47 : PasswordUpdate { password_update: PasswordUpdate },
48 : }
49 6 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
50 : struct AllowedIpsUpdate {
51 : project_id: ProjectIdInt,
52 : }
53 10 : #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
54 : struct PasswordUpdate {
55 : project_id: ProjectIdInt,
56 : role_name: RoleNameInt,
57 : }
58 4 : fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
59 4 : where
60 4 : T: for<'de2> serde::Deserialize<'de2>,
61 4 : D: serde::Deserializer<'de>,
62 4 : {
63 4 : let s = String::deserialize(deserializer)?;
64 4 : serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
65 4 : }
66 :
67 0 : fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
68 0 : use Notification::*;
69 0 : match msg {
70 0 : AllowedIpsUpdate { allowed_ips_update } => {
71 0 : cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
72 : }
73 0 : PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
74 0 : password_update.project_id,
75 0 : password_update.role_name,
76 0 : ),
77 : }
78 0 : }
79 :
80 0 : #[tracing::instrument(skip(cache))]
81 : fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
82 : where
83 : C: ProjectInfoCache + Send + Sync + 'static,
84 : {
85 : let payload: String = msg.get_payload()?;
86 0 : tracing::debug!(?payload, "received a message payload");
87 :
88 : let msg: Notification = match serde_json::from_str(&payload) {
89 : Ok(msg) => msg,
90 : Err(e) => {
91 0 : tracing::error!("broken message: {e}");
92 : return Ok(());
93 : }
94 : };
95 0 : tracing::debug!(?msg, "received a message");
96 : invalidate_cache(cache.clone(), msg.clone());
97 : // It might happen that the invalid entry is on the way to be cached.
98 : // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
99 : // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
100 0 : tokio::spawn(async move {
101 0 : tokio::time::sleep(INVALIDATION_LAG).await;
102 0 : invalidate_cache(cache, msg.clone());
103 0 : });
104 :
105 : Ok(())
106 : }
107 :
108 : /// Handle console's invalidation messages.
109 0 : #[tracing::instrument(name = "console_notifications", skip_all)]
110 : pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
111 : where
112 : C: ProjectInfoCache + Send + Sync + 'static,
113 : {
114 : cache.enable_ttl();
115 :
116 : loop {
117 : let redis = ConsoleRedisClient::new(&url)?;
118 : let conn = match redis.try_connect().await {
119 : Ok(conn) => {
120 : cache.disable_ttl();
121 : conn
122 : }
123 : Err(e) => {
124 0 : tracing::error!(
125 0 : "failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
126 0 : );
127 : tokio::time::sleep(RECONNECT_TIMEOUT).await;
128 : continue;
129 : }
130 : };
131 : let mut stream = conn.into_on_message();
132 : while let Some(msg) = stream.next().await {
133 : match handle_message(msg, cache.clone()) {
134 : Ok(()) => {}
135 : Err(e) => {
136 0 : tracing::error!("failed to handle message: {e}, will try to reconnect");
137 : break;
138 : }
139 : }
140 : }
141 : cache.enable_ttl();
142 : }
143 : }
144 :
145 : #[cfg(test)]
146 : mod tests {
147 : use crate::{ProjectId, RoleName};
148 :
149 : use super::*;
150 : use serde_json::json;
151 :
152 2 : #[test]
153 2 : fn parse_allowed_ips() -> anyhow::Result<()> {
154 2 : let project_id: ProjectId = "new_project".into();
155 2 : let data = format!("{{\"project_id\": \"{project_id}\"}}");
156 2 : let text = json!({
157 2 : "type": "message",
158 2 : "topic": "/allowed_ips_updated",
159 2 : "data": data,
160 2 : "extre_fields": "something"
161 2 : })
162 2 : .to_string();
163 :
164 2 : let result: Notification = serde_json::from_str(&text)?;
165 2 : assert_eq!(
166 2 : result,
167 2 : Notification::AllowedIpsUpdate {
168 2 : allowed_ips_update: AllowedIpsUpdate {
169 2 : project_id: (&project_id).into()
170 2 : }
171 2 : }
172 2 : );
173 :
174 2 : Ok(())
175 2 : }
176 :
177 2 : #[test]
178 2 : fn parse_password_updated() -> anyhow::Result<()> {
179 2 : let project_id: ProjectId = "new_project".into();
180 2 : let role_name: RoleName = "new_role".into();
181 2 : let data = format!("{{\"project_id\": \"{project_id}\", \"role_name\": \"{role_name}\"}}");
182 2 : let text = json!({
183 2 : "type": "message",
184 2 : "topic": "/password_updated",
185 2 : "data": data,
186 2 : "extre_fields": "something"
187 2 : })
188 2 : .to_string();
189 :
190 2 : let result: Notification = serde_json::from_str(&text)?;
191 2 : assert_eq!(
192 2 : result,
193 2 : Notification::PasswordUpdate {
194 2 : password_update: PasswordUpdate {
195 2 : project_id: (&project_id).into(),
196 2 : role_name: (&role_name).into(),
197 2 : }
198 2 : }
199 2 : );
200 :
201 2 : Ok(())
202 2 : }
203 : }
|