Line data Source code
1 : //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
2 : //! protocol commands.
3 :
4 : use anyhow::Context;
5 : use pageserver_api::models::ShardParameters;
6 : use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
7 : use safekeeper_api::models::ConnectionId;
8 : use safekeeper_api::Term;
9 : use std::future::Future;
10 : use std::str::{self, FromStr};
11 : use std::sync::Arc;
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tracing::{debug, info, info_span, Instrument};
14 : use utils::postgres_client::PostgresClientProtocol;
15 : use utils::shard::{ShardCount, ShardNumber};
16 :
17 : use crate::auth::check_permission;
18 : use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
19 :
20 : use crate::metrics::{TrafficMetrics, PG_QUERIES_GAUGE};
21 : use crate::timeline::TimelineError;
22 : use crate::{GlobalTimelines, SafeKeeperConf};
23 : use postgres_backend::PostgresBackend;
24 : use postgres_backend::QueryError;
25 : use postgres_ffi::PG_TLI;
26 : use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
27 : use regex::Regex;
28 : use utils::auth::{Claims, JwtAuth, Scope};
29 : use utils::{
30 : id::{TenantId, TenantTimelineId, TimelineId},
31 : lsn::Lsn,
32 : };
33 :
34 : /// Safekeeper handler of postgres commands
35 : pub struct SafekeeperPostgresHandler {
36 : pub conf: Arc<SafeKeeperConf>,
37 : /// assigned application name
38 : pub appname: Option<String>,
39 : pub tenant_id: Option<TenantId>,
40 : pub timeline_id: Option<TimelineId>,
41 : pub ttid: TenantTimelineId,
42 : pub shard: Option<ShardIdentity>,
43 : pub protocol: Option<PostgresClientProtocol>,
44 : /// Unique connection id is logged in spans for observability.
45 : pub conn_id: ConnectionId,
46 : pub global_timelines: Arc<GlobalTimelines>,
47 : /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
48 : auth: Option<(Scope, Arc<JwtAuth>)>,
49 : claims: Option<Claims>,
50 : io_metrics: Option<TrafficMetrics>,
51 : }
52 :
53 : /// Parsed Postgres command.
54 : enum SafekeeperPostgresCommand {
55 : StartWalPush,
56 : StartReplication { start_lsn: Lsn, term: Option<Term> },
57 : IdentifySystem,
58 : TimelineStatus,
59 : JSONCtrl { cmd: AppendLogicalMessage },
60 : }
61 :
62 0 : fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
63 0 : if cmd.starts_with("START_WAL_PUSH") {
64 0 : Ok(SafekeeperPostgresCommand::StartWalPush)
65 0 : } else if cmd.starts_with("START_REPLICATION") {
66 0 : let re = Regex::new(
67 0 : // We follow postgres START_REPLICATION LOGICAL options to pass term.
68 0 : r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
69 0 : )
70 0 : .unwrap();
71 0 : let caps = re
72 0 : .captures(cmd)
73 0 : .context(format!("failed to parse START_REPLICATION command {}", cmd))?;
74 0 : let start_lsn =
75 0 : Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?;
76 0 : let term = if let Some(m) = caps.get(2) {
77 0 : Some(m.as_str().parse::<u64>().context("invalid term")?)
78 : } else {
79 0 : None
80 : };
81 0 : Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
82 0 : } else if cmd.starts_with("IDENTIFY_SYSTEM") {
83 0 : Ok(SafekeeperPostgresCommand::IdentifySystem)
84 0 : } else if cmd.starts_with("TIMELINE_STATUS") {
85 0 : Ok(SafekeeperPostgresCommand::TimelineStatus)
86 0 : } else if cmd.starts_with("JSON_CTRL") {
87 0 : let cmd = cmd.strip_prefix("JSON_CTRL").context("invalid prefix")?;
88 : Ok(SafekeeperPostgresCommand::JSONCtrl {
89 0 : cmd: serde_json::from_str(cmd)?,
90 : })
91 : } else {
92 0 : anyhow::bail!("unsupported command {cmd}");
93 : }
94 0 : }
95 :
96 0 : fn cmd_to_string(cmd: &SafekeeperPostgresCommand) -> &str {
97 0 : match cmd {
98 0 : SafekeeperPostgresCommand::StartWalPush => "START_WAL_PUSH",
99 0 : SafekeeperPostgresCommand::StartReplication { .. } => "START_REPLICATION",
100 0 : SafekeeperPostgresCommand::TimelineStatus => "TIMELINE_STATUS",
101 0 : SafekeeperPostgresCommand::IdentifySystem => "IDENTIFY_SYSTEM",
102 0 : SafekeeperPostgresCommand::JSONCtrl { .. } => "JSON_CTRL",
103 : }
104 0 : }
105 :
106 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
107 : for SafekeeperPostgresHandler
108 : {
109 : // tenant_id and timeline_id are passed in connection string params
110 0 : fn startup(
111 0 : &mut self,
112 0 : _pgb: &mut PostgresBackend<IO>,
113 0 : sm: &FeStartupPacket,
114 0 : ) -> Result<(), QueryError> {
115 0 : if let FeStartupPacket::StartupMessage { params, .. } = sm {
116 0 : if let Some(options) = params.options_raw() {
117 0 : let mut shard_count: Option<u8> = None;
118 0 : let mut shard_number: Option<u8> = None;
119 0 : let mut shard_stripe_size: Option<u32> = None;
120 :
121 0 : for opt in options {
122 : // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
123 : // remove these after the PR gets deployed:
124 : // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
125 0 : match opt.split_once('=') {
126 0 : Some(("protocol", value)) => {
127 0 : self.protocol =
128 0 : Some(serde_json::from_str(value).with_context(|| {
129 0 : format!("Failed to parse {value} as protocol")
130 0 : })?);
131 : }
132 0 : Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
133 0 : self.tenant_id = Some(value.parse().with_context(|| {
134 0 : format!("Failed to parse {value} as tenant id")
135 0 : })?);
136 : }
137 0 : Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
138 0 : self.timeline_id = Some(value.parse().with_context(|| {
139 0 : format!("Failed to parse {value} as timeline id")
140 0 : })?);
141 : }
142 0 : Some(("availability_zone", client_az)) => {
143 0 : if let Some(metrics) = self.io_metrics.as_ref() {
144 0 : metrics.set_client_az(client_az)
145 0 : }
146 : }
147 0 : Some(("shard_count", value)) => {
148 0 : shard_count = Some(value.parse::<u8>().with_context(|| {
149 0 : format!("Failed to parse {value} as shard count")
150 0 : })?);
151 : }
152 0 : Some(("shard_number", value)) => {
153 0 : shard_number = Some(value.parse::<u8>().with_context(|| {
154 0 : format!("Failed to parse {value} as shard number")
155 0 : })?);
156 : }
157 0 : Some(("shard_stripe_size", value)) => {
158 0 : shard_stripe_size = Some(value.parse::<u32>().with_context(|| {
159 0 : format!("Failed to parse {value} as shard stripe size")
160 0 : })?);
161 : }
162 0 : _ => continue,
163 : }
164 : }
165 :
166 0 : match self.protocol() {
167 : PostgresClientProtocol::Vanilla => {
168 0 : if shard_count.is_some()
169 0 : || shard_number.is_some()
170 0 : || shard_stripe_size.is_some()
171 : {
172 0 : return Err(QueryError::Other(anyhow::anyhow!(
173 0 : "Shard params specified for vanilla protocol"
174 0 : )));
175 0 : }
176 : }
177 : PostgresClientProtocol::Interpreted { .. } => {
178 0 : match (shard_count, shard_number, shard_stripe_size) {
179 0 : (Some(count), Some(number), Some(stripe_size)) => {
180 0 : let params = ShardParameters {
181 0 : count: ShardCount(count),
182 0 : stripe_size: ShardStripeSize(stripe_size),
183 0 : };
184 0 : self.shard =
185 0 : Some(ShardIdentity::from_params(ShardNumber(number), ¶ms));
186 0 : }
187 : _ => {
188 0 : return Err(QueryError::Other(anyhow::anyhow!(
189 0 : "Shard params were not specified"
190 0 : )));
191 : }
192 : }
193 : }
194 : }
195 0 : }
196 :
197 0 : if let Some(app_name) = params.get("application_name") {
198 0 : self.appname = Some(app_name.to_owned());
199 0 : if let Some(metrics) = self.io_metrics.as_ref() {
200 0 : metrics.set_app_name(app_name)
201 0 : }
202 0 : }
203 :
204 0 : let ttid = TenantTimelineId::new(
205 0 : self.tenant_id.unwrap_or(TenantId::from([0u8; 16])),
206 0 : self.timeline_id.unwrap_or(TimelineId::from([0u8; 16])),
207 0 : );
208 0 : tracing::Span::current()
209 0 : .record("ttid", tracing::field::display(ttid))
210 0 : .record(
211 0 : "application_name",
212 0 : tracing::field::debug(self.appname.clone()),
213 0 : );
214 :
215 0 : if let Some(shard) = self.shard.as_ref() {
216 0 : if let Some(slug) = shard.shard_slug().strip_prefix("-") {
217 0 : tracing::Span::current().record("shard", tracing::field::display(slug));
218 0 : }
219 0 : }
220 :
221 0 : Ok(())
222 : } else {
223 0 : Err(QueryError::Other(anyhow::anyhow!(
224 0 : "Safekeeper received unexpected initial message: {sm:?}"
225 0 : )))
226 : }
227 0 : }
228 :
229 0 : fn check_auth_jwt(
230 0 : &mut self,
231 0 : _pgb: &mut PostgresBackend<IO>,
232 0 : jwt_response: &[u8],
233 0 : ) -> Result<(), QueryError> {
234 0 : // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
235 0 : // which requires auth to be present
236 0 : let (allowed_auth_scope, auth) = self
237 0 : .auth
238 0 : .as_ref()
239 0 : .expect("auth_type is configured but .auth of handler is missing");
240 0 : let data = auth
241 0 : .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)
242 0 : .map_err(|e| QueryError::Unauthorized(e.0))?;
243 :
244 : // The handler might be configured to allow only tenant scope tokens.
245 0 : if matches!(allowed_auth_scope, Scope::Tenant)
246 0 : && !matches!(data.claims.scope, Scope::Tenant)
247 : {
248 0 : return Err(QueryError::Unauthorized(
249 0 : "passed JWT token is for full access, but only tenant scope is allowed".into(),
250 0 : ));
251 0 : }
252 :
253 0 : if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
254 0 : return Err(QueryError::Unauthorized(
255 0 : "jwt token scope is Tenant, but tenant id is missing".into(),
256 0 : ));
257 0 : }
258 0 :
259 0 : debug!(
260 0 : "jwt scope check succeeded for scope: {:#?} by tenant id: {:?}",
261 : data.claims.scope, data.claims.tenant_id,
262 : );
263 :
264 0 : self.claims = Some(data.claims);
265 0 : Ok(())
266 0 : }
267 :
268 0 : fn process_query(
269 0 : &mut self,
270 0 : pgb: &mut PostgresBackend<IO>,
271 0 : query_string: &str,
272 0 : ) -> impl Future<Output = Result<(), QueryError>> {
273 0 : Box::pin(async move {
274 0 : if query_string
275 0 : .to_ascii_lowercase()
276 0 : .starts_with("set datestyle to ")
277 : {
278 : // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
279 0 : pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
280 0 : return Ok(());
281 0 : }
282 :
283 0 : let cmd = parse_cmd(query_string)?;
284 0 : let cmd_str = cmd_to_string(&cmd);
285 0 :
286 0 : let _guard = PG_QUERIES_GAUGE.with_label_values(&[cmd_str]).guard();
287 0 :
288 0 : info!("got query {:?}", query_string);
289 :
290 0 : let tenant_id = self.tenant_id.context("tenantid is required")?;
291 0 : let timeline_id = self.timeline_id.context("timelineid is required")?;
292 0 : self.check_permission(Some(tenant_id))?;
293 0 : self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
294 0 :
295 0 : match cmd {
296 : SafekeeperPostgresCommand::StartWalPush => {
297 0 : self.handle_start_wal_push(pgb)
298 0 : .instrument(info_span!("WAL receiver"))
299 0 : .await
300 : }
301 0 : SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
302 0 : self.handle_start_replication(pgb, start_lsn, term)
303 0 : .instrument(info_span!("WAL sender"))
304 0 : .await
305 : }
306 0 : SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
307 0 : SafekeeperPostgresCommand::TimelineStatus => self.handle_timeline_status(pgb).await,
308 0 : SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
309 0 : handle_json_ctrl(self, pgb, cmd).await
310 : }
311 : }
312 0 : })
313 0 : }
314 : }
315 :
316 : impl SafekeeperPostgresHandler {
317 0 : pub fn new(
318 0 : conf: Arc<SafeKeeperConf>,
319 0 : conn_id: u32,
320 0 : io_metrics: Option<TrafficMetrics>,
321 0 : auth: Option<(Scope, Arc<JwtAuth>)>,
322 0 : global_timelines: Arc<GlobalTimelines>,
323 0 : ) -> Self {
324 0 : SafekeeperPostgresHandler {
325 0 : conf,
326 0 : appname: None,
327 0 : tenant_id: None,
328 0 : timeline_id: None,
329 0 : ttid: TenantTimelineId::empty(),
330 0 : shard: None,
331 0 : protocol: None,
332 0 : conn_id,
333 0 : claims: None,
334 0 : auth,
335 0 : io_metrics,
336 0 : global_timelines,
337 0 : }
338 0 : }
339 :
340 0 : pub fn protocol(&self) -> PostgresClientProtocol {
341 0 : self.protocol.unwrap_or(PostgresClientProtocol::Vanilla)
342 0 : }
343 :
344 : // when accessing management api supply None as an argument
345 : // when using to authorize tenant pass corresponding tenant id
346 0 : fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<(), QueryError> {
347 0 : if self.auth.is_none() {
348 : // auth is set to Trust, nothing to check so just return ok
349 0 : return Ok(());
350 0 : }
351 0 : // auth is some, just checked above, when auth is some
352 0 : // then claims are always present because of checks during connection init
353 0 : // so this expect won't trigger
354 0 : let claims = self
355 0 : .claims
356 0 : .as_ref()
357 0 : .expect("claims presence already checked");
358 0 : check_permission(claims, tenant_id).map_err(|e| QueryError::Unauthorized(e.0))
359 0 : }
360 :
361 0 : async fn handle_timeline_status<IO: AsyncRead + AsyncWrite + Unpin>(
362 0 : &mut self,
363 0 : pgb: &mut PostgresBackend<IO>,
364 0 : ) -> Result<(), QueryError> {
365 : // Get timeline, handling "not found" error
366 0 : let tli = match self.global_timelines.get(self.ttid) {
367 0 : Ok(tli) => Ok(Some(tli)),
368 0 : Err(TimelineError::NotFound(_)) => Ok(None),
369 0 : Err(e) => Err(QueryError::Other(e.into())),
370 0 : }?;
371 :
372 : // Write row description
373 0 : pgb.write_message_noflush(&BeMessage::RowDescription(&[
374 0 : RowDescriptor::text_col(b"flush_lsn"),
375 0 : RowDescriptor::text_col(b"commit_lsn"),
376 0 : ]))?;
377 :
378 : // Write row if timeline exists
379 0 : if let Some(tli) = tli {
380 0 : let (inmem, _state) = tli.get_state().await;
381 0 : let flush_lsn = tli.get_flush_lsn().await;
382 0 : let commit_lsn = inmem.commit_lsn;
383 0 : pgb.write_message_noflush(&BeMessage::DataRow(&[
384 0 : Some(flush_lsn.to_string().as_bytes()),
385 0 : Some(commit_lsn.to_string().as_bytes()),
386 0 : ]))?;
387 0 : }
388 :
389 0 : pgb.write_message_noflush(&BeMessage::CommandComplete(b"TIMELINE_STATUS"))?;
390 0 : Ok(())
391 0 : }
392 :
393 : ///
394 : /// Handle IDENTIFY_SYSTEM replication command
395 : ///
396 0 : async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
397 0 : &mut self,
398 0 : pgb: &mut PostgresBackend<IO>,
399 0 : ) -> Result<(), QueryError> {
400 0 : let tli = self
401 0 : .global_timelines
402 0 : .get(self.ttid)
403 0 : .map_err(|e| QueryError::Other(e.into()))?;
404 :
405 0 : let lsn = if self.is_walproposer_recovery() {
406 : // walproposer should get all local WAL until flush_lsn
407 0 : tli.get_flush_lsn().await
408 : } else {
409 : // other clients shouldn't get any uncommitted WAL
410 0 : tli.get_state().await.0.commit_lsn
411 : }
412 0 : .to_string();
413 :
414 0 : let sysid = tli.get_state().await.1.server.system_id.to_string();
415 0 : let lsn_bytes = lsn.as_bytes();
416 0 : let tli = PG_TLI.to_string();
417 0 : let tli_bytes = tli.as_bytes();
418 0 : let sysid_bytes = sysid.as_bytes();
419 0 :
420 0 : pgb.write_message_noflush(&BeMessage::RowDescription(&[
421 0 : RowDescriptor {
422 0 : name: b"systemid",
423 0 : typoid: TEXT_OID,
424 0 : typlen: -1,
425 0 : ..Default::default()
426 0 : },
427 0 : RowDescriptor {
428 0 : name: b"timeline",
429 0 : typoid: INT4_OID,
430 0 : typlen: 4,
431 0 : ..Default::default()
432 0 : },
433 0 : RowDescriptor {
434 0 : name: b"xlogpos",
435 0 : typoid: TEXT_OID,
436 0 : typlen: -1,
437 0 : ..Default::default()
438 0 : },
439 0 : RowDescriptor {
440 0 : name: b"dbname",
441 0 : typoid: TEXT_OID,
442 0 : typlen: -1,
443 0 : ..Default::default()
444 0 : },
445 0 : ]))?
446 0 : .write_message_noflush(&BeMessage::DataRow(&[
447 0 : Some(sysid_bytes),
448 0 : Some(tli_bytes),
449 0 : Some(lsn_bytes),
450 0 : None,
451 0 : ]))?
452 0 : .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
453 0 : Ok(())
454 0 : }
455 :
456 : /// Returns true if current connection is a replication connection, originating
457 : /// from a walproposer recovery function. This connection gets a special handling:
458 : /// safekeeper must stream all local WAL till the flush_lsn, whether committed or not.
459 0 : pub fn is_walproposer_recovery(&self) -> bool {
460 0 : match &self.appname {
461 0 : None => false,
462 0 : Some(appname) => {
463 0 : appname == "wal_proposer_recovery" ||
464 : // set by safekeeper peer recovery
465 0 : appname.starts_with("safekeeper")
466 : }
467 : }
468 0 : }
469 : }
|