LCOV - code coverage report
Current view: top level - compute_tools/src/http - server.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 128 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 16 0

            Line data    Source code
       1              : use std::fmt::Display;
       2              : use std::net::{IpAddr, Ipv6Addr, SocketAddr};
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use anyhow::Result;
       7              : use axum::Router;
       8              : use axum::middleware::{self};
       9              : use axum::response::IntoResponse;
      10              : use axum::routing::{get, post};
      11              : use compute_api::responses::ComputeCtlConfig;
      12              : use http::StatusCode;
      13              : use tokio::net::TcpListener;
      14              : use tower::ServiceBuilder;
      15              : use tower_http::{
      16              :     auth::AsyncRequireAuthorizationLayer, request_id::PropagateRequestIdLayer, trace::TraceLayer,
      17              : };
      18              : use tracing::{Span, error, info};
      19              : 
      20              : use super::middleware::request_id::maybe_add_request_id_header;
      21              : use super::{
      22              :     headers::X_REQUEST_ID,
      23              :     middleware::authorize::Authorize,
      24              :     routes::{
      25              :         check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
      26              :         grants, hadron_liveness_probe, insights, lfc, metrics, metrics_json, promote,
      27              :         refresh_configuration, status, terminate,
      28              :     },
      29              : };
      30              : use crate::compute::ComputeNode;
      31              : 
      32              : /// `compute_ctl` has two servers: internal and external. The internal server
      33              : /// binds to the loopback interface and handles communication from clients on
      34              : /// the compute. The external server is what receives communication from the
      35              : /// control plane, the metrics scraper, etc. We make the distinction because
      36              : /// certain routes in `compute_ctl` only need to be exposed to local processes
      37              : /// like Postgres via the neon extension and local_proxy.
      38              : #[derive(Clone, Debug)]
      39              : pub enum Server {
      40              :     Internal {
      41              :         port: u16,
      42              :     },
      43              :     External {
      44              :         port: u16,
      45              :         config: ComputeCtlConfig,
      46              :         compute_id: String,
      47              :         instance_id: Option<String>,
      48              :     },
      49              : }
      50              : 
      51              : impl Display for Server {
      52            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      53            0 :         match self {
      54            0 :             Server::Internal { .. } => f.write_str("internal"),
      55            0 :             Server::External { .. } => f.write_str("external"),
      56              :         }
      57            0 :     }
      58              : }
      59              : 
      60              : impl From<&Server> for Router<Arc<ComputeNode>> {
      61            0 :     fn from(server: &Server) -> Self {
      62            0 :         let mut router = Router::<Arc<ComputeNode>>::new();
      63              : 
      64            0 :         router = match server {
      65              :             Server::Internal { .. } => {
      66            0 :                 router = router
      67            0 :                     .route(
      68            0 :                         "/extension_server/{*filename}",
      69            0 :                         post(extension_server::download_extension),
      70              :                     )
      71            0 :                     .route("/extensions", post(extensions::install_extension))
      72            0 :                     .route("/grants", post(grants::add_grant))
      73              :                     // Hadron: Compute-initiated configuration refresh
      74            0 :                     .route(
      75            0 :                         "/refresh_configuration",
      76            0 :                         post(refresh_configuration::refresh_configuration),
      77              :                     );
      78              : 
      79              :                 // Add in any testing support
      80            0 :                 if cfg!(feature = "testing") {
      81              :                     use super::routes::failpoints;
      82              : 
      83            0 :                     router = router.route("/failpoints", post(failpoints::configure_failpoints));
      84            0 :                 }
      85              : 
      86            0 :                 router
      87              :             }
      88              :             Server::External {
      89            0 :                 config,
      90            0 :                 compute_id,
      91            0 :                 instance_id,
      92              :                 ..
      93              :             } => {
      94            0 :                 let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
      95            0 :                     .route("/metrics", get(metrics::get_metrics))
      96            0 :                     .route(
      97            0 :                         "/autoscaling_metrics",
      98            0 :                         get(metrics::get_autoscaling_metrics),
      99              :                     );
     100              : 
     101            0 :                 let authenticated_router = Router::<Arc<ComputeNode>>::new()
     102            0 :                     .route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))
     103            0 :                     .route("/lfc/offload", get(lfc::offload_state).post(lfc::offload))
     104            0 :                     .route("/promote", post(promote::promote))
     105            0 :                     .route("/check_writability", post(check_writability::is_writable))
     106            0 :                     .route("/configure", post(configure::configure))
     107            0 :                     .route("/database_schema", get(database_schema::get_schema_dump))
     108            0 :                     .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
     109            0 :                     .route("/insights", get(insights::get_insights))
     110            0 :                     .route("/metrics.json", get(metrics_json::get_metrics))
     111            0 :                     .route("/status", get(status::get_status))
     112            0 :                     .route("/terminate", post(terminate::terminate))
     113            0 :                     .route(
     114            0 :                         "/hadron_liveness_probe",
     115            0 :                         get(hadron_liveness_probe::hadron_liveness_probe),
     116              :                     )
     117            0 :                     .layer(AsyncRequireAuthorizationLayer::new(Authorize::new(
     118            0 :                         compute_id.clone(),
     119            0 :                         instance_id.clone(),
     120            0 :                         config.jwks.clone(),
     121              :                     )));
     122              : 
     123            0 :                 router
     124            0 :                     .merge(unauthenticated_router)
     125            0 :                     .merge(authenticated_router)
     126              :             }
     127              :         };
     128              : 
     129            0 :         router
     130            0 :             .fallback(Server::handle_404)
     131            0 :             .method_not_allowed_fallback(Server::handle_405)
     132            0 :             .layer(
     133            0 :                 ServiceBuilder::new()
     134            0 :                     .layer(tower_otel::trace::HttpLayer::server(tracing::Level::INFO))
     135              :                     // Add this middleware since we assume the request ID exists
     136            0 :                     .layer(middleware::from_fn(maybe_add_request_id_header))
     137            0 :                     .layer(
     138            0 :                         TraceLayer::new_for_http()
     139            0 :                             .on_request(|request: &http::Request<_>, _span: &Span| {
     140            0 :                                 let request_id = request
     141            0 :                                     .headers()
     142            0 :                                     .get(X_REQUEST_ID)
     143            0 :                                     .unwrap()
     144            0 :                                     .to_str()
     145            0 :                                     .unwrap();
     146              : 
     147            0 :                                 info!(%request_id, "{} {}", request.method(), request.uri());
     148            0 :                             })
     149            0 :                             .on_response(
     150            0 :                                 |response: &http::Response<_>, latency: Duration, _span: &Span| {
     151            0 :                                     let request_id = response
     152            0 :                                         .headers()
     153            0 :                                         .get(X_REQUEST_ID)
     154            0 :                                         .unwrap()
     155            0 :                                         .to_str()
     156            0 :                                         .unwrap();
     157              : 
     158            0 :                                     info!(
     159              :                                         %request_id,
     160            0 :                                         code = response.status().as_u16(),
     161            0 :                                         latency = latency.as_millis()
     162              :                                     );
     163            0 :                                 },
     164              :                             ),
     165              :                     )
     166            0 :                     .layer(PropagateRequestIdLayer::x_request_id()),
     167              :             )
     168            0 :     }
     169              : }
     170              : 
     171              : impl Server {
     172            0 :     async fn handle_404() -> impl IntoResponse {
     173            0 :         StatusCode::NOT_FOUND
     174            0 :     }
     175              : 
     176            0 :     async fn handle_405() -> impl IntoResponse {
     177            0 :         StatusCode::METHOD_NOT_ALLOWED
     178            0 :     }
     179              : 
     180            0 :     async fn listener(&self) -> Result<TcpListener> {
     181            0 :         let addr = SocketAddr::new(self.ip(), self.port());
     182            0 :         let listener = TcpListener::bind(&addr).await?;
     183              : 
     184            0 :         Ok(listener)
     185            0 :     }
     186              : 
     187            0 :     fn ip(&self) -> IpAddr {
     188            0 :         match self {
     189              :             // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners
     190              :             // allow binding to localhost
     191            0 :             Server::Internal { .. } => IpAddr::from(Ipv6Addr::UNSPECIFIED),
     192            0 :             Server::External { .. } => IpAddr::from(Ipv6Addr::UNSPECIFIED),
     193              :         }
     194            0 :     }
     195              : 
     196            0 :     fn port(&self) -> u16 {
     197            0 :         match self {
     198            0 :             Server::Internal { port, .. } => *port,
     199            0 :             Server::External { port, .. } => *port,
     200              :         }
     201            0 :     }
     202              : 
     203            0 :     async fn serve(self, compute: Arc<ComputeNode>) {
     204            0 :         let listener = self.listener().await.unwrap_or_else(|e| {
     205              :             // If we can't bind, the compute cannot operate correctly
     206            0 :             panic!(
     207            0 :                 "failed to bind the compute_ctl {} HTTP server to {}: {}",
     208              :                 self,
     209            0 :                 SocketAddr::new(self.ip(), self.port()),
     210              :                 e
     211              :             );
     212              :         });
     213              : 
     214            0 :         if tracing::enabled!(tracing::Level::INFO) {
     215            0 :             let local_addr = match listener.local_addr() {
     216            0 :                 Ok(local_addr) => local_addr,
     217            0 :                 Err(_) => SocketAddr::new(self.ip(), self.port()),
     218              :             };
     219              : 
     220            0 :             info!(
     221            0 :                 "compute_ctl {} HTTP server listening at {}",
     222              :                 self, local_addr
     223              :             );
     224            0 :         }
     225              : 
     226            0 :         let router = Router::from(&self)
     227            0 :             .with_state(compute)
     228            0 :             .into_make_service_with_connect_info::<SocketAddr>();
     229              : 
     230            0 :         if let Err(e) = axum::serve(listener, router).await {
     231            0 :             error!("compute_ctl {} HTTP server error: {}", self, e);
     232            0 :         }
     233            0 :     }
     234              : 
     235            0 :     pub fn launch(self, compute: &Arc<ComputeNode>) {
     236            0 :         let state = Arc::clone(compute);
     237              : 
     238            0 :         info!("Launching the {} server", self);
     239              : 
     240            0 :         tokio::spawn(self.serve(state));
     241            0 :     }
     242              : }
        

Generated by: LCOV version 2.1-beta