       1              : use std::{collections::HashMap, time::Duration};
       2              : 
       3              : use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
       4              : use control_plane::local_env::LocalEnv;
       5              : use hyper::{Method, StatusCode};
       6              : use pageserver_api::shard::{ShardCount, ShardIndex, ShardNumber, TenantShardId};
       7              : use postgres_connection::parse_host_port;
       8              : use serde::{Deserialize, Serialize};
       9              : use tokio_util::sync::CancellationToken;
      10              : use utils::{
      11              :     backoff::{self},
      12              :     id::{NodeId, TenantId},
      13              : };
      14              : 
      15              : use crate::service::Config;
      16              : 
      17              : const BUSY_DELAY: Duration = Duration::from_secs(1);
      18              : const SLOWDOWN_DELAY: Duration = Duration::from_secs(5);
      19              : 
      20              : pub(crate) const API_CONCURRENCY: usize = 32;
      21              : 
      22              : pub(super) struct ComputeHookTenant {
      23              :     shards: Vec<(ShardIndex, NodeId)>,
      24              : }
      25              : 
      26            3 : #[derive(Serialize, Deserialize, Debug)]
      27              : struct ComputeHookNotifyRequestShard {
      28              :     node_id: NodeId,
      29              :     shard_number: ShardNumber,
      30              : }
      31              : 
      32              : /// Request body that we send to the control plane to notify it of where a tenant is attached
      33            3 : #[derive(Serialize, Deserialize, Debug)]
      34              : struct ComputeHookNotifyRequest {
      35              :     tenant_id: TenantId,
      36              :     shards: Vec<ComputeHookNotifyRequestShard>,
      37              : }
      38              : 
      39              : /// Error type for attempts to call into the control plane compute notification hook
      40            0 : #[derive(thiserror::Error, Debug)]
      41              : pub(crate) enum NotifyError {
      42              :     // Request was not send successfully, e.g. transport error
      43              :     #[error("Sending request: {0}")]
      44              :     Request(#[from] reqwest::Error),
      45              :     // Request could not be serviced right now due to ongoing Operation in control plane, but should be possible soon.
      46              :     #[error("Control plane tenant busy")]
      47              :     Busy,
      48              :     // Explicit 429 response asking us to retry less frequently
      49              :     #[error("Control plane overloaded")]
      50              :     SlowDown,
      51              :     // A 503 response indicates the control plane can't handle the request right now
      52              :     #[error("Control plane unavailable (status {0})")]
      53              :     Unavailable(StatusCode),
      54              :     // API returned unexpected non-success status.  We will retry, but log a warning.
      55              :     #[error("Control plane returned unexpected status {0}")]
      56              :     Unexpected(StatusCode),
      57              :     // We shutdown while sending
      58              :     #[error("Shutting down")]
      59              :     ShuttingDown,
      60              :     // A response indicates we will never succeed, such as 400 or 404
      61              :     #[error("Non-retryable error {0}")]
      62              :     Fatal(StatusCode),
      63              : }
      64              : 
      65              : impl ComputeHookTenant {
      66          492 :     async fn maybe_reconfigure(&mut self, tenant_id: TenantId) -> Option<ComputeHookNotifyRequest> {
      67          492 :         // Find the highest shard count and drop any shards that aren't
      68          492 :         // for that shard count.
      69          532 :         let shard_count = self.shards.iter().map(|(k, _v)| k.shard_count).max();
      70          492 :         let Some(shard_count) = shard_count else {
      71              :             // No shards, nothing to do.
      72            0 :             tracing::info!("ComputeHookTenant::maybe_reconfigure: no shards");
      73            0 :             return None;
      74              :         };
      75              : 
      76          532 :         self.shards.retain(|(k, _v)| k.shard_count == shard_count);
      77          492 :         self.shards
      78          492 :             .sort_by_key(|(shard, _node_id)| shard.shard_number);
      79          492 : 
      80          492 :         if self.shards.len() == shard_count.0 as usize || shard_count == ShardCount(0) {
      81              :             // We have pageservers for all the shards: emit a configuration update
      82          468 :             return Some(ComputeHookNotifyRequest {
      83          468 :                 tenant_id,
      84          468 :                 shards: self
      85          468 :                     .shards
      86          468 :                     .iter()
      87          496 :                     .map(|(shard, node_id)| ComputeHookNotifyRequestShard {
      88          496 :                         shard_number: shard.shard_number,
      89          496 :                         node_id: *node_id,
      90          496 :                     })
      91          468 :                     .collect(),
      92          468 :             });
      93              :         } else {
      94           24 :             tracing::info!(
      95           24 :                 "ComputeHookTenant::maybe_reconfigure: not enough shards ({}/{})",
      96           24 :                 self.shards.len(),
      97           24 :                 shard_count.0
      98           24 :             );
      99              :         }
     100              : 
     101           24 :         None
     102          492 :     }
     103              : }
     104              : 
     105              : /// The compute hook is a destination for notifications about changes to tenant:pageserver
     106              : /// mapping.  It aggregates updates for the shards in a tenant, and when appropriate reconfigures
     107              : /// the compute connection string.
     108              : pub(super) struct ComputeHook {
     109              :     config: Config,
     110              :     state: tokio::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
     111              :     authorization_header: Option<String>,
     112              : }
     113              : 
     114              : impl ComputeHook {
     115          361 :     pub(super) fn new(config: Config) -> Self {
     116          361 :         let authorization_header = config
     117          361 :             .control_plane_jwt_token
     118          361 :             .clone()
     119          361 :             .map(|jwt| format!("Bearer {}", jwt));
     120          361 : 
     121          361 :         Self {
     122          361 :             state: Default::default(),
     123          361 :             config,
     124          361 :             authorization_header,
     125          361 :         }
     126          361 :     }
     127              : 
     128              :     /// For test environments: use neon_local's LocalEnv to update compute
     129          465 :     async fn do_notify_local(
     130          465 :         &self,
     131          465 :         reconfigure_request: ComputeHookNotifyRequest,
     132          465 :     ) -> anyhow::Result<()> {
     133          465 :         let env = match LocalEnv::load_config() {
     134          465 :             Ok(e) => e,
     135            0 :             Err(e) => {
     136            0 :                 tracing::warn!("Couldn't load neon_local config, skipping compute update ({e})");
     137            0 :                 return Ok(());
     138              :             }
     139              :         };
     140          465 :         let cplane =
     141          465 :             ComputeControlPlane::load(env.clone()).expect("Error loading compute control plane");
     142          465 :         let ComputeHookNotifyRequest { tenant_id, shards } = reconfigure_request;
     143          465 : 
     144          465 :         let compute_pageservers = shards
     145          465 :             .into_iter()
     146          493 :             .map(|shard| {
     147          493 :                 let ps_conf = env
     148          493 :                     .get_pageserver_conf(shard.node_id)
     149          493 :                     .expect("Unknown pageserver");
     150          493 :                 let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr)
     151          493 :                     .expect("Unable to parse listen_pg_addr");
     152          493 :                 (pg_host, pg_port.unwrap_or(5432))
     153          493 :             })
     154          465 :             .collect::<Vec<_>>();
     155              : 
     156          512 :         for (endpoint_name, endpoint) in &cplane.endpoints {
     157           47 :             if endpoint.tenant_id == tenant_id && endpoint.status() == EndpointStatus::Running {
     158            0 :                 tracing::info!("🔁 Reconfiguring endpoint {}", endpoint_name,);
     159            0 :                 endpoint.reconfigure(compute_pageservers.clone()).await?;
     160           47 :             }
     161              :         }
     162              : 
     163          465 :         Ok(())
     164          465 :     }
     165              : 
     166            3 :     async fn do_notify_iteration(
     167            3 :         &self,
     168            3 :         client: &reqwest::Client,
     169            3 :         url: &String,
     170            3 :         reconfigure_request: &ComputeHookNotifyRequest,
     171            3 :         cancel: &CancellationToken,
     172            3 :     ) -> Result<(), NotifyError> {
     173            3 :         let req = client.request(Method::POST, url);
     174            3 :         let req = if let Some(value) = &self.authorization_header {
     175            0 :             req.header(reqwest::header::AUTHORIZATION, value)
     176              :         } else {
     177            3 :             req
     178              :         };
     179              : 
     180            0 :         tracing::debug!(
     181            0 :             "Sending notify request to {} ({:?})",
     182            0 :             url,
     183            0 :             reconfigure_request
     184            0 :         );
     185           12 :         let send_result = req.json(&reconfigure_request).send().await;
     186            3 :         let response = match send_result {
     187            3 :             Ok(r) => r,
     188            0 :             Err(e) => return Err(e.into()),
     189              :         };
     190              : 
     191              :         // Treat all 2xx responses as success
     192            3 :         if response.status() >= StatusCode::OK && response.status() < StatusCode::MULTIPLE_CHOICES {
     193            3 :             if response.status() != StatusCode::OK {
     194              :                 // Non-200 2xx response: it doesn't make sense to retry, but this is unexpected, so
     195              :                 // log a warning.
     196            0 :                 tracing::warn!(
     197            0 :                     "Unexpected 2xx response code {} from control plane",
     198            0 :                     response.status()
     199            0 :                 );
     200            3 :             }
     201              : 
     202            3 :             return Ok(());
     203            0 :         }
     204            0 : 
     205            0 :         // Error response codes
     206            0 :         match response.status() {
     207              :             StatusCode::TOO_MANY_REQUESTS => {
     208              :                 // TODO: 429 handling should be global: set some state visible to other requests
     209              :                 // so that they will delay before starting, rather than all notifications trying
     210              :                 // once before backing off.
     211            0 :                 tokio::time::timeout(SLOWDOWN_DELAY, cancel.cancelled())
     212            0 :                     .await
     213            0 :                     .ok();
     214            0 :                 Err(NotifyError::SlowDown)
     215              :             }
     216              :             StatusCode::LOCKED => {
     217              :                 // Delay our retry if busy: the usual fast exponential backoff in backoff::retry
     218              :                 // is not appropriate
     219            0 :                 tokio::time::timeout(BUSY_DELAY, cancel.cancelled())
     220            0 :                     .await
     221            0 :                     .ok();
     222            0 :                 Err(NotifyError::Busy)
     223              :             }
     224              :             StatusCode::SERVICE_UNAVAILABLE
     225              :             | StatusCode::GATEWAY_TIMEOUT
     226            0 :             | StatusCode::BAD_GATEWAY => Err(NotifyError::Unavailable(response.status())),
     227              :             StatusCode::BAD_REQUEST | StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
     228            0 :                 Err(NotifyError::Fatal(response.status()))
     229              :             }
     230            0 :             _ => Err(NotifyError::Unexpected(response.status())),
     231              :         }
     232            3 :     }
     233              : 
     234            3 :     async fn do_notify(
     235            3 :         &self,
     236            3 :         url: &String,
     237            3 :         reconfigure_request: ComputeHookNotifyRequest,
     238            3 :         cancel: &CancellationToken,
     239            3 :     ) -> Result<(), NotifyError> {
     240            3 :         let client = reqwest::Client::new();
     241            3 :         backoff::retry(
     242            3 :             || self.do_notify_iteration(&client, url, &reconfigure_request, cancel),
     243            3 :             |e| matches!(e, NotifyError::Fatal(_)),
     244            3 :             3,
     245            3 :             10,
     246            3 :             "Send compute notification",
     247            3 :             cancel,
     248            3 :         )
     249           12 :         .await
     250            3 :         .ok_or_else(|| NotifyError::ShuttingDown)
     251            3 :         .and_then(|x| x)
     252            3 :     }
     253              : 
     254              :     /// Call this to notify the compute (postgres) tier of new pageservers to use
     255              :     /// for a tenant.  notify() is called by each shard individually, and this function
     256              :     /// will decide whether an update to the tenant is sent.  An update is sent on the
     257              :     /// condition that:
     258              :     /// - We know a pageserver for every shard.
     259              :     /// - All the shards have the same shard_count (i.e. we are not mid-split)
     260              :     ///
     261              :     /// Cancellation token enables callers to drop out, e.g. if calling from a Reconciler
     262              :     /// that is cancelled.
     263              :     ///
     264              :     /// This function is fallible, including in the case that the control plane is transiently
     265              :     /// unavailable.  A limited number of retries are done internally to efficiently hide short unavailability
     266              :     /// periods, but we don't retry forever.  The **caller** is responsible for handling failures and
     267              :     /// ensuring that they eventually call again to ensure that the compute is eventually notified of
     268              :     /// the proper pageserver nodes for a tenant.
     269            5 :     #[tracing::instrument(skip_all, fields(tenant_shard_id, node_id))]
     270              :     pub(super) async fn notify(
     271              :         &self,
     272              :         tenant_shard_id: TenantShardId,
     273              :         node_id: NodeId,
     274              :         cancel: &CancellationToken,
     275              :     ) -> Result<(), NotifyError> {
     276              :         let mut locked = self.state.lock().await;
     277              :         let entry = locked
     278              :             .entry(tenant_shard_id.tenant_id)
     279          463 :             .or_insert_with(|| ComputeHookTenant { shards: Vec::new() });
     280              : 
     281              :         let shard_index = ShardIndex {
     282              :             shard_count: tenant_shard_id.shard_count,
     283              :             shard_number: tenant_shard_id.shard_number,
     284              :         };
     285              : 
     286              :         let mut set = false;
     287              :         for (existing_shard, existing_node) in &mut entry.shards {
     288              :             if *existing_shard == shard_index {
     289              :                 *existing_node = node_id;
     290              :                 set = true;
     291              :             }
     292              :         }
     293              :         if !set {
     294              :             entry.shards.push((shard_index, node_id));
     295              :         }
     296              : 
     297              :         let reconfigure_request = entry.maybe_reconfigure(tenant_shard_id.tenant_id).await;
     298              :         let Some(reconfigure_request) = reconfigure_request else {
     299              :             // The tenant doesn't yet have pageservers for all its shards: we won't notify anything
     300              :             // until it does.
     301            0 :             tracing::debug!("Tenant isn't yet ready to emit a notification",);
     302              :             return Ok(());
     303              :         };
     304              : 
     305              :         if let Some(notify_url) = &self.config.compute_hook_url {
     306              :             self.do_notify(notify_url, reconfigure_request, cancel)
     307              :                 .await
     308              :         } else {
     309              :             self.do_notify_local(reconfigure_request)
     310              :                 .await
     311            0 :                 .map_err(|e| {
     312            0 :                     // This path is for testing only, so munge the error into our prod-style error type.
     313            0 :                     tracing::error!("Local notification hook failed: {e}");
     314            0 :                     NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR)
     315            0 :                 })
     316              :         }
     317              :     }
     318              : }

