Line data Source code
1 : use std::sync::Arc;
2 :
3 : use pq_proto::CancelKeyData;
4 : use redis::AsyncCommands;
5 : use tokio::sync::Mutex;
6 : use uuid::Uuid;
7 :
8 : use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
9 :
10 : use super::{
11 : connection_with_credentials_provider::ConnectionWithCredentialsProvider,
12 : notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
13 : };
14 :
15 : pub trait CancellationPublisherMut: Send + Sync + 'static {
16 : #[allow(async_fn_in_trait)]
17 : async fn try_publish(
18 : &mut self,
19 : cancel_key_data: CancelKeyData,
20 : session_id: Uuid,
21 : ) -> anyhow::Result<()>;
22 : }
23 :
24 : pub trait CancellationPublisher: Send + Sync + 'static {
25 : #[allow(async_fn_in_trait)]
26 : async fn try_publish(
27 : &self,
28 : cancel_key_data: CancelKeyData,
29 : session_id: Uuid,
30 : ) -> anyhow::Result<()>;
31 : }
32 :
33 : impl CancellationPublisher for () {
34 1 : async fn try_publish(
35 1 : &self,
36 1 : _cancel_key_data: CancelKeyData,
37 1 : _session_id: Uuid,
38 1 : ) -> anyhow::Result<()> {
39 1 : Ok(())
40 1 : }
41 : }
42 :
43 : impl<P: CancellationPublisher> CancellationPublisherMut for P {
44 0 : async fn try_publish(
45 0 : &mut self,
46 0 : cancel_key_data: CancelKeyData,
47 0 : session_id: Uuid,
48 0 : ) -> anyhow::Result<()> {
49 0 : <P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id).await
50 0 : }
51 : }
52 :
53 : impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
54 0 : async fn try_publish(
55 0 : &self,
56 0 : cancel_key_data: CancelKeyData,
57 0 : session_id: Uuid,
58 0 : ) -> anyhow::Result<()> {
59 0 : if let Some(p) = self {
60 0 : p.try_publish(cancel_key_data, session_id).await
61 : } else {
62 0 : Ok(())
63 : }
64 0 : }
65 : }
66 :
67 : impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
68 0 : async fn try_publish(
69 0 : &self,
70 0 : cancel_key_data: CancelKeyData,
71 0 : session_id: Uuid,
72 0 : ) -> anyhow::Result<()> {
73 0 : self.lock()
74 0 : .await
75 0 : .try_publish(cancel_key_data, session_id)
76 0 : .await
77 0 : }
78 : }
79 :
80 : pub struct RedisPublisherClient {
81 : client: ConnectionWithCredentialsProvider,
82 : region_id: String,
83 : limiter: GlobalRateLimiter,
84 : }
85 :
86 : impl RedisPublisherClient {
87 0 : pub fn new(
88 0 : client: ConnectionWithCredentialsProvider,
89 0 : region_id: String,
90 0 : info: &'static [RateBucketInfo],
91 0 : ) -> anyhow::Result<Self> {
92 0 : Ok(Self {
93 0 : client,
94 0 : region_id,
95 0 : limiter: GlobalRateLimiter::new(info.into()),
96 0 : })
97 0 : }
98 :
99 0 : async fn publish(
100 0 : &mut self,
101 0 : cancel_key_data: CancelKeyData,
102 0 : session_id: Uuid,
103 0 : ) -> anyhow::Result<()> {
104 0 : let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
105 0 : region_id: Some(self.region_id.clone()),
106 0 : cancel_key_data,
107 0 : session_id,
108 0 : }))?;
109 0 : let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
110 0 : Ok(())
111 0 : }
112 0 : pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> {
113 0 : match self.client.connect().await {
114 0 : Ok(()) => {}
115 0 : Err(e) => {
116 0 : tracing::error!("failed to connect to redis: {e}");
117 0 : return Err(e);
118 : }
119 : }
120 0 : Ok(())
121 0 : }
122 0 : async fn try_publish_internal(
123 0 : &mut self,
124 0 : cancel_key_data: CancelKeyData,
125 0 : session_id: Uuid,
126 0 : ) -> anyhow::Result<()> {
127 0 : if !self.limiter.check() {
128 0 : tracing::info!("Rate limit exceeded. Skipping cancellation message");
129 0 : return Err(anyhow::anyhow!("Rate limit exceeded"));
130 0 : }
131 0 : match self.publish(cancel_key_data, session_id).await {
132 0 : Ok(()) => return Ok(()),
133 0 : Err(e) => {
134 0 : tracing::error!("failed to publish a message: {e}");
135 : }
136 : }
137 0 : tracing::info!("Publisher is disconnected. Reconnectiong...");
138 0 : self.try_connect().await?;
139 0 : self.publish(cancel_key_data, session_id).await
140 0 : }
141 : }
142 :
143 : impl CancellationPublisherMut for RedisPublisherClient {
144 0 : async fn try_publish(
145 0 : &mut self,
146 0 : cancel_key_data: CancelKeyData,
147 0 : session_id: Uuid,
148 0 : ) -> anyhow::Result<()> {
149 0 : tracing::info!("publishing cancellation key to Redis");
150 0 : match self.try_publish_internal(cancel_key_data, session_id).await {
151 : Ok(()) => {
152 0 : tracing::info!("cancellation key successfuly published to Redis");
153 0 : Ok(())
154 : }
155 0 : Err(e) => {
156 0 : tracing::error!("failed to publish a message: {e}");
157 0 : Err(e)
158 : }
159 : }
160 0 : }
161 : }
|