TLA Line data Source code
1 : /// The attachment service mimics the aspects of the control plane API
2 : /// that are required for a pageserver to operate.
3 : ///
4 : /// This enables running & testing pageservers without a full-blown
5 : /// deployment of the Neon cloud platform.
6 : ///
7 : use anyhow::anyhow;
8 : use clap::Parser;
9 : use hex::FromHex;
10 : use hyper::StatusCode;
11 : use hyper::{Body, Request, Response};
12 : use serde::{Deserialize, Serialize};
13 : use std::path::{Path, PathBuf};
14 : use std::{collections::HashMap, sync::Arc};
15 : use utils::logging::{self, LogFormat};
16 :
17 : use utils::{
18 : http::{
19 : endpoint::{self},
20 : error::ApiError,
21 : json::{json_request, json_response},
22 : RequestExt, RouterBuilder,
23 : },
24 : id::{NodeId, TenantId},
25 : tcp_listener,
26 : };
27 :
28 : use pageserver_api::control_api::{
29 : ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest, ValidateResponse,
30 : ValidateResponseTenant,
31 : };
32 :
33 : use control_plane::attachment_service::{AttachHookRequest, AttachHookResponse};
34 :
35 UBC 0 : #[derive(Parser)]
36 : #[command(author, version, about, long_about = None)]
37 : #[command(arg_required_else_help(true))]
38 : struct Cli {
39 : /// Host and port to listen on, like `127.0.0.1:1234`
40 : #[arg(short, long)]
41 0 : listen: std::net::SocketAddr,
42 :
43 : /// Path to the .json file to store state (will be created if it doesn't exist)
44 : #[arg(short, long)]
45 0 : path: PathBuf,
46 : }
47 :
48 : // The persistent state of each Tenant
49 0 : #[derive(Serialize, Deserialize, Clone)]
50 : struct TenantState {
51 : // Currently attached pageserver
52 : pageserver: Option<NodeId>,
53 :
54 : // Latest generation number: next time we attach, increment this
55 : // and use the incremented number when attaching
56 : generation: u32,
57 : }
58 :
59 0 : fn to_hex_map<S, V>(input: &HashMap<TenantId, V>, serializer: S) -> Result<S::Ok, S::Error>
60 0 : where
61 0 : S: serde::Serializer,
62 0 : V: Clone + Serialize,
63 0 : {
64 0 : let transformed = input.iter().map(|(k, v)| (hex::encode(k), v.clone()));
65 0 :
66 0 : transformed
67 0 : .collect::<HashMap<String, V>>()
68 0 : .serialize(serializer)
69 0 : }
70 :
71 0 : fn from_hex_map<'de, D, V>(deserializer: D) -> Result<HashMap<TenantId, V>, D::Error>
72 0 : where
73 0 : D: serde::de::Deserializer<'de>,
74 0 : V: Deserialize<'de>,
75 0 : {
76 0 : let hex_map = HashMap::<String, V>::deserialize(deserializer)?;
77 0 : hex_map
78 0 : .into_iter()
79 0 : .map(|(k, v)| {
80 0 : TenantId::from_hex(k)
81 0 : .map(|k| (k, v))
82 0 : .map_err(serde::de::Error::custom)
83 0 : })
84 0 : .collect()
85 0 : }
86 :
87 : // Top level state available to all HTTP handlers
88 0 : #[derive(Serialize, Deserialize)]
89 : struct PersistentState {
90 : #[serde(serialize_with = "to_hex_map", deserialize_with = "from_hex_map")]
91 : tenants: HashMap<TenantId, TenantState>,
92 :
93 : #[serde(skip)]
94 : path: PathBuf,
95 : }
96 :
97 : impl PersistentState {
98 0 : async fn save(&self) -> anyhow::Result<()> {
99 0 : let bytes = serde_json::to_vec(self)?;
100 0 : tokio::fs::write(&self.path, &bytes).await?;
101 :
102 0 : Ok(())
103 0 : }
104 :
105 0 : async fn load(path: &Path) -> anyhow::Result<Self> {
106 0 : let bytes = tokio::fs::read(path).await?;
107 0 : let mut decoded = serde_json::from_slice::<Self>(&bytes)?;
108 0 : decoded.path = path.to_owned();
109 0 : Ok(decoded)
110 0 : }
111 :
112 0 : async fn load_or_new(path: &Path) -> Self {
113 0 : match Self::load(path).await {
114 0 : Ok(s) => {
115 0 : tracing::info!("Loaded state file at {}", path.display());
116 0 : s
117 : }
118 0 : Err(e)
119 0 : if e.downcast_ref::<std::io::Error>()
120 0 : .map(|e| e.kind() == std::io::ErrorKind::NotFound)
121 0 : .unwrap_or(false) =>
122 : {
123 0 : tracing::info!("Will create state file at {}", path.display());
124 0 : Self {
125 0 : tenants: HashMap::new(),
126 0 : path: path.to_owned(),
127 0 : }
128 : }
129 0 : Err(e) => {
130 0 : panic!("Failed to load state from '{}': {e:#} (maybe your .neon/ dir was written by an older version?)", path.display())
131 : }
132 : }
133 0 : }
134 : }
135 :
136 : /// State available to HTTP request handlers
137 0 : #[derive(Clone)]
138 : struct State {
139 : inner: Arc<tokio::sync::RwLock<PersistentState>>,
140 : }
141 :
142 : impl State {
143 0 : fn new(persistent_state: PersistentState) -> State {
144 0 : Self {
145 0 : inner: Arc::new(tokio::sync::RwLock::new(persistent_state)),
146 0 : }
147 0 : }
148 : }
149 :
150 : #[inline(always)]
151 0 : fn get_state(request: &Request<Body>) -> &State {
152 0 : request
153 0 : .data::<Arc<State>>()
154 0 : .expect("unknown state type")
155 0 : .as_ref()
156 0 : }
157 :
158 : /// Pageserver calls into this on startup, to learn which tenants it should attach
159 0 : async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
160 0 : let reattach_req = json_request::<ReAttachRequest>(&mut req).await?;
161 :
162 0 : let state = get_state(&req).inner.clone();
163 0 : let mut locked = state.write().await;
164 :
165 0 : let mut response = ReAttachResponse {
166 0 : tenants: Vec::new(),
167 0 : };
168 0 : for (t, state) in &mut locked.tenants {
169 0 : if state.pageserver == Some(reattach_req.node_id) {
170 0 : state.generation += 1;
171 0 : response.tenants.push(ReAttachResponseTenant {
172 0 : id: *t,
173 0 : generation: state.generation,
174 0 : });
175 0 : }
176 : }
177 :
178 0 : locked.save().await.map_err(ApiError::InternalServerError)?;
179 :
180 0 : json_response(StatusCode::OK, response)
181 0 : }
182 :
183 : /// Pageserver calls into this before doing deletions, to confirm that it still
184 : /// holds the latest generation for the tenants with deletions enqueued
185 0 : async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
186 0 : let validate_req = json_request::<ValidateRequest>(&mut req).await?;
187 :
188 0 : let locked = get_state(&req).inner.read().await;
189 :
190 0 : let mut response = ValidateResponse {
191 0 : tenants: Vec::new(),
192 0 : };
193 :
194 0 : for req_tenant in validate_req.tenants {
195 0 : if let Some(tenant_state) = locked.tenants.get(&req_tenant.id) {
196 0 : let valid = tenant_state.generation == req_tenant.gen;
197 0 : response.tenants.push(ValidateResponseTenant {
198 0 : id: req_tenant.id,
199 0 : valid,
200 0 : });
201 0 : }
202 : }
203 :
204 0 : json_response(StatusCode::OK, response)
205 0 : }
206 : /// Call into this before attaching a tenant to a pageserver, to acquire a generation number
207 : /// (in the real control plane this is unnecessary, because the same program is managing
208 : /// generation numbers and doing attachments).
209 0 : async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
210 0 : let attach_req = json_request::<AttachHookRequest>(&mut req).await?;
211 :
212 0 : let state = get_state(&req).inner.clone();
213 0 : let mut locked = state.write().await;
214 :
215 0 : let tenant_state = locked
216 0 : .tenants
217 0 : .entry(attach_req.tenant_id)
218 0 : .or_insert_with(|| TenantState {
219 0 : pageserver: attach_req.pageserver_id,
220 0 : generation: 0,
221 0 : });
222 0 :
223 0 : if attach_req.pageserver_id.is_some() {
224 0 : tenant_state.generation += 1;
225 0 : }
226 0 : tenant_state.pageserver = attach_req.pageserver_id;
227 0 : let generation = tenant_state.generation;
228 0 :
229 0 : locked.save().await.map_err(ApiError::InternalServerError)?;
230 :
231 0 : json_response(
232 0 : StatusCode::OK,
233 0 : AttachHookResponse {
234 0 : gen: attach_req.pageserver_id.map(|_| generation),
235 0 : },
236 0 : )
237 0 : }
238 :
239 0 : fn make_router(persistent_state: PersistentState) -> RouterBuilder<hyper::Body, ApiError> {
240 0 : endpoint::make_router()
241 0 : .data(Arc::new(State::new(persistent_state)))
242 0 : .post("/re-attach", handle_re_attach)
243 0 : .post("/validate", handle_validate)
244 0 : .post("/attach_hook", handle_attach_hook)
245 0 : }
246 :
247 : #[tokio::main]
248 0 : async fn main() -> anyhow::Result<()> {
249 0 : logging::init(
250 0 : LogFormat::Plain,
251 0 : logging::TracingErrorLayerEnablement::Disabled,
252 0 : )?;
253 :
254 0 : let args = Cli::parse();
255 0 : tracing::info!(
256 0 : "Starting, state at {}, listening on {}",
257 0 : args.path.to_string_lossy(),
258 0 : args.listen
259 0 : );
260 :
261 0 : let persistent_state = PersistentState::load_or_new(&args.path).await;
262 :
263 0 : let http_listener = tcp_listener::bind(args.listen)?;
264 0 : let router = make_router(persistent_state)
265 0 : .build()
266 0 : .map_err(|err| anyhow!(err))?;
267 0 : let service = utils::http::RouterService::new(router).unwrap();
268 0 : let server = hyper::Server::from_tcp(http_listener)?.serve(service);
269 0 :
270 0 : tracing::info!("Serving on {0}", args.listen);
271 0 : server.await?;
272 :
273 0 : Ok(())
274 : }
|