Line data Source code
1 : use std::{
2 : fmt::Display,
3 : net::{IpAddr, Ipv6Addr, SocketAddr},
4 : sync::Arc,
5 : time::Duration,
6 : };
7 :
8 : use anyhow::Result;
9 : use axum::{
10 : extract::Request,
11 : middleware::{self, Next},
12 : response::{IntoResponse, Response},
13 : routing::{get, post},
14 : Router,
15 : };
16 : use http::StatusCode;
17 : use tokio::net::TcpListener;
18 : use tower::ServiceBuilder;
19 : use tower_http::{request_id::PropagateRequestIdLayer, trace::TraceLayer};
20 : use tracing::{debug, error, info, Span};
21 : use uuid::Uuid;
22 :
23 : use super::routes::{
24 : check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
25 : grants, insights, metrics, metrics_json, status, terminate,
26 : };
27 : use crate::compute::ComputeNode;
28 :
29 : const X_REQUEST_ID: &str = "x-request-id";
30 :
31 : /// `compute_ctl` has two servers: internal and external. The internal server
32 : /// binds to the loopback interface and handles communication from clients on
33 : /// the compute. The external server is what receives communication from the
34 : /// control plane, the metrics scraper, etc. We make the distinction because
35 : /// certain routes in `compute_ctl` only need to be exposed to local processes
36 : /// like Postgres via the neon extension and local_proxy.
37 : #[derive(Clone, Copy, Debug)]
38 : pub enum Server {
39 : Internal(u16),
40 : External(u16),
41 : }
42 :
43 : impl Display for Server {
44 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 0 : match self {
46 0 : Server::Internal(_) => f.write_str("internal"),
47 0 : Server::External(_) => f.write_str("external"),
48 : }
49 0 : }
50 : }
51 :
52 : impl From<Server> for Router<Arc<ComputeNode>> {
53 0 : fn from(server: Server) -> Self {
54 0 : let mut router = Router::<Arc<ComputeNode>>::new();
55 :
56 0 : router = match server {
57 : Server::Internal(_) => {
58 0 : router = router
59 0 : .route(
60 0 : "/extension_server/{*filename}",
61 0 : post(extension_server::download_extension),
62 0 : )
63 0 : .route("/extensions", post(extensions::install_extension))
64 0 : .route("/grants", post(grants::add_grant));
65 0 :
66 0 : // Add in any testing support
67 0 : if cfg!(feature = "testing") {
68 0 : use super::routes::failpoints;
69 0 :
70 0 : router = router.route("/failpoints", post(failpoints::configure_failpoints));
71 0 : }
72 :
73 0 : router
74 : }
75 0 : Server::External(_) => router
76 0 : .route("/check_writability", post(check_writability::is_writable))
77 0 : .route("/configure", post(configure::configure))
78 0 : .route("/database_schema", get(database_schema::get_schema_dump))
79 0 : .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
80 0 : .route("/insights", get(insights::get_insights))
81 0 : .route("/metrics", get(metrics::get_metrics))
82 0 : .route("/metrics.json", get(metrics_json::get_metrics))
83 0 : .route("/status", get(status::get_status))
84 0 : .route("/terminate", post(terminate::terminate)),
85 : };
86 :
87 0 : router.fallback(Server::handle_404).method_not_allowed_fallback(Server::handle_405).layer(
88 0 : ServiceBuilder::new()
89 0 : // Add this middleware since we assume the request ID exists
90 0 : .layer(middleware::from_fn(maybe_add_request_id_header))
91 0 : .layer(
92 0 : TraceLayer::new_for_http()
93 0 : .on_request(|request: &http::Request<_>, _span: &Span| {
94 0 : let request_id = request
95 0 : .headers()
96 0 : .get(X_REQUEST_ID)
97 0 : .unwrap()
98 0 : .to_str()
99 0 : .unwrap();
100 0 :
101 0 : match request.uri().path() {
102 0 : "/metrics" => {
103 0 : debug!(%request_id, "{} {}", request.method(), request.uri())
104 : }
105 0 : _ => info!(%request_id, "{} {}", request.method(), request.uri()),
106 : };
107 0 : })
108 0 : .on_response(
109 0 : |response: &http::Response<_>, latency: Duration, _span: &Span| {
110 0 : let request_id = response
111 0 : .headers()
112 0 : .get(X_REQUEST_ID)
113 0 : .unwrap()
114 0 : .to_str()
115 0 : .unwrap();
116 0 :
117 0 : info!(
118 : %request_id,
119 0 : code = response.status().as_u16(),
120 0 : latency = latency.as_millis()
121 : )
122 0 : },
123 0 : ),
124 0 : )
125 0 : .layer(PropagateRequestIdLayer::x_request_id()),
126 0 : )
127 0 : }
128 : }
129 :
130 : impl Server {
131 0 : async fn handle_404() -> impl IntoResponse {
132 0 : StatusCode::NOT_FOUND
133 0 : }
134 :
135 0 : async fn handle_405() -> impl IntoResponse {
136 0 : StatusCode::METHOD_NOT_ALLOWED
137 0 : }
138 :
139 0 : async fn listener(&self) -> Result<TcpListener> {
140 0 : let addr = SocketAddr::new(self.ip(), self.port());
141 0 : let listener = TcpListener::bind(&addr).await?;
142 :
143 0 : Ok(listener)
144 0 : }
145 :
146 0 : fn ip(&self) -> IpAddr {
147 0 : match self {
148 : // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners
149 : // allow binding to localhost
150 0 : Server::Internal(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
151 0 : Server::External(_) => IpAddr::from(Ipv6Addr::UNSPECIFIED),
152 : }
153 0 : }
154 :
155 0 : fn port(self) -> u16 {
156 0 : match self {
157 0 : Server::Internal(port) => port,
158 0 : Server::External(port) => port,
159 : }
160 0 : }
161 :
162 0 : async fn serve(self, compute: Arc<ComputeNode>) {
163 0 : let listener = self.listener().await.unwrap_or_else(|e| {
164 0 : // If we can't bind, the compute cannot operate correctly
165 0 : panic!(
166 0 : "failed to bind the compute_ctl {} HTTP server to {}: {}",
167 0 : self,
168 0 : SocketAddr::new(self.ip(), self.port()),
169 0 : e
170 0 : );
171 0 : });
172 0 :
173 0 : if tracing::enabled!(tracing::Level::INFO) {
174 0 : let local_addr = match listener.local_addr() {
175 0 : Ok(local_addr) => local_addr,
176 0 : Err(_) => SocketAddr::new(self.ip(), self.port()),
177 : };
178 :
179 0 : info!(
180 0 : "compute_ctl {} HTTP server listening at {}",
181 : self, local_addr
182 : );
183 0 : }
184 :
185 0 : let router = Router::from(self).with_state(compute);
186 :
187 0 : if let Err(e) = axum::serve(listener, router).await {
188 0 : error!("compute_ctl {} HTTP server error: {}", self, e);
189 0 : }
190 0 : }
191 :
192 0 : pub fn launch(self, compute: &Arc<ComputeNode>) {
193 0 : let state = Arc::clone(compute);
194 0 :
195 0 : info!("Launching the {} server", self);
196 :
197 0 : tokio::spawn(self.serve(state));
198 0 : }
199 : }
200 :
201 : /// This middleware function allows compute_ctl to generate its own request ID
202 : /// if one isn't supplied. The control plane will always send one as a UUID. The
203 : /// neon Postgres extension on the other hand does not send one.
204 0 : async fn maybe_add_request_id_header(mut request: Request, next: Next) -> Response {
205 0 : let headers = request.headers_mut();
206 0 : if headers.get(X_REQUEST_ID).is_none() {
207 0 : headers.append(X_REQUEST_ID, Uuid::new_v4().to_string().parse().unwrap());
208 0 : }
209 :
210 0 : next.run(request).await
211 0 : }
|