LCOV - differential code coverage report
Current view: top level - safekeeper/src - handler.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 93.0 % 256 238 18 238
Current Date: 2024-01-09 02:06:09 Functions: 47.1 % 34 16 18 16
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
       2                 : //! protocol commands.
       3                 : 
       4                 : use anyhow::Context;
       5                 : use std::str::FromStr;
       6                 : use std::str::{self};
       7                 : use std::sync::Arc;
       8                 : use tokio::io::{AsyncRead, AsyncWrite};
       9                 : use tracing::{debug, info, info_span, Instrument};
      10                 : 
      11                 : use crate::auth::check_permission;
      12                 : use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
      13                 : 
      14                 : use crate::metrics::{TrafficMetrics, PG_QUERIES_GAUGE};
      15                 : use crate::safekeeper::Term;
      16                 : use crate::timeline::TimelineError;
      17                 : use crate::wal_service::ConnectionId;
      18                 : use crate::{GlobalTimelines, SafeKeeperConf};
      19                 : use postgres_backend::QueryError;
      20                 : use postgres_backend::{self, PostgresBackend};
      21                 : use postgres_ffi::PG_TLI;
      22                 : use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
      23                 : use regex::Regex;
      24                 : use utils::auth::{Claims, JwtAuth, Scope};
      25                 : use utils::{
      26                 :     id::{TenantId, TenantTimelineId, TimelineId},
      27                 :     lsn::Lsn,
      28                 : };
      29                 : 
      30                 : /// Safekeeper handler of postgres commands
      31                 : pub struct SafekeeperPostgresHandler {
      32                 :     pub conf: SafeKeeperConf,
      33                 :     /// assigned application name
      34                 :     pub appname: Option<String>,
      35                 :     pub tenant_id: Option<TenantId>,
      36                 :     pub timeline_id: Option<TimelineId>,
      37                 :     pub ttid: TenantTimelineId,
      38                 :     /// Unique connection id is logged in spans for observability.
      39                 :     pub conn_id: ConnectionId,
      40                 :     /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
      41                 :     auth: Option<(Scope, Arc<JwtAuth>)>,
      42                 :     claims: Option<Claims>,
      43                 :     io_metrics: Option<TrafficMetrics>,
      44                 : }
      45                 : 
      46                 : /// Parsed Postgres command.
      47                 : enum SafekeeperPostgresCommand {
      48                 :     StartWalPush,
      49                 :     StartReplication { start_lsn: Lsn, term: Option<Term> },
      50                 :     IdentifySystem,
      51                 :     TimelineStatus,
      52                 :     JSONCtrl { cmd: AppendLogicalMessage },
      53                 : }
      54                 : 
      55 CBC        3833 : fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
      56            3833 :     if cmd.starts_with("START_WAL_PUSH") {
      57            1745 :         Ok(SafekeeperPostgresCommand::StartWalPush)
      58            2088 :     } else if cmd.starts_with("START_REPLICATION") {
      59             734 :         let re = Regex::new(
      60             734 :             // We follow postgres START_REPLICATION LOGICAL options to pass term.
      61             734 :             r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
      62             734 :         )
      63             734 :         .unwrap();
      64             734 :         let caps = re
      65             734 :             .captures(cmd)
      66             734 :             .context(format!("failed to parse START_REPLICATION command {}", cmd))?;
      67             734 :         let start_lsn =
      68             734 :             Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?;
      69             734 :         let term = if let Some(m) = caps.get(2) {
      70              16 :             Some(m.as_str().parse::<u64>().context("invalid term")?)
      71                 :         } else {
      72             718 :             None
      73                 :         };
      74             734 :         Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
      75            1354 :     } else if cmd.starts_with("IDENTIFY_SYSTEM") {
      76             727 :         Ok(SafekeeperPostgresCommand::IdentifySystem)
      77             627 :     } else if cmd.starts_with("TIMELINE_STATUS") {
      78             624 :         Ok(SafekeeperPostgresCommand::TimelineStatus)
      79               3 :     } else if cmd.starts_with("JSON_CTRL") {
      80               3 :         let cmd = cmd.strip_prefix("JSON_CTRL").context("invalid prefix")?;
      81                 :         Ok(SafekeeperPostgresCommand::JSONCtrl {
      82               3 :             cmd: serde_json::from_str(cmd)?,
      83                 :         })
      84                 :     } else {
      85 UBC           0 :         anyhow::bail!("unsupported command {cmd}");
      86                 :     }
      87 CBC        3833 : }
      88                 : 
      89            3833 : fn cmd_to_string(cmd: &SafekeeperPostgresCommand) -> &str {
      90            3833 :     match cmd {
      91            1745 :         SafekeeperPostgresCommand::StartWalPush => "START_WAL_PUSH",
      92             734 :         SafekeeperPostgresCommand::StartReplication { .. } => "START_REPLICATION",
      93             624 :         SafekeeperPostgresCommand::TimelineStatus => "TIMELINE_STATUS",
      94             727 :         SafekeeperPostgresCommand::IdentifySystem => "IDENTIFY_SYSTEM",
      95               3 :         SafekeeperPostgresCommand::JSONCtrl { .. } => "JSON_CTRL",
      96                 :     }
      97            3833 : }
      98                 : 
      99                 : #[async_trait::async_trait]
     100                 : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
     101                 :     for SafekeeperPostgresHandler
     102                 : {
     103                 :     // tenant_id and timeline_id are passed in connection string params
     104            3123 :     fn startup(
     105            3123 :         &mut self,
     106            3123 :         _pgb: &mut PostgresBackend<IO>,
     107            3123 :         sm: &FeStartupPacket,
     108            3123 :     ) -> Result<(), QueryError> {
     109            3123 :         if let FeStartupPacket::StartupMessage { params, .. } = sm {
     110            3123 :             if let Some(options) = params.options_raw() {
     111           12507 :                 for opt in options {
     112                 :                     // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
     113                 :                     // remove these after the PR gets deployed:
     114                 :                     // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
     115            9384 :                     match opt.split_once('=') {
     116            6267 :                         Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
     117            3123 :                             self.tenant_id = Some(value.parse().with_context(|| {
     118 UBC           0 :                                 format!("Failed to parse {value} as tenant id")
     119 CBC        3123 :                             })?);
     120                 :                         }
     121            3144 :                         Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
     122            3123 :                             self.timeline_id = Some(value.parse().with_context(|| {
     123 UBC           0 :                                 format!("Failed to parse {value} as timeline id")
     124 CBC        3123 :                             })?);
     125                 :                         }
     126              21 :                         Some(("availability_zone", client_az)) => {
     127               5 :                             if let Some(metrics) = self.io_metrics.as_ref() {
     128               5 :                                 metrics.set_client_az(client_az)
     129 UBC           0 :                             }
     130                 :                         }
     131 CBC        3133 :                         _ => continue,
     132                 :                     }
     133                 :                 }
     134 UBC           0 :             }
     135                 : 
     136 CBC        3123 :             if let Some(app_name) = params.get("application_name") {
     137             720 :                 self.appname = Some(app_name.to_owned());
     138             720 :                 if let Some(metrics) = self.io_metrics.as_ref() {
     139             720 :                     metrics.set_app_name(app_name)
     140 UBC           0 :                 }
     141 CBC        2403 :             }
     142                 : 
     143            3123 :             let ttid = TenantTimelineId::new(
     144            3123 :                 self.tenant_id.unwrap_or(TenantId::from([0u8; 16])),
     145            3123 :                 self.timeline_id.unwrap_or(TimelineId::from([0u8; 16])),
     146            3123 :             );
     147            3123 :             tracing::Span::current().record("ttid", tracing::field::display(ttid));
     148            3123 : 
     149            3123 :             Ok(())
     150                 :         } else {
     151 UBC           0 :             Err(QueryError::Other(anyhow::anyhow!(
     152               0 :                 "Safekeeper received unexpected initial message: {sm:?}"
     153               0 :             )))
     154                 :         }
     155 CBC        3123 :     }
     156                 : 
     157             128 :     fn check_auth_jwt(
     158             128 :         &mut self,
     159             128 :         _pgb: &mut PostgresBackend<IO>,
     160             128 :         jwt_response: &[u8],
     161             128 :     ) -> Result<(), QueryError> {
     162             128 :         // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
     163             128 :         // which requires auth to be present
     164             128 :         let (allowed_auth_scope, auth) = self
     165             128 :             .auth
     166             128 :             .as_ref()
     167             128 :             .expect("auth_type is configured but .auth of handler is missing");
     168             128 :         let data = auth
     169             128 :             .decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)
     170             128 :             .map_err(|e| QueryError::Unauthorized(e.0))?;
     171                 : 
     172                 :         // The handler might be configured to allow only tenant scope tokens.
     173             128 :         if matches!(allowed_auth_scope, Scope::Tenant)
     174             100 :             && !matches!(data.claims.scope, Scope::Tenant)
     175                 :         {
     176               1 :             return Err(QueryError::Unauthorized(
     177               1 :                 "passed JWT token is for full access, but only tenant scope is allowed".into(),
     178               1 :             ));
     179             127 :         }
     180                 : 
     181             127 :         if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
     182 UBC           0 :             return Err(QueryError::Unauthorized(
     183               0 :                 "jwt token scope is Tenant, but tenant id is missing".into(),
     184               0 :             ));
     185 CBC         127 :         }
     186             127 : 
     187             127 :         debug!(
     188 UBC           0 :             "jwt scope check succeeded for scope: {:#?} by tenant id: {:?}",
     189               0 :             data.claims.scope, data.claims.tenant_id,
     190               0 :         );
     191                 : 
     192 CBC         127 :         self.claims = Some(data.claims);
     193             127 :         Ok(())
     194             128 :     }
     195                 : 
     196            3844 :     async fn process_query(
     197            3844 :         &mut self,
     198            3844 :         pgb: &mut PostgresBackend<IO>,
     199            3844 :         query_string: &str,
     200            3844 :     ) -> Result<(), QueryError> {
     201            3844 :         if query_string
     202            3844 :             .to_ascii_lowercase()
     203            3844 :             .starts_with("set datestyle to ")
     204                 :         {
     205                 :             // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
     206              11 :             pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     207              11 :             return Ok(());
     208            3833 :         }
     209                 : 
     210            3833 :         let cmd = parse_cmd(query_string)?;
     211            3833 :         let cmd_str = cmd_to_string(&cmd);
     212            3833 : 
     213            3833 :         let _guard = PG_QUERIES_GAUGE.with_label_values(&[cmd_str]).guard();
     214            3833 : 
     215            3833 :         info!("got query {:?}", query_string);
     216                 : 
     217            3833 :         let tenant_id = self.tenant_id.context("tenantid is required")?;
     218            3833 :         let timeline_id = self.timeline_id.context("timelineid is required")?;
     219            3833 :         self.check_permission(Some(tenant_id))?;
     220            3831 :         self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
     221            3831 : 
     222            3831 :         match cmd {
     223                 :             SafekeeperPostgresCommand::StartWalPush => {
     224            1745 :                 self.handle_start_wal_push(pgb)
     225            1745 :                     .instrument(info_span!("WAL receiver"))
     226         2711470 :                     .await
     227                 :             }
     228             734 :             SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
     229             734 :                 self.handle_start_replication(pgb, start_lsn, term)
     230             734 :                     .instrument(info_span!("WAL sender"))
     231         2111088 :                     .await
     232                 :             }
     233             725 :             SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
     234             624 :             SafekeeperPostgresCommand::TimelineStatus => self.handle_timeline_status(pgb).await,
     235               3 :             SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
     236            9720 :                 handle_json_ctrl(self, pgb, cmd).await
     237                 :             }
     238                 :         }
     239            7298 :     }
     240                 : }
     241                 : 
     242                 : impl SafekeeperPostgresHandler {
     243            3123 :     pub fn new(
     244            3123 :         conf: SafeKeeperConf,
     245            3123 :         conn_id: u32,
     246            3123 :         io_metrics: Option<TrafficMetrics>,
     247            3123 :         auth: Option<(Scope, Arc<JwtAuth>)>,
     248            3123 :     ) -> Self {
     249            3123 :         SafekeeperPostgresHandler {
     250            3123 :             conf,
     251            3123 :             appname: None,
     252            3123 :             tenant_id: None,
     253            3123 :             timeline_id: None,
     254            3123 :             ttid: TenantTimelineId::empty(),
     255            3123 :             conn_id,
     256            3123 :             claims: None,
     257            3123 :             auth,
     258            3123 :             io_metrics,
     259            3123 :         }
     260            3123 :     }
     261                 : 
     262                 :     // when accessing management api supply None as an argument
     263                 :     // when using to authorize tenant pass corresponding tenant id
     264            3833 :     fn check_permission(&self, tenant_id: Option<TenantId>) -> Result<(), QueryError> {
     265            3833 :         if self.auth.is_none() {
     266                 :             // auth is set to Trust, nothing to check so just return ok
     267            3685 :             return Ok(());
     268             148 :         }
     269             148 :         // auth is some, just checked above, when auth is some
     270             148 :         // then claims are always present because of checks during connection init
     271             148 :         // so this expect won't trigger
     272             148 :         let claims = self
     273             148 :             .claims
     274             148 :             .as_ref()
     275             148 :             .expect("claims presence already checked");
     276             148 :         check_permission(claims, tenant_id).map_err(|e| QueryError::Unauthorized(e.0))
     277            3833 :     }
     278                 : 
     279             624 :     async fn handle_timeline_status<IO: AsyncRead + AsyncWrite + Unpin>(
     280             624 :         &mut self,
     281             624 :         pgb: &mut PostgresBackend<IO>,
     282             624 :     ) -> Result<(), QueryError> {
     283                 :         // Get timeline, handling "not found" error
     284             624 :         let tli = match GlobalTimelines::get(self.ttid) {
     285             181 :             Ok(tli) => Ok(Some(tli)),
     286             443 :             Err(TimelineError::NotFound(_)) => Ok(None),
     287 UBC           0 :             Err(e) => Err(QueryError::Other(e.into())),
     288               0 :         }?;
     289                 : 
     290                 :         // Write row description
     291 CBC         624 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     292             624 :             RowDescriptor::text_col(b"flush_lsn"),
     293             624 :             RowDescriptor::text_col(b"commit_lsn"),
     294             624 :         ]))?;
     295                 : 
     296                 :         // Write row if timeline exists
     297             624 :         if let Some(tli) = tli {
     298             181 :             let (inmem, _state) = tli.get_state().await;
     299             181 :             let flush_lsn = tli.get_flush_lsn().await;
     300             181 :             let commit_lsn = inmem.commit_lsn;
     301             181 :             pgb.write_message_noflush(&BeMessage::DataRow(&[
     302             181 :                 Some(flush_lsn.to_string().as_bytes()),
     303             181 :                 Some(commit_lsn.to_string().as_bytes()),
     304             181 :             ]))?;
     305             443 :         }
     306                 : 
     307             624 :         pgb.write_message_noflush(&BeMessage::CommandComplete(b"TIMELINE_STATUS"))?;
     308             624 :         Ok(())
     309             624 :     }
     310                 : 
     311                 :     ///
     312                 :     /// Handle IDENTIFY_SYSTEM replication command
     313                 :     ///
     314             725 :     async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
     315             725 :         &mut self,
     316             725 :         pgb: &mut PostgresBackend<IO>,
     317             725 :     ) -> Result<(), QueryError> {
     318             725 :         let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
     319                 : 
     320             724 :         let lsn = if self.is_walproposer_recovery() {
     321                 :             // walproposer should get all local WAL until flush_lsn
     322 UBC           0 :             tli.get_flush_lsn().await
     323                 :         } else {
     324                 :             // other clients shouldn't get any uncommitted WAL
     325 CBC         724 :             tli.get_state().await.0.commit_lsn
     326                 :         }
     327             724 :         .to_string();
     328                 : 
     329             724 :         let sysid = tli.get_state().await.1.server.system_id.to_string();
     330             724 :         let lsn_bytes = lsn.as_bytes();
     331             724 :         let tli = PG_TLI.to_string();
     332             724 :         let tli_bytes = tli.as_bytes();
     333             724 :         let sysid_bytes = sysid.as_bytes();
     334             724 : 
     335             724 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     336             724 :             RowDescriptor {
     337             724 :                 name: b"systemid",
     338             724 :                 typoid: TEXT_OID,
     339             724 :                 typlen: -1,
     340             724 :                 ..Default::default()
     341             724 :             },
     342             724 :             RowDescriptor {
     343             724 :                 name: b"timeline",
     344             724 :                 typoid: INT4_OID,
     345             724 :                 typlen: 4,
     346             724 :                 ..Default::default()
     347             724 :             },
     348             724 :             RowDescriptor {
     349             724 :                 name: b"xlogpos",
     350             724 :                 typoid: TEXT_OID,
     351             724 :                 typlen: -1,
     352             724 :                 ..Default::default()
     353             724 :             },
     354             724 :             RowDescriptor {
     355             724 :                 name: b"dbname",
     356             724 :                 typoid: TEXT_OID,
     357             724 :                 typlen: -1,
     358             724 :                 ..Default::default()
     359             724 :             },
     360             724 :         ]))?
     361             724 :         .write_message_noflush(&BeMessage::DataRow(&[
     362             724 :             Some(sysid_bytes),
     363             724 :             Some(tli_bytes),
     364             724 :             Some(lsn_bytes),
     365             724 :             None,
     366             724 :         ]))?
     367             724 :         .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
     368             724 :         Ok(())
     369             725 :     }
     370                 : 
     371                 :     /// Returns true if current connection is a replication connection, originating
     372                 :     /// from a walproposer recovery function. This connection gets a special handling:
     373                 :     /// safekeeper must stream all local WAL till the flush_lsn, whether committed or not.
     374             724 :     pub fn is_walproposer_recovery(&self) -> bool {
     375             724 :         match &self.appname {
     376               9 :             None => false,
     377             715 :             Some(appname) => {
     378             715 :                 appname == "wal_proposer_recovery" ||
     379                 :                 // set by safekeeper peer recovery
     380             715 :                 appname.starts_with("safekeeper")
     381                 :             }
     382                 :         }
     383             724 :     }
     384                 : }
        

Generated by: LCOV version 2.1-beta