Line data Source code
1 : use std::sync::Arc;
2 :
3 : use pq_proto::CancelKeyData;
4 : use redis::AsyncCommands;
5 : use tokio::sync::Mutex;
6 : use uuid::Uuid;
7 :
8 : use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
9 : use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
10 : use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
11 :
12 : pub trait CancellationPublisherMut: Send + Sync + 'static {
13 : #[allow(async_fn_in_trait)]
14 : async fn try_publish(
15 : &mut self,
16 : cancel_key_data: CancelKeyData,
17 : session_id: Uuid,
18 : ) -> anyhow::Result<()>;
19 : }
20 :
21 : pub trait CancellationPublisher: Send + Sync + 'static {
22 : #[allow(async_fn_in_trait)]
23 : async fn try_publish(
24 : &self,
25 : cancel_key_data: CancelKeyData,
26 : session_id: Uuid,
27 : ) -> anyhow::Result<()>;
28 : }
29 :
30 : impl CancellationPublisher for () {
31 1 : async fn try_publish(
32 1 : &self,
33 1 : _cancel_key_data: CancelKeyData,
34 1 : _session_id: Uuid,
35 1 : ) -> anyhow::Result<()> {
36 1 : Ok(())
37 1 : }
38 : }
39 :
40 : impl<P: CancellationPublisher> CancellationPublisherMut for P {
41 0 : async fn try_publish(
42 0 : &mut self,
43 0 : cancel_key_data: CancelKeyData,
44 0 : session_id: Uuid,
45 0 : ) -> anyhow::Result<()> {
46 0 : <P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id).await
47 0 : }
48 : }
49 :
50 : impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
51 0 : async fn try_publish(
52 0 : &self,
53 0 : cancel_key_data: CancelKeyData,
54 0 : session_id: Uuid,
55 0 : ) -> anyhow::Result<()> {
56 0 : if let Some(p) = self {
57 0 : p.try_publish(cancel_key_data, session_id).await
58 : } else {
59 0 : Ok(())
60 : }
61 0 : }
62 : }
63 :
64 : impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
65 0 : async fn try_publish(
66 0 : &self,
67 0 : cancel_key_data: CancelKeyData,
68 0 : session_id: Uuid,
69 0 : ) -> anyhow::Result<()> {
70 0 : self.lock()
71 0 : .await
72 0 : .try_publish(cancel_key_data, session_id)
73 0 : .await
74 0 : }
75 : }
76 :
77 : pub struct RedisPublisherClient {
78 : client: ConnectionWithCredentialsProvider,
79 : region_id: String,
80 : limiter: GlobalRateLimiter,
81 : }
82 :
83 : impl RedisPublisherClient {
84 0 : pub fn new(
85 0 : client: ConnectionWithCredentialsProvider,
86 0 : region_id: String,
87 0 : info: &'static [RateBucketInfo],
88 0 : ) -> anyhow::Result<Self> {
89 0 : Ok(Self {
90 0 : client,
91 0 : region_id,
92 0 : limiter: GlobalRateLimiter::new(info.into()),
93 0 : })
94 0 : }
95 :
96 0 : async fn publish(
97 0 : &mut self,
98 0 : cancel_key_data: CancelKeyData,
99 0 : session_id: Uuid,
100 0 : ) -> anyhow::Result<()> {
101 0 : let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
102 0 : region_id: Some(self.region_id.clone()),
103 0 : cancel_key_data,
104 0 : session_id,
105 0 : }))?;
106 0 : let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
107 0 : Ok(())
108 0 : }
109 0 : pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> {
110 0 : match self.client.connect().await {
111 0 : Ok(()) => {}
112 0 : Err(e) => {
113 0 : tracing::error!("failed to connect to redis: {e}");
114 0 : return Err(e);
115 : }
116 : }
117 0 : Ok(())
118 0 : }
119 0 : async fn try_publish_internal(
120 0 : &mut self,
121 0 : cancel_key_data: CancelKeyData,
122 0 : session_id: Uuid,
123 0 : ) -> anyhow::Result<()> {
124 0 : if !self.limiter.check() {
125 0 : tracing::info!("Rate limit exceeded. Skipping cancellation message");
126 0 : return Err(anyhow::anyhow!("Rate limit exceeded"));
127 0 : }
128 0 : match self.publish(cancel_key_data, session_id).await {
129 0 : Ok(()) => return Ok(()),
130 0 : Err(e) => {
131 0 : tracing::error!("failed to publish a message: {e}");
132 : }
133 : }
134 0 : tracing::info!("Publisher is disconnected. Reconnectiong...");
135 0 : self.try_connect().await?;
136 0 : self.publish(cancel_key_data, session_id).await
137 0 : }
138 : }
139 :
140 : impl CancellationPublisherMut for RedisPublisherClient {
141 0 : async fn try_publish(
142 0 : &mut self,
143 0 : cancel_key_data: CancelKeyData,
144 0 : session_id: Uuid,
145 0 : ) -> anyhow::Result<()> {
146 0 : tracing::info!("publishing cancellation key to Redis");
147 0 : match self.try_publish_internal(cancel_key_data, session_id).await {
148 : Ok(()) => {
149 0 : tracing::info!("cancellation key successfuly published to Redis");
150 0 : Ok(())
151 : }
152 0 : Err(e) => {
153 0 : tracing::error!("failed to publish a message: {e}");
154 0 : Err(e)
155 : }
156 : }
157 0 : }
158 : }
|