Line data Source code
1 : //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
2 : //! protocol commands.
3 :
4 : use std::future::Future;
5 : use std::str::{self, FromStr};
6 : use std::sync::Arc;
7 :
8 : use anyhow::Context;
9 : use jsonwebtoken::TokenData;
10 : use pageserver_api::models::ShardParameters;
11 : use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
12 : use postgres_backend::{PostgresBackend, QueryError};
13 : use postgres_ffi::PG_TLI;
14 : use pq_proto::{BeMessage, FeStartupPacket, INT4_OID, RowDescriptor, TEXT_OID};
15 : use regex::Regex;
16 : use safekeeper_api::Term;
17 : use safekeeper_api::models::ConnectionId;
18 : use tokio::io::{AsyncRead, AsyncWrite};
19 : use tracing::{Instrument, debug, info, info_span};
20 : use utils::auth::{Claims, JwtAuth, Scope};
21 : use utils::id::{TenantId, TenantTimelineId, TimelineId};
22 : use utils::lsn::Lsn;
23 : use utils::postgres_client::PostgresClientProtocol;
24 : use utils::shard::{ShardCount, ShardNumber};
25 :
26 : use crate::auth::check_permission;
27 : use crate::metrics::{PG_QUERIES_GAUGE, TrafficMetrics};
28 : use crate::timeline::TimelineError;
29 : use crate::{GlobalTimelines, SafeKeeperConf};
30 :
31 : /// Safekeeper handler of postgres commands
32 : pub struct SafekeeperPostgresHandler {
33 : pub conf: Arc<SafeKeeperConf>,
34 : /// assigned application name
35 : pub appname: Option<String>,
36 : pub tenant_id: Option<TenantId>,
37 : pub timeline_id: Option<TimelineId>,
38 : pub ttid: TenantTimelineId,
39 : pub shard: Option<ShardIdentity>,
40 : pub protocol: Option<PostgresClientProtocol>,
41 : /// Unique connection id is logged in spans for observability.
42 : pub conn_id: ConnectionId,
43 : pub global_timelines: Arc<GlobalTimelines>,
44 : /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
45 : auth: Option<(Scope, Arc<JwtAuth>)>,
46 : claims: Option<Claims>,
47 : io_metrics: Option<TrafficMetrics>,
48 : }
49 :
50 : /// Parsed Postgres command.
51 : enum SafekeeperPostgresCommand {
52 : StartWalPush {
53 : proto_version: u32,
54 : // Eventually timelines will be always created explicitly by storcon.
55 : // This option allows legacy behaviour for compute to do that until we
56 : // fully migrate.
57 : allow_timeline_creation: bool,
58 : },
59 : StartReplication {
60 : start_lsn: Lsn,
61 : term: Option<Term>,
62 : },
63 : IdentifySystem,
64 : TimelineStatus,
65 : }
66 :
67 2 : fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
68 2 : if cmd.starts_with("START_WAL_PUSH") {
69 : // Allow additional options in postgres START_REPLICATION style like
70 : // START_WAL_PUSH (proto_version '3', allow_timeline_creation 'false').
71 : // Parsing here is very naive and breaks in case of commas or
72 : // whitespaces in values, but enough for our purposes.
73 2 : let re = Regex::new(r"START_WAL_PUSH(\s+?\((.*)\))?").unwrap();
74 2 : let caps = re
75 2 : .captures(cmd)
76 2 : .context(format!("failed to parse START_WAL_PUSH command {}", cmd))?;
77 : // capture () content
78 2 : let options = caps.get(2).map(|m| m.as_str()).unwrap_or("");
79 2 : // default values
80 2 : let mut proto_version = 2;
81 2 : let mut allow_timeline_creation = true;
82 4 : for kvstr in options.split(",") {
83 4 : if kvstr.is_empty() {
84 1 : continue;
85 3 : }
86 3 : let mut kvit = kvstr.split_whitespace();
87 3 : let key = kvit.next().context(format!(
88 3 : "failed to parse key in kv {} in command {}",
89 3 : kvstr, cmd
90 3 : ))?;
91 3 : let value = kvit.next().context(format!(
92 3 : "failed to parse value in kv {} in command {}",
93 3 : kvstr, cmd
94 3 : ))?;
95 3 : let value_trimmed = value.trim_matches('\'');
96 3 : if key == "proto_version" {
97 1 : proto_version = value_trimmed.parse::<u32>().context(format!(
98 1 : "failed to parse proto_version value {} in command {}",
99 1 : value, cmd
100 1 : ))?;
101 2 : }
102 3 : if key == "allow_timeline_creation" {
103 1 : allow_timeline_creation = value_trimmed.parse::<bool>().context(format!(
104 1 : "failed to parse allow_timeline_creation value {} in command {}",
105 1 : value, cmd
106 1 : ))?;
107 2 : }
108 : }
109 2 : Ok(SafekeeperPostgresCommand::StartWalPush {
110 2 : proto_version,
111 2 : allow_timeline_creation,
112 2 : })
113 0 : } else if cmd.starts_with("START_REPLICATION") {
114 0 : let re = Regex::new(
115 0 : // We follow postgres START_REPLICATION LOGICAL options to pass term.
116 0 : r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
117 0 : )
118 0 : .unwrap();
119 0 : let caps = re
120 0 : .captures(cmd)
121 0 : .context(format!("failed to parse START_REPLICATION command {}", cmd))?;
122 0 : let start_lsn =
123 0 : Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?;
124 0 : let term = if let Some(m) = caps.get(2) {
125 0 : Some(m.as_str().parse::<u64>().context("invalid term")?)
126 : } else {
127 0 : None
128 : };
129 0 : Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
130 0 : } else if cmd.starts_with("IDENTIFY_SYSTEM") {
131 0 : Ok(SafekeeperPostgresCommand::IdentifySystem)
132 0 : } else if cmd.starts_with("TIMELINE_STATUS") {
133 0 : Ok(SafekeeperPostgresCommand::TimelineStatus)
134 : } else {
135 0 : anyhow::bail!("unsupported command {cmd}");
136 : }
137 2 : }
138 :
139 0 : fn cmd_to_string(cmd: &SafekeeperPostgresCommand) -> &str {
140 0 : match cmd {
141 0 : SafekeeperPostgresCommand::StartWalPush { .. } => "START_WAL_PUSH",
142 0 : SafekeeperPostgresCommand::StartReplication { .. } => "START_REPLICATION",
143 0 : SafekeeperPostgresCommand::TimelineStatus => "TIMELINE_STATUS",
144 0 : SafekeeperPostgresCommand::IdentifySystem => "IDENTIFY_SYSTEM",
145 : }
146 0 : }
147 :
148 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
149 : for SafekeeperPostgresHandler
150 : {
151 : // tenant_id and timeline_id are passed in connection string params
152 0 : fn startup(
153 0 : &mut self,
154 0 : _pgb: &mut PostgresBackend<IO>,
155 0 : sm: &FeStartupPacket,
156 0 : ) -> Result<(), QueryError> {
157 0 : if let FeStartupPacket::StartupMessage { params, .. } = sm {
158 0 : if let Some(options) = params.options_raw() {
159 0 : let mut shard_count: Option<u8> = None;
160 0 : let mut shard_number: Option<u8> = None;
161 0 : let mut shard_stripe_size: Option<u32> = None;
162 :
163 0 : for opt in options {
164 : // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
165 : // remove these after the PR gets deployed:
166 : // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
167 0 : match opt.split_once('=') {
168 0 : Some(("protocol", value)) => {
169 0 : self.protocol =
170 0 : Some(serde_json::from_str(value).with_context(|| {
171 0 : format!("Failed to parse {value} as protocol")
172 0 : })?);
173 : }
174 0 : Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
175 0 : self.tenant_id = Some(value.parse().with_context(|| {
176 0 : format!("Failed to parse {value} as tenant id")
177 0 : })?);
178 : }
179 0 : Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
180 0 : self.timeline_id = Some(value.parse().with_context(|| {
181 0 : format!("Failed to parse {value} as timeline id")
182 0 : })?);
183 : }
184 0 : Some(("availability_zone", client_az)) => {
185 0 : if let Some(metrics) = self.io_metrics.as_ref() {
186 0 : metrics.set_client_az(client_az)
187 0 : }
188 : }
189 0 : Some(("shard_count", value)) => {
190 0 : shard_count = Some(value.parse::<u8>().with_context(|| {
191 0 : format!("Failed to parse {value} as shard count")
192 0 : })?);
193 : }
194 0 : Some(("shard_number", value)) => {
195 0 : shard_number = Some(value.parse::<u8>().with_context(|| {
196 0 : format!("Failed to parse {value} as shard number")
197 0 : })?);
198 : }
199 0 : Some(("shard_stripe_size", value)) => {
200 0 : shard_stripe_size = Some(value.parse::<u32>().with_context(|| {
201 0 : format!("Failed to parse {value} as shard stripe size")
202 0 : })?);
203 : }
204 0 : _ => continue,
205 : }
206 : }
207 :
208 0 : match self.protocol() {
209 : PostgresClientProtocol::Vanilla => {
210 0 : if shard_count.is_some()
211 0 : || shard_number.is_some()
212 0 : || shard_stripe_size.is_some()
213 : {
214 0 : return Err(QueryError::Other(anyhow::anyhow!(
215 0 : "Shard params specified for vanilla protocol"
216 0 : )));
217 0 : }
218 : }
219 : PostgresClientProtocol::Interpreted { .. } => {
220 0 : match (shard_count, shard_number, shard_stripe_size) {
221 0 : (Some(count), Some(number), Some(stripe_size)) => {
222 0 : let params = ShardParameters {
223 0 : count: ShardCount(count),
224 0 : stripe_size: ShardStripeSize(stripe_size),
225 0 : };
226 0 : self.shard =
227 0 : Some(ShardIdentity::from_params(ShardNumber(number), ¶ms));
228 0 : }
229 : _ => {
230 0 : return Err(QueryError::Other(anyhow::anyhow!(
231 0 : "Shard params were not specified"
232 0 : )));
233 : }
234 : }
235 : }
236 : }
237 0 : }
238 :
239 0 : if let Some(app_name) = params.get("application_name") {
240 0 : self.appname = Some(app_name.to_owned());
241 0 : if let Some(metrics) = self.io_metrics.as_ref() {
242 0 : metrics.set_app_name(app_name)
243 0 : }
244 0 : }
245 :
246 0 : let ttid = TenantTimelineId::new(
247 0 : self.tenant_id.unwrap_or(TenantId::from([0u8; 16])),
248 0 : self.timeline_id.unwrap_or(TimelineId::from([0u8; 16])),
249 0 : );
250 0 : tracing::Span::current()
251 0 : .record("ttid", tracing::field::display(ttid))
252 0 : .record(
253 0 : "application_name",
254 0 : tracing::field::debug(self.appname.clone()),
255 0 : );
256 :
257 0 : if let Some(shard) = self.shard.as_ref() {
258 0 : if let Some(slug) = shard.shard_slug().strip_prefix("-") {
259 0 : tracing::Span::current().record("shard", tracing::field::display(slug));
260 0 : }
261 0 : }
262 :
263 0 : Ok(())
264 : } else {
265 0 : Err(QueryError::Other(anyhow::anyhow!(
266 0 : "Safekeeper received unexpected initial message: {sm:?}"
267 0 : )))
268 : }
269 0 : }
270 :
271 0 : fn check_auth_jwt(
272 0 : &mut self,
273 0 : _pgb: &mut PostgresBackend<IO>,
274 0 : jwt_response: &[u8],
275 0 : ) -> Result<(), QueryError> {
276 0 : // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
277 0 : // which requires auth to be present
278 0 : let (allowed_auth_scope, auth) = self
279 0 : .auth
280 0 : .as_ref()
281 0 : .expect("auth_type is configured but .auth of handler is missing");
282 0 : let data: TokenData<Claims> = auth
283 0 : .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)
284 0 : .map_err(|e| QueryError::Unauthorized(e.0))?;
285 :
286 : // The handler might be configured to allow only tenant scope tokens.
287 0 : if matches!(allowed_auth_scope, Scope::Tenant)
288 0 : && !matches!(data.claims.scope, Scope::Tenant)
289 : {
290 0 : return Err(QueryError::Unauthorized(
291 0 : "passed JWT token is for full access, but only tenant scope is allowed".into(),
292 0 : ));
293 0 : }
294 :
295 0 : if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
296 0 : return Err(QueryError::Unauthorized(
297 0 : "jwt token scope is Tenant, but tenant id is missing".into(),
298 0 : ));
299 0 : }
300 0 :
301 0 : debug!(
302 0 : "jwt scope check succeeded for scope: {:#?} by tenant id: {:?}",
303 : data.claims.scope, data.claims.tenant_id,
304 : );
305 :
306 0 : self.claims = Some(data.claims);
307 0 : Ok(())
308 0 : }
309 :
310 0 : fn process_query(
311 0 : &mut self,
312 0 : pgb: &mut PostgresBackend<IO>,
313 0 : query_string: &str,
314 0 : ) -> impl Future<Output = Result<(), QueryError>> {
315 0 : Box::pin(async move {
316 0 : if query_string
317 0 : .to_ascii_lowercase()
318 0 : .starts_with("set datestyle to ")
319 : {
320 : // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
321 0 : pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
322 0 : return Ok(());
323 0 : }
324 :
325 0 : let cmd = parse_cmd(query_string)?;
326 0 : let cmd_str = cmd_to_string(&cmd);
327 0 :
328 0 : let _guard = PG_QUERIES_GAUGE.with_label_values(&[cmd_str]).guard();
329 0 :
330 0 : info!("got query {:?}", query_string);
331 :
332 0 : let tenant_id = self.tenant_id.context("tenantid is required")?;
333 0 : let timeline_id = self.timeline_id.context("timelineid is required")?;
334 0 : self.check_permission(Some(tenant_id))?;
335 0 : self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
336 0 :
337 0 : match cmd {
338 : SafekeeperPostgresCommand::StartWalPush {
339 0 : proto_version,
340 0 : allow_timeline_creation,
341 0 : } => {
342 0 : self.handle_start_wal_push(pgb, proto_version, allow_timeline_creation)
343 0 : .instrument(info_span!("WAL receiver"))
344 0 : .await
345 : }
346 0 : SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
347 0 : self.handle_start_replication(pgb, start_lsn, term)
348 0 : .instrument(info_span!("WAL sender"))
349 0 : .await
350 : }
351 0 : SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
352 0 : SafekeeperPostgresCommand::TimelineStatus => self.handle_timeline_status(pgb).await,
353 : }
354 0 : })
355 0 : }
356 : }
357 :
358 : impl SafekeeperPostgresHandler {
359 0 : pub fn new(
360 0 : conf: Arc<SafeKeeperConf>,
361 0 : conn_id: u32,
362 0 : io_metrics: Option<TrafficMetrics>,
363 0 : auth: Option<(Scope, Arc<JwtAuth>)>,
364 0 : global_timelines: Arc<GlobalTimelines>,
365 0 : ) -> Self {
366 0 : SafekeeperPostgresHandler {
367 0 : conf,
368 0 : appname: None,
369 0 : tenant_id: None,
370 0 : timeline_id: None,
371 0 : ttid: TenantTimelineId::empty(),
372 0 : shard: None,
373 0 : protocol: None,
374 0 : conn_id,
375 0 : claims: None,
376 0 : auth,
377 0 : io_metrics,
378 0 : global_timelines,
379 0 : }
380 0 : }
381 :
382 0 : pub fn protocol(&self) -> PostgresClientProtocol {
383 0 : self.protocol.unwrap_or(PostgresClientProtocol::Vanilla)
384 0 : }
385 :
386 : // when accessing management api supply None as an argument
387 : // when using to authorize tenant pass corresponding tenant id
388 0 : fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<(), QueryError> {
389 0 : if self.auth.is_none() {
390 : // auth is set to Trust, nothing to check so just return ok
391 0 : return Ok(());
392 0 : }
393 0 : // auth is some, just checked above, when auth is some
394 0 : // then claims are always present because of checks during connection init
395 0 : // so this expect won't trigger
396 0 : let claims = self
397 0 : .claims
398 0 : .as_ref()
399 0 : .expect("claims presence already checked");
400 0 : check_permission(claims, tenant_id).map_err(|e| QueryError::Unauthorized(e.0))
401 0 : }
402 :
403 0 : async fn handle_timeline_status<IO: AsyncRead + AsyncWrite + Unpin>(
404 0 : &mut self,
405 0 : pgb: &mut PostgresBackend<IO>,
406 0 : ) -> Result<(), QueryError> {
407 : // Get timeline, handling "not found" error
408 0 : let tli = match self.global_timelines.get(self.ttid) {
409 0 : Ok(tli) => Ok(Some(tli)),
410 0 : Err(TimelineError::NotFound(_)) => Ok(None),
411 0 : Err(e) => Err(QueryError::Other(e.into())),
412 0 : }?;
413 :
414 : // Write row description
415 0 : pgb.write_message_noflush(&BeMessage::RowDescription(&[
416 0 : RowDescriptor::text_col(b"flush_lsn"),
417 0 : RowDescriptor::text_col(b"commit_lsn"),
418 0 : ]))?;
419 :
420 : // Write row if timeline exists
421 0 : if let Some(tli) = tli {
422 0 : let (inmem, _state) = tli.get_state().await;
423 0 : let flush_lsn = tli.get_flush_lsn().await;
424 0 : let commit_lsn = inmem.commit_lsn;
425 0 : pgb.write_message_noflush(&BeMessage::DataRow(&[
426 0 : Some(flush_lsn.to_string().as_bytes()),
427 0 : Some(commit_lsn.to_string().as_bytes()),
428 0 : ]))?;
429 0 : }
430 :
431 0 : pgb.write_message_noflush(&BeMessage::CommandComplete(b"TIMELINE_STATUS"))?;
432 0 : Ok(())
433 0 : }
434 :
435 : ///
436 : /// Handle IDENTIFY_SYSTEM replication command
437 : ///
438 0 : async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
439 0 : &mut self,
440 0 : pgb: &mut PostgresBackend<IO>,
441 0 : ) -> Result<(), QueryError> {
442 0 : let tli = self
443 0 : .global_timelines
444 0 : .get(self.ttid)
445 0 : .map_err(|e| QueryError::Other(e.into()))?;
446 :
447 0 : let lsn = if self.is_walproposer_recovery() {
448 : // walproposer should get all local WAL until flush_lsn
449 0 : tli.get_flush_lsn().await
450 : } else {
451 : // other clients shouldn't get any uncommitted WAL
452 0 : tli.get_state().await.0.commit_lsn
453 : }
454 0 : .to_string();
455 :
456 0 : let sysid = tli.get_state().await.1.server.system_id.to_string();
457 0 : let lsn_bytes = lsn.as_bytes();
458 0 : let tli = PG_TLI.to_string();
459 0 : let tli_bytes = tli.as_bytes();
460 0 : let sysid_bytes = sysid.as_bytes();
461 0 :
462 0 : pgb.write_message_noflush(&BeMessage::RowDescription(&[
463 0 : RowDescriptor {
464 0 : name: b"systemid",
465 0 : typoid: TEXT_OID,
466 0 : typlen: -1,
467 0 : ..Default::default()
468 0 : },
469 0 : RowDescriptor {
470 0 : name: b"timeline",
471 0 : typoid: INT4_OID,
472 0 : typlen: 4,
473 0 : ..Default::default()
474 0 : },
475 0 : RowDescriptor {
476 0 : name: b"xlogpos",
477 0 : typoid: TEXT_OID,
478 0 : typlen: -1,
479 0 : ..Default::default()
480 0 : },
481 0 : RowDescriptor {
482 0 : name: b"dbname",
483 0 : typoid: TEXT_OID,
484 0 : typlen: -1,
485 0 : ..Default::default()
486 0 : },
487 0 : ]))?
488 0 : .write_message_noflush(&BeMessage::DataRow(&[
489 0 : Some(sysid_bytes),
490 0 : Some(tli_bytes),
491 0 : Some(lsn_bytes),
492 0 : None,
493 0 : ]))?
494 0 : .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
495 0 : Ok(())
496 0 : }
497 :
498 : /// Returns true if current connection is a replication connection, originating
499 : /// from a walproposer recovery function. This connection gets a special handling:
500 : /// safekeeper must stream all local WAL till the flush_lsn, whether committed or not.
501 0 : pub fn is_walproposer_recovery(&self) -> bool {
502 0 : match &self.appname {
503 0 : None => false,
504 0 : Some(appname) => {
505 0 : appname == "wal_proposer_recovery" ||
506 : // set by safekeeper peer recovery
507 0 : appname.starts_with("safekeeper")
508 : }
509 : }
510 0 : }
511 : }
512 :
513 : #[cfg(test)]
514 : mod tests {
515 : use super::SafekeeperPostgresCommand;
516 :
517 : /// Test parsing of START_WAL_PUSH command
518 : #[test]
519 1 : fn test_start_wal_push_parse() {
520 1 : let cmd = "START_WAL_PUSH";
521 1 : let parsed = super::parse_cmd(cmd).expect("failed to parse");
522 1 : match parsed {
523 : SafekeeperPostgresCommand::StartWalPush {
524 1 : proto_version,
525 1 : allow_timeline_creation,
526 1 : } => {
527 1 : assert_eq!(proto_version, 2);
528 1 : assert!(allow_timeline_creation);
529 : }
530 0 : _ => panic!("unexpected command"),
531 : }
532 :
533 1 : let cmd =
534 1 : "START_WAL_PUSH (proto_version '3', allow_timeline_creation 'false', unknown 'hoho')";
535 1 : let parsed = super::parse_cmd(cmd).expect("failed to parse");
536 1 : match parsed {
537 : SafekeeperPostgresCommand::StartWalPush {
538 1 : proto_version,
539 1 : allow_timeline_creation,
540 1 : } => {
541 1 : assert_eq!(proto_version, 3);
542 1 : assert!(!allow_timeline_creation);
543 : }
544 0 : _ => panic!("unexpected command"),
545 : }
546 1 : }
547 : }
|