LCOV - code coverage report
Current view: top level - safekeeper/src - handler.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 93.9 % 247 232
Test Date: 2023-09-06 10:18:01 Functions: 48.5 % 33 16

            Line data    Source code
       1              : //! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
       2              : //! protocol commands.
       3              : 
       4              : use anyhow::Context;
       5              : use std::str::FromStr;
       6              : use std::str::{self};
       7              : use std::sync::Arc;
       8              : use tokio::io::{AsyncRead, AsyncWrite};
       9              : use tracing::{info, info_span, Instrument};
      10              : 
      11              : use crate::auth::check_permission;
      12              : use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
      13              : 
      14              : use crate::metrics::{TrafficMetrics, PG_QUERIES_FINISHED, PG_QUERIES_RECEIVED};
      15              : use crate::safekeeper::Term;
      16              : use crate::timeline::TimelineError;
      17              : use crate::wal_service::ConnectionId;
      18              : use crate::{GlobalTimelines, SafeKeeperConf};
      19              : use postgres_backend::QueryError;
      20              : use postgres_backend::{self, PostgresBackend};
      21              : use postgres_ffi::PG_TLI;
      22              : use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
      23              : use regex::Regex;
      24              : use utils::auth::{Claims, JwtAuth, Scope};
      25              : use utils::{
      26              :     id::{TenantId, TenantTimelineId, TimelineId},
      27              :     lsn::Lsn,
      28              : };
      29              : 
      30              : /// Safekeeper handler of postgres commands
      31              : pub struct SafekeeperPostgresHandler {
      32              :     pub conf: SafeKeeperConf,
      33              :     /// assigned application name
      34              :     pub appname: Option<String>,
      35              :     pub tenant_id: Option<TenantId>,
      36              :     pub timeline_id: Option<TimelineId>,
      37              :     pub ttid: TenantTimelineId,
      38              :     /// Unique connection id is logged in spans for observability.
      39              :     pub conn_id: ConnectionId,
      40              :     /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
      41              :     auth: Option<(Scope, Arc<JwtAuth>)>,
      42              :     claims: Option<Claims>,
      43              :     io_metrics: Option<TrafficMetrics>,
      44              : }
      45              : 
      46              : /// Parsed Postgres command.
      47              : enum SafekeeperPostgresCommand {
      48              :     StartWalPush,
      49              :     StartReplication { start_lsn: Lsn, term: Option<Term> },
      50              :     IdentifySystem,
      51              :     TimelineStatus,
      52              :     JSONCtrl { cmd: AppendLogicalMessage },
      53              : }
      54              : 
      55         4484 : fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
      56         4484 :     if cmd.starts_with("START_WAL_PUSH") {
      57         2105 :         Ok(SafekeeperPostgresCommand::StartWalPush)
      58         2379 :     } else if cmd.starts_with("START_REPLICATION") {
      59          830 :         let re = Regex::new(
      60          830 :             // We follow postgres START_REPLICATION LOGICAL options to pass term.
      61          830 :             r"START_REPLICATION(?: SLOT [^ ]+)?(?: PHYSICAL)? ([[:xdigit:]]+/[[:xdigit:]]+)(?: \(term='(\d+)'\))?",
      62          830 :         )
      63          830 :         .unwrap();
      64          830 :         let caps = re
      65          830 :             .captures(cmd)
      66          830 :             .context(format!("failed to parse START_REPLICATION command {}", cmd))?;
      67          830 :         let start_lsn =
      68          830 :             Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?;
      69          830 :         let term = if let Some(m) = caps.get(2) {
      70            1 :             Some(m.as_str().parse::<u64>().context("invalid term")?)
      71              :         } else {
      72          829 :             None
      73              :         };
      74          830 :         Ok(SafekeeperPostgresCommand::StartReplication { start_lsn, term })
      75         1549 :     } else if cmd.starts_with("IDENTIFY_SYSTEM") {
      76          776 :         Ok(SafekeeperPostgresCommand::IdentifySystem)
      77          773 :     } else if cmd.starts_with("TIMELINE_STATUS") {
      78          770 :         Ok(SafekeeperPostgresCommand::TimelineStatus)
      79            3 :     } else if cmd.starts_with("JSON_CTRL") {
      80            3 :         let cmd = cmd.strip_prefix("JSON_CTRL").context("invalid prefix")?;
      81              :         Ok(SafekeeperPostgresCommand::JSONCtrl {
      82            3 :             cmd: serde_json::from_str(cmd)?,
      83              :         })
      84              :     } else {
      85            0 :         anyhow::bail!("unsupported command {cmd}");
      86              :     }
      87         4484 : }
      88              : 
      89         4484 : fn cmd_to_string(cmd: &SafekeeperPostgresCommand) -> &str {
      90         4484 :     match cmd {
      91         2105 :         SafekeeperPostgresCommand::StartWalPush => "START_WAL_PUSH",
      92          830 :         SafekeeperPostgresCommand::StartReplication { .. } => "START_REPLICATION",
      93          770 :         SafekeeperPostgresCommand::TimelineStatus => "TIMELINE_STATUS",
      94          776 :         SafekeeperPostgresCommand::IdentifySystem => "IDENTIFY_SYSTEM",
      95            3 :         SafekeeperPostgresCommand::JSONCtrl { .. } => "JSON_CTRL",
      96              :     }
      97         4484 : }
      98              : 
      99              : #[async_trait::async_trait]
     100              : impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
     101              :     for SafekeeperPostgresHandler
     102              : {
     103              :     // tenant_id and timeline_id are passed in connection string params
     104              :     fn startup(
     105              :         &mut self,
     106              :         _pgb: &mut PostgresBackend<IO>,
     107              :         sm: &FeStartupPacket,
     108              :     ) -> Result<(), QueryError> {
     109         3727 :         if let FeStartupPacket::StartupMessage { params, .. } = sm {
     110         3727 :             if let Some(options) = params.options_raw() {
     111        14922 :                 for opt in options {
     112              :                     // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy,
     113              :                     // remove these after the PR gets deployed:
     114              :                     // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
     115        11195 :                     match opt.split_once('=') {
     116         7474 :                         Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
     117         3727 :                             self.tenant_id = Some(value.parse().with_context(|| {
     118            0 :                                 format!("Failed to parse {value} as tenant id")
     119         3727 :                             })?);
     120              :                         }
     121         3747 :                         Some(("ztimelineid", value)) | Some(("timeline_id", value)) => {
     122         3727 :                             self.timeline_id = Some(value.parse().with_context(|| {
     123            0 :                                 format!("Failed to parse {value} as timeline id")
     124         3727 :                             })?);
     125              :                         }
     126           20 :                         Some(("availability_zone", client_az)) => {
     127            4 :                             if let Some(metrics) = self.io_metrics.as_ref() {
     128            4 :                                 metrics.set_client_az(client_az)
     129            0 :                             }
     130              :                         }
     131         3737 :                         _ => continue,
     132              :                     }
     133              :                 }
     134            0 :             }
     135              : 
     136         3727 :             if let Some(app_name) = params.get("application_name") {
     137          829 :                 self.appname = Some(app_name.to_owned());
     138          829 :                 if let Some(metrics) = self.io_metrics.as_ref() {
     139          829 :                     metrics.set_app_name(app_name)
     140            0 :                 }
     141         2898 :             }
     142              : 
     143         3727 :             Ok(())
     144              :         } else {
     145            0 :             Err(QueryError::Other(anyhow::anyhow!(
     146            0 :                 "Safekeeper received unexpected initial message: {sm:?}"
     147            0 :             )))
     148              :         }
     149         3727 :     }
     150              : 
     151          131 :     fn check_auth_jwt(
     152          131 :         &mut self,
     153          131 :         _pgb: &mut PostgresBackend<IO>,
     154          131 :         jwt_response: &[u8],
     155          131 :     ) -> Result<(), QueryError> {
     156          131 :         // this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
     157          131 :         // which requires auth to be present
     158          131 :         let (allowed_auth_scope, auth) = self
     159          131 :             .auth
     160          131 :             .as_ref()
     161          131 :             .expect("auth_type is configured but .auth of handler is missing");
     162          131 :         let data =
     163          131 :             auth.decode(str::from_utf8(jwt_response).context("jwt response is not UTF-8")?)?;
     164              : 
     165              :         // The handler might be configured to allow only tenant scope tokens.
     166          131 :         if matches!(allowed_auth_scope, Scope::Tenant)
     167          104 :             && !matches!(data.claims.scope, Scope::Tenant)
     168              :         {
     169            1 :             return Err(QueryError::Other(anyhow::anyhow!(
     170            1 :                 "passed JWT token is for full access, but only tenant scope is allowed"
     171            1 :             )));
     172          130 :         }
     173              : 
     174          130 :         if matches!(data.claims.scope, Scope::Tenant) && data.claims.tenant_id.is_none() {
     175            0 :             return Err(QueryError::Other(anyhow::anyhow!(
     176            0 :                 "jwt token scope is Tenant, but tenant id is missing"
     177            0 :             )));
     178          130 :         }
     179          130 : 
     180          130 :         info!(
     181          130 :             "jwt auth succeeded for scope: {:#?} by tenant id: {:?}",
     182          130 :             data.claims.scope, data.claims.tenant_id,
     183          130 :         );
     184              : 
     185          130 :         self.claims = Some(data.claims);
     186          130 :         Ok(())
     187          131 :     }
     188              : 
     189         4495 :     async fn process_query(
     190         4495 :         &mut self,
     191         4495 :         pgb: &mut PostgresBackend<IO>,
     192         4495 :         query_string: &str,
     193         4495 :     ) -> Result<(), QueryError> {
     194         4495 :         if query_string
     195         4495 :             .to_ascii_lowercase()
     196         4495 :             .starts_with("set datestyle to ")
     197              :         {
     198              :             // important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
     199           11 :             pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     200           11 :             return Ok(());
     201         4484 :         }
     202              : 
     203         4484 :         let cmd = parse_cmd(query_string)?;
     204         4484 :         let cmd_str = cmd_to_string(&cmd);
     205         4484 : 
     206         4484 :         PG_QUERIES_RECEIVED.with_label_values(&[cmd_str]).inc();
     207         4484 :         scopeguard::defer! {
     208         4114 :             PG_QUERIES_FINISHED.with_label_values(&[cmd_str]).inc();
     209         4114 :         }
     210         4484 : 
     211         4484 :         info!(
     212         4484 :             "got query {:?} in timeline {:?}",
     213         4484 :             query_string, self.timeline_id
     214         4484 :         );
     215              : 
     216         4484 :         let tenant_id = self.tenant_id.context("tenantid is required")?;
     217         4484 :         let timeline_id = self.timeline_id.context("timelineid is required")?;
     218         4484 :         self.check_permission(Some(tenant_id))?;
     219         4482 :         self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
     220         4482 :         let span_ttid = self.ttid; // satisfy borrow checker
     221         4482 : 
     222         4482 :         match cmd {
     223              :             SafekeeperPostgresCommand::StartWalPush => {
     224         2105 :                 self.handle_start_wal_push(pgb)
     225         2105 :                     .instrument(info_span!("WAL receiver", ttid = %span_ttid))
     226      4548274 :                     .await
     227              :             }
     228          830 :             SafekeeperPostgresCommand::StartReplication { start_lsn, term } => {
     229          830 :                 self.handle_start_replication(pgb, start_lsn, term)
     230          830 :                     .instrument(info_span!("WAL sender", ttid = %span_ttid))
     231      2769161 :                     .await
     232              :             }
     233          774 :             SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
     234          770 :             SafekeeperPostgresCommand::TimelineStatus => self.handle_timeline_status(pgb).await,
     235            3 :             SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
     236         9782 :                 handle_json_ctrl(self, pgb, cmd).await
     237              :             }
     238              :         }
     239         8620 :     }
     240              : }
     241              : 
     242              : impl SafekeeperPostgresHandler {
     243         3727 :     pub fn new(
     244         3727 :         conf: SafeKeeperConf,
     245         3727 :         conn_id: u32,
     246         3727 :         io_metrics: Option<TrafficMetrics>,
     247         3727 :         auth: Option<(Scope, Arc<JwtAuth>)>,
     248         3727 :     ) -> Self {
     249         3727 :         SafekeeperPostgresHandler {
     250         3727 :             conf,
     251         3727 :             appname: None,
     252         3727 :             tenant_id: None,
     253         3727 :             timeline_id: None,
     254         3727 :             ttid: TenantTimelineId::empty(),
     255         3727 :             conn_id,
     256         3727 :             claims: None,
     257         3727 :             auth,
     258         3727 :             io_metrics,
     259         3727 :         }
     260         3727 :     }
     261              : 
     262              :     // when accessing management api supply None as an argument
     263              :     // when using to authorize tenant pass corresponding tenant id
     264         4484 :     fn check_permission(&self, tenant_id: Option<TenantId>) -> anyhow::Result<()> {
     265         4484 :         if self.auth.is_none() {
     266              :             // auth is set to Trust, nothing to check so just return ok
     267         4334 :             return Ok(());
     268          150 :         }
     269          150 :         // auth is some, just checked above, when auth is some
     270          150 :         // then claims are always present because of checks during connection init
     271          150 :         // so this expect won't trigger
     272          150 :         let claims = self
     273          150 :             .claims
     274          150 :             .as_ref()
     275          150 :             .expect("claims presence already checked");
     276          150 :         check_permission(claims, tenant_id)
     277         4484 :     }
     278              : 
     279          770 :     async fn handle_timeline_status<IO: AsyncRead + AsyncWrite + Unpin>(
     280          770 :         &mut self,
     281          770 :         pgb: &mut PostgresBackend<IO>,
     282          770 :     ) -> Result<(), QueryError> {
     283              :         // Get timeline, handling "not found" error
     284          770 :         let tli = match GlobalTimelines::get(self.ttid) {
     285          260 :             Ok(tli) => Ok(Some(tli)),
     286          510 :             Err(TimelineError::NotFound(_)) => Ok(None),
     287            0 :             Err(e) => Err(QueryError::Other(e.into())),
     288            0 :         }?;
     289              : 
     290              :         // Write row description
     291          770 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     292          770 :             RowDescriptor::text_col(b"flush_lsn"),
     293          770 :             RowDescriptor::text_col(b"commit_lsn"),
     294          770 :         ]))?;
     295              : 
     296              :         // Write row if timeline exists
     297          770 :         if let Some(tli) = tli {
     298          260 :             let (inmem, _state) = tli.get_state().await;
     299          260 :             let flush_lsn = tli.get_flush_lsn().await;
     300          260 :             let commit_lsn = inmem.commit_lsn;
     301          260 :             pgb.write_message_noflush(&BeMessage::DataRow(&[
     302          260 :                 Some(flush_lsn.to_string().as_bytes()),
     303          260 :                 Some(commit_lsn.to_string().as_bytes()),
     304          260 :             ]))?;
     305          510 :         }
     306              : 
     307          770 :         pgb.write_message_noflush(&BeMessage::CommandComplete(b"TIMELINE_STATUS"))?;
     308          770 :         Ok(())
     309          770 :     }
     310              : 
     311              :     ///
     312              :     /// Handle IDENTIFY_SYSTEM replication command
     313              :     ///
     314          774 :     async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
     315          774 :         &mut self,
     316          774 :         pgb: &mut PostgresBackend<IO>,
     317          774 :     ) -> Result<(), QueryError> {
     318          774 :         let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
     319              : 
     320          774 :         let lsn = if self.is_walproposer_recovery() {
     321              :             // walproposer should get all local WAL until flush_lsn
     322            0 :             tli.get_flush_lsn().await
     323              :         } else {
     324              :             // other clients shouldn't get any uncommitted WAL
     325          774 :             tli.get_state().await.0.commit_lsn
     326              :         }
     327          774 :         .to_string();
     328              : 
     329          774 :         let sysid = tli.get_state().await.1.server.system_id.to_string();
     330          774 :         let lsn_bytes = lsn.as_bytes();
     331          774 :         let tli = PG_TLI.to_string();
     332          774 :         let tli_bytes = tli.as_bytes();
     333          774 :         let sysid_bytes = sysid.as_bytes();
     334          774 : 
     335          774 :         pgb.write_message_noflush(&BeMessage::RowDescription(&[
     336          774 :             RowDescriptor {
     337          774 :                 name: b"systemid",
     338          774 :                 typoid: TEXT_OID,
     339          774 :                 typlen: -1,
     340          774 :                 ..Default::default()
     341          774 :             },
     342          774 :             RowDescriptor {
     343          774 :                 name: b"timeline",
     344          774 :                 typoid: INT4_OID,
     345          774 :                 typlen: 4,
     346          774 :                 ..Default::default()
     347          774 :             },
     348          774 :             RowDescriptor {
     349          774 :                 name: b"xlogpos",
     350          774 :                 typoid: TEXT_OID,
     351          774 :                 typlen: -1,
     352          774 :                 ..Default::default()
     353          774 :             },
     354          774 :             RowDescriptor {
     355          774 :                 name: b"dbname",
     356          774 :                 typoid: TEXT_OID,
     357          774 :                 typlen: -1,
     358          774 :                 ..Default::default()
     359          774 :             },
     360          774 :         ]))?
     361          774 :         .write_message_noflush(&BeMessage::DataRow(&[
     362          774 :             Some(sysid_bytes),
     363          774 :             Some(tli_bytes),
     364          774 :             Some(lsn_bytes),
     365          774 :             None,
     366          774 :         ]))?
     367          774 :         .write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
     368          774 :         Ok(())
     369          774 :     }
     370              : 
     371              :     /// Returns true if current connection is a replication connection, originating
     372              :     /// from a walproposer recovery function. This connection gets a special handling:
     373              :     /// safekeeper must stream all local WAL till the flush_lsn, whether committed or not.
     374         1604 :     pub fn is_walproposer_recovery(&self) -> bool {
     375         1604 :         self.appname == Some("wal_proposer_recovery".to_string())
     376         1604 :     }
     377              : }
        

Generated by: LCOV version 2.1-beta