LCOV - differential code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 84.9 % 584 496 88 496
Current Date: 2023-10-19 02:04:12 Functions: 76.7 % 120 92 28 92
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : //! Postgres protocol messages serialization-deserialization. See
       2                 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3                 : //! on message formats.
       4                 : 
       5                 : pub mod framed;
       6                 : 
       7                 : use byteorder::{BigEndian, ReadBytesExt};
       8                 : use bytes::{Buf, BufMut, Bytes, BytesMut};
       9                 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
      10                 : 
      11                 : // re-export for use in utils pageserver_feedback.rs
      12                 : pub use postgres_protocol::PG_EPOCH;
      13                 : 
      14                 : pub type Oid = u32;
      15                 : pub type SystemId = u64;
      16                 : 
      17                 : pub const INT8_OID: Oid = 20;
      18                 : pub const INT4_OID: Oid = 23;
      19                 : pub const TEXT_OID: Oid = 25;
      20                 : 
      21 CBC          59 : #[derive(Debug)]
      22                 : pub enum FeMessage {
      23                 :     // Simple query.
      24                 :     Query(Bytes),
      25                 :     // Extended query protocol.
      26                 :     Parse(FeParseMessage),
      27                 :     Describe(FeDescribeMessage),
      28                 :     Bind(FeBindMessage),
      29                 :     Execute(FeExecuteMessage),
      30                 :     Close(FeCloseMessage),
      31                 :     Sync,
      32                 :     Terminate,
      33                 :     CopyData(Bytes),
      34                 :     CopyDone,
      35                 :     CopyFail,
      36                 :     PasswordMessage(Bytes),
      37                 : }
      38                 : 
      39             140 : #[derive(Debug)]
      40                 : pub enum FeStartupPacket {
      41                 :     CancelRequest(CancelKeyData),
      42                 :     SslRequest,
      43                 :     GssEncRequest,
      44                 :     StartupMessage {
      45                 :         major_version: u32,
      46                 :         minor_version: u32,
      47                 :         params: StartupMessageParams,
      48                 :     },
      49                 : }
      50                 : 
      51              74 : #[derive(Debug)]
      52                 : pub struct StartupMessageParams {
      53                 :     params: HashMap<String, String>,
      54                 : }
      55                 : 
      56                 : impl StartupMessageParams {
      57                 :     /// Get parameter's value by its name.
      58            7454 :     pub fn get(&self, name: &str) -> Option<&str> {
      59            7454 :         self.params.get(name).map(|s| s.as_str())
      60            7454 :     }
      61                 : 
      62                 :     /// Split command-line options according to PostgreSQL's logic,
      63                 :     /// taking into account all escape sequences but leaving them as-is.
      64                 :     /// [`None`] means that there's no `options` in [`Self`].
      65            3663 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
      66            3663 :         self.get("options").map(Self::parse_options_raw)
      67            3663 :     }
      68                 : 
      69                 :     /// Split command-line options according to PostgreSQL's logic,
      70                 :     /// applying all escape sequences (using owned strings as needed).
      71                 :     /// [`None`] means that there's no `options` in [`Self`].
      72               5 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
      73               5 :         self.get("options").map(Self::parse_options_escaped)
      74               5 :     }
      75                 : 
      76                 :     /// Split command-line options according to PostgreSQL's logic,
      77                 :     /// taking into account all escape sequences but leaving them as-is.
      78            3634 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
      79            3634 :         // See `postgres: pg_split_opts`.
      80            3634 :         let mut last_was_escape = false;
      81            3634 :         input
      82          323264 :             .split(move |c: char| {
      83                 :                 // We split by non-escaped whitespace symbols.
      84          323264 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
      85          323264 :                 last_was_escape = c == '\\' && !last_was_escape;
      86          323264 :                 should_split
      87          323264 :             })
      88           10863 :             .filter(|s| !s.is_empty())
      89            3634 :     }
      90                 : 
      91                 :     /// Split command-line options according to PostgreSQL's logic,
      92                 :     /// applying all escape sequences (using owned strings as needed).
      93               4 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
      94               4 :         // See `postgres: pg_split_opts`.
      95               7 :         Self::parse_options_raw(input).map(|s| {
      96               7 :             let mut preserve_next_escape = false;
      97              17 :             let escape = |c| {
      98                 :                 // We should remove '\\' unless it's preceded by '\\'.
      99              17 :                 let should_remove = c == '\\' && !preserve_next_escape;
     100              17 :                 preserve_next_escape = should_remove;
     101              17 :                 should_remove
     102              17 :             };
     103                 : 
     104               7 :             match s.contains('\\') {
     105               3 :                 true => Cow::Owned(s.replace(escape, "")),
     106               4 :                 false => Cow::Borrowed(s),
     107                 :             }
     108               7 :         })
     109               4 :     }
     110                 : 
     111                 :     /// Iterate through key-value pairs in an arbitrary order.
     112 UBC           0 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     113               0 :         self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
     114               0 :     }
     115                 : 
     116                 :     // This function is mostly useful in tests.
     117                 :     #[doc(hidden)]
     118 CBC          48 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     119              48 :         Self {
     120             109 :             params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
     121              48 :         }
     122              48 :     }
     123                 : }
     124                 : 
     125              98 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
     126                 : pub struct CancelKeyData {
     127                 :     pub backend_pid: i32,
     128                 :     pub cancel_key: i32,
     129                 : }
     130                 : 
     131                 : impl fmt::Display for CancelKeyData {
     132             132 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     133             132 :         let hi = (self.backend_pid as u64) << 32;
     134             132 :         let lo = self.cancel_key as u64;
     135             132 :         let id = hi | lo;
     136             132 : 
     137             132 :         // This format is more compact and might work better for logs.
     138             132 :         f.debug_tuple("CancelKeyData")
     139             132 :             .field(&format_args!("{:x}", id))
     140             132 :             .finish()
     141             132 :     }
     142                 : }
     143                 : 
     144                 : use rand::distributions::{Distribution, Standard};
     145                 : impl Distribution<CancelKeyData> for Standard {
     146              34 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     147              34 :         CancelKeyData {
     148              34 :             backend_pid: rng.gen(),
     149              34 :             cancel_key: rng.gen(),
     150              34 :         }
     151              34 :     }
     152                 : }
     153                 : 
     154                 : // We only support the simple case of Parse on unnamed prepared statement and
     155                 : // no params
     156 UBC           0 : #[derive(Debug)]
     157                 : pub struct FeParseMessage {
     158                 :     pub query_string: Bytes,
     159                 : }
     160                 : 
     161               0 : #[derive(Debug)]
     162                 : pub struct FeDescribeMessage {
     163                 :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     164                 :                   // we only support unnamed prepared stmt or portal
     165                 : }
     166                 : 
     167                 : // we only support unnamed prepared stmt and portal
     168               0 : #[derive(Debug)]
     169                 : pub struct FeBindMessage;
     170                 : 
     171                 : // we only support unnamed prepared stmt or portal
     172               0 : #[derive(Debug)]
     173                 : pub struct FeExecuteMessage {
     174                 :     /// max # of rows
     175                 :     pub maxrows: i32,
     176                 : }
     177                 : 
     178                 : // we only support unnamed prepared stmt and portal
     179               0 : #[derive(Debug)]
     180                 : pub struct FeCloseMessage;
     181                 : 
     182                 : /// An error occurred while parsing or serializing raw stream into Postgres
     183                 : /// messages.
     184 CBC           7 : #[derive(thiserror::Error, Debug)]
     185                 : pub enum ProtocolError {
     186                 :     /// Invalid packet was received from the client (e.g. unexpected message
     187                 :     /// type or broken len).
     188                 :     #[error("Protocol error: {0}")]
     189                 :     Protocol(String),
     190                 :     /// Failed to parse or, (unlikely), serialize a protocol message.
     191                 :     #[error("Message parse error: {0}")]
     192                 :     BadMessage(String),
     193                 : }
     194                 : 
     195                 : impl ProtocolError {
     196                 :     /// Proxy stream.rs uses only io::Error; provide it.
     197 UBC           0 :     pub fn into_io_error(self) -> io::Error {
     198               0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     199               0 :     }
     200                 : }
     201                 : 
     202                 : impl FeMessage {
     203                 :     /// Read and parse one message from the `buf` input buffer. If there is at
     204                 :     /// least one valid message, returns it, advancing `buf`; redundant copies
     205                 :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     206                 :     /// directly into the `buf` (processed data is garbage collected after
     207                 :     /// parsed message is dropped).
     208                 :     ///
     209                 :     /// Returns None if `buf` doesn't contain enough data for a single message.
     210                 :     /// For efficiency, tries to reserve large enough space in `buf` for the
     211                 :     /// next message in this case to save the repeated calls.
     212                 :     ///
     213                 :     /// Returns Error if message is malformed, the only possible ErrorKind is
     214                 :     /// InvalidInput.
     215                 :     //
     216                 :     // Inspired by rust-postgres Message::parse.
     217 CBC    14442968 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     218        14442968 :         // Every message contains message type byte and 4 bytes len; can't do
     219        14442968 :         // much without them.
     220        14442968 :         if buf.len() < 5 {
     221         6811831 :             let to_read = 5 - buf.len();
     222         6811831 :             buf.reserve(to_read);
     223         6811831 :             return Ok(None);
     224         7631137 :         }
     225         7631137 : 
     226         7631137 :         // We shouldn't advance `buf` as probably full message is not there yet,
     227         7631137 :         // so can't directly use Bytes::get_u32 etc.
     228         7631137 :         let tag = buf[0];
     229         7631137 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     230         7631137 :         if len < 4 {
     231 UBC           0 :             return Err(ProtocolError::Protocol(format!(
     232               0 :                 "invalid message length {}",
     233               0 :                 len
     234               0 :             )));
     235 CBC     7631137 :         }
     236         7631137 : 
     237         7631137 :         // length field includes itself, but not message type.
     238         7631137 :         let total_len = len as usize + 1;
     239         7631137 :         if buf.len() < total_len {
     240                 :             // Don't have full message yet.
     241          112724 :             let to_read = total_len - buf.len();
     242          112724 :             buf.reserve(to_read);
     243          112724 :             return Ok(None);
     244         7518413 :         }
     245         7518413 : 
     246         7518413 :         // got the message, advance buffer
     247         7518413 :         let mut msg = buf.split_to(total_len).freeze();
     248         7518413 :         msg.advance(5); // consume message type and len
     249         7518413 : 
     250         7518413 :         match tag {
     251            8744 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     252             640 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     253             640 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     254             640 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     255             640 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     256             639 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     257            1934 :             b'S' => Ok(Some(FeMessage::Sync)),
     258            1525 :             b'X' => Ok(Some(FeMessage::Terminate)),
     259         7502651 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     260               6 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     261              70 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     262             284 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     263 UBC           0 :             tag => Err(ProtocolError::Protocol(format!(
     264               0 :                 "unknown message tag: {tag},'{msg:?}'"
     265               0 :             ))),
     266                 :         }
     267 CBC    14442968 :     }
     268                 : }
     269                 : 
     270                 : impl FeStartupPacket {
     271                 :     /// Read and parse startup message from the `buf` input buffer. It is
     272                 :     /// different from [`FeMessage::parse`] because startup messages don't have
     273                 :     /// message type byte; otherwise, its comments apply.
     274           30764 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     275           30764 :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     276           30764 :         const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
     277           30764 :         const CANCEL_REQUEST_CODE: u32 = 5678;
     278           30764 :         const NEGOTIATE_SSL_CODE: u32 = 5679;
     279           30764 :         const NEGOTIATE_GSS_CODE: u32 = 5680;
     280           30764 : 
     281           30764 :         // need at least 4 bytes with packet len
     282           30764 :         if buf.len() < 4 {
     283           15383 :             let to_read = 4 - buf.len();
     284           15383 :             buf.reserve(to_read);
     285           15383 :             return Ok(None);
     286           15381 :         }
     287           15381 : 
     288           15381 :         // We shouldn't advance `buf` as probably full message is not there yet,
     289           15381 :         // so can't directly use Bytes::get_u32 etc.
     290           15381 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     291           15381 :         // The proposed replacement is `!(4..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     292           15381 :         // which is less readable
     293           15381 :         #[allow(clippy::manual_range_contains)]
     294           15381 :         if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
     295 UBC           0 :             return Err(ProtocolError::Protocol(format!(
     296               0 :                 "invalid startup packet message length {}",
     297               0 :                 len
     298               0 :             )));
     299 CBC       15381 :         }
     300           15381 : 
     301           15381 :         if buf.len() < len {
     302                 :             // Don't have full message yet.
     303 UBC           0 :             let to_read = len - buf.len();
     304               0 :             buf.reserve(to_read);
     305               0 :             return Ok(None);
     306 CBC       15381 :         }
     307           15381 : 
     308           15381 :         // got the message, advance buffer
     309           15381 :         let mut msg = buf.split_to(len).freeze();
     310           15381 :         msg.advance(4); // consume len
     311           15381 : 
     312           15381 :         let request_code = msg.get_u32();
     313           15381 :         let req_hi = request_code >> 16;
     314           15381 :         let req_lo = request_code & ((1 << 16) - 1);
     315                 :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     316           15381 :         let message = match (req_hi, req_lo) {
     317                 :             (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
     318 UBC           0 :                 if msg.remaining() != 8 {
     319               0 :                     return Err(ProtocolError::BadMessage(
     320               0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     321               0 :                             .to_owned(),
     322               0 :                     ));
     323               0 :                 }
     324               0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     325               0 :                     backend_pid: msg.get_i32(),
     326               0 :                     cancel_key: msg.get_i32(),
     327               0 :                 })
     328                 :             }
     329                 :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
     330                 :                 // Requested upgrade to SSL (aka TLS)
     331 CBC        6653 :                 FeStartupPacket::SslRequest
     332                 :             }
     333                 :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
     334                 :                 // Requested upgrade to GSSAPI
     335 UBC           0 :                 FeStartupPacket::GssEncRequest
     336                 :             }
     337               0 :             (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
     338               0 :                 return Err(ProtocolError::Protocol(format!(
     339               0 :                     "Unrecognized request code {unrecognized_code}"
     340               0 :                 )));
     341                 :             }
     342                 :             // TODO bail if protocol major_version is not 3?
     343 CBC        8728 :             (major_version, minor_version) => {
     344                 :                 // StartupMessage
     345                 : 
     346                 :                 // Parse pairs of null-terminated strings (key, value).
     347                 :                 // See `postgres: ProcessStartupPacket, build_startup_packet`.
     348            8728 :                 let mut tokens = str::from_utf8(&msg)
     349            8728 :                     .map_err(|_e| {
     350 UBC           0 :                         ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     351 CBC        8728 :                     })?
     352            8728 :                     .strip_suffix('\0') // drop packet's own null
     353            8728 :                     .ok_or_else(|| {
     354 UBC           0 :                         ProtocolError::Protocol(
     355               0 :                             "StartupMessage params: missing null terminator".to_string(),
     356               0 :                         )
     357 CBC        8728 :                     })?
     358            8728 :                     .split_terminator('\0');
     359            8728 : 
     360            8728 :                 let mut params = HashMap::new();
     361           30726 :                 while let Some(name) = tokens.next() {
     362           21998 :                     let value = tokens.next().ok_or_else(|| {
     363 UBC           0 :                         ProtocolError::Protocol(
     364               0 :                             "StartupMessage params: key without value".to_string(),
     365               0 :                         )
     366 CBC       21998 :                     })?;
     367                 : 
     368           21998 :                     params.insert(name.to_owned(), value.to_owned());
     369                 :                 }
     370                 : 
     371            8728 :                 FeStartupPacket::StartupMessage {
     372            8728 :                     major_version,
     373            8728 :                     minor_version,
     374            8728 :                     params: StartupMessageParams { params },
     375            8728 :                 }
     376                 :             }
     377                 :         };
     378           15381 :         Ok(Some(message))
     379           30764 :     }
     380                 : }
     381                 : 
     382                 : impl FeParseMessage {
     383             640 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     384                 :         // FIXME: the rust-postgres driver uses a named prepared statement
     385                 :         // for copy_out(). We're not prepared to handle that correctly. For
     386                 :         // now, just ignore the statement name, assuming that the client never
     387                 :         // uses more than one prepared statement at a time.
     388                 : 
     389             640 :         let _pstmt_name = read_cstr(&mut buf)?;
     390             640 :         let query_string = read_cstr(&mut buf)?;
     391             640 :         if buf.remaining() < 2 {
     392 UBC           0 :             return Err(ProtocolError::BadMessage(
     393               0 :                 "Parse message is malformed, nparams missing".to_string(),
     394               0 :             ));
     395 CBC         640 :         }
     396             640 :         let nparams = buf.get_i16();
     397             640 : 
     398             640 :         if nparams != 0 {
     399 UBC           0 :             return Err(ProtocolError::BadMessage(
     400               0 :                 "query params not implemented".to_string(),
     401               0 :             ));
     402 CBC         640 :         }
     403             640 : 
     404             640 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     405             640 :     }
     406                 : }
     407                 : 
     408                 : impl FeDescribeMessage {
     409             640 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     410             640 :         let kind = buf.get_u8();
     411             640 :         let _pstmt_name = read_cstr(&mut buf)?;
     412                 : 
     413                 :         // FIXME: see FeParseMessage::parse
     414             640 :         if kind != b'S' {
     415 UBC           0 :             return Err(ProtocolError::BadMessage(
     416               0 :                 "only prepared statemement Describe is implemented".to_string(),
     417               0 :             ));
     418 CBC         640 :         }
     419             640 : 
     420             640 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     421             640 :     }
     422                 : }
     423                 : 
     424                 : impl FeExecuteMessage {
     425             640 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     426             640 :         let portal_name = read_cstr(&mut buf)?;
     427             640 :         if buf.remaining() < 4 {
     428 UBC           0 :             return Err(ProtocolError::BadMessage(
     429               0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     430               0 :             ));
     431 CBC         640 :         }
     432             640 :         let maxrows = buf.get_i32();
     433             640 : 
     434             640 :         if !portal_name.is_empty() {
     435 UBC           0 :             return Err(ProtocolError::BadMessage(
     436               0 :                 "named portals not implemented".to_string(),
     437               0 :             ));
     438 CBC         640 :         }
     439             640 :         if maxrows != 0 {
     440 UBC           0 :             return Err(ProtocolError::BadMessage(
     441               0 :                 "row limit in Execute message not implemented".to_string(),
     442               0 :             ));
     443 CBC         640 :         }
     444             640 : 
     445             640 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     446             640 :     }
     447                 : }
     448                 : 
     449                 : impl FeBindMessage {
     450             640 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     451             640 :         let portal_name = read_cstr(&mut buf)?;
     452             640 :         let _pstmt_name = read_cstr(&mut buf)?;
     453                 : 
     454                 :         // FIXME: see FeParseMessage::parse
     455             640 :         if !portal_name.is_empty() {
     456 UBC           0 :             return Err(ProtocolError::BadMessage(
     457               0 :                 "named portals not implemented".to_string(),
     458               0 :             ));
     459 CBC         640 :         }
     460             640 : 
     461             640 :         Ok(FeMessage::Bind(FeBindMessage))
     462             640 :     }
     463                 : }
     464                 : 
     465                 : impl FeCloseMessage {
     466             639 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     467             639 :         let _kind = buf.get_u8();
     468             639 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     469                 : 
     470                 :         // FIXME: we do nothing with Close
     471             639 :         Ok(FeMessage::Close(FeCloseMessage))
     472             639 :     }
     473                 : }
     474                 : 
     475                 : // Backend
     476                 : 
     477 UBC           0 : #[derive(Debug)]
     478                 : pub enum BeMessage<'a> {
     479                 :     AuthenticationOk,
     480                 :     AuthenticationMD5Password([u8; 4]),
     481                 :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     482                 :     AuthenticationCleartextPassword,
     483                 :     BackendKeyData(CancelKeyData),
     484                 :     BindComplete,
     485                 :     CommandComplete(&'a [u8]),
     486                 :     CopyData(&'a [u8]),
     487                 :     CopyDone,
     488                 :     CopyFail,
     489                 :     CopyInResponse,
     490                 :     CopyOutResponse,
     491                 :     CopyBothResponse,
     492                 :     CloseComplete,
     493                 :     // None means column is NULL
     494                 :     DataRow(&'a [Option<&'a [u8]>]),
     495                 :     // None errcode means internal_error will be sent.
     496                 :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     497                 :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     498                 :     EncryptionResponse(bool),
     499                 :     NoData,
     500                 :     ParameterDescription,
     501                 :     ParameterStatus {
     502                 :         name: &'a [u8],
     503                 :         value: &'a [u8],
     504                 :     },
     505                 :     ParseComplete,
     506                 :     ReadyForQuery,
     507                 :     RowDescription(&'a [RowDescriptor<'a>]),
     508                 :     XLogData(XLogDataBody<'a>),
     509                 :     NoticeResponse(&'a str),
     510                 :     KeepAlive(WalSndKeepAlive),
     511                 : }
     512                 : 
     513                 : /// Common shorthands.
     514                 : impl<'a> BeMessage<'a> {
     515                 :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     516                 :     /// This is a sensible default, given that:
     517                 :     ///  * rust strings only support this encoding out of the box.
     518                 :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     519                 :     ///
     520                 :     /// TODO: do we need to report `server_encoding` as well?
     521                 :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     522                 :         name: b"client_encoding",
     523                 :         value: b"UTF8",
     524                 :     };
     525                 : 
     526                 :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     527                 :         name: b"integer_datetimes",
     528                 :         value: b"on",
     529                 :     };
     530                 : 
     531                 :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     532 CBC        8459 :     pub fn server_version(version: &'a str) -> Self {
     533            8459 :         Self::ParameterStatus {
     534            8459 :             name: b"server_version",
     535            8459 :             value: version.as_bytes(),
     536            8459 :         }
     537            8459 :     }
     538                 : }
     539                 : 
     540 UBC           0 : #[derive(Debug)]
     541                 : pub enum BeAuthenticationSaslMessage<'a> {
     542                 :     Methods(&'a [&'a str]),
     543                 :     Continue(&'a [u8]),
     544                 :     Final(&'a [u8]),
     545                 : }
     546                 : 
     547               0 : #[derive(Debug)]
     548                 : pub enum BeParameterStatusMessage<'a> {
     549                 :     Encoding(&'a str),
     550                 :     ServerVersion(&'a str),
     551                 : }
     552                 : 
     553                 : // One row description in RowDescription packet.
     554               0 : #[derive(Debug)]
     555                 : pub struct RowDescriptor<'a> {
     556                 :     pub name: &'a [u8],
     557                 :     pub tableoid: Oid,
     558                 :     pub attnum: i16,
     559                 :     pub typoid: Oid,
     560                 :     pub typlen: i16,
     561                 :     pub typmod: i32,
     562                 :     pub formatcode: i16,
     563                 : }
     564                 : 
     565                 : impl Default for RowDescriptor<'_> {
     566 CBC        2791 :     fn default() -> RowDescriptor<'static> {
     567            2791 :         RowDescriptor {
     568            2791 :             name: b"",
     569            2791 :             tableoid: 0,
     570            2791 :             attnum: 0,
     571            2791 :             typoid: 0,
     572            2791 :             typlen: 0,
     573            2791 :             typmod: 0,
     574            2791 :             formatcode: 0,
     575            2791 :         }
     576            2791 :     }
     577                 : }
     578                 : 
     579                 : impl RowDescriptor<'_> {
     580                 :     /// Convenience function to create a RowDescriptor message for an int8 column
     581              45 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     582              45 :         RowDescriptor {
     583              45 :             name,
     584              45 :             tableoid: 0,
     585              45 :             attnum: 0,
     586              45 :             typoid: INT8_OID,
     587              45 :             typlen: 8,
     588              45 :             typmod: 0,
     589              45 :             formatcode: 0,
     590              45 :         }
     591              45 :     }
     592                 : 
     593            1508 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     594            1508 :         RowDescriptor {
     595            1508 :             name,
     596            1508 :             tableoid: 0,
     597            1508 :             attnum: 0,
     598            1508 :             typoid: TEXT_OID,
     599            1508 :             typlen: -1,
     600            1508 :             typmod: 0,
     601            1508 :             formatcode: 0,
     602            1508 :         }
     603            1508 :     }
     604                 : }
     605                 : 
     606 UBC           0 : #[derive(Debug)]
     607                 : pub struct XLogDataBody<'a> {
     608                 :     pub wal_start: u64,
     609                 :     pub wal_end: u64, // current end of WAL on the server
     610                 :     pub timestamp: i64,
     611                 :     pub data: &'a [u8],
     612                 : }
     613                 : 
     614               0 : #[derive(Debug)]
     615                 : pub struct WalSndKeepAlive {
     616                 :     pub wal_end: u64, // current end of WAL on the server
     617                 :     pub timestamp: i64,
     618                 :     pub request_reply: bool,
     619                 : }
     620                 : 
     621                 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     622                 : 
     623                 : // single text column
     624                 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     625                 :     name: b"data",
     626                 :     tableoid: 0,
     627                 :     attnum: 0,
     628                 :     typoid: TEXT_OID,
     629                 :     typlen: -1,
     630                 :     typmod: 0,
     631                 :     formatcode: 0,
     632                 : }]);
     633                 : 
     634                 : /// Call f() to write body of the message and prepend it with 4-byte len as
     635                 : /// prescribed by the protocol.
     636 CBC     8027128 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     637         8027128 :     let base = buf.len();
     638         8027128 :     buf.extend_from_slice(&[0; 4]);
     639         8027128 : 
     640         8027128 :     let res = f(buf);
     641         8027128 : 
     642         8027128 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     643         8027128 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     644         8027128 : 
     645         8027128 :     res
     646         8027128 : }
     647                 : 
     648                 : /// Safe write of s into buf as cstring (String in the protocol).
     649           58728 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     650           58728 :     let bytes = s.as_ref();
     651           58728 :     if bytes.contains(&0) {
     652 UBC           0 :         return Err(ProtocolError::BadMessage(
     653               0 :             "string contains embedded null".to_owned(),
     654               0 :         ));
     655 CBC       58728 :     }
     656           58728 :     buf.put_slice(bytes);
     657           58728 :     buf.put_u8(0);
     658           58728 :     Ok(())
     659           58728 : }
     660                 : 
     661                 : /// Read cstring from buf, advancing it.
     662         3901695 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     663         3901695 :     let pos = buf
     664         3901695 :         .iter()
     665        55402786 :         .position(|x| *x == 0)
     666         3901695 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     667         3901695 :     let result = buf.split_to(pos);
     668         3901695 :     buf.advance(1); // drop the null terminator
     669         3901695 :     Ok(result)
     670         3901695 : }
     671                 : 
     672                 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     673                 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
     674                 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     675                 : 
     676                 : impl<'a> BeMessage<'a> {
     677                 :     /// Serialize `message` to the given `buf`.
     678                 :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     679                 :     /// precedes its body and it is handy to write it down first and then fill
     680                 :     /// the length. With Write we would have to either calc it manually or have
     681                 :     /// one more buffer.
     682         8033781 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     683         8033781 :         match message {
     684            8711 :             BeMessage::AuthenticationOk => {
     685            8711 :                 buf.put_u8(b'R');
     686            8711 :                 write_body(buf, |buf| {
     687            8711 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     688            8711 :                 });
     689            8711 :             }
     690                 : 
     691             228 :             BeMessage::AuthenticationCleartextPassword => {
     692             228 :                 buf.put_u8(b'R');
     693             228 :                 write_body(buf, |buf| {
     694             228 :                     buf.put_i32(3); // Specifies that clear text password is required.
     695             228 :                 });
     696             228 :             }
     697                 : 
     698 UBC           0 :             BeMessage::AuthenticationMD5Password(salt) => {
     699               0 :                 buf.put_u8(b'R');
     700               0 :                 write_body(buf, |buf| {
     701               0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     702               0 :                     buf.put_slice(&salt[..]);
     703               0 :                 });
     704               0 :             }
     705                 : 
     706 CBC          89 :             BeMessage::AuthenticationSasl(msg) => {
     707              89 :                 buf.put_u8(b'R');
     708              89 :                 write_body(buf, |buf| {
     709              89 :                     use BeAuthenticationSaslMessage::*;
     710              89 :                     match msg {
     711              31 :                         Methods(methods) => {
     712              31 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     713              31 :                             for method in methods.iter() {
     714              31 :                                 write_cstr(method, buf)?;
     715                 :                             }
     716              31 :                             buf.put_u8(0); // zero terminator for the list
     717                 :                         }
     718              31 :                         Continue(extra) => {
     719              31 :                             buf.put_i32(11); // Continue SASL auth.
     720              31 :                             buf.put_slice(extra);
     721              31 :                         }
     722              27 :                         Final(extra) => {
     723              27 :                             buf.put_i32(12); // Send final SASL message.
     724              27 :                             buf.put_slice(extra);
     725              27 :                         }
     726                 :                     }
     727              89 :                     Ok(())
     728              89 :                 })?;
     729                 :             }
     730                 : 
     731              29 :             BeMessage::BackendKeyData(key_data) => {
     732              29 :                 buf.put_u8(b'K');
     733              29 :                 write_body(buf, |buf| {
     734              29 :                     buf.put_i32(key_data.backend_pid);
     735              29 :                     buf.put_i32(key_data.cancel_key);
     736              29 :                 });
     737              29 :             }
     738                 : 
     739             640 :             BeMessage::BindComplete => {
     740             640 :                 buf.put_u8(b'2');
     741             640 :                 write_body(buf, |_| {});
     742             640 :             }
     743                 : 
     744             639 :             BeMessage::CloseComplete => {
     745             639 :                 buf.put_u8(b'3');
     746             639 :                 write_body(buf, |_| {});
     747             639 :             }
     748                 : 
     749            2137 :             BeMessage::CommandComplete(cmd) => {
     750            2137 :                 buf.put_u8(b'C');
     751            2137 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     752                 :             }
     753                 : 
     754         7171073 :             BeMessage::CopyData(data) => {
     755         7171073 :                 buf.put_u8(b'd');
     756         7171073 :                 write_body(buf, |buf| {
     757         7171073 :                     buf.put_slice(data);
     758         7171073 :                 });
     759         7171073 :             }
     760                 : 
     761             638 :             BeMessage::CopyDone => {
     762             638 :                 buf.put_u8(b'c');
     763             638 :                 write_body(buf, |_| {});
     764             638 :             }
     765                 : 
     766 UBC           0 :             BeMessage::CopyFail => {
     767               0 :                 buf.put_u8(b'f');
     768               0 :                 write_body(buf, |_| {});
     769               0 :             }
     770                 : 
     771 CBC           7 :             BeMessage::CopyInResponse => {
     772               7 :                 buf.put_u8(b'G');
     773               7 :                 write_body(buf, |buf| {
     774               7 :                     buf.put_u8(1); // copy_is_binary
     775               7 :                     buf.put_i16(0); // numAttributes
     776               7 :                 });
     777               7 :             }
     778                 : 
     779             640 :             BeMessage::CopyOutResponse => {
     780             640 :                 buf.put_u8(b'H');
     781             640 :                 write_body(buf, |buf| {
     782             640 :                     buf.put_u8(0); // copy_is_binary
     783             640 :                     buf.put_i16(0); // numAttributes
     784             640 :                 });
     785             640 :             }
     786                 : 
     787            7197 :             BeMessage::CopyBothResponse => {
     788            7197 :                 buf.put_u8(b'W');
     789            7197 :                 write_body(buf, |buf| {
     790            7197 :                     // doesn't matter, used only for replication
     791            7197 :                     buf.put_u8(0); // copy_is_binary
     792            7197 :                     buf.put_i16(0); // numAttributes
     793            7197 :                 });
     794            7197 :             }
     795                 : 
     796             994 :             BeMessage::DataRow(vals) => {
     797             994 :                 buf.put_u8(b'D');
     798             994 :                 write_body(buf, |buf| {
     799             994 :                     buf.put_u16(vals.len() as u16); // num of cols
     800            3409 :                     for val_opt in vals.iter() {
     801            3409 :                         if let Some(val) = val_opt {
     802            2712 :                             buf.put_u32(val.len() as u32);
     803            2712 :                             buf.put_slice(val);
     804            2712 :                         } else {
     805             697 :                             buf.put_i32(-1);
     806             697 :                         }
     807                 :                     }
     808             994 :                 });
     809             994 :             }
     810                 : 
     811                 :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     812                 :             // First byte of each field represents type of this field. Set just enough fields
     813                 :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     814                 :             // message text.
     815             247 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     816             247 :                 // 'E' signalizes ErrorResponse messages
     817             247 :                 buf.put_u8(b'E');
     818             247 :                 write_body(buf, |buf| {
     819             247 :                     buf.put_u8(b'S'); // severity
     820             247 :                     buf.put_slice(b"ERROR\0");
     821             247 : 
     822             247 :                     buf.put_u8(b'C'); // SQLSTATE error code
     823             247 :                     buf.put_slice(&terminate_code(
     824             247 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     825             247 :                     ));
     826             247 : 
     827             247 :                     buf.put_u8(b'M'); // the message
     828             247 :                     write_cstr(error_msg, buf)?;
     829                 : 
     830             247 :                     buf.put_u8(0); // terminator
     831             247 :                     Ok(())
     832             247 :                 })?;
     833                 :             }
     834                 : 
     835                 :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     836                 :             // message but continue listening for ReadyForQuery or ErrorResponse"
     837               6 :             BeMessage::NoticeResponse(error_msg) => {
     838               6 :                 // For all the errors set Severity to Error and error code to
     839               6 :                 // 'internal error'.
     840               6 : 
     841               6 :                 // 'N' signalizes NoticeResponse messages
     842               6 :                 buf.put_u8(b'N');
     843               6 :                 write_body(buf, |buf| {
     844               6 :                     buf.put_u8(b'S'); // severity
     845               6 :                     buf.put_slice(b"NOTICE\0");
     846               6 : 
     847               6 :                     buf.put_u8(b'C'); // SQLSTATE error code
     848               6 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     849               6 : 
     850               6 :                     buf.put_u8(b'M'); // the message
     851               6 :                     write_cstr(error_msg.as_bytes(), buf)?;
     852                 : 
     853               6 :                     buf.put_u8(0); // terminator
     854               6 :                     Ok(())
     855               6 :                 })?;
     856                 :             }
     857                 : 
     858             640 :             BeMessage::NoData => {
     859             640 :                 buf.put_u8(b'n');
     860             640 :                 write_body(buf, |_| {});
     861             640 :             }
     862                 : 
     863            6653 :             BeMessage::EncryptionResponse(should_negotiate) => {
     864            6653 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     865            6653 :                 buf.put_u8(response);
     866                 :             }
     867                 : 
     868           25980 :             BeMessage::ParameterStatus { name, value } => {
     869           25980 :                 buf.put_u8(b'S');
     870           25980 :                 write_body(buf, |buf| {
     871           25980 :                     write_cstr(name, buf)?;
     872           25980 :                     write_cstr(value, buf)
     873           25980 :                 })?;
     874                 :             }
     875                 : 
     876             640 :             BeMessage::ParameterDescription => {
     877             640 :                 buf.put_u8(b't');
     878             640 :                 write_body(buf, |buf| {
     879             640 :                     // we don't support params, so always 0
     880             640 :                     buf.put_i16(0);
     881             640 :                 });
     882             640 :             }
     883                 : 
     884             640 :             BeMessage::ParseComplete => {
     885             640 :                 buf.put_u8(b'1');
     886             640 :                 write_body(buf, |_| {});
     887             640 :             }
     888                 : 
     889           18909 :             BeMessage::ReadyForQuery => {
     890           18909 :                 buf.put_u8(b'Z');
     891           18909 :                 write_body(buf, |buf| {
     892           18909 :                     buf.put_u8(b'I');
     893           18909 :                 });
     894           18909 :             }
     895                 : 
     896            1463 :             BeMessage::RowDescription(rows) => {
     897            1463 :                 buf.put_u8(b'T');
     898            1463 :                 write_body(buf, |buf| {
     899            1463 :                     buf.put_i16(rows.len() as i16); // # of fields
     900            4347 :                     for row in rows.iter() {
     901            4347 :                         write_cstr(row.name, buf)?;
     902            4347 :                         buf.put_i32(0); /* table oid */
     903            4347 :                         buf.put_i16(0); /* attnum */
     904            4347 :                         buf.put_u32(row.typoid);
     905            4347 :                         buf.put_i16(row.typlen);
     906            4347 :                         buf.put_i32(-1); /* typmod */
     907            4347 :                         buf.put_i16(0); /* format code */
     908                 :                     }
     909            1463 :                     Ok(())
     910            1463 :                 })?;
     911                 :             }
     912                 : 
     913          783639 :             BeMessage::XLogData(body) => {
     914          783639 :                 buf.put_u8(b'd');
     915          783639 :                 write_body(buf, |buf| {
     916          783639 :                     buf.put_u8(b'w');
     917          783639 :                     buf.put_u64(body.wal_start);
     918          783639 :                     buf.put_u64(body.wal_end);
     919          783639 :                     buf.put_i64(body.timestamp);
     920          783639 :                     buf.put_slice(body.data);
     921          783639 :                 });
     922          783639 :             }
     923                 : 
     924            1942 :             BeMessage::KeepAlive(req) => {
     925            1942 :                 buf.put_u8(b'd');
     926            1942 :                 write_body(buf, |buf| {
     927            1942 :                     buf.put_u8(b'k');
     928            1942 :                     buf.put_u64(req.wal_end);
     929            1942 :                     buf.put_i64(req.timestamp);
     930            1942 :                     buf.put_u8(u8::from(req.request_reply));
     931            1942 :                 });
     932            1942 :             }
     933                 :         }
     934         8033781 :         Ok(())
     935         8033781 :     }
     936                 : }
     937                 : 
     938             253 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
     939             253 :     let mut terminated = [0; 6];
     940            1265 :     for (i, &elem) in code.iter().enumerate() {
     941            1265 :         terminated[i] = elem;
     942            1265 :     }
     943                 : 
     944             253 :     terminated
     945             253 : }
     946                 : 
     947                 : #[cfg(test)]
     948                 : mod tests {
     949                 :     use super::*;
     950                 : 
     951               1 :     #[test]
     952               1 :     fn test_startup_message_params_options_escaped() {
     953               4 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
     954               4 :             params
     955               4 :                 .options_escaped()
     956               4 :                 .expect("options are None")
     957               4 :                 .collect()
     958               4 :         }
     959               1 : 
     960               4 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
     961                 : 
     962               1 :         let params = StartupMessageParams::new([]);
     963               1 :         assert!(params.options_escaped().is_none());
     964                 : 
     965               1 :         let params = make_params("");
     966               1 :         assert!(split_options(&params).is_empty());
     967                 : 
     968               1 :         let params = make_params("foo");
     969               1 :         assert_eq!(split_options(&params), ["foo"]);
     970                 : 
     971               1 :         let params = make_params(" foo  bar ");
     972               1 :         assert_eq!(split_options(&params), ["foo", "bar"]);
     973                 : 
     974               1 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
     975               1 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
     976               1 :     }
     977                 : }
        

Generated by: LCOV version 2.1-beta