TLA Line data Source code
1 : #[cfg(feature = "testing")]
2 : pub mod mock;
3 : pub mod neon;
4 :
5 : use super::messages::MetricsAuxInfo;
6 : use crate::{
7 : auth::backend::ComputeUserInfo,
8 : cache::{timed_lru, TimedLru},
9 : compute,
10 : context::RequestMonitoring,
11 : scram,
12 : };
13 : use async_trait::async_trait;
14 : use dashmap::DashMap;
15 : use smol_str::SmolStr;
16 : use std::{sync::Arc, time::Duration};
17 : use tokio::{
18 : sync::{OwnedSemaphorePermit, Semaphore},
19 : time::Instant,
20 : };
21 : use tracing::info;
22 :
23 : pub mod errors {
24 : use crate::{
25 : error::{io_error, UserFacingError},
26 : http,
27 : proxy::retry::ShouldRetry,
28 : };
29 : use thiserror::Error;
30 :
31 : /// A go-to error message which doesn't leak any detail.
32 : const REQUEST_FAILED: &str = "Console request failed";
33 :
34 : /// Common console API error.
35 CBC 25 : #[derive(Debug, Error)]
36 : pub enum ApiError {
37 : /// Error returned by the console itself.
38 : #[error("{REQUEST_FAILED} with {}: {}", .status, .text)]
39 : Console {
40 : status: http::StatusCode,
41 : text: Box<str>,
42 : },
43 :
44 : /// Various IO errors like broken pipe or malformed payload.
45 : #[error("{REQUEST_FAILED}: {0}")]
46 : Transport(#[from] std::io::Error),
47 : }
48 :
49 : impl ApiError {
50 : /// Returns HTTP status code if it's the reason for failure.
51 3 : pub fn http_status_code(&self) -> Option<http::StatusCode> {
52 3 : use ApiError::*;
53 3 : match self {
54 2 : Console { status, .. } => Some(*status),
55 1 : _ => None,
56 : }
57 3 : }
58 : }
59 :
60 : impl UserFacingError for ApiError {
61 4 : fn to_string_client(&self) -> String {
62 4 : use ApiError::*;
63 4 : match self {
64 : // To minimize risks, only select errors are forwarded to users.
65 : // Ask @neondatabase/control-plane for review before adding more.
66 2 : Console { status, .. } => match *status {
67 : http::StatusCode::NOT_FOUND => {
68 : // Status 404: failed to get a project-related resource.
69 UBC 0 : format!("{REQUEST_FAILED}: endpoint cannot be found")
70 : }
71 : http::StatusCode::NOT_ACCEPTABLE => {
72 : // Status 406: endpoint is disabled (we don't allow connections).
73 0 : format!("{REQUEST_FAILED}: endpoint is disabled")
74 : }
75 : http::StatusCode::LOCKED => {
76 : // Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
77 0 : format!("{REQUEST_FAILED}: endpoint is temporary unavailable. check your quotas and/or contact our support")
78 : }
79 CBC 2 : _ => REQUEST_FAILED.to_owned(),
80 : },
81 2 : _ => REQUEST_FAILED.to_owned(),
82 : }
83 4 : }
84 : }
85 :
86 : impl ShouldRetry for ApiError {
87 4 : fn could_retry(&self) -> bool {
88 4 : match self {
89 : // retry some transport errors
90 UBC 0 : Self::Transport(io) => io.could_retry(),
91 : // retry some temporary failures because the compute was in a bad state
92 : // (bad request can be returned when the endpoint was in transition)
93 : Self::Console {
94 : status: http::StatusCode::BAD_REQUEST,
95 : ..
96 CBC 2 : } => true,
97 : // locked can be returned when the endpoint was in transition
98 : // or when quotas are exceeded. don't retry when quotas are exceeded
99 : Self::Console {
100 : status: http::StatusCode::LOCKED,
101 UBC 0 : ref text,
102 0 : } => {
103 0 : // written data quota exceeded
104 0 : // data transfer quota exceeded
105 0 : // compute time quota exceeded
106 0 : // logical size quota exceeded
107 0 : !text.contains("quota exceeded")
108 0 : && !text.contains("the limit for current plan reached")
109 : }
110 CBC 2 : _ => false,
111 : }
112 4 : }
113 : }
114 :
115 : impl From<reqwest::Error> for ApiError {
116 1 : fn from(e: reqwest::Error) -> Self {
117 1 : io_error(e).into()
118 1 : }
119 : }
120 :
121 : impl From<reqwest_middleware::Error> for ApiError {
122 1 : fn from(e: reqwest_middleware::Error) -> Self {
123 1 : io_error(e).into()
124 1 : }
125 : }
126 :
127 16 : #[derive(Debug, Error)]
128 : pub enum GetAuthInfoError {
129 : // We shouldn't include the actual secret here.
130 : #[error("Console responded with a malformed auth secret")]
131 : BadSecret,
132 :
133 : #[error(transparent)]
134 : ApiError(ApiError),
135 : }
136 :
137 : // This allows more useful interactions than `#[from]`.
138 : impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
139 4 : fn from(e: E) -> Self {
140 4 : Self::ApiError(e.into())
141 4 : }
142 : }
143 :
144 : impl UserFacingError for GetAuthInfoError {
145 4 : fn to_string_client(&self) -> String {
146 4 : use GetAuthInfoError::*;
147 4 : match self {
148 : // We absolutely should not leak any secrets!
149 UBC 0 : BadSecret => REQUEST_FAILED.to_owned(),
150 : // However, API might return a meaningful error.
151 CBC 4 : ApiError(e) => e.to_string_client(),
152 : }
153 4 : }
154 : }
155 UBC 0 : #[derive(Debug, Error)]
156 : pub enum WakeComputeError {
157 : #[error("Console responded with a malformed compute address: {0}")]
158 : BadComputeAddress(Box<str>),
159 :
160 : #[error(transparent)]
161 : ApiError(ApiError),
162 :
163 : #[error("Timeout waiting to acquire wake compute lock")]
164 : TimeoutError,
165 : }
166 :
167 : // This allows more useful interactions than `#[from]`.
168 : impl<E: Into<ApiError>> From<E> for WakeComputeError {
169 0 : fn from(e: E) -> Self {
170 0 : Self::ApiError(e.into())
171 0 : }
172 : }
173 :
174 : impl From<tokio::sync::AcquireError> for WakeComputeError {
175 0 : fn from(_: tokio::sync::AcquireError) -> Self {
176 0 : WakeComputeError::TimeoutError
177 0 : }
178 : }
179 : impl From<tokio::time::error::Elapsed> for WakeComputeError {
180 0 : fn from(_: tokio::time::error::Elapsed) -> Self {
181 0 : WakeComputeError::TimeoutError
182 0 : }
183 : }
184 :
185 : impl UserFacingError for WakeComputeError {
186 0 : fn to_string_client(&self) -> String {
187 0 : use WakeComputeError::*;
188 0 : match self {
189 : // We shouldn't show user the address even if it's broken.
190 : // Besides, user is unlikely to care about this detail.
191 0 : BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
192 : // However, API might return a meaningful error.
193 0 : ApiError(e) => e.to_string_client(),
194 :
195 0 : TimeoutError => "timeout while acquiring the compute resource lock".to_owned(),
196 : }
197 0 : }
198 : }
199 : }
200 :
201 : /// Extra query params we'd like to pass to the console.
202 : pub struct ConsoleReqExtra {
203 : pub options: Vec<(String, String)>,
204 : }
205 :
206 : impl ConsoleReqExtra {
207 : // https://swagger.io/docs/specification/serialization/ DeepObject format
208 : // paramName[prop1]=value1¶mName[prop2]=value2&....
209 0 : pub fn options_as_deep_object(&self) -> Vec<(String, String)> {
210 0 : self.options
211 0 : .iter()
212 0 : .map(|(k, v)| (format!("options[{}]", k), v.to_string()))
213 0 : .collect()
214 0 : }
215 : }
216 :
217 : /// Auth secret which is managed by the cloud.
218 CBC 37 : #[derive(Clone)]
219 : pub enum AuthSecret {
220 : #[cfg(feature = "testing")]
221 : /// Md5 hash of user's password.
222 : Md5([u8; 16]),
223 :
224 : /// [SCRAM](crate::scram) authentication info.
225 : Scram(scram::ServerSecret),
226 : }
227 :
228 UBC 0 : #[derive(Default)]
229 : pub struct AuthInfo {
230 : pub secret: Option<AuthSecret>,
231 : /// List of IP addresses allowed for the autorization.
232 : pub allowed_ips: Vec<String>,
233 : }
234 :
235 : /// Info for establishing a connection to a compute node.
236 : /// This is what we get after auth succeeded, but not before!
237 0 : #[derive(Clone)]
238 : pub struct NodeInfo {
239 : /// Compute node connection params.
240 : /// It's sad that we have to clone this, but this will improve
241 : /// once we migrate to a bespoke connection logic.
242 : pub config: compute::ConnCfg,
243 :
244 : /// Labels for proxy's metrics.
245 : pub aux: MetricsAuxInfo,
246 :
247 : /// Whether we should accept self-signed certificates (for testing)
248 : pub allow_self_signed_compute: bool,
249 : }
250 :
251 : pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
252 : pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
253 : pub type AllowedIpsCache = TimedLru<SmolStr, Arc<Vec<String>>>;
254 : pub type RoleSecretCache = TimedLru<(SmolStr, SmolStr), Option<AuthSecret>>;
255 : pub type CachedRoleSecret = timed_lru::Cached<&'static RoleSecretCache>;
256 :
257 : /// This will allocate per each call, but the http requests alone
258 : /// already require a few allocations, so it should be fine.
259 : #[async_trait]
260 : pub trait Api {
261 : /// Get the client's auth secret for authentication.
262 : async fn get_role_secret(
263 : &self,
264 : ctx: &mut RequestMonitoring,
265 : creds: &ComputeUserInfo,
266 : ) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
267 :
268 : async fn get_allowed_ips(
269 : &self,
270 : ctx: &mut RequestMonitoring,
271 : creds: &ComputeUserInfo,
272 : ) -> Result<Arc<Vec<String>>, errors::GetAuthInfoError>;
273 :
274 : /// Wake up the compute node and return the corresponding connection info.
275 : async fn wake_compute(
276 : &self,
277 : ctx: &mut RequestMonitoring,
278 : extra: &ConsoleReqExtra,
279 : creds: &ComputeUserInfo,
280 : ) -> Result<CachedNodeInfo, errors::WakeComputeError>;
281 : }
282 :
283 : /// Various caches for [`console`](super).
284 : pub struct ApiCaches {
285 : /// Cache for the `wake_compute` API method.
286 : pub node_info: NodeInfoCache,
287 : /// Cache for the `get_allowed_ips`. TODO(anna): use notifications listener instead.
288 : pub allowed_ips: AllowedIpsCache,
289 : /// Cache for the `get_role_secret`. TODO(anna): use notifications listener instead.
290 : pub role_secret: RoleSecretCache,
291 : }
292 :
293 : /// Various caches for [`console`](super).
294 : pub struct ApiLocks {
295 : name: &'static str,
296 : node_locks: DashMap<Arc<str>, Arc<Semaphore>>,
297 : permits: usize,
298 : timeout: Duration,
299 : registered: prometheus::IntCounter,
300 : unregistered: prometheus::IntCounter,
301 : reclamation_lag: prometheus::Histogram,
302 : lock_acquire_lag: prometheus::Histogram,
303 : }
304 :
305 : impl ApiLocks {
306 CBC 1 : pub fn new(
307 1 : name: &'static str,
308 1 : permits: usize,
309 1 : shards: usize,
310 1 : timeout: Duration,
311 1 : ) -> prometheus::Result<Self> {
312 1 : let registered = prometheus::IntCounter::with_opts(
313 1 : prometheus::Opts::new(
314 1 : "semaphores_registered",
315 1 : "Number of semaphores registered in this api lock",
316 1 : )
317 1 : .namespace(name),
318 1 : )?;
319 1 : prometheus::register(Box::new(registered.clone()))?;
320 1 : let unregistered = prometheus::IntCounter::with_opts(
321 1 : prometheus::Opts::new(
322 1 : "semaphores_unregistered",
323 1 : "Number of semaphores unregistered in this api lock",
324 1 : )
325 1 : .namespace(name),
326 1 : )?;
327 1 : prometheus::register(Box::new(unregistered.clone()))?;
328 1 : let reclamation_lag = prometheus::Histogram::with_opts(
329 1 : prometheus::HistogramOpts::new(
330 1 : "reclamation_lag_seconds",
331 1 : "Time it takes to reclaim unused semaphores in the api lock",
332 1 : )
333 1 : .namespace(name)
334 1 : // 1us -> 65ms
335 1 : // benchmarks on my mac indicate it's usually in the range of 256us and 512us
336 1 : .buckets(prometheus::exponential_buckets(1e-6, 2.0, 16)?),
337 UBC 0 : )?;
338 CBC 1 : prometheus::register(Box::new(reclamation_lag.clone()))?;
339 1 : let lock_acquire_lag = prometheus::Histogram::with_opts(
340 1 : prometheus::HistogramOpts::new(
341 1 : "semaphore_acquire_seconds",
342 1 : "Time it takes to reclaim unused semaphores in the api lock",
343 1 : )
344 1 : .namespace(name)
345 1 : // 0.1ms -> 6s
346 1 : .buckets(prometheus::exponential_buckets(1e-4, 2.0, 16)?),
347 UBC 0 : )?;
348 CBC 1 : prometheus::register(Box::new(lock_acquire_lag.clone()))?;
349 :
350 1 : Ok(Self {
351 1 : name,
352 1 : node_locks: DashMap::with_shard_amount(shards),
353 1 : permits,
354 1 : timeout,
355 1 : lock_acquire_lag,
356 1 : registered,
357 1 : unregistered,
358 1 : reclamation_lag,
359 1 : })
360 1 : }
361 :
362 UBC 0 : pub async fn get_wake_compute_permit(
363 0 : &self,
364 0 : key: &Arc<str>,
365 0 : ) -> Result<WakeComputePermit, errors::WakeComputeError> {
366 0 : if self.permits == 0 {
367 0 : return Ok(WakeComputePermit { permit: None });
368 0 : }
369 0 : let now = Instant::now();
370 0 : let semaphore = {
371 : // get fast path
372 0 : if let Some(semaphore) = self.node_locks.get(key) {
373 0 : semaphore.clone()
374 : } else {
375 0 : self.node_locks
376 0 : .entry(key.clone())
377 0 : .or_insert_with(|| {
378 0 : self.registered.inc();
379 0 : Arc::new(Semaphore::new(self.permits))
380 0 : })
381 0 : .clone()
382 : }
383 : };
384 0 : let permit = tokio::time::timeout_at(now + self.timeout, semaphore.acquire_owned()).await;
385 :
386 0 : self.lock_acquire_lag
387 0 : .observe((Instant::now() - now).as_secs_f64());
388 0 :
389 0 : Ok(WakeComputePermit {
390 0 : permit: Some(permit??),
391 : })
392 0 : }
393 :
394 CBC 1 : pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
395 1 : if self.permits == 0 {
396 1 : return;
397 UBC 0 : }
398 0 :
399 0 : let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
400 : loop {
401 0 : for (i, shard) in self.node_locks.shards().iter().enumerate() {
402 0 : interval.tick().await;
403 : // temporary lock a single shard and then clear any semaphores that aren't currently checked out
404 : // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
405 : // therefore releasing it is safe from race conditions
406 0 : info!(
407 0 : name = self.name,
408 0 : shard = i,
409 0 : "performing epoch reclamation on api lock"
410 0 : );
411 0 : let mut lock = shard.write();
412 0 : let timer = self.reclamation_lag.start_timer();
413 0 : let count = lock
414 0 : .extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
415 0 : .count();
416 0 : drop(lock);
417 0 : self.unregistered.inc_by(count as u64);
418 0 : timer.observe_duration()
419 : }
420 : }
421 CBC 1 : }
422 : }
423 :
424 : pub struct WakeComputePermit {
425 : // None if the lock is disabled
426 : permit: Option<OwnedSemaphorePermit>,
427 : }
428 :
429 : impl WakeComputePermit {
430 UBC 0 : pub fn should_check_cache(&self) -> bool {
431 0 : self.permit.is_some()
432 0 : }
433 : }
|