Line data Source code
1 : use std::{
2 : net::{IpAddr, Ipv6Addr, SocketAddr},
3 : sync::Arc,
4 : thread,
5 : time::Duration,
6 : };
7 :
8 : use anyhow::Result;
9 : use axum::{
10 : extract::Request,
11 : middleware::{self, Next},
12 : response::{IntoResponse, Response},
13 : routing::{get, post},
14 : Router,
15 : };
16 : use http::StatusCode;
17 : use tokio::net::TcpListener;
18 : use tower::ServiceBuilder;
19 : use tower_http::{request_id::PropagateRequestIdLayer, trace::TraceLayer};
20 : use tracing::{debug, error, info, Span};
21 : use uuid::Uuid;
22 :
23 : use super::routes::{
24 : check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
25 : grants, insights, metrics, metrics_json, status, terminate,
26 : };
27 : use crate::compute::ComputeNode;
28 :
29 0 : async fn handle_404() -> Response {
30 0 : StatusCode::NOT_FOUND.into_response()
31 0 : }
32 :
33 : const X_REQUEST_ID: &str = "x-request-id";
34 :
35 : /// This middleware function allows compute_ctl to generate its own request ID
36 : /// if one isn't supplied. The control plane will always send one as a UUID. The
37 : /// neon Postgres extension on the other hand does not send one.
38 0 : async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response {
39 0 : let headers = request.headers_mut();
40 0 :
41 0 : if headers.get(X_REQUEST_ID).is_none() {
42 0 : headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap());
43 0 : }
44 :
45 0 : next.run(request).await
46 0 : }
47 :
48 : /// Run the HTTP server and wait on it forever.
49 : #[tokio::main]
50 0 : async fn serve(port: u16, compute: Arc<ComputeNode>) {
51 0 : let mut app = Router::new()
52 0 : .route("/check_writability", post(check_writability::is_writable))
53 0 : .route("/configure", post(configure::configure))
54 0 : .route("/database_schema", get(database_schema::get_schema_dump))
55 0 : .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
56 0 : .route(
57 0 : "/extension_server/{*filename}",
58 0 : post(extension_server::download_extension),
59 0 : )
60 0 : .route("/extensions", post(extensions::install_extension))
61 0 : .route("/grants", post(grants::add_grant))
62 0 : .route("/insights", get(insights::get_insights))
63 0 : .route("/metrics", get(metrics::get_metrics))
64 0 : .route("/metrics.json", get(metrics_json::get_metrics))
65 0 : .route("/status", get(status::get_status))
66 0 : .route("/terminate", post(terminate::terminate))
67 0 : .fallback(handle_404)
68 0 : .layer(
69 0 : ServiceBuilder::new()
70 0 : // Add this middleware since we assume the request ID exists
71 0 : .layer(middleware::from_fn(maybe_add_request_id_header))
72 0 : .layer(
73 0 : TraceLayer::new_for_http()
74 0 : .on_request(|request: &http::Request<_>, _span: &Span| {
75 0 : let request_id = request
76 0 : .headers()
77 0 : .get(X_REQUEST_ID)
78 0 : .unwrap()
79 0 : .to_str()
80 0 : .unwrap();
81 0 :
82 0 : match request.uri().path() {
83 0 : "/metrics" => {
84 0 : debug!(%request_id, "{} {}", request.method(), request.uri())
85 0 : }
86 0 : _ => info!(%request_id, "{} {}", request.method(), request.uri()),
87 0 : };
88 0 : })
89 0 : .on_response(
90 0 : |response: &http::Response<_>, latency: Duration, _span: &Span| {
91 0 : let request_id = response
92 0 : .headers()
93 0 : .get(X_REQUEST_ID)
94 0 : .unwrap()
95 0 : .to_str()
96 0 : .unwrap();
97 0 :
98 0 : info!(
99 0 : %request_id,
100 0 : code = response.status().as_u16(),
101 0 : latency = latency.as_millis()
102 0 : )
103 0 : },
104 0 : ),
105 0 : )
106 0 : .layer(PropagateRequestIdLayer::x_request_id()),
107 0 : )
108 0 : .with_state(compute);
109 0 :
110 0 : // Add in any testing support
111 0 : if cfg!(feature = "testing") {
112 0 : use super::routes::failpoints;
113 0 :
114 0 : app = app.route("/failpoints", post(failpoints::configure_failpoints))
115 0 : }
116 0 :
117 0 : // This usually binds to both IPv4 and IPv6 on Linux, see
118 0 : // https://github.com/rust-lang/rust/pull/34440 for more information
119 0 : let addr = SocketAddr::new(IpAddr::from(Ipv6Addr::UNSPECIFIED), port);
120 0 : let listener = match TcpListener::bind(&addr).await {
121 0 : Ok(listener) => listener,
122 0 : Err(e) => {
123 0 : error!(
124 0 : "failed to bind the compute_ctl HTTP server to port {}: {}",
125 0 : port, e
126 0 : );
127 0 : return;
128 0 : }
129 0 : };
130 0 :
131 0 : if let Ok(local_addr) = listener.local_addr() {
132 0 : info!("compute_ctl HTTP server listening on {}", local_addr);
133 0 : } else {
134 0 : info!("compute_ctl HTTP server listening on port {}", port);
135 0 : }
136 0 :
137 0 : if let Err(e) = axum::serve(listener, app).await {
138 0 : error!("compute_ctl HTTP server error: {}", e);
139 0 : }
140 0 : }
141 :
142 : /// Launch a separate HTTP server thread and return its `JoinHandle`.
143 0 : pub fn launch_http_server(port: u16, state: &Arc<ComputeNode>) -> Result<thread::JoinHandle<()>> {
144 0 : let state = Arc::clone(state);
145 0 :
146 0 : Ok(thread::Builder::new()
147 0 : .name("http-server".into())
148 0 : .spawn(move || serve(port, state))?)
149 0 : }
|