LCOV - code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit
Test: aca8877be6ceba750c1be359ed71bc1799d52b30.info Lines: 86.2 % 589 508
Test Date: 2024-02-14 18:05:35 Functions: 68.2 % 148 101

            Line data    Source code
       1              : //! Postgres protocol messages serialization-deserialization. See
       2              : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3              : //! on message formats.
       4              : #![deny(clippy::undocumented_unsafe_blocks)]
       5              : 
       6              : pub mod framed;
       7              : 
       8              : use byteorder::{BigEndian, ReadBytesExt};
       9              : use bytes::{Buf, BufMut, Bytes, BytesMut};
      10              : use serde::{Deserialize, Serialize};
      11              : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
      12              : 
      13              : // re-export for use in utils pageserver_feedback.rs
      14              : pub use postgres_protocol::PG_EPOCH;
      15              : 
      16              : pub type Oid = u32;
      17              : pub type SystemId = u64;
      18              : 
      19              : pub const INT8_OID: Oid = 20;
      20              : pub const INT4_OID: Oid = 23;
      21              : pub const TEXT_OID: Oid = 25;
      22              : 
      23           46 : #[derive(Debug)]
      24              : pub enum FeMessage {
      25              :     // Simple query.
      26              :     Query(Bytes),
      27              :     // Extended query protocol.
      28              :     Parse(FeParseMessage),
      29              :     Describe(FeDescribeMessage),
      30              :     Bind(FeBindMessage),
      31              :     Execute(FeExecuteMessage),
      32              :     Close(FeCloseMessage),
      33              :     Sync,
      34              :     Terminate,
      35              :     CopyData(Bytes),
      36              :     CopyDone,
      37              :     CopyFail,
      38              :     PasswordMessage(Bytes),
      39              : }
      40              : 
      41          230 : #[derive(Debug)]
      42              : pub enum FeStartupPacket {
      43              :     CancelRequest(CancelKeyData),
      44              :     SslRequest,
      45              :     GssEncRequest,
      46              :     StartupMessage {
      47              :         major_version: u32,
      48              :         minor_version: u32,
      49              :         params: StartupMessageParams,
      50              :     },
      51              : }
      52              : 
      53          126 : #[derive(Debug)]
      54              : pub struct StartupMessageParams {
      55              :     params: HashMap<String, String>,
      56              : }
      57              : 
      58              : impl StartupMessageParams {
      59              :     /// Get parameter's value by its name.
      60         7242 :     pub fn get(&self, name: &str) -> Option<&str> {
      61         7242 :         self.params.get(name).map(|s| s.as_str())
      62         7242 :     }
      63              : 
      64              :     /// Split command-line options according to PostgreSQL's logic,
      65              :     /// taking into account all escape sequences but leaving them as-is.
      66              :     /// [`None`] means that there's no `options` in [`Self`].
      67         3510 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
      68         3510 :         self.get("options").map(Self::parse_options_raw)
      69         3510 :     }
      70              : 
      71              :     /// Split command-line options according to PostgreSQL's logic,
      72              :     /// applying all escape sequences (using owned strings as needed).
      73              :     /// [`None`] means that there's no `options` in [`Self`].
      74           10 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
      75           10 :         self.get("options").map(Self::parse_options_escaped)
      76           10 :     }
      77              : 
      78              :     /// Split command-line options according to PostgreSQL's logic,
      79              :     /// taking into account all escape sequences but leaving them as-is.
      80         3514 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
      81         3514 :         // See `postgres: pg_split_opts`.
      82         3514 :         let mut last_was_escape = false;
      83         3514 :         input
      84       304287 :             .split(move |c: char| {
      85              :                 // We split by non-escaped whitespace symbols.
      86       304287 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
      87       304287 :                 last_was_escape = c == '\\' && !last_was_escape;
      88       304287 :                 should_split
      89       304287 :             })
      90        10393 :             .filter(|s| !s.is_empty())
      91         3514 :     }
      92              : 
      93              :     /// Split command-line options according to PostgreSQL's logic,
      94              :     /// applying all escape sequences (using owned strings as needed).
      95            8 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
      96            8 :         // See `postgres: pg_split_opts`.
      97           14 :         Self::parse_options_raw(input).map(|s| {
      98           14 :             let mut preserve_next_escape = false;
      99           34 :             let escape = |c| {
     100              :                 // We should remove '\\' unless it's preceded by '\\'.
     101           34 :                 let should_remove = c == '\\' && !preserve_next_escape;
     102           34 :                 preserve_next_escape = should_remove;
     103           34 :                 should_remove
     104           34 :             };
     105              : 
     106           14 :             match s.contains('\\') {
     107            6 :                 true => Cow::Owned(s.replace(escape, "")),
     108            8 :                 false => Cow::Borrowed(s),
     109              :             }
     110           14 :         })
     111            8 :     }
     112              : 
     113              :     /// Iterate through key-value pairs in an arbitrary order.
     114           14 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     115           42 :         self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
     116           14 :     }
     117              : 
     118              :     // This function is mostly useful in tests.
     119              :     #[doc(hidden)]
     120           46 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     121           46 :         Self {
     122           62 :             params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
     123           46 :         }
     124           46 :     }
     125              : }
     126              : 
     127          301 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
     128              : pub struct CancelKeyData {
     129              :     pub backend_pid: i32,
     130              :     pub cancel_key: i32,
     131              : }
     132              : 
     133              : impl fmt::Display for CancelKeyData {
     134          164 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     135          164 :         let hi = (self.backend_pid as u64) << 32;
     136          164 :         let lo = self.cancel_key as u64;
     137          164 :         let id = hi | lo;
     138          164 : 
     139          164 :         // This format is more compact and might work better for logs.
     140          164 :         f.debug_tuple("CancelKeyData")
     141          164 :             .field(&format_args!("{:x}", id))
     142          164 :             .finish()
     143          164 :     }
     144              : }
     145              : 
     146              : use rand::distributions::{Distribution, Standard};
     147              : impl Distribution<CancelKeyData> for Standard {
     148           43 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     149           43 :         CancelKeyData {
     150           43 :             backend_pid: rng.gen(),
     151           43 :             cancel_key: rng.gen(),
     152           43 :         }
     153           43 :     }
     154              : }
     155              : 
     156              : // We only support the simple case of Parse on unnamed prepared statement and
     157              : // no params
     158            0 : #[derive(Debug)]
     159              : pub struct FeParseMessage {
     160              :     pub query_string: Bytes,
     161              : }
     162              : 
     163            0 : #[derive(Debug)]
     164              : pub struct FeDescribeMessage {
     165              :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     166              :                   // we only support unnamed prepared stmt or portal
     167              : }
     168              : 
     169              : // we only support unnamed prepared stmt and portal
     170            0 : #[derive(Debug)]
     171              : pub struct FeBindMessage;
     172              : 
     173              : // we only support unnamed prepared stmt or portal
     174            0 : #[derive(Debug)]
     175              : pub struct FeExecuteMessage {
     176              :     /// max # of rows
     177              :     pub maxrows: i32,
     178              : }
     179              : 
     180              : // we only support unnamed prepared stmt and portal
     181            0 : #[derive(Debug)]
     182              : pub struct FeCloseMessage;
     183              : 
     184              : /// An error occurred while parsing or serializing raw stream into Postgres
     185              : /// messages.
     186            6 : #[derive(thiserror::Error, Debug)]
     187              : pub enum ProtocolError {
     188              :     /// Invalid packet was received from the client (e.g. unexpected message
     189              :     /// type or broken len).
     190              :     #[error("Protocol error: {0}")]
     191              :     Protocol(String),
     192              :     /// Failed to parse or, (unlikely), serialize a protocol message.
     193              :     #[error("Message parse error: {0}")]
     194              :     BadMessage(String),
     195              : }
     196              : 
     197              : impl ProtocolError {
     198              :     /// Proxy stream.rs uses only io::Error; provide it.
     199            0 :     pub fn into_io_error(self) -> io::Error {
     200            0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     201            0 :     }
     202              : }
     203              : 
     204              : impl FeMessage {
     205              :     /// Read and parse one message from the `buf` input buffer. If there is at
     206              :     /// least one valid message, returns it, advancing `buf`; redundant copies
     207              :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     208              :     /// directly into the `buf` (processed data is garbage collected after
     209              :     /// parsed message is dropped).
     210              :     ///
     211              :     /// Returns None if `buf` doesn't contain enough data for a single message.
     212              :     /// For efficiency, tries to reserve large enough space in `buf` for the
     213              :     /// next message in this case to save the repeated calls.
     214              :     ///
     215              :     /// Returns Error if message is malformed, the only possible ErrorKind is
     216              :     /// InvalidInput.
     217              :     //
     218              :     // Inspired by rust-postgres Message::parse.
     219     14509451 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     220     14509451 :         // Every message contains message type byte and 4 bytes len; can't do
     221     14509451 :         // much without them.
     222     14509451 :         if buf.len() < 5 {
     223      6600309 :             let to_read = 5 - buf.len();
     224      6600309 :             buf.reserve(to_read);
     225      6600309 :             return Ok(None);
     226      7909142 :         }
     227      7909142 : 
     228      7909142 :         // We shouldn't advance `buf` as probably full message is not there yet,
     229      7909142 :         // so can't directly use Bytes::get_u32 etc.
     230      7909142 :         let tag = buf[0];
     231      7909142 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     232      7909142 :         if len < 4 {
     233            0 :             return Err(ProtocolError::Protocol(format!(
     234            0 :                 "invalid message length {}",
     235            0 :                 len
     236            0 :             )));
     237      7909142 :         }
     238      7909142 : 
     239      7909142 :         // length field includes itself, but not message type.
     240      7909142 :         let total_len = len as usize + 1;
     241      7909142 :         if buf.len() < total_len {
     242              :             // Don't have full message yet.
     243       170419 :             let to_read = total_len - buf.len();
     244       170419 :             buf.reserve(to_read);
     245       170419 :             return Ok(None);
     246      7738723 :         }
     247      7738723 : 
     248      7738723 :         // got the message, advance buffer
     249      7738723 :         let mut msg = buf.split_to(total_len).freeze();
     250      7738723 :         msg.advance(5); // consume message type and len
     251      7738723 : 
     252      7738723 :         match tag {
     253        14089 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     254          630 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     255          630 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     256          630 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     257          630 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     258          629 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     259         1909 :             b'S' => Ok(Some(FeMessage::Sync)),
     260         2170 :             b'X' => Ok(Some(FeMessage::Terminate)),
     261      7716986 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     262           13 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     263           73 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     264          334 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     265            0 :             tag => Err(ProtocolError::Protocol(format!(
     266            0 :                 "unknown message tag: {tag},'{msg:?}'"
     267            0 :             ))),
     268              :         }
     269     14509451 :     }
     270              : }
     271              : 
     272              : impl FeStartupPacket {
     273              :     /// Read and parse startup message from the `buf` input buffer. It is
     274              :     /// different from [`FeMessage::parse`] because startup messages don't have
     275              :     /// message type byte; otherwise, its comments apply.
     276        52036 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     277        52036 :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     278        52036 :         const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
     279        52036 :         const CANCEL_REQUEST_CODE: u32 = 5678;
     280        52036 :         const NEGOTIATE_SSL_CODE: u32 = 5679;
     281        52036 :         const NEGOTIATE_GSS_CODE: u32 = 5680;
     282        52036 : 
     283        52036 :         // need at least 4 bytes with packet len
     284        52036 :         if buf.len() < 4 {
     285        26018 :             let to_read = 4 - buf.len();
     286        26018 :             buf.reserve(to_read);
     287        26018 :             return Ok(None);
     288        26018 :         }
     289        26018 : 
     290        26018 :         // We shouldn't advance `buf` as probably full message is not there yet,
     291        26018 :         // so can't directly use Bytes::get_u32 etc.
     292        26018 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     293        26018 :         // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     294        26018 :         // which is less readable
     295        26018 :         #[allow(clippy::manual_range_contains)]
     296        26018 :         if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
     297            2 :             return Err(ProtocolError::Protocol(format!(
     298            2 :                 "invalid startup packet message length {}",
     299            2 :                 len
     300            2 :             )));
     301        26016 :         }
     302        26016 : 
     303        26016 :         if buf.len() < len {
     304              :             // Don't have full message yet.
     305            0 :             let to_read = len - buf.len();
     306            0 :             buf.reserve(to_read);
     307            0 :             return Ok(None);
     308        26016 :         }
     309        26016 : 
     310        26016 :         // got the message, advance buffer
     311        26016 :         let mut msg = buf.split_to(len).freeze();
     312        26016 :         msg.advance(4); // consume len
     313        26016 : 
     314        26016 :         let request_code = msg.get_u32();
     315        26016 :         let req_hi = request_code >> 16;
     316        26016 :         let req_lo = request_code & ((1 << 16) - 1);
     317              :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     318        26016 :         let message = match (req_hi, req_lo) {
     319              :             (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
     320            0 :                 if msg.remaining() != 8 {
     321            0 :                     return Err(ProtocolError::BadMessage(
     322            0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     323            0 :                             .to_owned(),
     324            0 :                     ));
     325            0 :                 }
     326            0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     327            0 :                     backend_pid: msg.get_i32(),
     328            0 :                     cancel_key: msg.get_i32(),
     329            0 :                 })
     330              :             }
     331              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
     332              :                 // Requested upgrade to SSL (aka TLS)
     333        11974 :                 FeStartupPacket::SslRequest
     334              :             }
     335              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
     336              :                 // Requested upgrade to GSSAPI
     337            0 :                 FeStartupPacket::GssEncRequest
     338              :             }
     339            0 :             (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
     340            0 :                 return Err(ProtocolError::Protocol(format!(
     341            0 :                     "Unrecognized request code {unrecognized_code}"
     342            0 :                 )));
     343              :             }
     344              :             // TODO bail if protocol major_version is not 3?
     345        14042 :             (major_version, minor_version) => {
     346              :                 // StartupMessage
     347              : 
     348              :                 // Parse pairs of null-terminated strings (key, value).
     349              :                 // See `postgres: ProcessStartupPacket, build_startup_packet`.
     350        14042 :                 let mut tokens = str::from_utf8(&msg)
     351        14042 :                     .map_err(|_e| {
     352            0 :                         ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     353        14042 :                     })?
     354        14042 :                     .strip_suffix('\0') // drop packet's own null
     355        14042 :                     .ok_or_else(|| {
     356            0 :                         ProtocolError::Protocol(
     357            0 :                             "StartupMessage params: missing null terminator".to_string(),
     358            0 :                         )
     359        14042 :                     })?
     360        14042 :                     .split_terminator('\0');
     361        14042 : 
     362        14042 :                 let mut params = HashMap::new();
     363        46390 :                 while let Some(name) = tokens.next() {
     364        32348 :                     let value = tokens.next().ok_or_else(|| {
     365            0 :                         ProtocolError::Protocol(
     366            0 :                             "StartupMessage params: key without value".to_string(),
     367            0 :                         )
     368        32348 :                     })?;
     369              : 
     370        32348 :                     params.insert(name.to_owned(), value.to_owned());
     371              :                 }
     372              : 
     373        14042 :                 FeStartupPacket::StartupMessage {
     374        14042 :                     major_version,
     375        14042 :                     minor_version,
     376        14042 :                     params: StartupMessageParams { params },
     377        14042 :                 }
     378              :             }
     379              :         };
     380        26016 :         Ok(Some(message))
     381        52036 :     }
     382              : }
     383              : 
     384              : impl FeParseMessage {
     385          630 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     386              :         // FIXME: the rust-postgres driver uses a named prepared statement
     387              :         // for copy_out(). We're not prepared to handle that correctly. For
     388              :         // now, just ignore the statement name, assuming that the client never
     389              :         // uses more than one prepared statement at a time.
     390              : 
     391          630 :         let _pstmt_name = read_cstr(&mut buf)?;
     392          630 :         let query_string = read_cstr(&mut buf)?;
     393          630 :         if buf.remaining() < 2 {
     394            0 :             return Err(ProtocolError::BadMessage(
     395            0 :                 "Parse message is malformed, nparams missing".to_string(),
     396            0 :             ));
     397          630 :         }
     398          630 :         let nparams = buf.get_i16();
     399          630 : 
     400          630 :         if nparams != 0 {
     401            0 :             return Err(ProtocolError::BadMessage(
     402            0 :                 "query params not implemented".to_string(),
     403            0 :             ));
     404          630 :         }
     405          630 : 
     406          630 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     407          630 :     }
     408              : }
     409              : 
     410              : impl FeDescribeMessage {
     411          630 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     412          630 :         let kind = buf.get_u8();
     413          630 :         let _pstmt_name = read_cstr(&mut buf)?;
     414              : 
     415              :         // FIXME: see FeParseMessage::parse
     416          630 :         if kind != b'S' {
     417            0 :             return Err(ProtocolError::BadMessage(
     418            0 :                 "only prepared statemement Describe is implemented".to_string(),
     419            0 :             ));
     420          630 :         }
     421          630 : 
     422          630 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     423          630 :     }
     424              : }
     425              : 
     426              : impl FeExecuteMessage {
     427          630 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     428          630 :         let portal_name = read_cstr(&mut buf)?;
     429          630 :         if buf.remaining() < 4 {
     430            0 :             return Err(ProtocolError::BadMessage(
     431            0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     432            0 :             ));
     433          630 :         }
     434          630 :         let maxrows = buf.get_i32();
     435          630 : 
     436          630 :         if !portal_name.is_empty() {
     437            0 :             return Err(ProtocolError::BadMessage(
     438            0 :                 "named portals not implemented".to_string(),
     439            0 :             ));
     440          630 :         }
     441          630 :         if maxrows != 0 {
     442            0 :             return Err(ProtocolError::BadMessage(
     443            0 :                 "row limit in Execute message not implemented".to_string(),
     444            0 :             ));
     445          630 :         }
     446          630 : 
     447          630 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     448          630 :     }
     449              : }
     450              : 
     451              : impl FeBindMessage {
     452          630 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     453          630 :         let portal_name = read_cstr(&mut buf)?;
     454          630 :         let _pstmt_name = read_cstr(&mut buf)?;
     455              : 
     456              :         // FIXME: see FeParseMessage::parse
     457          630 :         if !portal_name.is_empty() {
     458            0 :             return Err(ProtocolError::BadMessage(
     459            0 :                 "named portals not implemented".to_string(),
     460            0 :             ));
     461          630 :         }
     462          630 : 
     463          630 :         Ok(FeMessage::Bind(FeBindMessage))
     464          630 :     }
     465              : }
     466              : 
     467              : impl FeCloseMessage {
     468          629 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     469          629 :         let _kind = buf.get_u8();
     470          629 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     471              : 
     472              :         // FIXME: we do nothing with Close
     473          629 :         Ok(FeMessage::Close(FeCloseMessage))
     474          629 :     }
     475              : }
     476              : 
     477              : // Backend
     478              : 
     479            0 : #[derive(Debug)]
     480              : pub enum BeMessage<'a> {
     481              :     AuthenticationOk,
     482              :     AuthenticationMD5Password([u8; 4]),
     483              :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     484              :     AuthenticationCleartextPassword,
     485              :     BackendKeyData(CancelKeyData),
     486              :     BindComplete,
     487              :     CommandComplete(&'a [u8]),
     488              :     CopyData(&'a [u8]),
     489              :     CopyDone,
     490              :     CopyFail,
     491              :     CopyInResponse,
     492              :     CopyOutResponse,
     493              :     CopyBothResponse,
     494              :     CloseComplete,
     495              :     // None means column is NULL
     496              :     DataRow(&'a [Option<&'a [u8]>]),
     497              :     // None errcode means internal_error will be sent.
     498              :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     499              :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     500              :     EncryptionResponse(bool),
     501              :     NoData,
     502              :     ParameterDescription,
     503              :     ParameterStatus {
     504              :         name: &'a [u8],
     505              :         value: &'a [u8],
     506              :     },
     507              :     ParseComplete,
     508              :     ReadyForQuery,
     509              :     RowDescription(&'a [RowDescriptor<'a>]),
     510              :     XLogData(XLogDataBody<'a>),
     511              :     NoticeResponse(&'a str),
     512              :     KeepAlive(WalSndKeepAlive),
     513              : }
     514              : 
     515              : /// Common shorthands.
     516              : impl<'a> BeMessage<'a> {
     517              :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     518              :     /// This is a sensible default, given that:
     519              :     ///  * rust strings only support this encoding out of the box.
     520              :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     521              :     ///
     522              :     /// TODO: do we need to report `server_encoding` as well?
     523              :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     524              :         name: b"client_encoding",
     525              :         value: b"UTF8",
     526              :     };
     527              : 
     528              :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     529              :         name: b"integer_datetimes",
     530              :         value: b"on",
     531              :     };
     532              : 
     533              :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     534        13719 :     pub fn server_version(version: &'a str) -> Self {
     535        13719 :         Self::ParameterStatus {
     536        13719 :             name: b"server_version",
     537        13719 :             value: version.as_bytes(),
     538        13719 :         }
     539        13719 :     }
     540              : }
     541              : 
     542            0 : #[derive(Debug)]
     543              : pub enum BeAuthenticationSaslMessage<'a> {
     544              :     Methods(&'a [&'a str]),
     545              :     Continue(&'a [u8]),
     546              :     Final(&'a [u8]),
     547              : }
     548              : 
     549            0 : #[derive(Debug)]
     550              : pub enum BeParameterStatusMessage<'a> {
     551              :     Encoding(&'a str),
     552              :     ServerVersion(&'a str),
     553              : }
     554              : 
     555              : // One row description in RowDescription packet.
     556            0 : #[derive(Debug)]
     557              : pub struct RowDescriptor<'a> {
     558              :     pub name: &'a [u8],
     559              :     pub tableoid: Oid,
     560              :     pub attnum: i16,
     561              :     pub typoid: Oid,
     562              :     pub typlen: i16,
     563              :     pub typmod: i32,
     564              :     pub formatcode: i16,
     565              : }
     566              : 
     567              : impl Default for RowDescriptor<'_> {
     568         3091 :     fn default() -> RowDescriptor<'static> {
     569         3091 :         RowDescriptor {
     570         3091 :             name: b"",
     571         3091 :             tableoid: 0,
     572         3091 :             attnum: 0,
     573         3091 :             typoid: 0,
     574         3091 :             typlen: 0,
     575         3091 :             typmod: 0,
     576         3091 :             formatcode: 0,
     577         3091 :         }
     578         3091 :     }
     579              : }
     580              : 
     581              : impl RowDescriptor<'_> {
     582              :     /// Convenience function to create a RowDescriptor message for an int8 column
     583           45 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     584           45 :         RowDescriptor {
     585           45 :             name,
     586           45 :             tableoid: 0,
     587           45 :             attnum: 0,
     588           45 :             typoid: INT8_OID,
     589           45 :             typlen: 8,
     590           45 :             typmod: 0,
     591           45 :             formatcode: 0,
     592           45 :         }
     593           45 :     }
     594              : 
     595         1342 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     596         1342 :         RowDescriptor {
     597         1342 :             name,
     598         1342 :             tableoid: 0,
     599         1342 :             attnum: 0,
     600         1342 :             typoid: TEXT_OID,
     601         1342 :             typlen: -1,
     602         1342 :             typmod: 0,
     603         1342 :             formatcode: 0,
     604         1342 :         }
     605         1342 :     }
     606              : }
     607              : 
     608            0 : #[derive(Debug)]
     609              : pub struct XLogDataBody<'a> {
     610              :     pub wal_start: u64,
     611              :     pub wal_end: u64, // current end of WAL on the server
     612              :     pub timestamp: i64,
     613              :     pub data: &'a [u8],
     614              : }
     615              : 
     616            0 : #[derive(Debug)]
     617              : pub struct WalSndKeepAlive {
     618              :     pub wal_end: u64, // current end of WAL on the server
     619              :     pub timestamp: i64,
     620              :     pub request_reply: bool,
     621              : }
     622              : 
     623              : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     624              : 
     625              : // single text column
     626              : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     627              :     name: b"data",
     628              :     tableoid: 0,
     629              :     attnum: 0,
     630              :     typoid: TEXT_OID,
     631              :     typlen: -1,
     632              :     typmod: 0,
     633              :     formatcode: 0,
     634              : }]);
     635              : 
     636              : /// Call f() to write body of the message and prepend it with 4-byte len as
     637              : /// prescribed by the protocol.
     638      7828537 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     639      7828537 :     let base = buf.len();
     640      7828537 :     buf.extend_from_slice(&[0; 4]);
     641      7828537 : 
     642      7828537 :     let res = f(buf);
     643      7828537 : 
     644      7828537 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     645      7828537 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     646      7828537 : 
     647      7828537 :     res
     648      7828537 : }
     649              : 
     650              : /// Safe write of s into buf as cstring (String in the protocol).
     651        90737 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     652        90737 :     let bytes = s.as_ref();
     653        90737 :     if bytes.contains(&0) {
     654            0 :         return Err(ProtocolError::BadMessage(
     655            0 :             "string contains embedded null".to_owned(),
     656            0 :         ));
     657        90737 :     }
     658        90737 :     buf.put_slice(bytes);
     659        90737 :     buf.put_u8(0);
     660        90737 :     Ok(())
     661        90737 : }
     662              : 
     663              : /// Read cstring from buf, advancing it.
     664      3764746 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     665      3764746 :     let pos = buf
     666      3764746 :         .iter()
     667     53457811 :         .position(|x| *x == 0)
     668      3764746 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     669      3764746 :     let result = buf.split_to(pos);
     670      3764746 :     buf.advance(1); // drop the null terminator
     671      3764746 :     Ok(result)
     672      3764746 : }
     673              : 
     674              : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     675              : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
     676              : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     677              : 
     678              : impl<'a> BeMessage<'a> {
     679              :     /// Serialize `message` to the given `buf`.
     680              :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     681              :     /// precedes its body and it is handy to write it down first and then fill
     682              :     /// the length. With Write we would have to either calc it manually or have
     683              :     /// one more buffer.
     684      7840511 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     685      7840511 :         match message {
     686        13984 :             BeMessage::AuthenticationOk => {
     687        13984 :                 buf.put_u8(b'R');
     688        13984 :                 write_body(buf, |buf| {
     689        13984 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     690        13984 :                 });
     691        13984 :             }
     692              : 
     693          219 :             BeMessage::AuthenticationCleartextPassword => {
     694          219 :                 buf.put_u8(b'R');
     695          219 :                 write_body(buf, |buf| {
     696          219 :                     buf.put_i32(3); // Specifies that clear text password is required.
     697          219 :                 });
     698          219 :             }
     699              : 
     700            0 :             BeMessage::AuthenticationMD5Password(salt) => {
     701            0 :                 buf.put_u8(b'R');
     702            0 :                 write_body(buf, |buf| {
     703            0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     704            0 :                     buf.put_slice(&salt[..]);
     705            0 :                 });
     706            0 :             }
     707              : 
     708          168 :             BeMessage::AuthenticationSasl(msg) => {
     709          168 :                 buf.put_u8(b'R');
     710          168 :                 write_body(buf, |buf| {
     711          168 :                     use BeAuthenticationSaslMessage::*;
     712          168 :                     match msg {
     713           63 :                         Methods(methods) => {
     714           63 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     715          126 :                             for method in methods.iter() {
     716          126 :                                 write_cstr(method, buf)?;
     717              :                             }
     718           63 :                             buf.put_u8(0); // zero terminator for the list
     719              :                         }
     720           59 :                         Continue(extra) => {
     721           59 :                             buf.put_i32(11); // Continue SASL auth.
     722           59 :                             buf.put_slice(extra);
     723           59 :                         }
     724           46 :                         Final(extra) => {
     725           46 :                             buf.put_i32(12); // Send final SASL message.
     726           46 :                             buf.put_slice(extra);
     727           46 :                         }
     728              :                     }
     729          168 :                     Ok(())
     730          168 :                 })?;
     731              :             }
     732              : 
     733           41 :             BeMessage::BackendKeyData(key_data) => {
     734           41 :                 buf.put_u8(b'K');
     735           41 :                 write_body(buf, |buf| {
     736           41 :                     buf.put_i32(key_data.backend_pid);
     737           41 :                     buf.put_i32(key_data.cancel_key);
     738           41 :                 });
     739           41 :             }
     740              : 
     741          630 :             BeMessage::BindComplete => {
     742          630 :                 buf.put_u8(b'2');
     743          630 :                 write_body(buf, |_| {});
     744          630 :             }
     745              : 
     746          629 :             BeMessage::CloseComplete => {
     747          629 :                 buf.put_u8(b'3');
     748          629 :                 write_body(buf, |_| {});
     749          629 :             }
     750              : 
     751         2093 :             BeMessage::CommandComplete(cmd) => {
     752         2093 :                 buf.put_u8(b'C');
     753         2093 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     754              :             }
     755              : 
     756      6963135 :             BeMessage::CopyData(data) => {
     757      6963135 :                 buf.put_u8(b'd');
     758      6963135 :                 write_body(buf, |buf| {
     759      6963135 :                     buf.put_slice(data);
     760      6963135 :                 });
     761      6963135 :             }
     762              : 
     763          595 :             BeMessage::CopyDone => {
     764          595 :                 buf.put_u8(b'c');
     765          595 :                 write_body(buf, |_| {});
     766          595 :             }
     767              : 
     768            0 :             BeMessage::CopyFail => {
     769            0 :                 buf.put_u8(b'f');
     770            0 :                 write_body(buf, |_| {});
     771            0 :             }
     772              : 
     773           14 :             BeMessage::CopyInResponse => {
     774           14 :                 buf.put_u8(b'G');
     775           14 :                 write_body(buf, |buf| {
     776           14 :                     buf.put_u8(1); // copy_is_binary
     777           14 :                     buf.put_i16(0); // numAttributes
     778           14 :                 });
     779           14 :             }
     780              : 
     781          607 :             BeMessage::CopyOutResponse => {
     782          607 :                 buf.put_u8(b'H');
     783          607 :                 write_body(buf, |buf| {
     784          607 :                     buf.put_u8(0); // copy_is_binary
     785          607 :                     buf.put_i16(0); // numAttributes
     786          607 :                 });
     787          607 :             }
     788              : 
     789        12540 :             BeMessage::CopyBothResponse => {
     790        12540 :                 buf.put_u8(b'W');
     791        12540 :                 write_body(buf, |buf| {
     792        12540 :                     // doesn't matter, used only for replication
     793        12540 :                     buf.put_u8(0); // copy_is_binary
     794        12540 :                     buf.put_i16(0); // numAttributes
     795        12540 :                 });
     796        12540 :             }
     797              : 
     798          982 :             BeMessage::DataRow(vals) => {
     799          982 :                 buf.put_u8(b'D');
     800          982 :                 write_body(buf, |buf| {
     801          982 :                     buf.put_u16(vals.len() as u16); // num of cols
     802         3533 :                     for val_opt in vals.iter() {
     803         3533 :                         if let Some(val) = val_opt {
     804         2761 :                             buf.put_u32(val.len() as u32);
     805         2761 :                             buf.put_slice(val);
     806         2761 :                         } else {
     807          772 :                             buf.put_i32(-1);
     808          772 :                         }
     809              :                     }
     810          982 :                 });
     811          982 :             }
     812              : 
     813              :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     814              :             // First byte of each field represents type of this field. Set just enough fields
     815              :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     816              :             // message text.
     817          197 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     818          197 :                 // 'E' signalizes ErrorResponse messages
     819          197 :                 buf.put_u8(b'E');
     820          197 :                 write_body(buf, |buf| {
     821          197 :                     buf.put_u8(b'S'); // severity
     822          197 :                     buf.put_slice(b"ERROR\0");
     823          197 : 
     824          197 :                     buf.put_u8(b'C'); // SQLSTATE error code
     825          197 :                     buf.put_slice(&terminate_code(
     826          197 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     827          197 :                     ));
     828          197 : 
     829          197 :                     buf.put_u8(b'M'); // the message
     830          197 :                     write_cstr(error_msg, buf)?;
     831              : 
     832          197 :                     buf.put_u8(0); // terminator
     833          197 :                     Ok(())
     834          197 :                 })?;
     835              :             }
     836              : 
     837              :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     838              :             // message but continue listening for ReadyForQuery or ErrorResponse"
     839            6 :             BeMessage::NoticeResponse(error_msg) => {
     840            6 :                 // For all the errors set Severity to Error and error code to
     841            6 :                 // 'internal error'.
     842            6 : 
     843            6 :                 // 'N' signalizes NoticeResponse messages
     844            6 :                 buf.put_u8(b'N');
     845            6 :                 write_body(buf, |buf| {
     846            6 :                     buf.put_u8(b'S'); // severity
     847            6 :                     buf.put_slice(b"NOTICE\0");
     848            6 : 
     849            6 :                     buf.put_u8(b'C'); // SQLSTATE error code
     850            6 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     851            6 : 
     852            6 :                     buf.put_u8(b'M'); // the message
     853            6 :                     write_cstr(error_msg.as_bytes(), buf)?;
     854              : 
     855            6 :                     buf.put_u8(0); // terminator
     856            6 :                     Ok(())
     857            6 :                 })?;
     858              :             }
     859              : 
     860          630 :             BeMessage::NoData => {
     861          630 :                 buf.put_u8(b'n');
     862          630 :                 write_body(buf, |_| {});
     863          630 :             }
     864              : 
     865        11974 :             BeMessage::EncryptionResponse(should_negotiate) => {
     866        11974 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     867        11974 :                 buf.put_u8(response);
     868              :             }
     869              : 
     870        41917 :             BeMessage::ParameterStatus { name, value } => {
     871        41917 :                 buf.put_u8(b'S');
     872        41917 :                 write_body(buf, |buf| {
     873        41917 :                     write_cstr(name, buf)?;
     874        41917 :                     write_cstr(value, buf)
     875        41917 :                 })?;
     876              :             }
     877              : 
     878          630 :             BeMessage::ParameterDescription => {
     879          630 :                 buf.put_u8(b't');
     880          630 :                 write_body(buf, |buf| {
     881          630 :                     // we don't support params, so always 0
     882          630 :                     buf.put_i16(0);
     883          630 :                 });
     884          630 :             }
     885              : 
     886          630 :             BeMessage::ParseComplete => {
     887          630 :                 buf.put_u8(b'1');
     888          630 :                 write_body(buf, |_| {});
     889          630 :             }
     890              : 
     891        29314 :             BeMessage::ReadyForQuery => {
     892        29314 :                 buf.put_u8(b'Z');
     893        29314 :                 write_body(buf, |buf| {
     894        29314 :                     buf.put_u8(b'I');
     895        29314 :                 });
     896        29314 :             }
     897              : 
     898         1456 :             BeMessage::RowDescription(rows) => {
     899         1456 :                 buf.put_u8(b'T');
     900         1456 :                 write_body(buf, |buf| {
     901         1456 :                     buf.put_i16(rows.len() as i16); // # of fields
     902         4481 :                     for row in rows.iter() {
     903         4481 :                         write_cstr(row.name, buf)?;
     904         4481 :                         buf.put_i32(0); /* table oid */
     905         4481 :                         buf.put_i16(0); /* attnum */
     906         4481 :                         buf.put_u32(row.typoid);
     907         4481 :                         buf.put_i16(row.typlen);
     908         4481 :                         buf.put_i32(-1); /* typmod */
     909         4481 :                         buf.put_i16(0); /* format code */
     910              :                     }
     911         1456 :                     Ok(())
     912         1456 :                 })?;
     913              :             }
     914              : 
     915       754616 :             BeMessage::XLogData(body) => {
     916       754616 :                 buf.put_u8(b'd');
     917       754616 :                 write_body(buf, |buf| {
     918       754616 :                     buf.put_u8(b'w');
     919       754616 :                     buf.put_u64(body.wal_start);
     920       754616 :                     buf.put_u64(body.wal_end);
     921       754616 :                     buf.put_i64(body.timestamp);
     922       754616 :                     buf.put_slice(body.data);
     923       754616 :                 });
     924       754616 :             }
     925              : 
     926         3504 :             BeMessage::KeepAlive(req) => {
     927         3504 :                 buf.put_u8(b'd');
     928         3504 :                 write_body(buf, |buf| {
     929         3504 :                     buf.put_u8(b'k');
     930         3504 :                     buf.put_u64(req.wal_end);
     931         3504 :                     buf.put_i64(req.timestamp);
     932         3504 :                     buf.put_u8(u8::from(req.request_reply));
     933         3504 :                 });
     934         3504 :             }
     935              :         }
     936      7840511 :         Ok(())
     937      7840511 :     }
     938              : }
     939              : 
     940          203 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
     941          203 :     let mut terminated = [0; 6];
     942         1015 :     for (i, &elem) in code.iter().enumerate() {
     943         1015 :         terminated[i] = elem;
     944         1015 :     }
     945              : 
     946          203 :     terminated
     947          203 : }
     948              : 
     949              : #[cfg(test)]
     950              : mod tests {
     951              :     use super::*;
     952              : 
     953            2 :     #[test]
     954            2 :     fn test_startup_message_params_options_escaped() {
     955            8 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
     956            8 :             params
     957            8 :                 .options_escaped()
     958            8 :                 .expect("options are None")
     959            8 :                 .collect()
     960            8 :         }
     961            2 : 
     962            8 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
     963              : 
     964            2 :         let params = StartupMessageParams::new([]);
     965            2 :         assert!(params.options_escaped().is_none());
     966              : 
     967            2 :         let params = make_params("");
     968            2 :         assert!(split_options(&params).is_empty());
     969              : 
     970            2 :         let params = make_params("foo");
     971            2 :         assert_eq!(split_options(&params), ["foo"]);
     972              : 
     973            2 :         let params = make_params(" foo  bar ");
     974            2 :         assert_eq!(split_options(&params), ["foo", "bar"]);
     975              : 
     976            2 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
     977            2 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
     978            2 :     }
     979              : 
     980            2 :     #[test]
     981            2 :     fn parse_fe_startup_packet_regression() {
     982            2 :         let data = [0, 0, 0, 7, 0, 0, 0, 0];
     983            2 :         FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
     984            2 :     }
     985              : }
        

Generated by: LCOV version 2.1-beta