LCOV - code coverage report
Current view: top level - compute_tools/src/http - server.rs (source / functions) Coverage Total Hit
Test: b9728233c33232dfae45024a493738ef141ccd5d.info Lines: 0.0 % 122 0
Test Date: 2025-01-10 20:41:15 Functions: 0.0 % 9 0

            Line data    Source code
       1              : use std::{
       2              :     net::{IpAddr, Ipv6Addr, SocketAddr},
       3              :     sync::{
       4              :         atomic::{AtomicU64, Ordering},
       5              :         Arc,
       6              :     },
       7              :     thread,
       8              :     time::Duration,
       9              : };
      10              : 
      11              : use anyhow::Result;
      12              : use axum::{
      13              :     response::{IntoResponse, Response},
      14              :     routing::{get, post},
      15              :     Router,
      16              : };
      17              : use http::StatusCode;
      18              : use tokio::net::TcpListener;
      19              : use tower::ServiceBuilder;
      20              : use tower_http::{
      21              :     request_id::{MakeRequestId, PropagateRequestIdLayer, RequestId, SetRequestIdLayer},
      22              :     trace::TraceLayer,
      23              : };
      24              : use tracing::{debug, error, info, Span};
      25              : 
      26              : use super::routes::{
      27              :     check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
      28              :     grants, info as info_route, insights, installed_extensions, metrics, metrics_json, status,
      29              :     terminate,
      30              : };
      31              : use crate::compute::ComputeNode;
      32              : 
      33            0 : async fn handle_404() -> Response {
      34            0 :     StatusCode::NOT_FOUND.into_response()
      35            0 : }
      36              : 
      37              : #[derive(Clone, Default)]
      38              : struct ComputeMakeRequestId(Arc<AtomicU64>);
      39              : 
      40              : impl MakeRequestId for ComputeMakeRequestId {
      41            0 :     fn make_request_id<B>(
      42            0 :         &mut self,
      43            0 :         _request: &http::Request<B>,
      44            0 :     ) -> Option<tower_http::request_id::RequestId> {
      45            0 :         let request_id = self
      46            0 :             .0
      47            0 :             .fetch_add(1, Ordering::SeqCst)
      48            0 :             .to_string()
      49            0 :             .parse()
      50            0 :             .unwrap();
      51            0 : 
      52            0 :         Some(RequestId::new(request_id))
      53            0 :     }
      54              : }
      55              : 
      56              : /// Run the HTTP server and wait on it forever.
      57              : #[tokio::main]
      58            0 : async fn serve(port: u16, compute: Arc<ComputeNode>) {
      59            0 :     const X_REQUEST_ID: &str = "x-request-id";
      60            0 : 
      61            0 :     let mut app = Router::new()
      62            0 :         .route("/check_writability", post(check_writability::is_writable))
      63            0 :         .route("/configure", post(configure::configure))
      64            0 :         .route("/database_schema", get(database_schema::get_schema_dump))
      65            0 :         .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
      66            0 :         .route(
      67            0 :             "/extension_server/*filename",
      68            0 :             post(extension_server::download_extension),
      69            0 :         )
      70            0 :         .route("/extensions", post(extensions::install_extension))
      71            0 :         .route("/grants", post(grants::add_grant))
      72            0 :         .route("/info", get(info_route::get_info))
      73            0 :         .route("/insights", get(insights::get_insights))
      74            0 :         .route(
      75            0 :             "/installed_extensions",
      76            0 :             get(installed_extensions::get_installed_extensions),
      77            0 :         )
      78            0 :         .route("/metrics", get(metrics::get_metrics))
      79            0 :         .route("/metrics.json", get(metrics_json::get_metrics))
      80            0 :         .route("/status", get(status::get_status))
      81            0 :         .route("/terminate", post(terminate::terminate))
      82            0 :         .fallback(handle_404)
      83            0 :         .layer(
      84            0 :             ServiceBuilder::new()
      85            0 :                 .layer(SetRequestIdLayer::x_request_id(
      86            0 :                     ComputeMakeRequestId::default(),
      87            0 :                 ))
      88            0 :                 .layer(
      89            0 :                     TraceLayer::new_for_http()
      90            0 :                         .on_request(|request: &http::Request<_>, _span: &Span| {
      91            0 :                             let request_id = request
      92            0 :                                 .headers()
      93            0 :                                 .get(X_REQUEST_ID)
      94            0 :                                 .unwrap()
      95            0 :                                 .to_str()
      96            0 :                                 .unwrap();
      97            0 : 
      98            0 :                             match request.uri().path() {
      99            0 :                                 "/metrics" => {
     100            0 :                                     debug!(%request_id, "{} {}", request.method(), request.uri())
     101            0 :                                 }
     102            0 :                                 _ => info!(%request_id, "{} {}", request.method(), request.uri()),
     103            0 :                             };
     104            0 :                         })
     105            0 :                         .on_response(
     106            0 :                             |response: &http::Response<_>, latency: Duration, _span: &Span| {
     107            0 :                                 let request_id = response
     108            0 :                                     .headers()
     109            0 :                                     .get(X_REQUEST_ID)
     110            0 :                                     .unwrap()
     111            0 :                                     .to_str()
     112            0 :                                     .unwrap();
     113            0 : 
     114            0 :                                 info!(
     115            0 :                                     %request_id,
     116            0 :                                     code = response.status().as_u16(),
     117            0 :                                     latency = latency.as_millis()
     118            0 :                                 )
     119            0 :                             },
     120            0 :                         ),
     121            0 :                 )
     122            0 :                 .layer(PropagateRequestIdLayer::x_request_id()),
     123            0 :         )
     124            0 :         .with_state(compute);
     125            0 : 
     126            0 :     // Add in any testing support
     127            0 :     if cfg!(feature = "testing") {
     128            0 :         use super::routes::failpoints;
     129            0 : 
     130            0 :         app = app.route("/failpoints", post(failpoints::configure_failpoints))
     131            0 :     }
     132            0 : 
     133            0 :     // This usually binds to both IPv4 and IPv6 on Linux, see
     134            0 :     // https://github.com/rust-lang/rust/pull/34440 for more information
     135            0 :     let addr = SocketAddr::new(IpAddr::from(Ipv6Addr::UNSPECIFIED), port);
     136            0 :     let listener = match TcpListener::bind(&addr).await {
     137            0 :         Ok(listener) => listener,
     138            0 :         Err(e) => {
     139            0 :             error!(
     140            0 :                 "failed to bind the compute_ctl HTTP server to port {}: {}",
     141            0 :                 port, e
     142            0 :             );
     143            0 :             return;
     144            0 :         }
     145            0 :     };
     146            0 : 
     147            0 :     if let Ok(local_addr) = listener.local_addr() {
     148            0 :         info!("compute_ctl HTTP server listening on {}", local_addr);
     149            0 :     } else {
     150            0 :         info!("compute_ctl HTTP server listening on port {}", port);
     151            0 :     }
     152            0 : 
     153            0 :     if let Err(e) = axum::serve(listener, app).await {
     154            0 :         error!("compute_ctl HTTP server error: {}", e);
     155            0 :     }
     156            0 : }
     157              : 
     158              : /// Launch a separate HTTP server thread and return its `JoinHandle`.
     159            0 : pub fn launch_http_server(port: u16, state: &Arc<ComputeNode>) -> Result<thread::JoinHandle<()>> {
     160            0 :     let state = Arc::clone(state);
     161            0 : 
     162            0 :     Ok(thread::Builder::new()
     163            0 :         .name("http-server".into())
     164            0 :         .spawn(move || serve(port, state))?)
     165            0 : }
        

Generated by: LCOV version 2.1-beta