LCOV - code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit
Test: 86c536b7fe84b2afe03c3bb264199e9c319ae0f8.info Lines: 54.9 % 577 317
Test Date: 2024-06-24 16:38:41 Functions: 47.0 % 115 54

            Line data    Source code
       1              : //! Postgres protocol messages serialization-deserialization. See
       2              : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3              : //! on message formats.
       4              : #![deny(clippy::undocumented_unsafe_blocks)]
       5              : 
       6              : pub mod framed;
       7              : 
       8              : use byteorder::{BigEndian, ReadBytesExt};
       9              : use bytes::{Buf, BufMut, Bytes, BytesMut};
      10              : use itertools::Itertools;
      11              : use serde::{Deserialize, Serialize};
      12              : use std::{borrow::Cow, fmt, io, str};
      13              : 
      14              : // re-export for use in utils pageserver_feedback.rs
      15              : pub use postgres_protocol::PG_EPOCH;
      16              : 
      17              : pub type Oid = u32;
      18              : pub type SystemId = u64;
      19              : 
      20              : pub const INT8_OID: Oid = 20;
      21              : pub const INT4_OID: Oid = 23;
      22              : pub const TEXT_OID: Oid = 25;
      23              : 
      24              : #[derive(Debug)]
      25              : pub enum FeMessage {
      26              :     // Simple query.
      27              :     Query(Bytes),
      28              :     // Extended query protocol.
      29              :     Parse(FeParseMessage),
      30              :     Describe(FeDescribeMessage),
      31              :     Bind(FeBindMessage),
      32              :     Execute(FeExecuteMessage),
      33              :     Close(FeCloseMessage),
      34              :     Sync,
      35              :     Terminate,
      36              :     CopyData(Bytes),
      37              :     CopyDone,
      38              :     CopyFail,
      39              :     PasswordMessage(Bytes),
      40              : }
      41              : 
      42              : #[derive(Debug)]
      43              : pub enum FeStartupPacket {
      44              :     CancelRequest(CancelKeyData),
      45              :     SslRequest,
      46              :     GssEncRequest,
      47              :     StartupMessage {
      48              :         major_version: u32,
      49              :         minor_version: u32,
      50              :         params: StartupMessageParams,
      51              :     },
      52              : }
      53              : 
      54              : #[derive(Debug, Clone, Default)]
      55              : pub struct StartupMessageParamsBuilder {
      56              :     params: BytesMut,
      57              : }
      58              : 
      59              : impl StartupMessageParamsBuilder {
      60              :     /// Set parameter's value by its name.
      61              :     /// name and value must not contain a \0 byte
      62           50 :     pub fn insert(&mut self, name: &str, value: &str) {
      63           50 :         self.params.put(name.as_bytes());
      64           50 :         self.params.put(&b"\0"[..]);
      65           50 :         self.params.put(value.as_bytes());
      66           50 :         self.params.put(&b"\0"[..]);
      67           50 :     }
      68              : 
      69           34 :     pub fn freeze(self) -> StartupMessageParams {
      70           34 :         StartupMessageParams {
      71           34 :             params: self.params.freeze(),
      72           34 :         }
      73           34 :     }
      74              : }
      75              : 
      76              : #[derive(Debug, Clone, Default)]
      77              : pub struct StartupMessageParams {
      78              :     params: Bytes,
      79              : }
      80              : 
      81              : impl StartupMessageParams {
      82              :     /// Get parameter's value by its name.
      83           84 :     pub fn get(&self, name: &str) -> Option<&str> {
      84          116 :         self.iter().find_map(|(k, v)| (k == name).then_some(v))
      85           84 :     }
      86              : 
      87              :     /// Split command-line options according to PostgreSQL's logic,
      88              :     /// taking into account all escape sequences but leaving them as-is.
      89              :     /// [`None`] means that there's no `options` in [`Self`].
      90           48 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
      91           48 :         self.get("options").map(Self::parse_options_raw)
      92           48 :     }
      93              : 
      94              :     /// Split command-line options according to PostgreSQL's logic,
      95              :     /// applying all escape sequences (using owned strings as needed).
      96              :     /// [`None`] means that there's no `options` in [`Self`].
      97           10 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
      98           10 :         self.get("options").map(Self::parse_options_escaped)
      99           10 :     }
     100              : 
     101              :     /// Split command-line options according to PostgreSQL's logic,
     102              :     /// taking into account all escape sequences but leaving them as-is.
     103           60 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
     104           60 :         // See `postgres: pg_split_opts`.
     105           60 :         let mut last_was_escape = false;
     106           60 :         input
     107         1128 :             .split(move |c: char| {
     108              :                 // We split by non-escaped whitespace symbols.
     109         1128 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
     110         1128 :                 last_was_escape = c == '\\' && !last_was_escape;
     111         1128 :                 should_split
     112         1128 :             })
     113          152 :             .filter(|s| !s.is_empty())
     114           60 :     }
     115              : 
     116              :     /// Split command-line options according to PostgreSQL's logic,
     117              :     /// applying all escape sequences (using owned strings as needed).
     118            8 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
     119            8 :         // See `postgres: pg_split_opts`.
     120           14 :         Self::parse_options_raw(input).map(|s| {
     121           14 :             let mut preserve_next_escape = false;
     122           34 :             let escape = |c| {
     123              :                 // We should remove '\\' unless it's preceded by '\\'.
     124           34 :                 let should_remove = c == '\\' && !preserve_next_escape;
     125           34 :                 preserve_next_escape = should_remove;
     126           34 :                 should_remove
     127           34 :             };
     128              : 
     129           14 :             match s.contains('\\') {
     130            6 :                 true => Cow::Owned(s.replace(escape, "")),
     131            8 :                 false => Cow::Borrowed(s),
     132              :             }
     133           14 :         })
     134            8 :     }
     135              : 
     136              :     /// Iterate through key-value pairs in an arbitrary order.
     137           98 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     138           98 :         let params =
     139           98 :             std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
     140           98 :         params.split_terminator('\0').tuples()
     141           98 :     }
     142              : 
     143              :     // This function is mostly useful in tests.
     144              :     #[doc(hidden)]
     145           34 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     146           34 :         let mut b = StartupMessageParamsBuilder::default();
     147           84 :         for (k, v) in pairs {
     148           50 :             b.insert(k, v)
     149              :         }
     150           34 :         b.freeze()
     151           34 :     }
     152              : }
     153              : 
     154           12 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
     155              : pub struct CancelKeyData {
     156              :     pub backend_pid: i32,
     157              :     pub cancel_key: i32,
     158              : }
     159              : 
     160              : impl fmt::Display for CancelKeyData {
     161            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     162            0 :         let hi = (self.backend_pid as u64) << 32;
     163            0 :         let lo = self.cancel_key as u64;
     164            0 :         let id = hi | lo;
     165            0 : 
     166            0 :         // This format is more compact and might work better for logs.
     167            0 :         f.debug_tuple("CancelKeyData")
     168            0 :             .field(&format_args!("{:x}", id))
     169            0 :             .finish()
     170            0 :     }
     171              : }
     172              : 
     173              : use rand::distributions::{Distribution, Standard};
     174              : impl Distribution<CancelKeyData> for Standard {
     175            2 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     176            2 :         CancelKeyData {
     177            2 :             backend_pid: rng.gen(),
     178            2 :             cancel_key: rng.gen(),
     179            2 :         }
     180            2 :     }
     181              : }
     182              : 
     183              : // We only support the simple case of Parse on unnamed prepared statement and
     184              : // no params
     185              : #[derive(Debug)]
     186              : pub struct FeParseMessage {
     187              :     pub query_string: Bytes,
     188              : }
     189              : 
     190              : #[derive(Debug)]
     191              : pub struct FeDescribeMessage {
     192              :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     193              :                   // we only support unnamed prepared stmt or portal
     194              : }
     195              : 
     196              : // we only support unnamed prepared stmt and portal
     197              : #[derive(Debug)]
     198              : pub struct FeBindMessage;
     199              : 
     200              : // we only support unnamed prepared stmt or portal
     201              : #[derive(Debug)]
     202              : pub struct FeExecuteMessage {
     203              :     /// max # of rows
     204              :     pub maxrows: i32,
     205              : }
     206              : 
     207              : // we only support unnamed prepared stmt and portal
     208              : #[derive(Debug)]
     209              : pub struct FeCloseMessage;
     210              : 
     211              : /// An error occurred while parsing or serializing raw stream into Postgres
     212              : /// messages.
     213            0 : #[derive(thiserror::Error, Debug)]
     214              : pub enum ProtocolError {
     215              :     /// Invalid packet was received from the client (e.g. unexpected message
     216              :     /// type or broken len).
     217              :     #[error("Protocol error: {0}")]
     218              :     Protocol(String),
     219              :     /// Failed to parse or, (unlikely), serialize a protocol message.
     220              :     #[error("Message parse error: {0}")]
     221              :     BadMessage(String),
     222              : }
     223              : 
     224              : impl ProtocolError {
     225              :     /// Proxy stream.rs uses only io::Error; provide it.
     226            0 :     pub fn into_io_error(self) -> io::Error {
     227            0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     228            0 :     }
     229              : }
     230              : 
     231              : impl FeMessage {
     232              :     /// Read and parse one message from the `buf` input buffer. If there is at
     233              :     /// least one valid message, returns it, advancing `buf`; redundant copies
     234              :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     235              :     /// directly into the `buf` (processed data is garbage collected after
     236              :     /// parsed message is dropped).
     237              :     ///
     238              :     /// Returns None if `buf` doesn't contain enough data for a single message.
     239              :     /// For efficiency, tries to reserve large enough space in `buf` for the
     240              :     /// next message in this case to save the repeated calls.
     241              :     ///
     242              :     /// Returns Error if message is malformed, the only possible ErrorKind is
     243              :     /// InvalidInput.
     244              :     //
     245              :     // Inspired by rust-postgres Message::parse.
     246          114 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     247          114 :         // Every message contains message type byte and 4 bytes len; can't do
     248          114 :         // much without them.
     249          114 :         if buf.len() < 5 {
     250           60 :             let to_read = 5 - buf.len();
     251           60 :             buf.reserve(to_read);
     252           60 :             return Ok(None);
     253           54 :         }
     254           54 : 
     255           54 :         // We shouldn't advance `buf` as probably full message is not there yet,
     256           54 :         // so can't directly use Bytes::get_u32 etc.
     257           54 :         let tag = buf[0];
     258           54 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     259           54 :         if len < 4 {
     260            0 :             return Err(ProtocolError::Protocol(format!(
     261            0 :                 "invalid message length {}",
     262            0 :                 len
     263            0 :             )));
     264           54 :         }
     265           54 : 
     266           54 :         // length field includes itself, but not message type.
     267           54 :         let total_len = len as usize + 1;
     268           54 :         if buf.len() < total_len {
     269              :             // Don't have full message yet.
     270            0 :             let to_read = total_len - buf.len();
     271            0 :             buf.reserve(to_read);
     272            0 :             return Ok(None);
     273           54 :         }
     274           54 : 
     275           54 :         // got the message, advance buffer
     276           54 :         let mut msg = buf.split_to(total_len).freeze();
     277           54 :         msg.advance(5); // consume message type and len
     278           54 : 
     279           54 :         match tag {
     280            4 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     281            0 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     282            0 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     283            0 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     284            0 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     285            0 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     286            0 :             b'S' => Ok(Some(FeMessage::Sync)),
     287            0 :             b'X' => Ok(Some(FeMessage::Terminate)),
     288            0 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     289            0 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     290            0 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     291           50 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     292            0 :             tag => Err(ProtocolError::Protocol(format!(
     293            0 :                 "unknown message tag: {tag},'{msg:?}'"
     294            0 :             ))),
     295              :         }
     296          114 :     }
     297              : }
     298              : 
     299              : impl FeStartupPacket {
     300              :     /// Read and parse startup message from the `buf` input buffer. It is
     301              :     /// different from [`FeMessage::parse`] because startup messages don't have
     302              :     /// message type byte; otherwise, its comments apply.
     303          182 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     304          182 :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     305          182 :         const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
     306          182 :         const CANCEL_REQUEST_CODE: u32 = 5678;
     307          182 :         const NEGOTIATE_SSL_CODE: u32 = 5679;
     308          182 :         const NEGOTIATE_GSS_CODE: u32 = 5680;
     309          182 : 
     310          182 :         // need at least 4 bytes with packet len
     311          182 :         if buf.len() < 4 {
     312           90 :             let to_read = 4 - buf.len();
     313           90 :             buf.reserve(to_read);
     314           90 :             return Ok(None);
     315           92 :         }
     316           92 : 
     317           92 :         // We shouldn't advance `buf` as probably full message is not there yet,
     318           92 :         // so can't directly use Bytes::get_u32 etc.
     319           92 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     320           92 :         // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     321           92 :         // which is less readable
     322           92 :         #[allow(clippy::manual_range_contains)]
     323           92 :         if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
     324            2 :             return Err(ProtocolError::Protocol(format!(
     325            2 :                 "invalid startup packet message length {}",
     326            2 :                 len
     327            2 :             )));
     328           90 :         }
     329           90 : 
     330           90 :         if buf.len() < len {
     331              :             // Don't have full message yet.
     332            0 :             let to_read = len - buf.len();
     333            0 :             buf.reserve(to_read);
     334            0 :             return Ok(None);
     335           90 :         }
     336           90 : 
     337           90 :         // got the message, advance buffer
     338           90 :         let mut msg = buf.split_to(len).freeze();
     339           90 :         msg.advance(4); // consume len
     340           90 : 
     341           90 :         let request_code = msg.get_u32();
     342           90 :         let req_hi = request_code >> 16;
     343           90 :         let req_lo = request_code & ((1 << 16) - 1);
     344              :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     345           90 :         let message = match (req_hi, req_lo) {
     346              :             (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
     347            0 :                 if msg.remaining() != 8 {
     348            0 :                     return Err(ProtocolError::BadMessage(
     349            0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     350            0 :                             .to_owned(),
     351            0 :                     ));
     352            0 :                 }
     353            0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     354            0 :                     backend_pid: msg.get_i32(),
     355            0 :                     cancel_key: msg.get_i32(),
     356            0 :                 })
     357              :             }
     358              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
     359              :                 // Requested upgrade to SSL (aka TLS)
     360           42 :                 FeStartupPacket::SslRequest
     361              :             }
     362              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
     363              :                 // Requested upgrade to GSSAPI
     364            0 :                 FeStartupPacket::GssEncRequest
     365              :             }
     366            0 :             (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
     367            0 :                 return Err(ProtocolError::Protocol(format!(
     368            0 :                     "Unrecognized request code {unrecognized_code}"
     369            0 :                 )));
     370              :             }
     371              :             // TODO bail if protocol major_version is not 3?
     372           48 :             (major_version, minor_version) => {
     373              :                 // StartupMessage
     374              : 
     375           48 :                 let s = str::from_utf8(&msg).map_err(|_e| {
     376            0 :                     ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     377           48 :                 })?;
     378           48 :                 let s = s.strip_suffix('\0').ok_or_else(|| {
     379            0 :                     ProtocolError::Protocol(
     380            0 :                         "StartupMessage params: missing null terminator".to_string(),
     381            0 :                     )
     382           48 :                 })?;
     383              : 
     384           48 :                 FeStartupPacket::StartupMessage {
     385           48 :                     major_version,
     386           48 :                     minor_version,
     387           48 :                     params: StartupMessageParams {
     388           48 :                         params: msg.slice_ref(s.as_bytes()),
     389           48 :                     },
     390           48 :                 }
     391              :             }
     392              :         };
     393           90 :         Ok(Some(message))
     394          182 :     }
     395              : }
     396              : 
     397              : impl FeParseMessage {
     398            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     399              :         // FIXME: the rust-postgres driver uses a named prepared statement
     400              :         // for copy_out(). We're not prepared to handle that correctly. For
     401              :         // now, just ignore the statement name, assuming that the client never
     402              :         // uses more than one prepared statement at a time.
     403              : 
     404            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     405            0 :         let query_string = read_cstr(&mut buf)?;
     406            0 :         if buf.remaining() < 2 {
     407            0 :             return Err(ProtocolError::BadMessage(
     408            0 :                 "Parse message is malformed, nparams missing".to_string(),
     409            0 :             ));
     410            0 :         }
     411            0 :         let nparams = buf.get_i16();
     412            0 : 
     413            0 :         if nparams != 0 {
     414            0 :             return Err(ProtocolError::BadMessage(
     415            0 :                 "query params not implemented".to_string(),
     416            0 :             ));
     417            0 :         }
     418            0 : 
     419            0 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     420            0 :     }
     421              : }
     422              : 
     423              : impl FeDescribeMessage {
     424            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     425            0 :         let kind = buf.get_u8();
     426            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     427              : 
     428              :         // FIXME: see FeParseMessage::parse
     429            0 :         if kind != b'S' {
     430            0 :             return Err(ProtocolError::BadMessage(
     431            0 :                 "only prepared statemement Describe is implemented".to_string(),
     432            0 :             ));
     433            0 :         }
     434            0 : 
     435            0 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     436            0 :     }
     437              : }
     438              : 
     439              : impl FeExecuteMessage {
     440            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     441            0 :         let portal_name = read_cstr(&mut buf)?;
     442            0 :         if buf.remaining() < 4 {
     443            0 :             return Err(ProtocolError::BadMessage(
     444            0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     445            0 :             ));
     446            0 :         }
     447            0 :         let maxrows = buf.get_i32();
     448            0 : 
     449            0 :         if !portal_name.is_empty() {
     450            0 :             return Err(ProtocolError::BadMessage(
     451            0 :                 "named portals not implemented".to_string(),
     452            0 :             ));
     453            0 :         }
     454            0 :         if maxrows != 0 {
     455            0 :             return Err(ProtocolError::BadMessage(
     456            0 :                 "row limit in Execute message not implemented".to_string(),
     457            0 :             ));
     458            0 :         }
     459            0 : 
     460            0 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     461            0 :     }
     462              : }
     463              : 
     464              : impl FeBindMessage {
     465            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     466            0 :         let portal_name = read_cstr(&mut buf)?;
     467            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     468              : 
     469              :         // FIXME: see FeParseMessage::parse
     470            0 :         if !portal_name.is_empty() {
     471            0 :             return Err(ProtocolError::BadMessage(
     472            0 :                 "named portals not implemented".to_string(),
     473            0 :             ));
     474            0 :         }
     475            0 : 
     476            0 :         Ok(FeMessage::Bind(FeBindMessage))
     477            0 :     }
     478              : }
     479              : 
     480              : impl FeCloseMessage {
     481            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     482            0 :         let _kind = buf.get_u8();
     483            0 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     484              : 
     485              :         // FIXME: we do nothing with Close
     486            0 :         Ok(FeMessage::Close(FeCloseMessage))
     487            0 :     }
     488              : }
     489              : 
     490              : // Backend
     491              : 
     492              : #[derive(Debug)]
     493              : pub enum BeMessage<'a> {
     494              :     AuthenticationOk,
     495              :     AuthenticationMD5Password([u8; 4]),
     496              :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     497              :     AuthenticationCleartextPassword,
     498              :     BackendKeyData(CancelKeyData),
     499              :     BindComplete,
     500              :     CommandComplete(&'a [u8]),
     501              :     CopyData(&'a [u8]),
     502              :     CopyDone,
     503              :     CopyFail,
     504              :     CopyInResponse,
     505              :     CopyOutResponse,
     506              :     CopyBothResponse,
     507              :     CloseComplete,
     508              :     // None means column is NULL
     509              :     DataRow(&'a [Option<&'a [u8]>]),
     510              :     // None errcode means internal_error will be sent.
     511              :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     512              :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     513              :     EncryptionResponse(bool),
     514              :     NoData,
     515              :     ParameterDescription,
     516              :     ParameterStatus {
     517              :         name: &'a [u8],
     518              :         value: &'a [u8],
     519              :     },
     520              :     ParseComplete,
     521              :     ReadyForQuery,
     522              :     RowDescription(&'a [RowDescriptor<'a>]),
     523              :     XLogData(XLogDataBody<'a>),
     524              :     NoticeResponse(&'a str),
     525              :     KeepAlive(WalSndKeepAlive),
     526              : }
     527              : 
     528              : /// Common shorthands.
     529              : impl<'a> BeMessage<'a> {
     530              :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     531              :     /// This is a sensible default, given that:
     532              :     ///  * rust strings only support this encoding out of the box.
     533              :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     534              :     ///
     535              :     /// TODO: do we need to report `server_encoding` as well?
     536              :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     537              :         name: b"client_encoding",
     538              :         value: b"UTF8",
     539              :     };
     540              : 
     541              :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     542              :         name: b"integer_datetimes",
     543              :         value: b"on",
     544              :     };
     545              : 
     546              :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     547            4 :     pub fn server_version(version: &'a str) -> Self {
     548            4 :         Self::ParameterStatus {
     549            4 :             name: b"server_version",
     550            4 :             value: version.as_bytes(),
     551            4 :         }
     552            4 :     }
     553              : }
     554              : 
     555              : #[derive(Debug)]
     556              : pub enum BeAuthenticationSaslMessage<'a> {
     557              :     Methods(&'a [&'a str]),
     558              :     Continue(&'a [u8]),
     559              :     Final(&'a [u8]),
     560              : }
     561              : 
     562              : #[derive(Debug)]
     563              : pub enum BeParameterStatusMessage<'a> {
     564              :     Encoding(&'a str),
     565              :     ServerVersion(&'a str),
     566              : }
     567              : 
     568              : // One row description in RowDescription packet.
     569              : #[derive(Debug)]
     570              : pub struct RowDescriptor<'a> {
     571              :     pub name: &'a [u8],
     572              :     pub tableoid: Oid,
     573              :     pub attnum: i16,
     574              :     pub typoid: Oid,
     575              :     pub typlen: i16,
     576              :     pub typmod: i32,
     577              :     pub formatcode: i16,
     578              : }
     579              : 
     580              : impl Default for RowDescriptor<'_> {
     581            0 :     fn default() -> RowDescriptor<'static> {
     582            0 :         RowDescriptor {
     583            0 :             name: b"",
     584            0 :             tableoid: 0,
     585            0 :             attnum: 0,
     586            0 :             typoid: 0,
     587            0 :             typlen: 0,
     588            0 :             typmod: 0,
     589            0 :             formatcode: 0,
     590            0 :         }
     591            0 :     }
     592              : }
     593              : 
     594              : impl RowDescriptor<'_> {
     595              :     /// Convenience function to create a RowDescriptor message for an int8 column
     596            0 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     597            0 :         RowDescriptor {
     598            0 :             name,
     599            0 :             tableoid: 0,
     600            0 :             attnum: 0,
     601            0 :             typoid: INT8_OID,
     602            0 :             typlen: 8,
     603            0 :             typmod: 0,
     604            0 :             formatcode: 0,
     605            0 :         }
     606            0 :     }
     607              : 
     608            4 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     609            4 :         RowDescriptor {
     610            4 :             name,
     611            4 :             tableoid: 0,
     612            4 :             attnum: 0,
     613            4 :             typoid: TEXT_OID,
     614            4 :             typlen: -1,
     615            4 :             typmod: 0,
     616            4 :             formatcode: 0,
     617            4 :         }
     618            4 :     }
     619              : }
     620              : 
     621              : #[derive(Debug)]
     622              : pub struct XLogDataBody<'a> {
     623              :     pub wal_start: u64,
     624              :     pub wal_end: u64, // current end of WAL on the server
     625              :     pub timestamp: i64,
     626              :     pub data: &'a [u8],
     627              : }
     628              : 
     629              : #[derive(Debug)]
     630              : pub struct WalSndKeepAlive {
     631              :     pub wal_end: u64, // current end of WAL on the server
     632              :     pub timestamp: i64,
     633              :     pub request_reply: bool,
     634              : }
     635              : 
     636              : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     637              : 
     638              : // single text column
     639              : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     640              :     name: b"data",
     641              :     tableoid: 0,
     642              :     attnum: 0,
     643              :     typoid: TEXT_OID,
     644              :     typlen: -1,
     645              :     typmod: 0,
     646              :     formatcode: 0,
     647              : }]);
     648              : 
     649              : /// Call f() to write body of the message and prepend it with 4-byte len as
     650              : /// prescribed by the protocol.
     651          150 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     652          150 :     let base = buf.len();
     653          150 :     buf.extend_from_slice(&[0; 4]);
     654          150 : 
     655          150 :     let res = f(buf);
     656          150 : 
     657          150 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     658          150 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     659          150 : 
     660          150 :     res
     661          150 : }
     662              : 
     663              : /// Safe write of s into buf as cstring (String in the protocol).
     664          112 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     665          112 :     let bytes = s.as_ref();
     666          112 :     if bytes.contains(&0) {
     667            0 :         return Err(ProtocolError::BadMessage(
     668            0 :             "string contains embedded null".to_owned(),
     669            0 :         ));
     670          112 :     }
     671          112 :     buf.put_slice(bytes);
     672          112 :     buf.put_u8(0);
     673          112 :     Ok(())
     674          112 : }
     675              : 
     676              : /// Read cstring from buf, advancing it.
     677           22 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     678           22 :     let pos = buf
     679           22 :         .iter()
     680          312 :         .position(|x| *x == 0)
     681           22 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     682           22 :     let result = buf.split_to(pos);
     683           22 :     buf.advance(1); // drop the null terminator
     684           22 :     Ok(result)
     685           22 : }
     686              : 
     687              : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     688              : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
     689              : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     690              : 
     691              : impl<'a> BeMessage<'a> {
     692              :     /// Serialize `message` to the given `buf`.
     693              :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     694              :     /// precedes its body and it is handy to write it down first and then fill
     695              :     /// the length. With Write we would have to either calc it manually or have
     696              :     /// one more buffer.
     697          192 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     698          192 :         match message {
     699           24 :             BeMessage::AuthenticationOk => {
     700           24 :                 buf.put_u8(b'R');
     701           24 :                 write_body(buf, |buf| {
     702           24 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     703           24 :                 });
     704           24 :             }
     705              : 
     706            4 :             BeMessage::AuthenticationCleartextPassword => {
     707            4 :                 buf.put_u8(b'R');
     708            4 :                 write_body(buf, |buf| {
     709            4 :                     buf.put_i32(3); // Specifies that clear text password is required.
     710            4 :                 });
     711            4 :             }
     712              : 
     713            0 :             BeMessage::AuthenticationMD5Password(salt) => {
     714            0 :                 buf.put_u8(b'R');
     715            0 :                 write_body(buf, |buf| {
     716            0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     717            0 :                     buf.put_slice(&salt[..]);
     718            0 :                 });
     719            0 :             }
     720              : 
     721           60 :             BeMessage::AuthenticationSasl(msg) => {
     722           60 :                 buf.put_u8(b'R');
     723           60 :                 write_body(buf, |buf| {
     724           60 :                     use BeAuthenticationSaslMessage::*;
     725           60 :                     match msg {
     726           26 :                         Methods(methods) => {
     727           26 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     728           50 :                             for method in methods.iter() {
     729           50 :                                 write_cstr(method, buf)?;
     730              :                             }
     731           26 :                             buf.put_u8(0); // zero terminator for the list
     732              :                         }
     733           22 :                         Continue(extra) => {
     734           22 :                             buf.put_i32(11); // Continue SASL auth.
     735           22 :                             buf.put_slice(extra);
     736           22 :                         }
     737           12 :                         Final(extra) => {
     738           12 :                             buf.put_i32(12); // Send final SASL message.
     739           12 :                             buf.put_slice(extra);
     740           12 :                         }
     741              :                     }
     742           60 :                     Ok(())
     743           60 :                 })?;
     744              :             }
     745              : 
     746            0 :             BeMessage::BackendKeyData(key_data) => {
     747            0 :                 buf.put_u8(b'K');
     748            0 :                 write_body(buf, |buf| {
     749            0 :                     buf.put_i32(key_data.backend_pid);
     750            0 :                     buf.put_i32(key_data.cancel_key);
     751            0 :                 });
     752            0 :             }
     753              : 
     754            0 :             BeMessage::BindComplete => {
     755            0 :                 buf.put_u8(b'2');
     756            0 :                 write_body(buf, |_| {});
     757            0 :             }
     758              : 
     759            0 :             BeMessage::CloseComplete => {
     760            0 :                 buf.put_u8(b'3');
     761            0 :                 write_body(buf, |_| {});
     762            0 :             }
     763              : 
     764            4 :             BeMessage::CommandComplete(cmd) => {
     765            4 :                 buf.put_u8(b'C');
     766            4 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     767              :             }
     768              : 
     769            0 :             BeMessage::CopyData(data) => {
     770            0 :                 buf.put_u8(b'd');
     771            0 :                 write_body(buf, |buf| {
     772            0 :                     buf.put_slice(data);
     773            0 :                 });
     774            0 :             }
     775              : 
     776            0 :             BeMessage::CopyDone => {
     777            0 :                 buf.put_u8(b'c');
     778            0 :                 write_body(buf, |_| {});
     779            0 :             }
     780              : 
     781            0 :             BeMessage::CopyFail => {
     782            0 :                 buf.put_u8(b'f');
     783            0 :                 write_body(buf, |_| {});
     784            0 :             }
     785              : 
     786            0 :             BeMessage::CopyInResponse => {
     787            0 :                 buf.put_u8(b'G');
     788            0 :                 write_body(buf, |buf| {
     789            0 :                     buf.put_u8(1); // copy_is_binary
     790            0 :                     buf.put_i16(0); // numAttributes
     791            0 :                 });
     792            0 :             }
     793              : 
     794            0 :             BeMessage::CopyOutResponse => {
     795            0 :                 buf.put_u8(b'H');
     796            0 :                 write_body(buf, |buf| {
     797            0 :                     buf.put_u8(0); // copy_is_binary
     798            0 :                     buf.put_i16(0); // numAttributes
     799            0 :                 });
     800            0 :             }
     801              : 
     802            0 :             BeMessage::CopyBothResponse => {
     803            0 :                 buf.put_u8(b'W');
     804            0 :                 write_body(buf, |buf| {
     805            0 :                     // doesn't matter, used only for replication
     806            0 :                     buf.put_u8(0); // copy_is_binary
     807            0 :                     buf.put_i16(0); // numAttributes
     808            0 :                 });
     809            0 :             }
     810              : 
     811            4 :             BeMessage::DataRow(vals) => {
     812            4 :                 buf.put_u8(b'D');
     813            4 :                 write_body(buf, |buf| {
     814            4 :                     buf.put_u16(vals.len() as u16); // num of cols
     815            4 :                     for val_opt in vals.iter() {
     816            4 :                         if let Some(val) = val_opt {
     817            4 :                             buf.put_u32(val.len() as u32);
     818            4 :                             buf.put_slice(val);
     819            4 :                         } else {
     820            0 :                             buf.put_i32(-1);
     821            0 :                         }
     822              :                     }
     823            4 :                 });
     824            4 :             }
     825              : 
     826              :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     827              :             // First byte of each field represents type of this field. Set just enough fields
     828              :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     829              :             // message text.
     830            2 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     831            2 :                 // 'E' signalizes ErrorResponse messages
     832            2 :                 buf.put_u8(b'E');
     833            2 :                 write_body(buf, |buf| {
     834            2 :                     buf.put_u8(b'S'); // severity
     835            2 :                     buf.put_slice(b"ERROR\0");
     836            2 : 
     837            2 :                     buf.put_u8(b'C'); // SQLSTATE error code
     838            2 :                     buf.put_slice(&terminate_code(
     839            2 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     840            2 :                     ));
     841            2 : 
     842            2 :                     buf.put_u8(b'M'); // the message
     843            2 :                     write_cstr(error_msg, buf)?;
     844              : 
     845            2 :                     buf.put_u8(0); // terminator
     846            2 :                     Ok(())
     847            2 :                 })?;
     848              :             }
     849              : 
     850              :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     851              :             // message but continue listening for ReadyForQuery or ErrorResponse"
     852            0 :             BeMessage::NoticeResponse(error_msg) => {
     853            0 :                 // For all the errors set Severity to Error and error code to
     854            0 :                 // 'internal error'.
     855            0 : 
     856            0 :                 // 'N' signalizes NoticeResponse messages
     857            0 :                 buf.put_u8(b'N');
     858            0 :                 write_body(buf, |buf| {
     859            0 :                     buf.put_u8(b'S'); // severity
     860            0 :                     buf.put_slice(b"NOTICE\0");
     861            0 : 
     862            0 :                     buf.put_u8(b'C'); // SQLSTATE error code
     863            0 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     864            0 : 
     865            0 :                     buf.put_u8(b'M'); // the message
     866            0 :                     write_cstr(error_msg.as_bytes(), buf)?;
     867              : 
     868            0 :                     buf.put_u8(0); // terminator
     869            0 :                     Ok(())
     870            0 :                 })?;
     871              :             }
     872              : 
     873            0 :             BeMessage::NoData => {
     874            0 :                 buf.put_u8(b'n');
     875            0 :                 write_body(buf, |_| {});
     876            0 :             }
     877              : 
     878           42 :             BeMessage::EncryptionResponse(should_negotiate) => {
     879           42 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     880           42 :                 buf.put_u8(response);
     881              :             }
     882              : 
     883           26 :             BeMessage::ParameterStatus { name, value } => {
     884           26 :                 buf.put_u8(b'S');
     885           26 :                 write_body(buf, |buf| {
     886           26 :                     write_cstr(name, buf)?;
     887           26 :                     write_cstr(value, buf)
     888           26 :                 })?;
     889              :             }
     890              : 
     891            0 :             BeMessage::ParameterDescription => {
     892            0 :                 buf.put_u8(b't');
     893            0 :                 write_body(buf, |buf| {
     894            0 :                     // we don't support params, so always 0
     895            0 :                     buf.put_i16(0);
     896            0 :                 });
     897            0 :             }
     898              : 
     899            0 :             BeMessage::ParseComplete => {
     900            0 :                 buf.put_u8(b'1');
     901            0 :                 write_body(buf, |_| {});
     902            0 :             }
     903              : 
     904           22 :             BeMessage::ReadyForQuery => {
     905           22 :                 buf.put_u8(b'Z');
     906           22 :                 write_body(buf, |buf| {
     907           22 :                     buf.put_u8(b'I');
     908           22 :                 });
     909           22 :             }
     910              : 
     911            4 :             BeMessage::RowDescription(rows) => {
     912            4 :                 buf.put_u8(b'T');
     913            4 :                 write_body(buf, |buf| {
     914            4 :                     buf.put_i16(rows.len() as i16); // # of fields
     915            4 :                     for row in rows.iter() {
     916            4 :                         write_cstr(row.name, buf)?;
     917            4 :                         buf.put_i32(0); /* table oid */
     918            4 :                         buf.put_i16(0); /* attnum */
     919            4 :                         buf.put_u32(row.typoid);
     920            4 :                         buf.put_i16(row.typlen);
     921            4 :                         buf.put_i32(-1); /* typmod */
     922            4 :                         buf.put_i16(0); /* format code */
     923              :                     }
     924            4 :                     Ok(())
     925            4 :                 })?;
     926              :             }
     927              : 
     928            0 :             BeMessage::XLogData(body) => {
     929            0 :                 buf.put_u8(b'd');
     930            0 :                 write_body(buf, |buf| {
     931            0 :                     buf.put_u8(b'w');
     932            0 :                     buf.put_u64(body.wal_start);
     933            0 :                     buf.put_u64(body.wal_end);
     934            0 :                     buf.put_i64(body.timestamp);
     935            0 :                     buf.put_slice(body.data);
     936            0 :                 });
     937            0 :             }
     938              : 
     939            0 :             BeMessage::KeepAlive(req) => {
     940            0 :                 buf.put_u8(b'd');
     941            0 :                 write_body(buf, |buf| {
     942            0 :                     buf.put_u8(b'k');
     943            0 :                     buf.put_u64(req.wal_end);
     944            0 :                     buf.put_i64(req.timestamp);
     945            0 :                     buf.put_u8(u8::from(req.request_reply));
     946            0 :                 });
     947            0 :             }
     948              :         }
     949          192 :         Ok(())
     950          192 :     }
     951              : }
     952              : 
     953            2 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
     954            2 :     let mut terminated = [0; 6];
     955           10 :     for (i, &elem) in code.iter().enumerate() {
     956           10 :         terminated[i] = elem;
     957           10 :     }
     958              : 
     959            2 :     terminated
     960            2 : }
     961              : 
     962              : #[cfg(test)]
     963              : mod tests {
     964              :     use super::*;
     965              : 
     966              :     #[test]
     967            2 :     fn test_startup_message_params_options_escaped() {
     968            8 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
     969            8 :             params
     970            8 :                 .options_escaped()
     971            8 :                 .expect("options are None")
     972            8 :                 .collect()
     973            8 :         }
     974            2 : 
     975            8 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
     976              : 
     977            2 :         let params = StartupMessageParams::new([]);
     978            2 :         assert!(params.options_escaped().is_none());
     979              : 
     980            2 :         let params = make_params("");
     981            2 :         assert!(split_options(&params).is_empty());
     982              : 
     983            2 :         let params = make_params("foo");
     984            2 :         assert_eq!(split_options(&params), ["foo"]);
     985              : 
     986            2 :         let params = make_params(" foo  bar ");
     987            2 :         assert_eq!(split_options(&params), ["foo", "bar"]);
     988              : 
     989            2 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
     990            2 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
     991            2 :     }
     992              : 
     993              :     #[test]
     994            2 :     fn parse_fe_startup_packet_regression() {
     995            2 :         let data = [0, 0, 0, 7, 0, 0, 0, 0];
     996            2 :         FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
     997            2 :     }
     998              : }
        

Generated by: LCOV version 2.1-beta