LCOV - code coverage report
Current view: top level - compute_tools/src/http - server.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 0.0 % 127 0
Test Date: 2025-02-20 13:11:02 Functions: 0.0 % 18 0

            Line data    Source code
       1              : use std::{
       2              :     fmt::Display,
       3              :     net::{IpAddr, Ipv6Addr, SocketAddr},
       4              :     sync::Arc,
       5              :     time::Duration,
       6              : };
       7              : 
       8              : use anyhow::Result;
       9              : use axum::{
      10              :     extract::Request,
      11              :     middleware::{self, Next},
      12              :     response::{IntoResponse, Response},
      13              :     routing::{get, post},
      14              :     Router,
      15              : };
      16              : use http::StatusCode;
      17              : use tokio::net::TcpListener;
      18              : use tower::ServiceBuilder;
      19              : use tower_http::{request_id::PropagateRequestIdLayer, trace::TraceLayer};
      20              : use tracing::{debug, error, info, Span};
      21              : use uuid::Uuid;
      22              : 
      23              : use super::routes::{
      24              :     check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
      25              :     grants, insights, metrics, metrics_json, status, terminate,
      26              : };
      27              : use crate::compute::ComputeNode;
      28              : 
      29              : const X_REQUEST_ID: &str = "x-request-id";
      30              : 
      31              : /// `compute_ctl` has two servers: internal and external. The internal server
      32              : /// binds to the loopback interface and handles communication from clients on
      33              : /// the compute. The external server is what receives communication from the
      34              : /// control plane, the metrics scraper, etc. We make the distinction because
      35              : /// certain routes in `compute_ctl` only need to be exposed to local processes
      36              : /// like Postgres via the neon extension and local_proxy.
      37              : #[derive(Clone, Copy, Debug)]
      38              : pub enum Server {
      39              :     Internal(u16),
      40              :     External(u16),
      41              : }
      42              : 
      43              : impl Display for Server {
      44            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      45            0 :         match self {
      46            0 :             Server::Internal(_) => f.write_str("internal"),
      47            0 :             Server::External(_) => f.write_str("external"),
      48              :         }
      49            0 :     }
      50              : }
      51              : 
      52              : impl From<Server> for Router<Arc<ComputeNode>> {
      53            0 :     fn from(server: Server) -> Self {
      54            0 :         let mut router = Router::<Arc<ComputeNode>>::new();
      55              : 
      56            0 :         router = match server {
      57              :             Server::Internal(_) => {
      58            0 :                 router = router
      59            0 :                     .route(
      60            0 :                         "/extension_server/{*filename}",
      61            0 :                         post(extension_server::download_extension),
      62            0 :                     )
      63            0 :                     .route("/extensions", post(extensions::install_extension))
      64            0 :                     .route("/grants", post(grants::add_grant));
      65            0 : 
      66            0 :                 // Add in any testing support
      67            0 :                 if cfg!(feature = "testing") {
      68            0 :                     use super::routes::failpoints;
      69            0 : 
      70            0 :                     router = router.route("/failpoints", post(failpoints::configure_failpoints));
      71            0 :                 }
      72              : 
      73            0 :                 router
      74              :             }
      75            0 :             Server::External(_) => router
      76            0 :                 .route("/check_writability", post(check_writability::is_writable))
      77            0 :                 .route("/configure", post(configure::configure))
      78            0 :                 .route("/database_schema", get(database_schema::get_schema_dump))
      79            0 :                 .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
      80            0 :                 .route("/insights", get(insights::get_insights))
      81            0 :                 .route("/metrics", get(metrics::get_metrics))
      82            0 :                 .route("/metrics.json", get(metrics_json::get_metrics))
      83            0 :                 .route("/status", get(status::get_status))
      84            0 :                 .route("/terminate", post(terminate::terminate)),
      85              :         };
      86              : 
      87            0 :         router.fallback(Server::handle_404).method_not_allowed_fallback(Server::handle_405).layer(
      88            0 :             ServiceBuilder::new()
      89            0 :                 // Add this middleware since we assume the request ID exists
      90            0 :                 .layer(middleware::from_fn(maybe_add_request_id_header))
      91            0 :                 .layer(
      92            0 :                     TraceLayer::new_for_http()
      93            0 :                         .on_request(|request: &http::Request<_>, _span: &Span| {
      94            0 :                             let request_id = request
      95            0 :                                 .headers()
      96            0 :                                 .get(X_REQUEST_ID)
      97            0 :                                 .unwrap()
      98            0 :                                 .to_str()
      99            0 :                                 .unwrap();
     100            0 : 
     101            0 :                             match request.uri().path() {
     102            0 :                                 "/metrics" => {
     103            0 :                                     debug!(%request_id, "{} {}", request.method(), request.uri())
     104              :                                 }
     105            0 :                                 _ => info!(%request_id, "{} {}", request.method(), request.uri()),
     106              :                             };
     107            0 :                         })
     108            0 :                         .on_response(
     109            0 :                             |response: &http::Response<_>, latency: Duration, _span: &Span| {
     110            0 :                                 let request_id = response
     111            0 :                                     .headers()
     112            0 :                                     .get(X_REQUEST_ID)
     113            0 :                                     .unwrap()
     114            0 :                                     .to_str()
     115            0 :                                     .unwrap();
     116            0 : 
     117            0 :                                 info!(
     118              :                                     %request_id,
     119            0 :                                     code = response.status().as_u16(),
     120            0 :                                     latency = latency.as_millis()
     121              :                                 )
     122            0 :                             },
     123            0 :                         ),
     124            0 :                 )
     125            0 :                 .layer(PropagateRequestIdLayer::x_request_id()),
     126            0 :         )
     127            0 :     }
     128              : }
     129              : 
     130              : impl Server {
     131            0 :     async fn handle_404() -> impl IntoResponse {
     132            0 :         StatusCode::NOT_FOUND
     133            0 :     }
     134              : 
     135            0 :     async fn handle_405() -> impl IntoResponse {
     136            0 :         StatusCode::METHOD_NOT_ALLOWED
     137            0 :     }
     138              : 
     139            0 :     async fn listener(&self) -> Result<TcpListener> {
     140            0 :         let addr = SocketAddr::new(self.ip(), self.port());
     141            0 :         let listener = TcpListener::bind(&addr).await?;
     142              : 
     143            0 :         Ok(listener)
     144            0 :     }
     145              : 
     146            0 :     fn ip(&self) -> IpAddr {
     147            0 :         match self {
     148              :             // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners
     149              :             // allow binding to localhost
     150            0 :             Server::Internal(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
     151            0 :             Server::External(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
     152              :         }
     153            0 :     }
     154              : 
     155            0 :     fn port(self) -> u16 {
     156            0 :         match self {
     157            0 :             Server::Internal(port) => port,
     158            0 :             Server::External(port) => port,
     159              :         }
     160            0 :     }
     161              : 
     162            0 :     async fn serve(self, compute: Arc<ComputeNode>) {
     163            0 :         let listener = self.listener().await.unwrap_or_else(|e| {
     164            0 :             // If we can't bind, the compute cannot operate correctly
     165            0 :             panic!(
     166            0 :                 "failed to bind the compute_ctl {} HTTP server to {}: {}",
     167            0 :                 self,
     168            0 :                 SocketAddr::new(self.ip(), self.port()),
     169            0 :                 e
     170            0 :             );
     171            0 :         });
     172            0 : 
     173            0 :         if tracing::enabled!(tracing::Level::INFO) {
     174            0 :             let local_addr = match listener.local_addr() {
     175            0 :                 Ok(local_addr) => local_addr,
     176            0 :                 Err(_) => SocketAddr::new(self.ip(), self.port()),
     177              :             };
     178              : 
     179            0 :             info!(
     180            0 :                 "compute_ctl {} HTTP server listening at {}",
     181              :                 self, local_addr
     182              :             );
     183            0 :         }
     184              : 
     185            0 :         let router = Router::from(self).with_state(compute);
     186              : 
     187            0 :         if let Err(e) = axum::serve(listener, router).await {
     188            0 :             error!("compute_ctl {} HTTP server error: {}", self, e);
     189            0 :         }
     190            0 :     }
     191              : 
     192            0 :     pub fn launch(self, compute: &Arc<ComputeNode>) {
     193            0 :         let state = Arc::clone(compute);
     194            0 : 
     195            0 :         info!("Launching the {} server", self);
     196              : 
     197            0 :         tokio::spawn(self.serve(state));
     198            0 :     }
     199              : }
     200              : 
     201              : /// This middleware function allows compute_ctl to generate its own request ID
     202              : /// if one isn't supplied. The control plane will always send one as a UUID. The
     203              : /// neon Postgres extension on the other hand does not send one.
     204            0 : async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response {
     205            0 :     let headers = request.headers_mut();
     206            0 :     if headers.get(X_REQUEST_ID).is_none() {
     207            0 :         headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap());
     208            0 :     }
     209              : 
     210            0 :     next.run(request).await
     211            0 : }
        

Generated by: LCOV version 2.1-beta