Line data Source code
1 : use std::{
2 : net::{IpAddr, Ipv6Addr, SocketAddr},
3 : sync::Arc,
4 : thread,
5 : time::Duration,
6 : };
7 :
8 : use anyhow::Result;
9 : use axum::{
10 : extract::Request,
11 : middleware::{self, Next},
12 : response::{IntoResponse, Response},
13 : routing::{get, post},
14 : Router,
15 : };
16 : use http::StatusCode;
17 : use tokio::net::TcpListener;
18 : use tower::ServiceBuilder;
19 : use tower_http::{request_id::PropagateRequestIdLayer, trace::TraceLayer};
20 : use tracing::{debug, error, info, Span};
21 : use uuid::Uuid;
22 :
23 : use super::routes::{
24 : check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
25 : grants, info as info_route, insights, installed_extensions, metrics, metrics_json, status,
26 : terminate,
27 : };
28 : use crate::compute::ComputeNode;
29 :
30 0 : async fn handle_404() -> Response {
31 0 : StatusCode::NOT_FOUND.into_response()
32 0 : }
33 :
34 : const X_REQUEST_ID: &str = "x-request-id";
35 :
36 : /// This middleware function allows compute_ctl to generate its own request ID
37 : /// if one isn't supplied. The control plane will always send one as a UUID. The
38 : /// neon Postgres extension on the other hand does not send one.
39 0 : async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response {
40 0 : let headers = request.headers_mut();
41 0 :
42 0 : if headers.get(X_REQUEST_ID).is_none() {
43 0 : headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap());
44 0 : }
45 :
46 0 : next.run(request).await
47 0 : }
48 :
49 : /// Run the HTTP server and wait on it forever.
50 : #[tokio::main]
51 0 : async fn serve(port: u16, compute: Arc<ComputeNode>) {
52 0 : let mut app = Router::new()
53 0 : .route("/check_writability", post(check_writability::is_writable))
54 0 : .route("/configure", post(configure::configure))
55 0 : .route("/database_schema", get(database_schema::get_schema_dump))
56 0 : .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
57 0 : .route(
58 0 : "/extension_server/{*filename}",
59 0 : post(extension_server::download_extension),
60 0 : )
61 0 : .route("/extensions", post(extensions::install_extension))
62 0 : .route("/grants", post(grants::add_grant))
63 0 : .route("/info", get(info_route::get_info))
64 0 : .route("/insights", get(insights::get_insights))
65 0 : .route(
66 0 : "/installed_extensions",
67 0 : get(installed_extensions::get_installed_extensions),
68 0 : )
69 0 : .route("/metrics", get(metrics::get_metrics))
70 0 : .route("/metrics.json", get(metrics_json::get_metrics))
71 0 : .route("/status", get(status::get_status))
72 0 : .route("/terminate", post(terminate::terminate))
73 0 : .fallback(handle_404)
74 0 : .layer(
75 0 : ServiceBuilder::new()
76 0 : // Add this middleware since we assume the request ID exists
77 0 : .layer(middleware::from_fn(maybe_add_request_id_header))
78 0 : .layer(
79 0 : TraceLayer::new_for_http()
80 0 : .on_request(|request: &http::Request<_>, _span: &Span| {
81 0 : let request_id = request
82 0 : .headers()
83 0 : .get(X_REQUEST_ID)
84 0 : .unwrap()
85 0 : .to_str()
86 0 : .unwrap();
87 0 :
88 0 : match request.uri().path() {
89 0 : "/metrics" => {
90 0 : debug!(%request_id, "{} {}", request.method(), request.uri())
91 0 : }
92 0 : _ => info!(%request_id, "{} {}", request.method(), request.uri()),
93 0 : };
94 0 : })
95 0 : .on_response(
96 0 : |response: &http::Response<_>, latency: Duration, _span: &Span| {
97 0 : let request_id = response
98 0 : .headers()
99 0 : .get(X_REQUEST_ID)
100 0 : .unwrap()
101 0 : .to_str()
102 0 : .unwrap();
103 0 :
104 0 : info!(
105 0 : %request_id,
106 0 : code = response.status().as_u16(),
107 0 : latency = latency.as_millis()
108 0 : )
109 0 : },
110 0 : ),
111 0 : )
112 0 : .layer(PropagateRequestIdLayer::x_request_id()),
113 0 : )
114 0 : .with_state(compute);
115 0 :
116 0 : // Add in any testing support
117 0 : if cfg!(feature = "testing") {
118 0 : use super::routes::failpoints;
119 0 :
120 0 : app = app.route("/failpoints", post(failpoints::configure_failpoints))
121 0 : }
122 0 :
123 0 : // This usually binds to both IPv4 and IPv6 on Linux, see
124 0 : // https://github.com/rust-lang/rust/pull/34440 for more information
125 0 : let addr = SocketAddr::new(IpAddr::from(Ipv6Addr::UNSPECIFIED), port);
126 0 : let listener = match TcpListener::bind(&addr).await {
127 0 : Ok(listener) => listener,
128 0 : Err(e) => {
129 0 : error!(
130 0 : "failed to bind the compute_ctl HTTP server to port {}: {}",
131 0 : port, e
132 0 : );
133 0 : return;
134 0 : }
135 0 : };
136 0 :
137 0 : if let Ok(local_addr) = listener.local_addr() {
138 0 : info!("compute_ctl HTTP server listening on {}", local_addr);
139 0 : } else {
140 0 : info!("compute_ctl HTTP server listening on port {}", port);
141 0 : }
142 0 :
143 0 : if let Err(e) = axum::serve(listener, app).await {
144 0 : error!("compute_ctl HTTP server error: {}", e);
145 0 : }
146 0 : }
147 :
148 : /// Launch a separate HTTP server thread and return its `JoinHandle`.
149 0 : pub fn launch_http_server(port: u16, state: &Arc<ComputeNode>) -> Result<thread::JoinHandle<()>> {
150 0 : let state = Arc::clone(state);
151 0 :
152 0 : Ok(thread::Builder::new()
153 0 : .name("http-server".into())
154 0 : .spawn(move || serve(port, state))?)
155 0 : }
|