Line data Source code
1 : use std::fmt::Display;
2 : use std::net::{IpAddr, Ipv6Addr, SocketAddr};
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use anyhow::Result;
7 : use axum::Router;
8 : use axum::extract::Request;
9 : use axum::middleware::{self, Next};
10 : use axum::response::{IntoResponse, Response};
11 : use axum::routing::{get, post};
12 : use http::StatusCode;
13 : use tokio::net::TcpListener;
14 : use tower::ServiceBuilder;
15 : use tower_http::request_id::PropagateRequestIdLayer;
16 : use tower_http::trace::TraceLayer;
17 : use tracing::{Span, debug, error, info};
18 : use uuid::Uuid;
19 :
20 : use super::routes::{
21 : check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
22 : grants, insights, metrics, metrics_json, status, terminate,
23 : };
24 : use crate::compute::ComputeNode;
25 :
26 : const X_REQUEST_ID: &str = "x-request-id";
27 :
28 : /// `compute_ctl` has two servers: internal and external. The internal server
29 : /// binds to the loopback interface and handles communication from clients on
30 : /// the compute. The external server is what receives communication from the
31 : /// control plane, the metrics scraper, etc. We make the distinction because
32 : /// certain routes in `compute_ctl` only need to be exposed to local processes
33 : /// like Postgres via the neon extension and local_proxy.
34 : #[derive(Clone, Copy, Debug)]
35 : pub enum Server {
36 : Internal(u16),
37 : External(u16),
38 : }
39 :
40 : impl Display for Server {
41 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 0 : match self {
43 0 : Server::Internal(_) => f.write_str("internal"),
44 0 : Server::External(_) => f.write_str("external"),
45 : }
46 0 : }
47 : }
48 :
49 : impl From<Server> for Router<Arc<ComputeNode>> {
50 0 : fn from(server: Server) -> Self {
51 0 : let mut router = Router::<Arc<ComputeNode>>::new();
52 :
53 0 : router = match server {
54 : Server::Internal(_) => {
55 0 : router = router
56 0 : .route(
57 0 : "/extension_server/{*filename}",
58 0 : post(extension_server::download_extension),
59 0 : )
60 0 : .route("/extensions", post(extensions::install_extension))
61 0 : .route("/grants", post(grants::add_grant));
62 0 :
63 0 : // Add in any testing support
64 0 : if cfg!(feature = "testing") {
65 0 : use super::routes::failpoints;
66 0 :
67 0 : router = router.route("/failpoints", post(failpoints::configure_failpoints));
68 0 : }
69 :
70 0 : router
71 : }
72 0 : Server::External(_) => router
73 0 : .route("/check_writability", post(check_writability::is_writable))
74 0 : .route("/configure", post(configure::configure))
75 0 : .route("/database_schema", get(database_schema::get_schema_dump))
76 0 : .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
77 0 : .route("/insights", get(insights::get_insights))
78 0 : .route("/metrics", get(metrics::get_metrics))
79 0 : .route("/metrics.json", get(metrics_json::get_metrics))
80 0 : .route("/status", get(status::get_status))
81 0 : .route("/terminate", post(terminate::terminate)),
82 : };
83 :
84 0 : router.fallback(Server::handle_404).method_not_allowed_fallback(Server::handle_405).layer(
85 0 : ServiceBuilder::new()
86 0 : // Add this middleware since we assume the request ID exists
87 0 : .layer(middleware::from_fn(maybe_add_request_id_header))
88 0 : .layer(
89 0 : TraceLayer::new_for_http()
90 0 : .on_request(|request: &http::Request<_>, _span: &Span| {
91 0 : let request_id = request
92 0 : .headers()
93 0 : .get(X_REQUEST_ID)
94 0 : .unwrap()
95 0 : .to_str()
96 0 : .unwrap();
97 0 :
98 0 : match request.uri().path() {
99 0 : "/metrics" => {
100 0 : debug!(%request_id, "{} {}", request.method(), request.uri())
101 : }
102 0 : _ => info!(%request_id, "{} {}", request.method(), request.uri()),
103 : };
104 0 : })
105 0 : .on_response(
106 0 : |response: &http::Response<_>, latency: Duration, _span: &Span| {
107 0 : let request_id = response
108 0 : .headers()
109 0 : .get(X_REQUEST_ID)
110 0 : .unwrap()
111 0 : .to_str()
112 0 : .unwrap();
113 0 :
114 0 : info!(
115 : %request_id,
116 0 : code = response.status().as_u16(),
117 0 : latency = latency.as_millis()
118 : )
119 0 : },
120 0 : ),
121 0 : )
122 0 : .layer(PropagateRequestIdLayer::x_request_id()),
123 0 : )
124 0 : .layer(tower_otel::trace::HttpLayer::server(tracing::Level::INFO))
125 0 : }
126 : }
127 :
128 : impl Server {
129 0 : async fn handle_404() -> impl IntoResponse {
130 0 : StatusCode::NOT_FOUND
131 0 : }
132 :
133 0 : async fn handle_405() -> impl IntoResponse {
134 0 : StatusCode::METHOD_NOT_ALLOWED
135 0 : }
136 :
137 0 : async fn listener(&self) -> Result<TcpListener> {
138 0 : let addr = SocketAddr::new(self.ip(), self.port());
139 0 : let listener = TcpListener::bind(&addr).await?;
140 :
141 0 : Ok(listener)
142 0 : }
143 :
144 0 : fn ip(&self) -> IpAddr {
145 0 : match self {
146 : // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners
147 : // allow binding to localhost
148 0 : Server::Internal(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
149 0 : Server::External(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
150 : }
151 0 : }
152 :
153 0 : fn port(self) -> u16 {
154 0 : match self {
155 0 : Server::Internal(port) => port,
156 0 : Server::External(port) => port,
157 : }
158 0 : }
159 :
160 0 : async fn serve(self, compute: Arc<ComputeNode>) {
161 0 : let listener = self.listener().await.unwrap_or_else(|e| {
162 0 : // If we can't bind, the compute cannot operate correctly
163 0 : panic!(
164 0 : "failed to bind the compute_ctl {} HTTP server to {}: {}",
165 0 : self,
166 0 : SocketAddr::new(self.ip(), self.port()),
167 0 : e
168 0 : );
169 0 : });
170 0 :
171 0 : if tracing::enabled!(tracing::Level::INFO) {
172 0 : let local_addr = match listener.local_addr() {
173 0 : Ok(local_addr) => local_addr,
174 0 : Err(_) => SocketAddr::new(self.ip(), self.port()),
175 : };
176 :
177 0 : info!(
178 0 : "compute_ctl {} HTTP server listening at {}",
179 : self, local_addr
180 : );
181 0 : }
182 :
183 0 : let router = Router::from(self).with_state(compute);
184 :
185 0 : if let Err(e) = axum::serve(listener, router).await {
186 0 : error!("compute_ctl {} HTTP server error: {}", self, e);
187 0 : }
188 0 : }
189 :
190 0 : pub fn launch(self, compute: &Arc<ComputeNode>) {
191 0 : let state = Arc::clone(compute);
192 0 :
193 0 : info!("Launching the {} server", self);
194 :
195 0 : tokio::spawn(self.serve(state));
196 0 : }
197 : }
198 :
199 : /// This middleware function allows compute_ctl to generate its own request ID
200 : /// if one isn't supplied. The control plane will always send one as a UUID. The
201 : /// neon Postgres extension on the other hand does not send one.
202 0 : async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response {
203 0 : let headers = request.headers_mut();
204 0 : if headers.get(X_REQUEST_ID).is_none() {
205 0 : headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap());
206 0 : }
207 :
208 0 : next.run(request).await
209 0 : }
|