LCOV - code coverage report
Current view: top level - safekeeper/src - handler.rs (source / functions) Coverage Total Hit
Test: f081ec316c96fa98335efd15ef501745aa4f015d.info Lines: 0.0 % 263 0
Test Date: 2024-06-25 15:11:17 Functions: 0.0 % 30 0

            Line data    Source code
       1              : //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
       2              : //! protocol commands.
       3              : 
       4              : use anyhow::Context;
       5              : use std::str::{self, FromStr};
       6              : use std::sync::Arc;
       7              : use tokio::io::{AsyncRead, AsyncWrite};
       8              : use tracing::{debug, info, info_span, Instrument};
       9              : 
      10              : use crate::auth::check_permission;
      11              : use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
      12              : 
      13              : use crate::metrics::{TrafficMetrics, PG_QUERIES_GAUGE};
      14              : use crate::safekeeper::Term;
      15              : use crate::timeline::TimelineError;
      16              : use crate::wal_service::ConnectionId;
      17              : use crate::{GlobalTimelines, SafeKeeperConf};
      18              : use postgres_backend::PostgresBackend;
      19              : use postgres_backend::QueryError;
      20              : use postgres_ffi::PG_TLI;
      21              : use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
      22              : use regex::Regex;
      23              : use utils::auth::{Claims, JwtAuth, Scope};
      24              : use utils::{
      25              :     id::{TenantId, TenantTimelineId, TimelineId},
      26              :     lsn::Lsn,
      27              : };
      28              : 
      29              : /// Safekeeper handler of postgres commands
      30              : pub struct SafekeeperPostgresHandler {
      31              :     pub conf: SafeKeeperConf,
      32              :     /// assigned application name
      33              :     pub appname: Option<String>,
      34              :     pub tenant_id: Option<TenantId>,
      35              :     pub timeline_id: Option<TimelineId>,
      36              :     pub ttid: TenantTimelineId,
      37              :     /// Unique connection id is logged in spans for observability.
      38              :     pub conn_id: ConnectionId,
      39              :     /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
      40              :     auth: Option<(Scope, Arc<JwtAuth>)>,
      41              :     claims: Option<Claims>,
      42              :     io_metrics: Option<TrafficMetrics>,
      43              : }
      44              : 
      45              : /// Parsed Postgres command.
      46              : enum SafekeeperPostgresCommand {
      47              :     StartWalPush,
      48              :     StartReplication { start_lsn: Lsn, term: Option<Term> },
      49              :     IdentifySystem,
      50              :     TimelineStatus,
      51              :     JSONCtrl { cmd: AppendLogicalMessage },
      52              : }
      53              : 
      54            0 : fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
      55            0 :     if cmd.starts_with("START_WAL_PUSH") {
      56            0 :         Ok(SafekeeperPostgresCommand::StartWalPush)
      57            0 :     } else if cmd.starts_with("START_REPLICATION") {
      58            0 :         let re = Regex::new(
      59            0 :             // We follow postgres START_REPLICATION LOGICAL options to pass term.
      60            0 :             r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
      61            0 :         )
      62            0 :         .unwrap();
      63            0 :         let caps = re
      64            0 :             .captures(cmd)
      65            0 :             .context(format!("failed to parse START_REPLICATION command {}", cmd))?;
      66            0 :         let start_lsn =
      67            0 :             Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?;
      68            0 :         let term = if let Some(m) = caps.get(2) {
      69            0 :             Some(m.as_str().parse::<u64>().context("invalid term")?)
      70              :         } else {
      71            0 :             None
      72              :         };
      73            0 :         Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
      74            0 :     } else if cmd.starts_with("IDENTIFY_SYSTEM") {
      75            0 :         Ok(SafekeeperPostgresCommand::IdentifySystem)
      76            0 :     } else if cmd.starts_with("TIMELINE_STATUS") {
      77            0 :         Ok(SafekeeperPostgresCommand::TimelineStatus)
      78            0 :     } else if cmd.starts_with("JSON_CTRL") {
      79            0 :         let cmd = cmd.strip_prefix("JSON_CTRL").context("invalid prefix")?;
      80              :         Ok(SafekeeperPostgresCommand::JSONCtrl {
      81            0 :             cmd: serde_json::from_str(cmd)?,
      82              :         })
      83              :     } else {
      84            0 :         anyhow::bail!("unsupported command {cmd}");
      85              :     }
      86            0 : }
      87              : 
      88            0 : fn cmd_to_string(cmd: &SafekeeperPostgresCommand) -> &str {
      89            0 :     match cmd {
      90            0 :         SafekeeperPostgresCommand::StartWalPush => "START_WAL_PUSH",
      91            0 :         SafekeeperPostgresCommand::StartReplication { .. } => "START_REPLICATION",
      92            0 :         SafekeeperPostgresCommand::TimelineStatus => "TIMELINE_STATUS",
      93            0 :         SafekeeperPostgresCommand::IdentifySystem => "IDENTIFY_SYSTEM",
      94            0 :         SafekeeperPostgresCommand::JSONCtrl { .. } => "JSON_CTRL",
      95              :     }
      96            0 : }
      97              : 
      98              : #[async_trait::async_trait]
      99              : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
     100              :     for SafekeeperPostgresHandler
     101              : {
     102              :     // tenant_id and timeline_id are passed in connection string params
     103            0 :     fn startup(
     104            0 :         &mut self,
     105            0 :         _pgb: &mut PostgresBackend<IO>,
     106            0 :         sm: &FeStartupPacket,
     107            0 :     ) -> Result<(), QueryError> {
     108            0 :         if let FeStartupPacket::StartupMessage { params, .. } = sm {
     109            0 :             if let Some(options) = params.options_raw() {
     110            0 :                 for opt in options {
     111              :                     // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
     112              :                     // remove these after the PR gets deployed:
     113              :                     // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
     114            0 :                     match opt.split_once('=') {
     115            0 :                         Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
     116            0 :                             self.tenant_id = Some(value.parse().with_context(|| {
     117            0 :                                 format!("Failed to parse {value} as tenant id")
     118            0 :                             })?);
     119              :                         }
     120            0 :                         Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
     121            0 :                             self.timeline_id = Some(value.parse().with_context(|| {
     122            0 :                                 format!("Failed to parse {value} as timeline id")
     123            0 :                             })?);
     124              :                         }
     125            0 :                         Some(("availability_zone", client_az)) => {
     126            0 :                             if let Some(metrics) = self.io_metrics.as_ref() {
     127            0 :                                 metrics.set_client_az(client_az)
     128            0 :                             }
     129              :                         }
     130            0 :                         _ => continue,
     131              :                     }
     132              :                 }
     133            0 :             }
     134              : 
     135            0 :             if let Some(app_name) = params.get("application_name") {
     136            0 :                 self.appname = Some(app_name.to_owned());
     137            0 :                 if let Some(metrics) = self.io_metrics.as_ref() {
     138            0 :                     metrics.set_app_name(app_name)
     139            0 :                 }
     140            0 :             }
     141              : 
     142            0 :             let ttid = TenantTimelineId::new(
     143            0 :                 self.tenant_id.unwrap_or(TenantId::from([0u8; 16])),
     144            0 :                 self.timeline_id.unwrap_or(TimelineId::from([0u8; 16])),
     145            0 :             );
     146            0 :             tracing::Span::current().record("ttid", tracing::field::display(ttid));
     147            0 : 
     148            0 :             Ok(())
     149              :         } else {
     150            0 :             Err(QueryError::Other(anyhow::anyhow!(
     151            0 :                 "Safekeeper received unexpected initial message: {sm:?}"
     152            0 :             )))
     153              :         }
     154            0 :     }
     155              : 
     156            0 :     fn check_auth_jwt(
     157            0 :         &mut self,
     158            0 :         _pgb: &mut PostgresBackend<IO>,
     159            0 :         jwt_response: &[u8],
     160            0 :     ) -> Result<(), QueryError> {
     161            0 :         // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
     162            0 :         // which requires auth to be present
     163            0 :         let (allowed_auth_scope, auth) = self
     164            0 :             .auth
     165            0 :             .as_ref()
     166            0 :             .expect("auth_type is configured but .auth of handler is missing");
     167            0 :         let data = auth
     168            0 :             .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)
     169            0 :             .map_err(|e| QueryError::Unauthorized(e.0))?;
     170              : 
     171              :         // The handler might be configured to allow only tenant scope tokens.
     172            0 :         if matches!(allowed_auth_scope, Scope::Tenant)
     173            0 :             && !matches!(data.claims.scope, Scope::Tenant)
     174              :         {
     175            0 :             return Err(QueryError::Unauthorized(
     176            0 :                 "passed JWT token is for full access, but only tenant scope is allowed".into(),
     177            0 :             ));
     178            0 :         }
     179              : 
     180            0 :         if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
     181            0 :             return Err(QueryError::Unauthorized(
     182            0 :                 "jwt token scope is Tenant, but tenant id is missing".into(),
     183            0 :             ));
     184            0 :         }
     185            0 : 
     186            0 :         debug!(
     187            0 :             "jwt scope check succeeded for scope: {:#?} by tenant id: {:?}",
     188              :             data.claims.scope, data.claims.tenant_id,
     189              :         );
     190              : 
     191            0 :         self.claims = Some(data.claims);
     192            0 :         Ok(())
     193            0 :     }
     194              : 
     195            0 :     async fn process_query(
     196            0 :         &mut self,
     197            0 :         pgb: &mut PostgresBackend<IO>,
     198            0 :         query_string: &str,
     199            0 :     ) -> Result<(), QueryError> {
     200            0 :         if query_string
     201            0 :             .to_ascii_lowercase()
     202            0 :             .starts_with("set datestyle to ")
     203            0 :         {
     204            0 :             // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
     205            0 :             pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     206            0 :             return Ok(());
     207            0 :         }
     208            0 : 
     209            0 :         let cmd = parse_cmd(query_string)?;
     210            0 :         let cmd_str = cmd_to_string(&cmd);
     211            0 : 
     212            0 :         let _guard = PG_QUERIES_GAUGE.with_label_values(&[cmd_str]).guard();
     213            0 : 
     214            0 :         info!("got query {:?}", query_string);
     215            0 : 
     216            0 :         let tenant_id = self.tenant_id.context("tenantid is required")?;
     217            0 :         let timeline_id = self.timeline_id.context("timelineid is required")?;
     218            0 :         self.check_permission(Some(tenant_id))?;
     219            0 :         self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
     220            0 : 
     221            0 :         match cmd {
     222            0 :             SafekeeperPostgresCommand::StartWalPush => {
     223            0 :                 self.handle_start_wal_push(pgb)
     224            0 :                     .instrument(info_span!("WAL receiver"))
     225            0 :                     .await
     226            0 :             }
     227            0 :             SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
     228            0 :                 self.handle_start_replication(pgb, start_lsn, term)
     229            0 :                     .instrument(info_span!("WAL sender"))
     230            0 :                     .await
     231            0 :             }
     232            0 :             SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
     233            0 :             SafekeeperPostgresCommand::TimelineStatus => self.handle_timeline_status(pgb).await,
     234            0 :             SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
     235            0 :                 handle_json_ctrl(self, pgb, cmd).await
     236            0 :             }
     237            0 :         }
     238            0 :     }
     239              : }
     240              : 
     241              : impl SafekeeperPostgresHandler {
     242            0 :     pub fn new(
     243            0 :         conf: SafeKeeperConf,
     244            0 :         conn_id: u32,
     245            0 :         io_metrics: Option<TrafficMetrics>,
     246            0 :         auth: Option<(Scope, Arc<JwtAuth>)>,
     247            0 :     ) -> Self {
     248            0 :         SafekeeperPostgresHandler {
     249            0 :             conf,
     250            0 :             appname: None,
     251            0 :             tenant_id: None,
     252            0 :             timeline_id: None,
     253            0 :             ttid: TenantTimelineId::empty(),
     254            0 :             conn_id,
     255            0 :             claims: None,
     256            0 :             auth,
     257            0 :             io_metrics,
     258            0 :         }
     259            0 :     }
     260              : 
     261              :     // when accessing management api supply None as an argument
     262              :     // when using to authorize tenant pass corresponding tenant id
     263            0 :     fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<(), QueryError> {
     264            0 :         if self.auth.is_none() {
     265              :             // auth is set to Trust, nothing to check so just return ok
     266            0 :             return Ok(());
     267            0 :         }
     268            0 :         // auth is some, just checked above, when auth is some
     269            0 :         // then claims are always present because of checks during connection init
     270            0 :         // so this expect won't trigger
     271            0 :         let claims = self
     272            0 :             .claims
     273            0 :             .as_ref()
     274            0 :             .expect("claims presence already checked");
     275            0 :         check_permission(claims, tenant_id).map_err(|e| QueryError::Unauthorized(e.0))
     276            0 :     }
     277              : 
     278            0 :     async fn handle_timeline_status<IO: AsyncRead + AsyncWrite + Unpin>(
     279            0 :         &mut self,
     280            0 :         pgb: &mut PostgresBackend<IO>,
     281            0 :     ) -> Result<(), QueryError> {
     282              :         // Get timeline, handling "not found" error
     283            0 :         let tli = match GlobalTimelines::get(self.ttid) {
     284            0 :             Ok(tli) => Ok(Some(tli)),
     285            0 :             Err(TimelineError::NotFound(_)) => Ok(None),
     286            0 :             Err(e) => Err(QueryError::Other(e.into())),
     287            0 :         }?;
     288              : 
     289              :         // Write row description
     290            0 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     291            0 :             RowDescriptor::text_col(b"flush_lsn"),
     292            0 :             RowDescriptor::text_col(b"commit_lsn"),
     293            0 :         ]))?;
     294              : 
     295              :         // Write row if timeline exists
     296            0 :         if let Some(tli) = tli {
     297            0 :             let (inmem, _state) = tli.get_state().await;
     298            0 :             let flush_lsn = tli.get_flush_lsn().await;
     299            0 :             let commit_lsn = inmem.commit_lsn;
     300            0 :             pgb.write_message_noflush(&BeMessage::DataRow(&[
     301            0 :                 Some(flush_lsn.to_string().as_bytes()),
     302            0 :                 Some(commit_lsn.to_string().as_bytes()),
     303            0 :             ]))?;
     304            0 :         }
     305              : 
     306            0 :         pgb.write_message_noflush(&BeMessage::CommandComplete(b"TIMELINE_STATUS"))?;
     307            0 :         Ok(())
     308            0 :     }
     309              : 
     310              :     ///
     311              :     /// Handle IDENTIFY_SYSTEM replication command
     312              :     ///
     313            0 :     async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
     314            0 :         &mut self,
     315            0 :         pgb: &mut PostgresBackend<IO>,
     316            0 :     ) -> Result<(), QueryError> {
     317            0 :         let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
     318              : 
     319            0 :         let lsn = if self.is_walproposer_recovery() {
     320              :             // walproposer should get all local WAL until flush_lsn
     321            0 :             tli.get_flush_lsn().await
     322              :         } else {
     323              :             // other clients shouldn't get any uncommitted WAL
     324            0 :             tli.get_state().await.0.commit_lsn
     325              :         }
     326            0 :         .to_string();
     327              : 
     328            0 :         let sysid = tli.get_state().await.1.server.system_id.to_string();
     329            0 :         let lsn_bytes = lsn.as_bytes();
     330            0 :         let tli = PG_TLI.to_string();
     331            0 :         let tli_bytes = tli.as_bytes();
     332            0 :         let sysid_bytes = sysid.as_bytes();
     333            0 : 
     334            0 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     335            0 :             RowDescriptor {
     336            0 :                 name: b"systemid",
     337            0 :                 typoid: TEXT_OID,
     338            0 :                 typlen: -1,
     339            0 :                 ..Default::default()
     340            0 :             },
     341            0 :             RowDescriptor {
     342            0 :                 name: b"timeline",
     343            0 :                 typoid: INT4_OID,
     344            0 :                 typlen: 4,
     345            0 :                 ..Default::default()
     346            0 :             },
     347            0 :             RowDescriptor {
     348            0 :                 name: b"xlogpos",
     349            0 :                 typoid: TEXT_OID,
     350            0 :                 typlen: -1,
     351            0 :                 ..Default::default()
     352            0 :             },
     353            0 :             RowDescriptor {
     354            0 :                 name: b"dbname",
     355            0 :                 typoid: TEXT_OID,
     356            0 :                 typlen: -1,
     357            0 :                 ..Default::default()
     358            0 :             },
     359            0 :         ]))?
     360            0 :         .write_message_noflush(&BeMessage::DataRow(&[
     361            0 :             Some(sysid_bytes),
     362            0 :             Some(tli_bytes),
     363            0 :             Some(lsn_bytes),
     364            0 :             None,
     365            0 :         ]))?
     366            0 :         .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
     367            0 :         Ok(())
     368            0 :     }
     369              : 
     370              :     /// Returns true if current connection is a replication connection, originating
     371              :     /// from a walproposer recovery function. This connection gets a special handling:
     372              :     /// safekeeper must stream all local WAL till the flush_lsn, whether committed or not.
     373            0 :     pub fn is_walproposer_recovery(&self) -> bool {
     374            0 :         match &self.appname {
     375            0 :             None => false,
     376            0 :             Some(appname) => {
     377            0 :                 appname == "wal_proposer_recovery" ||
     378              :                 // set by safekeeper peer recovery
     379            0 :                 appname.starts_with("safekeeper")
     380              :             }
     381              :         }
     382            0 :     }
     383              : }
        

Generated by: LCOV version 2.1-beta