LCOV - differential code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit UBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 86.2 % 589 508 81 508
Current Date: 2024-01-09 02:06:09 Functions: 77.0 % 126 97 29 97
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Postgres protocol messages serialization-deserialization. See
       2                 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3                 : //! on message formats.
       4                 : #![deny(clippy::undocumented_unsafe_blocks)]
       5                 : 
       6                 : pub mod framed;
       7                 : 
       8                 : use byteorder::{BigEndian, ReadBytesExt};
       9                 : use bytes::{Buf, BufMut, Bytes, BytesMut};
      10                 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
      11                 : 
      12                 : // re-export for use in utils pageserver_feedback.rs
      13                 : pub use postgres_protocol::PG_EPOCH;
      14                 : 
      15                 : pub type Oid = u32;
      16                 : pub type SystemId = u64;
      17                 : 
      18                 : pub const INT8_OID: Oid = 20;
      19                 : pub const INT4_OID: Oid = 23;
      20                 : pub const TEXT_OID: Oid = 25;
      21                 : 
      22 CBC          50 : #[derive(Debug)]
      23                 : pub enum FeMessage {
      24                 :     // Simple query.
      25                 :     Query(Bytes),
      26                 :     // Extended query protocol.
      27                 :     Parse(FeParseMessage),
      28                 :     Describe(FeDescribeMessage),
      29                 :     Bind(FeBindMessage),
      30                 :     Execute(FeExecuteMessage),
      31                 :     Close(FeCloseMessage),
      32                 :     Sync,
      33                 :     Terminate,
      34                 :     CopyData(Bytes),
      35                 :     CopyDone,
      36                 :     CopyFail,
      37                 :     PasswordMessage(Bytes),
      38                 : }
      39                 : 
      40             218 : #[derive(Debug)]
      41                 : pub enum FeStartupPacket {
      42                 :     CancelRequest(CancelKeyData),
      43                 :     SslRequest,
      44                 :     GssEncRequest,
      45                 :     StartupMessage {
      46                 :         major_version: u32,
      47                 :         minor_version: u32,
      48                 :         params: StartupMessageParams,
      49                 :     },
      50                 : }
      51                 : 
      52             120 : #[derive(Debug)]
      53                 : pub struct StartupMessageParams {
      54                 :     params: HashMap<String, String>,
      55                 : }
      56                 : 
      57                 : impl StartupMessageParams {
      58                 :     /// Get parameter's value by its name.
      59            6957 :     pub fn get(&self, name: &str) -> Option<&str> {
      60            6957 :         self.params.get(name).map(|s| s.as_str())
      61            6957 :     }
      62                 : 
      63                 :     /// Split command-line options according to PostgreSQL's logic,
      64                 :     /// taking into account all escape sequences but leaving them as-is.
      65                 :     /// [`None`] means that there's no `options` in [`Self`].
      66            3455 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
      67            3455 :         self.get("options").map(Self::parse_options_raw)
      68            3455 :     }
      69                 : 
      70                 :     /// Split command-line options according to PostgreSQL's logic,
      71                 :     /// applying all escape sequences (using owned strings as needed).
      72                 :     /// [`None`] means that there's no `options` in [`Self`].
      73               5 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
      74               5 :         self.get("options").map(Self::parse_options_escaped)
      75               5 :     }
      76                 : 
      77                 :     /// Split command-line options according to PostgreSQL's logic,
      78                 :     /// taking into account all escape sequences but leaving them as-is.
      79            3448 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
      80            3448 :         // See `postgres: pg_split_opts`.
      81            3448 :         let mut last_was_escape = false;
      82            3448 :         input
      83          288295 :             .split(move |c: char| {
      84                 :                 // We split by non-escaped whitespace symbols.
      85          288295 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
      86          288295 :                 last_was_escape = c == '\\' && !last_was_escape;
      87          288295 :                 should_split
      88          288295 :             })
      89            9948 :             .filter(|s| !s.is_empty())
      90            3448 :     }
      91                 : 
      92                 :     /// Split command-line options according to PostgreSQL's logic,
      93                 :     /// applying all escape sequences (using owned strings as needed).
      94               4 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
      95               4 :         // See `postgres: pg_split_opts`.
      96               7 :         Self::parse_options_raw(input).map(|s| {
      97               7 :             let mut preserve_next_escape = false;
      98              17 :             let escape = |c| {
      99                 :                 // We should remove '\\' unless it's preceded by '\\'.
     100              17 :                 let should_remove = c == '\\' && !preserve_next_escape;
     101              17 :                 preserve_next_escape = should_remove;
     102              17 :                 should_remove
     103              17 :             };
     104                 : 
     105               7 :             match s.contains('\\') {
     106               3 :                 true => Cow::Owned(s.replace(escape, "")),
     107               4 :                 false => Cow::Borrowed(s),
     108                 :             }
     109               7 :         })
     110               4 :     }
     111                 : 
     112                 :     /// Iterate through key-value pairs in an arbitrary order.
     113               7 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     114              21 :         self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
     115               7 :     }
     116                 : 
     117                 :     // This function is mostly useful in tests.
     118                 :     #[doc(hidden)]
     119              64 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     120              64 :         Self {
     121             195 :             params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
     122              64 :         }
     123              64 :     }
     124                 : }
     125                 : 
     126             328 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
     127                 : pub struct CancelKeyData {
     128                 :     pub backend_pid: i32,
     129                 :     pub cancel_key: i32,
     130                 : }
     131                 : 
     132                 : impl fmt::Display for CancelKeyData {
     133             196 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     134             196 :         let hi = (self.backend_pid as u64) << 32;
     135             196 :         let lo = self.cancel_key as u64;
     136             196 :         let id = hi | lo;
     137             196 : 
     138             196 :         // This format is more compact and might work better for logs.
     139             196 :         f.debug_tuple("CancelKeyData")
     140             196 :             .field(&format_args!("{:x}", id))
     141             196 :             .finish()
     142             196 :     }
     143                 : }
     144                 : 
     145                 : use rand::distributions::{Distribution, Standard};
     146                 : impl Distribution<CancelKeyData> for Standard {
     147              50 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     148              50 :         CancelKeyData {
     149              50 :             backend_pid: rng.gen(),
     150              50 :             cancel_key: rng.gen(),
     151              50 :         }
     152              50 :     }
     153                 : }
     154                 : 
     155                 : // We only support the simple case of Parse on unnamed prepared statement and
     156                 : // no params
     157 UBC           0 : #[derive(Debug)]
     158                 : pub struct FeParseMessage {
     159                 :     pub query_string: Bytes,
     160                 : }
     161                 : 
     162               0 : #[derive(Debug)]
     163                 : pub struct FeDescribeMessage {
     164                 :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     165                 :                   // we only support unnamed prepared stmt or portal
     166                 : }
     167                 : 
     168                 : // we only support unnamed prepared stmt and portal
     169               0 : #[derive(Debug)]
     170                 : pub struct FeBindMessage;
     171                 : 
     172                 : // we only support unnamed prepared stmt or portal
     173               0 : #[derive(Debug)]
     174                 : pub struct FeExecuteMessage {
     175                 :     /// max # of rows
     176                 :     pub maxrows: i32,
     177                 : }
     178                 : 
     179                 : // we only support unnamed prepared stmt and portal
     180               0 : #[derive(Debug)]
     181                 : pub struct FeCloseMessage;
     182                 : 
     183                 : /// An error occurred while parsing or serializing raw stream into Postgres
     184                 : /// messages.
     185 CBC           5 : #[derive(thiserror::Error, Debug)]
     186                 : pub enum ProtocolError {
     187                 :     /// Invalid packet was received from the client (e.g. unexpected message
     188                 :     /// type or broken len).
     189                 :     #[error("Protocol error: {0}")]
     190                 :     Protocol(String),
     191                 :     /// Failed to parse or, (unlikely), serialize a protocol message.
     192                 :     #[error("Message parse error: {0}")]
     193                 :     BadMessage(String),
     194                 : }
     195                 : 
     196                 : impl ProtocolError {
     197                 :     /// Proxy stream.rs uses only io::Error; provide it.
     198 UBC           0 :     pub fn into_io_error(self) -> io::Error {
     199               0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     200               0 :     }
     201                 : }
     202                 : 
     203                 : impl FeMessage {
     204                 :     /// Read and parse one message from the `buf` input buffer. If there is at
     205                 :     /// least one valid message, returns it, advancing `buf`; redundant copies
     206                 :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     207                 :     /// directly into the `buf` (processed data is garbage collected after
     208                 :     /// parsed message is dropped).
     209                 :     ///
     210                 :     /// Returns None if `buf` doesn't contain enough data for a single message.
     211                 :     /// For efficiency, tries to reserve large enough space in `buf` for the
     212                 :     /// next message in this case to save the repeated calls.
     213                 :     ///
     214                 :     /// Returns Error if message is malformed, the only possible ErrorKind is
     215                 :     /// InvalidInput.
     216                 :     //
     217                 :     // Inspired by rust-postgres Message::parse.
     218 CBC    11812587 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     219        11812587 :         // Every message contains message type byte and 4 bytes len; can't do
     220        11812587 :         // much without them.
     221        11812587 :         if buf.len() < 5 {
     222         5371100 :             let to_read = 5 - buf.len();
     223         5371100 :             buf.reserve(to_read);
     224         5371100 :             return Ok(None);
     225         6441487 :         }
     226         6441487 : 
     227         6441487 :         // We shouldn't advance `buf` as probably full message is not there yet,
     228         6441487 :         // so can't directly use Bytes::get_u32 etc.
     229         6441487 :         let tag = buf[0];
     230         6441487 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     231         6441487 :         if len < 4 {
     232 UBC           0 :             return Err(ProtocolError::Protocol(format!(
     233               0 :                 "invalid message length {}",
     234               0 :                 len
     235               0 :             )));
     236 CBC     6441487 :         }
     237         6441487 : 
     238         6441487 :         // length field includes itself, but not message type.
     239         6441487 :         let total_len = len as usize + 1;
     240         6441487 :         if buf.len() < total_len {
     241                 :             // Don't have full message yet.
     242          149702 :             let to_read = total_len - buf.len();
     243          149702 :             buf.reserve(to_read);
     244          149702 :             return Ok(None);
     245         6291785 :         }
     246         6291785 : 
     247         6291785 :         // got the message, advance buffer
     248         6291785 :         let mut msg = buf.split_to(total_len).freeze();
     249         6291785 :         msg.advance(5); // consume message type and len
     250         6291785 : 
     251         6291785 :         match tag {
     252           11393 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     253             552 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     254             552 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     255             552 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     256             552 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     257             551 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     258            1669 :             b'S' => Ok(Some(FeMessage::Sync)),
     259            3391 :             b'X' => Ok(Some(FeMessage::Terminate)),
     260         6272188 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     261              12 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     262              67 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     263             306 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     264 UBC           0 :             tag => Err(ProtocolError::Protocol(format!(
     265               0 :                 "unknown message tag: {tag},'{msg:?}'"
     266               0 :             ))),
     267                 :         }
     268 CBC    11812587 :     }
     269                 : }
     270                 : 
     271                 : impl FeStartupPacket {
     272                 :     /// Read and parse startup message from the `buf` input buffer. It is
     273                 :     /// different from [`FeMessage::parse`] because startup messages don't have
     274                 :     /// message type byte; otherwise, its comments apply.
     275           41362 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     276           41362 :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     277           41362 :         const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
     278           41362 :         const CANCEL_REQUEST_CODE: u32 = 5678;
     279           41362 :         const NEGOTIATE_SSL_CODE: u32 = 5679;
     280           41362 :         const NEGOTIATE_GSS_CODE: u32 = 5680;
     281           41362 : 
     282           41362 :         // need at least 4 bytes with packet len
     283           41362 :         if buf.len() < 4 {
     284           20681 :             let to_read = 4 - buf.len();
     285           20681 :             buf.reserve(to_read);
     286           20681 :             return Ok(None);
     287           20681 :         }
     288           20681 : 
     289           20681 :         // We shouldn't advance `buf` as probably full message is not there yet,
     290           20681 :         // so can't directly use Bytes::get_u32 etc.
     291           20681 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     292           20681 :         // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     293           20681 :         // which is less readable
     294           20681 :         #[allow(clippy::manual_range_contains)]
     295           20681 :         if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
     296               1 :             return Err(ProtocolError::Protocol(format!(
     297               1 :                 "invalid startup packet message length {}",
     298               1 :                 len
     299               1 :             )));
     300           20680 :         }
     301           20680 : 
     302           20680 :         if buf.len() < len {
     303                 :             // Don't have full message yet.
     304 UBC           0 :             let to_read = len - buf.len();
     305               0 :             buf.reserve(to_read);
     306               0 :             return Ok(None);
     307 CBC       20680 :         }
     308           20680 : 
     309           20680 :         // got the message, advance buffer
     310           20680 :         let mut msg = buf.split_to(len).freeze();
     311           20680 :         msg.advance(4); // consume len
     312           20680 : 
     313           20680 :         let request_code = msg.get_u32();
     314           20680 :         let req_hi = request_code >> 16;
     315           20680 :         let req_lo = request_code & ((1 << 16) - 1);
     316                 :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     317           20680 :         let message = match (req_hi, req_lo) {
     318                 :             (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
     319 UBC           0 :                 if msg.remaining() != 8 {
     320               0 :                     return Err(ProtocolError::BadMessage(
     321               0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     322               0 :                             .to_owned(),
     323               0 :                     ));
     324               0 :                 }
     325               0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     326               0 :                     backend_pid: msg.get_i32(),
     327               0 :                     cancel_key: msg.get_i32(),
     328               0 :                 })
     329                 :             }
     330                 :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
     331                 :                 // Requested upgrade to SSL (aka TLS)
     332 CBC        9389 :                 FeStartupPacket::SslRequest
     333                 :             }
     334                 :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
     335                 :                 // Requested upgrade to GSSAPI
     336 UBC           0 :                 FeStartupPacket::GssEncRequest
     337                 :             }
     338               0 :             (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
     339               0 :                 return Err(ProtocolError::Protocol(format!(
     340               0 :                     "Unrecognized request code {unrecognized_code}"
     341               0 :                 )));
     342                 :             }
     343                 :             // TODO bail if protocol major_version is not 3?
     344 CBC       11291 :             (major_version, minor_version) => {
     345                 :                 // StartupMessage
     346                 : 
     347                 :                 // Parse pairs of null-terminated strings (key, value).
     348                 :                 // See `postgres: ProcessStartupPacket, build_startup_packet`.
     349           11291 :                 let mut tokens = str::from_utf8(&msg)
     350           11291 :                     .map_err(|_e| {
     351 UBC           0 :                         ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     352 CBC       11291 :                     })?
     353           11291 :                     .strip_suffix('\0') // drop packet's own null
     354           11291 :                     .ok_or_else(|| {
     355 UBC           0 :                         ProtocolError::Protocol(
     356               0 :                             "StartupMessage params: missing null terminator".to_string(),
     357               0 :                         )
     358 CBC       11291 :                     })?
     359           11291 :                     .split_terminator('\0');
     360           11291 : 
     361           11291 :                 let mut params = HashMap::new();
     362           37868 :                 while let Some(name) = tokens.next() {
     363           26577 :                     let value = tokens.next().ok_or_else(|| {
     364 UBC           0 :                         ProtocolError::Protocol(
     365               0 :                             "StartupMessage params: key without value".to_string(),
     366               0 :                         )
     367 CBC       26577 :                     })?;
     368                 : 
     369           26577 :                     params.insert(name.to_owned(), value.to_owned());
     370                 :                 }
     371                 : 
     372           11291 :                 FeStartupPacket::StartupMessage {
     373           11291 :                     major_version,
     374           11291 :                     minor_version,
     375           11291 :                     params: StartupMessageParams { params },
     376           11291 :                 }
     377                 :             }
     378                 :         };
     379           20680 :         Ok(Some(message))
     380           41362 :     }
     381                 : }
     382                 : 
     383                 : impl FeParseMessage {
     384             552 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     385                 :         // FIXME: the rust-postgres driver uses a named prepared statement
     386                 :         // for copy_out(). We're not prepared to handle that correctly. For
     387                 :         // now, just ignore the statement name, assuming that the client never
     388                 :         // uses more than one prepared statement at a time.
     389                 : 
     390             552 :         let _pstmt_name = read_cstr(&mut buf)?;
     391             552 :         let query_string = read_cstr(&mut buf)?;
     392             552 :         if buf.remaining() < 2 {
     393 UBC           0 :             return Err(ProtocolError::BadMessage(
     394               0 :                 "Parse message is malformed, nparams missing".to_string(),
     395               0 :             ));
     396 CBC         552 :         }
     397             552 :         let nparams = buf.get_i16();
     398             552 : 
     399             552 :         if nparams != 0 {
     400 UBC           0 :             return Err(ProtocolError::BadMessage(
     401               0 :                 "query params not implemented".to_string(),
     402               0 :             ));
     403 CBC         552 :         }
     404             552 : 
     405             552 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     406             552 :     }
     407                 : }
     408                 : 
     409                 : impl FeDescribeMessage {
     410             552 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     411             552 :         let kind = buf.get_u8();
     412             552 :         let _pstmt_name = read_cstr(&mut buf)?;
     413                 : 
     414                 :         // FIXME: see FeParseMessage::parse
     415             552 :         if kind != b'S' {
     416 UBC           0 :             return Err(ProtocolError::BadMessage(
     417               0 :                 "only prepared statemement Describe is implemented".to_string(),
     418               0 :             ));
     419 CBC         552 :         }
     420             552 : 
     421             552 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     422             552 :     }
     423                 : }
     424                 : 
     425                 : impl FeExecuteMessage {
     426             552 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     427             552 :         let portal_name = read_cstr(&mut buf)?;
     428             552 :         if buf.remaining() < 4 {
     429 UBC           0 :             return Err(ProtocolError::BadMessage(
     430               0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     431               0 :             ));
     432 CBC         552 :         }
     433             552 :         let maxrows = buf.get_i32();
     434             552 : 
     435             552 :         if !portal_name.is_empty() {
     436 UBC           0 :             return Err(ProtocolError::BadMessage(
     437               0 :                 "named portals not implemented".to_string(),
     438               0 :             ));
     439 CBC         552 :         }
     440             552 :         if maxrows != 0 {
     441 UBC           0 :             return Err(ProtocolError::BadMessage(
     442               0 :                 "row limit in Execute message not implemented".to_string(),
     443               0 :             ));
     444 CBC         552 :         }
     445             552 : 
     446             552 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     447             552 :     }
     448                 : }
     449                 : 
     450                 : impl FeBindMessage {
     451             552 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     452             552 :         let portal_name = read_cstr(&mut buf)?;
     453             552 :         let _pstmt_name = read_cstr(&mut buf)?;
     454                 : 
     455                 :         // FIXME: see FeParseMessage::parse
     456             552 :         if !portal_name.is_empty() {
     457 UBC           0 :             return Err(ProtocolError::BadMessage(
     458               0 :                 "named portals not implemented".to_string(),
     459               0 :             ));
     460 CBC         552 :         }
     461             552 : 
     462             552 :         Ok(FeMessage::Bind(FeBindMessage))
     463             552 :     }
     464                 : }
     465                 : 
     466                 : impl FeCloseMessage {
     467             551 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     468             551 :         let _kind = buf.get_u8();
     469             551 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     470                 : 
     471                 :         // FIXME: we do nothing with Close
     472             551 :         Ok(FeMessage::Close(FeCloseMessage))
     473             551 :     }
     474                 : }
     475                 : 
     476                 : // Backend
     477                 : 
     478 UBC           0 : #[derive(Debug)]
     479                 : pub enum BeMessage<'a> {
     480                 :     AuthenticationOk,
     481                 :     AuthenticationMD5Password([u8; 4]),
     482                 :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     483                 :     AuthenticationCleartextPassword,
     484                 :     BackendKeyData(CancelKeyData),
     485                 :     BindComplete,
     486                 :     CommandComplete(&'a [u8]),
     487                 :     CopyData(&'a [u8]),
     488                 :     CopyDone,
     489                 :     CopyFail,
     490                 :     CopyInResponse,
     491                 :     CopyOutResponse,
     492                 :     CopyBothResponse,
     493                 :     CloseComplete,
     494                 :     // None means column is NULL
     495                 :     DataRow(&'a [Option<&'a [u8]>]),
     496                 :     // None errcode means internal_error will be sent.
     497                 :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     498                 :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     499                 :     EncryptionResponse(bool),
     500                 :     NoData,
     501                 :     ParameterDescription,
     502                 :     ParameterStatus {
     503                 :         name: &'a [u8],
     504                 :         value: &'a [u8],
     505                 :     },
     506                 :     ParseComplete,
     507                 :     ReadyForQuery,
     508                 :     RowDescription(&'a [RowDescriptor<'a>]),
     509                 :     XLogData(XLogDataBody<'a>),
     510                 :     NoticeResponse(&'a str),
     511                 :     KeepAlive(WalSndKeepAlive),
     512                 : }
     513                 : 
     514                 : /// Common shorthands.
     515                 : impl<'a> BeMessage<'a> {
     516                 :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     517                 :     /// This is a sensible default, given that:
     518                 :     ///  * rust strings only support this encoding out of the box.
     519                 :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     520                 :     ///
     521                 :     /// TODO: do we need to report `server_encoding` as well?
     522                 :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     523                 :         name: b"client_encoding",
     524                 :         value: b"UTF8",
     525                 :     };
     526                 : 
     527                 :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     528                 :         name: b"integer_datetimes",
     529                 :         value: b"on",
     530                 :     };
     531                 : 
     532                 :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     533 CBC       10994 :     pub fn server_version(version: &'a str) -> Self {
     534           10994 :         Self::ParameterStatus {
     535           10994 :             name: b"server_version",
     536           10994 :             value: version.as_bytes(),
     537           10994 :         }
     538           10994 :     }
     539                 : }
     540                 : 
     541 UBC           0 : #[derive(Debug)]
     542                 : pub enum BeAuthenticationSaslMessage<'a> {
     543                 :     Methods(&'a [&'a str]),
     544                 :     Continue(&'a [u8]),
     545                 :     Final(&'a [u8]),
     546                 : }
     547                 : 
     548               0 : #[derive(Debug)]
     549                 : pub enum BeParameterStatusMessage<'a> {
     550                 :     Encoding(&'a str),
     551                 :     ServerVersion(&'a str),
     552                 : }
     553                 : 
     554                 : // One row description in RowDescription packet.
     555               0 : #[derive(Debug)]
     556                 : pub struct RowDescriptor<'a> {
     557                 :     pub name: &'a [u8],
     558                 :     pub tableoid: Oid,
     559                 :     pub attnum: i16,
     560                 :     pub typoid: Oid,
     561                 :     pub typlen: i16,
     562                 :     pub typmod: i32,
     563                 :     pub formatcode: i16,
     564                 : }
     565                 : 
     566                 : impl Default for RowDescriptor<'_> {
     567 CBC        2899 :     fn default() -> RowDescriptor<'static> {
     568            2899 :         RowDescriptor {
     569            2899 :             name: b"",
     570            2899 :             tableoid: 0,
     571            2899 :             attnum: 0,
     572            2899 :             typoid: 0,
     573            2899 :             typlen: 0,
     574            2899 :             typmod: 0,
     575            2899 :             formatcode: 0,
     576            2899 :         }
     577            2899 :     }
     578                 : }
     579                 : 
     580                 : impl RowDescriptor<'_> {
     581                 :     /// Convenience function to create a RowDescriptor message for an int8 column
     582              45 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     583              45 :         RowDescriptor {
     584              45 :             name,
     585              45 :             tableoid: 0,
     586              45 :             attnum: 0,
     587              45 :             typoid: INT8_OID,
     588              45 :             typlen: 8,
     589              45 :             typmod: 0,
     590              45 :             formatcode: 0,
     591              45 :         }
     592              45 :     }
     593                 : 
     594            1268 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     595            1268 :         RowDescriptor {
     596            1268 :             name,
     597            1268 :             tableoid: 0,
     598            1268 :             attnum: 0,
     599            1268 :             typoid: TEXT_OID,
     600            1268 :             typlen: -1,
     601            1268 :             typmod: 0,
     602            1268 :             formatcode: 0,
     603            1268 :         }
     604            1268 :     }
     605                 : }
     606                 : 
     607 UBC           0 : #[derive(Debug)]
     608                 : pub struct XLogDataBody<'a> {
     609                 :     pub wal_start: u64,
     610                 :     pub wal_end: u64, // current end of WAL on the server
     611                 :     pub timestamp: i64,
     612                 :     pub data: &'a [u8],
     613                 : }
     614                 : 
     615               0 : #[derive(Debug)]
     616                 : pub struct WalSndKeepAlive {
     617                 :     pub wal_end: u64, // current end of WAL on the server
     618                 :     pub timestamp: i64,
     619                 :     pub request_reply: bool,
     620                 : }
     621                 : 
     622                 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     623                 : 
     624                 : // single text column
     625                 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     626                 :     name: b"data",
     627                 :     tableoid: 0,
     628                 :     attnum: 0,
     629                 :     typoid: TEXT_OID,
     630                 :     typlen: -1,
     631                 :     typmod: 0,
     632                 :     formatcode: 0,
     633                 : }]);
     634                 : 
     635                 : /// Call f() to write body of the message and prepend it with 4-byte len as
     636                 : /// prescribed by the protocol.
     637 CBC     6499104 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     638         6499104 :     let base = buf.len();
     639         6499104 :     buf.extend_from_slice(&[0; 4]);
     640         6499104 : 
     641         6499104 :     let res = f(buf);
     642         6499104 : 
     643         6499104 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     644         6499104 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     645         6499104 : 
     646         6499104 :     res
     647         6499104 : }
     648                 : 
     649                 : /// Safe write of s into buf as cstring (String in the protocol).
     650           74334 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     651           74334 :     let bytes = s.as_ref();
     652           74334 :     if bytes.contains(&0) {
     653 UBC           0 :         return Err(ProtocolError::BadMessage(
     654               0 :             "string contains embedded null".to_owned(),
     655               0 :         ));
     656 CBC       74334 :     }
     657           74334 :     buf.put_slice(bytes);
     658           74334 :     buf.put_u8(0);
     659           74334 :     Ok(())
     660           74334 : }
     661                 : 
     662                 : /// Read cstring from buf, advancing it.
     663         2832929 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     664         2832929 :     let pos = buf
     665         2832929 :         .iter()
     666        40226045 :         .position(|x| *x == 0)
     667         2832929 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     668         2832929 :     let result = buf.split_to(pos);
     669         2832929 :     buf.advance(1); // drop the null terminator
     670         2832929 :     Ok(result)
     671         2832929 : }
     672                 : 
     673                 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     674                 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
     675                 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     676                 : 
     677                 : impl<'a> BeMessage<'a> {
     678                 :     /// Serialize `message` to the given `buf`.
     679                 :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     680                 :     /// precedes its body and it is handy to write it down first and then fill
     681                 :     /// the length. With Write we would have to either calc it manually or have
     682                 :     /// one more buffer.
     683         6508493 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     684         6508493 :         match message {
     685           11248 :             BeMessage::AuthenticationOk => {
     686           11248 :                 buf.put_u8(b'R');
     687           11248 :                 write_body(buf, |buf| {
     688           11248 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     689           11248 :                 });
     690           11248 :             }
     691                 : 
     692             218 :             BeMessage::AuthenticationCleartextPassword => {
     693             218 :                 buf.put_u8(b'R');
     694             218 :                 write_body(buf, |buf| {
     695             218 :                     buf.put_i32(3); // Specifies that clear text password is required.
     696             218 :                 });
     697             218 :             }
     698                 : 
     699 UBC           0 :             BeMessage::AuthenticationMD5Password(salt) => {
     700               0 :                 buf.put_u8(b'R');
     701               0 :                 write_body(buf, |buf| {
     702               0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     703               0 :                     buf.put_slice(&salt[..]);
     704               0 :                 });
     705               0 :             }
     706                 : 
     707 CBC         132 :             BeMessage::AuthenticationSasl(msg) => {
     708             132 :                 buf.put_u8(b'R');
     709             132 :                 write_body(buf, |buf| {
     710             132 :                     use BeAuthenticationSaslMessage::*;
     711             132 :                     match msg {
     712              48 :                         Methods(methods) => {
     713              48 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     714              96 :                             for method in methods.iter() {
     715              96 :                                 write_cstr(method, buf)?;
     716                 :                             }
     717              48 :                             buf.put_u8(0); // zero terminator for the list
     718                 :                         }
     719              46 :                         Continue(extra) => {
     720              46 :                             buf.put_i32(11); // Continue SASL auth.
     721              46 :                             buf.put_slice(extra);
     722              46 :                         }
     723              38 :                         Final(extra) => {
     724              38 :                             buf.put_i32(12); // Send final SASL message.
     725              38 :                             buf.put_slice(extra);
     726              38 :                         }
     727                 :                     }
     728             132 :                     Ok(())
     729             132 :                 })?;
     730                 :             }
     731                 : 
     732              38 :             BeMessage::BackendKeyData(key_data) => {
     733              38 :                 buf.put_u8(b'K');
     734              38 :                 write_body(buf, |buf| {
     735              38 :                     buf.put_i32(key_data.backend_pid);
     736              38 :                     buf.put_i32(key_data.cancel_key);
     737              38 :                 });
     738              38 :             }
     739                 : 
     740             552 :             BeMessage::BindComplete => {
     741             552 :                 buf.put_u8(b'2');
     742             552 :                 write_body(buf, |_| {});
     743             552 :             }
     744                 : 
     745             551 :             BeMessage::CloseComplete => {
     746             551 :                 buf.put_u8(b'3');
     747             551 :                 write_body(buf, |_| {});
     748             551 :             }
     749                 : 
     750            1969 :             BeMessage::CommandComplete(cmd) => {
     751            1969 :                 buf.put_u8(b'C');
     752            1969 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     753                 :             }
     754                 : 
     755         5840052 :             BeMessage::CopyData(data) => {
     756         5840052 :                 buf.put_u8(b'd');
     757         5840052 :                 write_body(buf, |buf| {
     758         5840052 :                     buf.put_slice(data);
     759         5840052 :                 });
     760         5840052 :             }
     761                 : 
     762             557 :             BeMessage::CopyDone => {
     763             557 :                 buf.put_u8(b'c');
     764             557 :                 write_body(buf, |_| {});
     765             557 :             }
     766                 : 
     767 UBC           0 :             BeMessage::CopyFail => {
     768               0 :                 buf.put_u8(b'f');
     769               0 :                 write_body(buf, |_| {});
     770               0 :             }
     771                 : 
     772 CBC          13 :             BeMessage::CopyInResponse => {
     773              13 :                 buf.put_u8(b'G');
     774              13 :                 write_body(buf, |buf| {
     775              13 :                     buf.put_u8(1); // copy_is_binary
     776              13 :                     buf.put_i16(0); // numAttributes
     777              13 :                 });
     778              13 :             }
     779                 : 
     780             559 :             BeMessage::CopyOutResponse => {
     781             559 :                 buf.put_u8(b'H');
     782             559 :                 write_body(buf, |buf| {
     783             559 :                     buf.put_u8(0); // copy_is_binary
     784             559 :                     buf.put_i16(0); // numAttributes
     785             559 :                 });
     786             559 :             }
     787                 : 
     788            9440 :             BeMessage::CopyBothResponse => {
     789            9440 :                 buf.put_u8(b'W');
     790            9440 :                 write_body(buf, |buf| {
     791            9440 :                     // doesn't matter, used only for replication
     792            9440 :                     buf.put_u8(0); // copy_is_binary
     793            9440 :                     buf.put_i16(0); // numAttributes
     794            9440 :                 });
     795            9440 :             }
     796                 : 
     797             927 :             BeMessage::DataRow(vals) => {
     798             927 :                 buf.put_u8(b'D');
     799             927 :                 write_body(buf, |buf| {
     800             927 :                     buf.put_u16(vals.len() as u16); // num of cols
     801            3329 :                     for val_opt in vals.iter() {
     802            3329 :                         if let Some(val) = val_opt {
     803            2605 :                             buf.put_u32(val.len() as u32);
     804            2605 :                             buf.put_slice(val);
     805            2605 :                         } else {
     806             724 :                             buf.put_i32(-1);
     807             724 :                         }
     808                 :                     }
     809             927 :                 });
     810             927 :             }
     811                 : 
     812                 :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     813                 :             // First byte of each field represents type of this field. Set just enough fields
     814                 :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     815                 :             // message text.
     816             658 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     817             658 :                 // 'E' signalizes ErrorResponse messages
     818             658 :                 buf.put_u8(b'E');
     819             658 :                 write_body(buf, |buf| {
     820             658 :                     buf.put_u8(b'S'); // severity
     821             658 :                     buf.put_slice(b"ERROR\0");
     822             658 : 
     823             658 :                     buf.put_u8(b'C'); // SQLSTATE error code
     824             658 :                     buf.put_slice(&terminate_code(
     825             658 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     826             658 :                     ));
     827             658 : 
     828             658 :                     buf.put_u8(b'M'); // the message
     829             658 :                     write_cstr(error_msg, buf)?;
     830                 : 
     831             658 :                     buf.put_u8(0); // terminator
     832             658 :                     Ok(())
     833             658 :                 })?;
     834                 :             }
     835                 : 
     836                 :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     837                 :             // message but continue listening for ReadyForQuery or ErrorResponse"
     838               6 :             BeMessage::NoticeResponse(error_msg) => {
     839               6 :                 // For all the errors set Severity to Error and error code to
     840               6 :                 // 'internal error'.
     841               6 : 
     842               6 :                 // 'N' signalizes NoticeResponse messages
     843               6 :                 buf.put_u8(b'N');
     844               6 :                 write_body(buf, |buf| {
     845               6 :                     buf.put_u8(b'S'); // severity
     846               6 :                     buf.put_slice(b"NOTICE\0");
     847               6 : 
     848               6 :                     buf.put_u8(b'C'); // SQLSTATE error code
     849               6 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     850               6 : 
     851               6 :                     buf.put_u8(b'M'); // the message
     852               6 :                     write_cstr(error_msg.as_bytes(), buf)?;
     853                 : 
     854               6 :                     buf.put_u8(0); // terminator
     855               6 :                     Ok(())
     856               6 :                 })?;
     857                 :             }
     858                 : 
     859             552 :             BeMessage::NoData => {
     860             552 :                 buf.put_u8(b'n');
     861             552 :                 write_body(buf, |_| {});
     862             552 :             }
     863                 : 
     864            9389 :             BeMessage::EncryptionResponse(should_negotiate) => {
     865            9389 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     866            9389 :                 buf.put_u8(response);
     867                 :             }
     868                 : 
     869           33695 :             BeMessage::ParameterStatus { name, value } => {
     870           33695 :                 buf.put_u8(b'S');
     871           33695 :                 write_body(buf, |buf| {
     872           33695 :                     write_cstr(name, buf)?;
     873           33695 :                     write_cstr(value, buf)
     874           33695 :                 })?;
     875                 :             }
     876                 : 
     877             552 :             BeMessage::ParameterDescription => {
     878             552 :                 buf.put_u8(b't');
     879             552 :                 write_body(buf, |buf| {
     880             552 :                     // we don't support params, so always 0
     881             552 :                     buf.put_i16(0);
     882             552 :                 });
     883             552 :             }
     884                 : 
     885             552 :             BeMessage::ParseComplete => {
     886             552 :                 buf.put_u8(b'1');
     887             552 :                 write_body(buf, |_| {});
     888             552 :             }
     889                 : 
     890           23672 :             BeMessage::ReadyForQuery => {
     891           23672 :                 buf.put_u8(b'Z');
     892           23672 :                 write_body(buf, |buf| {
     893           23672 :                     buf.put_u8(b'I');
     894           23672 :                 });
     895           23672 :             }
     896                 : 
     897            1370 :             BeMessage::RowDescription(rows) => {
     898            1370 :                 buf.put_u8(b'T');
     899            1370 :                 write_body(buf, |buf| {
     900            1370 :                     buf.put_i16(rows.len() as i16); // # of fields
     901            4215 :                     for row in rows.iter() {
     902            4215 :                         write_cstr(row.name, buf)?;
     903            4215 :                         buf.put_i32(0); /* table oid */
     904            4215 :                         buf.put_i16(0); /* attnum */
     905            4215 :                         buf.put_u32(row.typoid);
     906            4215 :                         buf.put_i16(row.typlen);
     907            4215 :                         buf.put_i32(-1); /* typmod */
     908            4215 :                         buf.put_i16(0); /* format code */
     909                 :                     }
     910            1370 :                     Ok(())
     911            1370 :                 })?;
     912                 :             }
     913                 : 
     914          568377 :             BeMessage::XLogData(body) => {
     915          568377 :                 buf.put_u8(b'd');
     916          568377 :                 write_body(buf, |buf| {
     917          568377 :                     buf.put_u8(b'w');
     918          568377 :                     buf.put_u64(body.wal_start);
     919          568377 :                     buf.put_u64(body.wal_end);
     920          568377 :                     buf.put_i64(body.timestamp);
     921          568377 :                     buf.put_slice(body.data);
     922          568377 :                 });
     923          568377 :             }
     924                 : 
     925            3414 :             BeMessage::KeepAlive(req) => {
     926            3414 :                 buf.put_u8(b'd');
     927            3414 :                 write_body(buf, |buf| {
     928            3414 :                     buf.put_u8(b'k');
     929            3414 :                     buf.put_u64(req.wal_end);
     930            3414 :                     buf.put_i64(req.timestamp);
     931            3414 :                     buf.put_u8(u8::from(req.request_reply));
     932            3414 :                 });
     933            3414 :             }
     934                 :         }
     935         6508493 :         Ok(())
     936         6508493 :     }
     937                 : }
     938                 : 
     939             664 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
     940             664 :     let mut terminated = [0; 6];
     941            3320 :     for (i, &elem) in code.iter().enumerate() {
     942            3320 :         terminated[i] = elem;
     943            3320 :     }
     944                 : 
     945             664 :     terminated
     946             664 : }
     947                 : 
     948                 : #[cfg(test)]
     949                 : mod tests {
     950                 :     use super::*;
     951                 : 
     952               1 :     #[test]
     953               1 :     fn test_startup_message_params_options_escaped() {
     954               4 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
     955               4 :             params
     956               4 :                 .options_escaped()
     957               4 :                 .expect("options are None")
     958               4 :                 .collect()
     959               4 :         }
     960               1 : 
     961               4 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
     962                 : 
     963               1 :         let params = StartupMessageParams::new([]);
     964               1 :         assert!(params.options_escaped().is_none());
     965                 : 
     966               1 :         let params = make_params("");
     967               1 :         assert!(split_options(&params).is_empty());
     968                 : 
     969               1 :         let params = make_params("foo");
     970               1 :         assert_eq!(split_options(&params), ["foo"]);
     971                 : 
     972               1 :         let params = make_params(" foo  bar ");
     973               1 :         assert_eq!(split_options(&params), ["foo", "bar"]);
     974                 : 
     975               1 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
     976               1 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
     977               1 :     }
     978                 : 
     979               1 :     #[test]
     980               1 :     fn parse_fe_startup_packet_regression() {
     981               1 :         let data = [0, 0, 0, 7, 0, 0, 0, 0];
     982               1 :         FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
     983               1 :     }
     984                 : }
        

Generated by: LCOV version 2.1-beta