Line data Source code
1 : use std::fmt::Display;
2 : use std::net::{IpAddr, Ipv6Addr, SocketAddr};
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use anyhow::Result;
7 : use axum::Router;
8 : use axum::middleware::{self};
9 : use axum::response::IntoResponse;
10 : use axum::routing::{get, post};
11 : use compute_api::responses::ComputeCtlConfig;
12 : use http::StatusCode;
13 : use tokio::net::TcpListener;
14 : use tower::ServiceBuilder;
15 : use tower_http::{
16 : auth::AsyncRequireAuthorizationLayer, request_id::PropagateRequestIdLayer, trace::TraceLayer,
17 : };
18 : use tracing::{Span, error, info};
19 :
20 : use super::middleware::request_id::maybe_add_request_id_header;
21 : use super::{
22 : headers::X_REQUEST_ID,
23 : middleware::authorize::Authorize,
24 : routes::{
25 : check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions,
26 : grants, hadron_liveness_probe, insights, lfc, metrics, metrics_json, promote,
27 : refresh_configuration, status, terminate,
28 : },
29 : };
30 : use crate::compute::ComputeNode;
31 :
32 : /// `compute_ctl` has two servers: internal and external. The internal server
33 : /// binds to the loopback interface and handles communication from clients on
34 : /// the compute. The external server is what receives communication from the
35 : /// control plane, the metrics scraper, etc. We make the distinction because
36 : /// certain routes in `compute_ctl` only need to be exposed to local processes
37 : /// like Postgres via the neon extension and local_proxy.
38 : #[derive(Clone, Debug)]
39 : pub enum Server {
40 : Internal {
41 : port: u16,
42 : },
43 : External {
44 : port: u16,
45 : config: ComputeCtlConfig,
46 : compute_id: String,
47 : instance_id: Option<String>,
48 : },
49 : }
50 :
51 : impl Display for Server {
52 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 0 : match self {
54 0 : Server::Internal { .. } => f.write_str("internal"),
55 0 : Server::External { .. } => f.write_str("external"),
56 : }
57 0 : }
58 : }
59 :
60 : impl From<&Server> for Router<Arc<ComputeNode>> {
61 0 : fn from(server: &Server) -> Self {
62 0 : let mut router = Router::<Arc<ComputeNode>>::new();
63 :
64 0 : router = match server {
65 : Server::Internal { .. } => {
66 0 : router = router
67 0 : .route(
68 0 : "/extension_server/{*filename}",
69 0 : post(extension_server::download_extension),
70 : )
71 0 : .route("/extensions", post(extensions::install_extension))
72 0 : .route("/grants", post(grants::add_grant))
73 : // Hadron: Compute-initiated configuration refresh
74 0 : .route(
75 0 : "/refresh_configuration",
76 0 : post(refresh_configuration::refresh_configuration),
77 : );
78 :
79 : // Add in any testing support
80 0 : if cfg!(feature = "testing") {
81 : use super::routes::failpoints;
82 :
83 0 : router = router.route("/failpoints", post(failpoints::configure_failpoints));
84 0 : }
85 :
86 0 : router
87 : }
88 : Server::External {
89 0 : config,
90 0 : compute_id,
91 0 : instance_id,
92 : ..
93 : } => {
94 0 : let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
95 0 : .route("/metrics", get(metrics::get_metrics))
96 0 : .route(
97 0 : "/autoscaling_metrics",
98 0 : get(metrics::get_autoscaling_metrics),
99 : );
100 :
101 0 : let authenticated_router = Router::<Arc<ComputeNode>>::new()
102 0 : .route(
103 0 : "/lfc/prewarm",
104 0 : get(lfc::prewarm_state)
105 0 : .post(lfc::prewarm)
106 0 : .delete(lfc::cancel_prewarm),
107 : )
108 0 : .route("/lfc/offload", get(lfc::offload_state).post(lfc::offload))
109 0 : .route("/promote", post(promote::promote))
110 0 : .route("/check_writability", post(check_writability::is_writable))
111 0 : .route("/configure", post(configure::configure))
112 0 : .route("/database_schema", get(database_schema::get_schema_dump))
113 0 : .route("/dbs_and_roles", get(dbs_and_roles::get_catalog_objects))
114 0 : .route("/insights", get(insights::get_insights))
115 0 : .route("/metrics.json", get(metrics_json::get_metrics))
116 0 : .route("/status", get(status::get_status))
117 0 : .route("/terminate", post(terminate::terminate))
118 0 : .route(
119 0 : "/hadron_liveness_probe",
120 0 : get(hadron_liveness_probe::hadron_liveness_probe),
121 : )
122 0 : .layer(AsyncRequireAuthorizationLayer::new(Authorize::new(
123 0 : compute_id.clone(),
124 0 : instance_id.clone(),
125 0 : config.jwks.clone(),
126 : )));
127 :
128 0 : router
129 0 : .merge(unauthenticated_router)
130 0 : .merge(authenticated_router)
131 : }
132 : };
133 :
134 0 : router
135 0 : .fallback(Server::handle_404)
136 0 : .method_not_allowed_fallback(Server::handle_405)
137 0 : .layer(
138 0 : ServiceBuilder::new()
139 0 : .layer(tower_otel::trace::HttpLayer::server(tracing::Level::INFO))
140 : // Add this middleware since we assume the request ID exists
141 0 : .layer(middleware::from_fn(maybe_add_request_id_header))
142 0 : .layer(
143 0 : TraceLayer::new_for_http()
144 0 : .on_request(|request: &http::Request<_>, _span: &Span| {
145 0 : let request_id = request
146 0 : .headers()
147 0 : .get(X_REQUEST_ID)
148 0 : .unwrap()
149 0 : .to_str()
150 0 : .unwrap();
151 :
152 0 : info!(%request_id, "{} {}", request.method(), request.uri());
153 0 : })
154 0 : .on_response(
155 0 : |response: &http::Response<_>, latency: Duration, _span: &Span| {
156 0 : let request_id = response
157 0 : .headers()
158 0 : .get(X_REQUEST_ID)
159 0 : .unwrap()
160 0 : .to_str()
161 0 : .unwrap();
162 :
163 0 : info!(
164 : %request_id,
165 0 : code = response.status().as_u16(),
166 0 : latency = latency.as_millis()
167 : );
168 0 : },
169 : ),
170 : )
171 0 : .layer(PropagateRequestIdLayer::x_request_id()),
172 : )
173 0 : }
174 : }
175 :
176 : impl Server {
177 0 : async fn handle_404() -> impl IntoResponse {
178 0 : StatusCode::NOT_FOUND
179 0 : }
180 :
181 0 : async fn handle_405() -> impl IntoResponse {
182 0 : StatusCode::METHOD_NOT_ALLOWED
183 0 : }
184 :
185 0 : async fn listener(&self) -> Result<TcpListener> {
186 0 : let addr = SocketAddr::new(self.ip(), self.port());
187 0 : let listener = TcpListener::bind(&addr).await?;
188 :
189 0 : Ok(listener)
190 0 : }
191 :
192 0 : fn ip(&self) -> IpAddr {
193 0 : match self {
194 : // TODO: Change this to Ipv6Addr::LOCALHOST when the GitHub runners
195 : // allow binding to localhost
196 0 : Server::Internal { .. } => IpAddr::from(Ipv6Addr::UNSPECIFIED),
197 0 : Server::External { .. } => IpAddr::from(Ipv6Addr::UNSPECIFIED),
198 : }
199 0 : }
200 :
201 0 : fn port(&self) -> u16 {
202 0 : match self {
203 0 : Server::Internal { port, .. } => *port,
204 0 : Server::External { port, .. } => *port,
205 : }
206 0 : }
207 :
208 0 : async fn serve(self, compute: Arc<ComputeNode>) {
209 0 : let listener = self.listener().await.unwrap_or_else(|e| {
210 : // If we can't bind, the compute cannot operate correctly
211 0 : panic!(
212 0 : "failed to bind the compute_ctl {} HTTP server to {}: {}",
213 : self,
214 0 : SocketAddr::new(self.ip(), self.port()),
215 : e
216 : );
217 : });
218 :
219 0 : if tracing::enabled!(tracing::Level::INFO) {
220 0 : let local_addr = match listener.local_addr() {
221 0 : Ok(local_addr) => local_addr,
222 0 : Err(_) => SocketAddr::new(self.ip(), self.port()),
223 : };
224 :
225 0 : info!(
226 0 : "compute_ctl {} HTTP server listening at {}",
227 : self, local_addr
228 : );
229 0 : }
230 :
231 0 : let router = Router::from(&self)
232 0 : .with_state(compute)
233 0 : .into_make_service_with_connect_info::<SocketAddr>();
234 :
235 0 : if let Err(e) = axum::serve(listener, router).await {
236 0 : error!("compute_ctl {} HTTP server error: {}", self, e);
237 0 : }
238 0 : }
239 :
240 0 : pub fn launch(self, compute: &Arc<ComputeNode>) {
241 0 : let state = Arc::clone(compute);
242 :
243 0 : info!("Launching the {} server", self);
244 :
245 0 : tokio::spawn(self.serve(state));
246 0 : }
247 : }
|