LCOV - code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 84.8 % 584 495
Test Date: 2023-09-06 10:18:01 Functions: 74.6 % 122 91

            Line data    Source code
       1              : //! Postgres protocol messages serialization-deserialization. See
       2              : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3              : //! on message formats.
       4              : 
       5              : pub mod framed;
       6              : 
       7              : use byteorder::{BigEndian, ReadBytesExt};
       8              : use bytes::{Buf, BufMut, Bytes, BytesMut};
       9              : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
      10              : 
      11              : // re-export for use in utils pageserver_feedback.rs
      12              : pub use postgres_protocol::PG_EPOCH;
      13              : 
      14              : pub type Oid = u32;
      15              : pub type SystemId = u64;
      16              : 
      17              : pub const INT8_OID: Oid = 20;
      18              : pub const INT4_OID: Oid = 23;
      19              : pub const TEXT_OID: Oid = 25;
      20              : 
      21          135 : #[derive(Debug)]
      22              : pub enum FeMessage {
      23              :     // Simple query.
      24              :     Query(Bytes),
      25              :     // Extended query protocol.
      26              :     Parse(FeParseMessage),
      27              :     Describe(FeDescribeMessage),
      28              :     Bind(FeBindMessage),
      29              :     Execute(FeExecuteMessage),
      30              :     Close(FeCloseMessage),
      31              :     Sync,
      32              :     Terminate,
      33              :     CopyData(Bytes),
      34              :     CopyDone,
      35              :     CopyFail,
      36              :     PasswordMessage(Bytes),
      37              : }
      38              : 
      39          132 : #[derive(Debug)]
      40              : pub enum FeStartupPacket {
      41              :     CancelRequest(CancelKeyData),
      42              :     SslRequest,
      43              :     GssEncRequest,
      44              :     StartupMessage {
      45              :         major_version: u32,
      46              :         minor_version: u32,
      47              :         params: StartupMessageParams,
      48              :     },
      49              : }
      50              : 
      51           70 : #[derive(Debug)]
      52              : pub struct StartupMessageParams {
      53              :     params: HashMap<String, String>,
      54              : }
      55              : 
      56              : impl StartupMessageParams {
      57              :     /// Get parameter's value by its name.
      58         7750 :     pub fn get(&self, name: &str) -> Option<&str> {
      59         7750 :         self.params.get(name).map(|s| s.as_str())
      60         7750 :     }
      61              : 
      62              :     /// Split command-line options according to PostgreSQL's logic,
      63              :     /// taking into account all escape sequences but leaving them as-is.
      64              :     /// [`None`] means that there's no `options` in [`Self`].
      65         3819 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
      66         3819 :         self.get("options").map(Self::parse_options_raw)
      67         3819 :     }
      68              : 
      69              :     /// Split command-line options according to PostgreSQL's logic,
      70              :     /// applying all escape sequences (using owned strings as needed).
      71              :     /// [`None`] means that there's no `options` in [`Self`].
      72            5 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
      73            5 :         self.get("options").map(Self::parse_options_escaped)
      74            5 :     }
      75              : 
      76              :     /// Split command-line options according to PostgreSQL's logic,
      77              :     /// taking into account all escape sequences but leaving them as-is.
      78         3797 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
      79         3797 :         // See `postgres: pg_split_opts`.
      80         3797 :         let mut last_was_escape = false;
      81         3797 :         input
      82       338194 :             .split(move |c: char| {
      83              :                 // We split by non-escaped whitespace symbols.
      84       338194 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
      85       338194 :                 last_was_escape = c == '\\' && !last_was_escape;
      86       338194 :                 should_split
      87       338194 :             })
      88        11356 :             .filter(|s| !s.is_empty())
      89         3797 :     }
      90              : 
      91              :     /// Split command-line options according to PostgreSQL's logic,
      92              :     /// applying all escape sequences (using owned strings as needed).
      93            4 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
      94            4 :         // See `postgres: pg_split_opts`.
      95            7 :         Self::parse_options_raw(input).map(|s| {
      96            7 :             let mut preserve_next_escape = false;
      97           17 :             let escape = |c| {
      98              :                 // We should remove '\\' unless it's preceded by '\\'.
      99           17 :                 let should_remove = c == '\\' && !preserve_next_escape;
     100           17 :                 preserve_next_escape = should_remove;
     101           17 :                 should_remove
     102           17 :             };
     103              : 
     104            7 :             match s.contains('\\') {
     105            3 :                 true => Cow::Owned(s.replace(escape, "")),
     106            4 :                 false => Cow::Borrowed(s),
     107              :             }
     108            7 :         })
     109            4 :     }
     110              : 
     111              :     /// Iterate through key-value pairs in an arbitrary order.
     112            0 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     113            0 :         self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
     114            0 :     }
     115              : 
     116              :     // This function is mostly useful in tests.
     117              :     #[doc(hidden)]
     118           41 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     119           41 :         Self {
     120           88 :             params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
     121           41 :         }
     122           41 :     }
     123              : }
     124              : 
     125           92 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
     126              : pub struct CancelKeyData {
     127              :     pub backend_pid: i32,
     128              :     pub cancel_key: i32,
     129              : }
     130              : 
     131              : impl fmt::Display for CancelKeyData {
     132          124 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     133          124 :         let hi = (self.backend_pid as u64) << 32;
     134          124 :         let lo = self.cancel_key as u64;
     135          124 :         let id = hi | lo;
     136          124 : 
     137          124 :         // This format is more compact and might work better for logs.
     138          124 :         f.debug_tuple("CancelKeyData")
     139          124 :             .field(&format_args!("{:x}", id))
     140          124 :             .finish()
     141          124 :     }
     142              : }
     143              : 
     144              : use rand::distributions::{Distribution, Standard};
     145              : impl Distribution<CancelKeyData> for Standard {
     146           32 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     147           32 :         CancelKeyData {
     148           32 :             backend_pid: rng.gen(),
     149           32 :             cancel_key: rng.gen(),
     150           32 :         }
     151           32 :     }
     152              : }
     153              : 
     154              : // We only support the simple case of Parse on unnamed prepared statement and
     155              : // no params
     156            0 : #[derive(Debug)]
     157              : pub struct FeParseMessage {
     158              :     pub query_string: Bytes,
     159              : }
     160              : 
     161            0 : #[derive(Debug)]
     162              : pub struct FeDescribeMessage {
     163              :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     164              :                   // we only support unnamed prepared stmt or portal
     165              : }
     166              : 
     167              : // we only support unnamed prepared stmt and portal
     168            0 : #[derive(Debug)]
     169              : pub struct FeBindMessage;
     170              : 
     171              : // we only support unnamed prepared stmt or portal
     172            0 : #[derive(Debug)]
     173              : pub struct FeExecuteMessage {
     174              :     /// max # of rows
     175              :     pub maxrows: i32,
     176              : }
     177              : 
     178              : // we only support unnamed prepared stmt and portal
     179            0 : #[derive(Debug)]
     180              : pub struct FeCloseMessage;
     181              : 
     182              : /// An error occurred while parsing or serializing raw stream into Postgres
     183              : /// messages.
     184            0 : #[derive(thiserror::Error, Debug)]
     185              : pub enum ProtocolError {
     186              :     /// Invalid packet was received from the client (e.g. unexpected message
     187              :     /// type or broken len).
     188              :     #[error("Protocol error: {0}")]
     189              :     Protocol(String),
     190              :     /// Failed to parse or, (unlikely), serialize a protocol message.
     191              :     #[error("Message parse error: {0}")]
     192              :     BadMessage(String),
     193              : }
     194              : 
     195              : impl ProtocolError {
     196              :     /// Proxy stream.rs uses only io::Error; provide it.
     197            0 :     pub fn into_io_error(self) -> io::Error {
     198            0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     199            0 :     }
     200              : }
     201              : 
     202              : impl FeMessage {
     203              :     /// Read and parse one message from the `buf` input buffer. If there is at
     204              :     /// least one valid message, returns it, advancing `buf`; redundant copies
     205              :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     206              :     /// directly into the `buf` (processed data is garbage collected after
     207              :     /// parsed message is dropped).
     208              :     ///
     209              :     /// Returns None if `buf` doesn't contain enough data for a single message.
     210              :     /// For efficiency, tries to reserve large enough space in `buf` for the
     211              :     /// next message in this case to save the repeated calls.
     212              :     ///
     213              :     /// Returns Error if message is malformed, the only possible ErrorKind is
     214              :     /// InvalidInput.
     215              :     //
     216              :     // Inspired by rust-postgres Message::parse.
     217     15417306 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     218     15417306 :         // Every message contains message type byte and 4 bytes len; can't do
     219     15417306 :         // much without them.
     220     15417306 :         if buf.len() < 5 {
     221      7278005 :             let to_read = 5 - buf.len();
     222      7278005 :             buf.reserve(to_read);
     223      7278005 :             return Ok(None);
     224      8139301 :         }
     225      8139301 : 
     226      8139301 :         // We shouldn't advance `buf` as probably full message is not there yet,
     227      8139301 :         // so can't directly use Bytes::get_u32 etc.
     228      8139301 :         let tag = buf[0];
     229      8139301 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     230      8139301 :         if len < 4 {
     231            0 :             return Err(ProtocolError::Protocol(format!(
     232            0 :                 "invalid message length {}",
     233            0 :                 len
     234            0 :             )));
     235      8139301 :         }
     236      8139301 : 
     237      8139301 :         // length field includes itself, but not message type.
     238      8139301 :         let total_len = len as usize + 1;
     239      8139301 :         if buf.len() < total_len {
     240              :             // Don't have full message yet.
     241       111963 :             let to_read = total_len - buf.len();
     242       111963 :             buf.reserve(to_read);
     243       111963 :             return Ok(None);
     244      8027338 :         }
     245      8027338 : 
     246      8027338 :         // got the message, advance buffer
     247      8027338 :         let mut msg = buf.split_to(total_len).freeze();
     248      8027338 :         msg.advance(5); // consume message type and len
     249      8027338 : 
     250      8027338 :         match tag {
     251         9172 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     252          662 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     253          662 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     254          662 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     255          662 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     256          661 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     257         2001 :             b'S' => Ok(Some(FeMessage::Sync)),
     258         1586 :             b'X' => Ok(Some(FeMessage::Terminate)),
     259      8010840 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     260            6 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     261          152 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     262          272 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     263            0 :             tag => Err(ProtocolError::Protocol(format!(
     264            0 :                 "unknown message tag: {tag},'{msg:?}'"
     265            0 :             ))),
     266              :         }
     267     15417306 :     }
     268              : }
     269              : 
     270              : impl FeStartupPacket {
     271              :     /// Read and parse startup message from the `buf` input buffer. It is
     272              :     /// different from [`FeMessage::parse`] because startup messages don't have
     273              :     /// message type byte; otherwise, its comments apply.
     274        31999 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     275        31999 :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     276        31999 :         const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
     277        31999 :         const CANCEL_REQUEST_CODE: u32 = 5678;
     278        31999 :         const NEGOTIATE_SSL_CODE: u32 = 5679;
     279        31999 :         const NEGOTIATE_GSS_CODE: u32 = 5680;
     280        31999 : 
     281        31999 :         // need at least 4 bytes with packet len
     282        31999 :         if buf.len() < 4 {
     283        16000 :             let to_read = 4 - buf.len();
     284        16000 :             buf.reserve(to_read);
     285        16000 :             return Ok(None);
     286        15999 :         }
     287        15999 : 
     288        15999 :         // We shouldn't advance `buf` as probably full message is not there yet,
     289        15999 :         // so can't directly use Bytes::get_u32 etc.
     290        15999 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     291        15999 :         // The proposed replacement is `!(4..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     292        15999 :         // which is less readable
     293        15999 :         #[allow(clippy::manual_range_contains)]
     294        15999 :         if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
     295            0 :             return Err(ProtocolError::Protocol(format!(
     296            0 :                 "invalid startup packet message length {}",
     297            0 :                 len
     298            0 :             )));
     299        15999 :         }
     300        15999 : 
     301        15999 :         if buf.len() < len {
     302              :             // Don't have full message yet.
     303            0 :             let to_read = len - buf.len();
     304            0 :             buf.reserve(to_read);
     305            0 :             return Ok(None);
     306        15999 :         }
     307        15999 : 
     308        15999 :         // got the message, advance buffer
     309        15999 :         let mut msg = buf.split_to(len).freeze();
     310        15999 :         msg.advance(4); // consume len
     311        15999 : 
     312        15999 :         let request_code = msg.get_u32();
     313        15999 :         let req_hi = request_code >> 16;
     314        15999 :         let req_lo = request_code & ((1 << 16) - 1);
     315              :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     316        15999 :         let message = match (req_hi, req_lo) {
     317              :             (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
     318            0 :                 if msg.remaining() != 8 {
     319            0 :                     return Err(ProtocolError::BadMessage(
     320            0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     321            0 :                             .to_owned(),
     322            0 :                     ));
     323            0 :                 }
     324            0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     325            0 :                     backend_pid: msg.get_i32(),
     326            0 :                     cancel_key: msg.get_i32(),
     327            0 :                 })
     328              :             }
     329              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
     330              :                 // Requested upgrade to SSL (aka TLS)
     331         6898 :                 FeStartupPacket::SslRequest
     332              :             }
     333              :             (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
     334              :                 // Requested upgrade to GSSAPI
     335            0 :                 FeStartupPacket::GssEncRequest
     336              :             }
     337            0 :             (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
     338            0 :                 return Err(ProtocolError::Protocol(format!(
     339            0 :                     "Unrecognized request code {unrecognized_code}"
     340            0 :                 )));
     341              :             }
     342              :             // TODO bail if protocol major_version is not 3?
     343         9101 :             (major_version, minor_version) => {
     344              :                 // StartupMessage
     345              : 
     346              :                 // Parse pairs of null-terminated strings (key, value).
     347              :                 // See `postgres: ProcessStartupPacket, build_startup_packet`.
     348         9101 :                 let mut tokens = str::from_utf8(&msg)
     349         9101 :                     .map_err(|_e| {
     350            0 :                         ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     351         9101 :                     })?
     352         9101 :                     .strip_suffix('\0') // drop packet's own null
     353         9101 :                     .ok_or_else(|| {
     354            0 :                         ProtocolError::Protocol(
     355            0 :                             "StartupMessage params: missing null terminator".to_string(),
     356            0 :                         )
     357         9101 :                     })?
     358         9101 :                     .split_terminator('\0');
     359         9101 : 
     360         9101 :                 let mut params = HashMap::new();
     361        32027 :                 while let Some(name) = tokens.next() {
     362        22926 :                     let value = tokens.next().ok_or_else(|| {
     363            0 :                         ProtocolError::Protocol(
     364            0 :                             "StartupMessage params: key without value".to_string(),
     365            0 :                         )
     366        22926 :                     })?;
     367              : 
     368        22926 :                     params.insert(name.to_owned(), value.to_owned());
     369              :                 }
     370              : 
     371         9101 :                 FeStartupPacket::StartupMessage {
     372         9101 :                     major_version,
     373         9101 :                     minor_version,
     374         9101 :                     params: StartupMessageParams { params },
     375         9101 :                 }
     376              :             }
     377              :         };
     378        15999 :         Ok(Some(message))
     379        31999 :     }
     380              : }
     381              : 
     382              : impl FeParseMessage {
     383          662 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     384              :         // FIXME: the rust-postgres driver uses a named prepared statement
     385              :         // for copy_out(). We're not prepared to handle that correctly. For
     386              :         // now, just ignore the statement name, assuming that the client never
     387              :         // uses more than one prepared statement at a time.
     388              : 
     389          662 :         let _pstmt_name = read_cstr(&mut buf)?;
     390          662 :         let query_string = read_cstr(&mut buf)?;
     391          662 :         if buf.remaining() < 2 {
     392            0 :             return Err(ProtocolError::BadMessage(
     393            0 :                 "Parse message is malformed, nparams missing".to_string(),
     394            0 :             ));
     395          662 :         }
     396          662 :         let nparams = buf.get_i16();
     397          662 : 
     398          662 :         if nparams != 0 {
     399            0 :             return Err(ProtocolError::BadMessage(
     400            0 :                 "query params not implemented".to_string(),
     401            0 :             ));
     402          662 :         }
     403          662 : 
     404          662 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     405          662 :     }
     406              : }
     407              : 
     408              : impl FeDescribeMessage {
     409          662 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     410          662 :         let kind = buf.get_u8();
     411          662 :         let _pstmt_name = read_cstr(&mut buf)?;
     412              : 
     413              :         // FIXME: see FeParseMessage::parse
     414          662 :         if kind != b'S' {
     415            0 :             return Err(ProtocolError::BadMessage(
     416            0 :                 "only prepared statemement Describe is implemented".to_string(),
     417            0 :             ));
     418          662 :         }
     419          662 : 
     420          662 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     421          662 :     }
     422              : }
     423              : 
     424              : impl FeExecuteMessage {
     425          662 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     426          662 :         let portal_name = read_cstr(&mut buf)?;
     427          662 :         if buf.remaining() < 4 {
     428            0 :             return Err(ProtocolError::BadMessage(
     429            0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     430            0 :             ));
     431          662 :         }
     432          662 :         let maxrows = buf.get_i32();
     433          662 : 
     434          662 :         if !portal_name.is_empty() {
     435            0 :             return Err(ProtocolError::BadMessage(
     436            0 :                 "named portals not implemented".to_string(),
     437            0 :             ));
     438          662 :         }
     439          662 :         if maxrows != 0 {
     440            0 :             return Err(ProtocolError::BadMessage(
     441            0 :                 "row limit in Execute message not implemented".to_string(),
     442            0 :             ));
     443          662 :         }
     444          662 : 
     445          662 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     446          662 :     }
     447              : }
     448              : 
     449              : impl FeBindMessage {
     450          662 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     451          662 :         let portal_name = read_cstr(&mut buf)?;
     452          662 :         let _pstmt_name = read_cstr(&mut buf)?;
     453              : 
     454              :         // FIXME: see FeParseMessage::parse
     455          662 :         if !portal_name.is_empty() {
     456            0 :             return Err(ProtocolError::BadMessage(
     457            0 :                 "named portals not implemented".to_string(),
     458            0 :             ));
     459          662 :         }
     460          662 : 
     461          662 :         Ok(FeMessage::Bind(FeBindMessage))
     462          662 :     }
     463              : }
     464              : 
     465              : impl FeCloseMessage {
     466          661 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     467          661 :         let _kind = buf.get_u8();
     468          661 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     469              : 
     470              :         // FIXME: we do nothing with Close
     471          661 :         Ok(FeMessage::Close(FeCloseMessage))
     472          661 :     }
     473              : }
     474              : 
     475              : // Backend
     476              : 
     477            0 : #[derive(Debug)]
     478              : pub enum BeMessage<'a> {
     479              :     AuthenticationOk,
     480              :     AuthenticationMD5Password([u8; 4]),
     481              :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     482              :     AuthenticationCleartextPassword,
     483              :     BackendKeyData(CancelKeyData),
     484              :     BindComplete,
     485              :     CommandComplete(&'a [u8]),
     486              :     CopyData(&'a [u8]),
     487              :     CopyDone,
     488              :     CopyFail,
     489              :     CopyInResponse,
     490              :     CopyOutResponse,
     491              :     CopyBothResponse,
     492              :     CloseComplete,
     493              :     // None means column is NULL
     494              :     DataRow(&'a [Option<&'a [u8]>]),
     495              :     // None errcode means internal_error will be sent.
     496              :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     497              :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     498              :     EncryptionResponse(bool),
     499              :     NoData,
     500              :     ParameterDescription,
     501              :     ParameterStatus {
     502              :         name: &'a [u8],
     503              :         value: &'a [u8],
     504              :     },
     505              :     ParseComplete,
     506              :     ReadyForQuery,
     507              :     RowDescription(&'a [RowDescriptor<'a>]),
     508              :     XLogData(XLogDataBody<'a>),
     509              :     NoticeResponse(&'a str),
     510              :     KeepAlive(WalSndKeepAlive),
     511              : }
     512              : 
     513              : /// Common shorthands.
     514              : impl<'a> BeMessage<'a> {
     515              :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     516              :     /// This is a sensible default, given that:
     517              :     ///  * rust strings only support this encoding out of the box.
     518              :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     519              :     ///
     520              :     /// TODO: do we need to report `server_encoding` as well?
     521              :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     522              :         name: b"client_encoding",
     523              :         value: b"UTF8",
     524              :     };
     525              : 
     526              :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     527              :         name: b"integer_datetimes",
     528              :         value: b"on",
     529              :     };
     530              : 
     531              :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     532         8843 :     pub fn server_version(version: &'a str) -> Self {
     533         8843 :         Self::ParameterStatus {
     534         8843 :             name: b"server_version",
     535         8843 :             value: version.as_bytes(),
     536         8843 :         }
     537         8843 :     }
     538              : }
     539              : 
     540            0 : #[derive(Debug)]
     541              : pub enum BeAuthenticationSaslMessage<'a> {
     542              :     Methods(&'a [&'a str]),
     543              :     Continue(&'a [u8]),
     544              :     Final(&'a [u8]),
     545              : }
     546              : 
     547            0 : #[derive(Debug)]
     548              : pub enum BeParameterStatusMessage<'a> {
     549              :     Encoding(&'a str),
     550              :     ServerVersion(&'a str),
     551              : }
     552              : 
     553              : // One row description in RowDescription packet.
     554            0 : #[derive(Debug)]
     555              : pub struct RowDescriptor<'a> {
     556              :     pub name: &'a [u8],
     557              :     pub tableoid: Oid,
     558              :     pub attnum: i16,
     559              :     pub typoid: Oid,
     560              :     pub typlen: i16,
     561              :     pub typmod: i32,
     562              :     pub formatcode: i16,
     563              : }
     564              : 
     565              : impl Default for RowDescriptor<'_> {
     566         3099 :     fn default() -> RowDescriptor<'static> {
     567         3099 :         RowDescriptor {
     568         3099 :             name: b"",
     569         3099 :             tableoid: 0,
     570         3099 :             attnum: 0,
     571         3099 :             typoid: 0,
     572         3099 :             typlen: 0,
     573         3099 :             typmod: 0,
     574         3099 :             formatcode: 0,
     575         3099 :         }
     576         3099 :     }
     577              : }
     578              : 
     579              : impl RowDescriptor<'_> {
     580              :     /// Convenience function to create a RowDescriptor message for an int8 column
     581           45 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     582           45 :         RowDescriptor {
     583           45 :             name,
     584           45 :             tableoid: 0,
     585           45 :             attnum: 0,
     586           45 :             typoid: INT8_OID,
     587           45 :             typlen: 8,
     588           45 :             typmod: 0,
     589           45 :             formatcode: 0,
     590           45 :         }
     591           45 :     }
     592              : 
     593         1560 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     594         1560 :         RowDescriptor {
     595         1560 :             name,
     596         1560 :             tableoid: 0,
     597         1560 :             attnum: 0,
     598         1560 :             typoid: TEXT_OID,
     599         1560 :             typlen: -1,
     600         1560 :             typmod: 0,
     601         1560 :             formatcode: 0,
     602         1560 :         }
     603         1560 :     }
     604              : }
     605              : 
     606            0 : #[derive(Debug)]
     607              : pub struct XLogDataBody<'a> {
     608              :     pub wal_start: u64,
     609              :     pub wal_end: u64, // current end of WAL on the server
     610              :     pub timestamp: i64,
     611              :     pub data: &'a [u8],
     612              : }
     613              : 
     614            0 : #[derive(Debug)]
     615              : pub struct WalSndKeepAlive {
     616              :     pub wal_end: u64, // current end of WAL on the server
     617              :     pub timestamp: i64,
     618              :     pub request_reply: bool,
     619              : }
     620              : 
     621              : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     622              : 
     623              : // single text column
     624              : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     625              :     name: b"data",
     626              :     tableoid: 0,
     627              :     attnum: 0,
     628              :     typoid: TEXT_OID,
     629              :     typlen: -1,
     630              :     typmod: 0,
     631              :     formatcode: 0,
     632              : }]);
     633              : 
     634              : /// Call f() to write body of the message and prepend it with 4-byte len as
     635              : /// prescribed by the protocol.
     636      8628863 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     637      8628863 :     let base = buf.len();
     638      8628863 :     buf.extend_from_slice(&[0; 4]);
     639      8628863 : 
     640      8628863 :     let res = f(buf);
     641      8628863 : 
     642      8628863 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     643      8628863 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     644      8628863 : 
     645      8628863 :     res
     646      8628863 : }
     647              : 
     648              : /// Safe write of s into buf as cstring (String in the protocol).
     649        61453 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     650        61453 :     let bytes = s.as_ref();
     651        61453 :     if bytes.contains(&0) {
     652            0 :         return Err(ProtocolError::BadMessage(
     653            0 :             "string contains embedded null".to_owned(),
     654            0 :         ));
     655        61453 :     }
     656        61453 :     buf.put_slice(bytes);
     657        61453 :     buf.put_u8(0);
     658        61453 :     Ok(())
     659        61453 : }
     660              : 
     661              : /// Read cstring from buf, advancing it.
     662      3686889 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     663      3686889 :     let pos = buf
     664      3686889 :         .iter()
     665     52352084 :         .position(|x| *x == 0)
     666      3686889 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     667      3686889 :     let result = buf.split_to(pos);
     668      3686889 :     buf.advance(1); // drop the null terminator
     669      3686889 :     Ok(result)
     670      3686889 : }
     671              : 
     672              : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     673              : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     674              : 
     675              : impl<'a> BeMessage<'a> {
     676              :     /// Serialize `message` to the given `buf`.
     677              :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     678              :     /// precedes its body and it is handy to write it down first and then fill
     679              :     /// the length. With Write we would have to either calc it manually or have
     680              :     /// one more buffer.
     681      8635761 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     682      8635761 :         match message {
     683         9085 :             BeMessage::AuthenticationOk => {
     684         9085 :                 buf.put_u8(b'R');
     685         9085 :                 write_body(buf, |buf| {
     686         9085 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     687         9085 :                 });
     688         9085 :             }
     689              : 
     690          219 :             BeMessage::AuthenticationCleartextPassword => {
     691          219 :                 buf.put_u8(b'R');
     692          219 :                 write_body(buf, |buf| {
     693          219 :                     buf.put_i32(3); // Specifies that clear text password is required.
     694          219 :                 });
     695          219 :             }
     696              : 
     697            0 :             BeMessage::AuthenticationMD5Password(salt) => {
     698            0 :                 buf.put_u8(b'R');
     699            0 :                 write_body(buf, |buf| {
     700            0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     701            0 :                     buf.put_slice(&salt[..]);
     702            0 :                 });
     703            0 :             }
     704              : 
     705           83 :             BeMessage::AuthenticationSasl(msg) => {
     706           83 :                 buf.put_u8(b'R');
     707           83 :                 write_body(buf, |buf| {
     708           83 :                     use BeAuthenticationSaslMessage::*;
     709           83 :                     match msg {
     710           29 :                         Methods(methods) => {
     711           29 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     712           29 :                             for method in methods.iter() {
     713           29 :                                 write_cstr(method, buf)?;
     714              :                             }
     715           29 :                             buf.put_u8(0); // zero terminator for the list
     716              :                         }
     717           29 :                         Continue(extra) => {
     718           29 :                             buf.put_i32(11); // Continue SASL auth.
     719           29 :                             buf.put_slice(extra);
     720           29 :                         }
     721           25 :                         Final(extra) => {
     722           25 :                             buf.put_i32(12); // Send final SASL message.
     723           25 :                             buf.put_slice(extra);
     724           25 :                         }
     725              :                     }
     726           83 :                     Ok(())
     727           83 :                 })?;
     728              :             }
     729              : 
     730           27 :             BeMessage::BackendKeyData(key_data) => {
     731           27 :                 buf.put_u8(b'K');
     732           27 :                 write_body(buf, |buf| {
     733           27 :                     buf.put_i32(key_data.backend_pid);
     734           27 :                     buf.put_i32(key_data.cancel_key);
     735           27 :                 });
     736           27 :             }
     737              : 
     738          662 :             BeMessage::BindComplete => {
     739          662 :                 buf.put_u8(b'2');
     740          662 :                 write_body(buf, |_| {});
     741          662 :             }
     742              : 
     743          661 :             BeMessage::CloseComplete => {
     744          661 :                 buf.put_u8(b'3');
     745          661 :                 write_body(buf, |_| {});
     746          661 :             }
     747              : 
     748         2262 :             BeMessage::CommandComplete(cmd) => {
     749         2262 :                 buf.put_u8(b'C');
     750         2262 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     751              :             }
     752              : 
     753      7813023 :             BeMessage::CopyData(data) => {
     754      7813023 :                 buf.put_u8(b'd');
     755      7813023 :                 write_body(buf, |buf| {
     756      7813023 :                     buf.put_slice(data);
     757      7813023 :                 });
     758      7813023 :             }
     759              : 
     760          660 :             BeMessage::CopyDone => {
     761          660 :                 buf.put_u8(b'c');
     762          660 :                 write_body(buf, |_| {});
     763          660 :             }
     764              : 
     765            0 :             BeMessage::CopyFail => {
     766            0 :                 buf.put_u8(b'f');
     767            0 :                 write_body(buf, |_| {});
     768            0 :             }
     769              : 
     770            7 :             BeMessage::CopyInResponse => {
     771            7 :                 buf.put_u8(b'G');
     772            7 :                 write_body(buf, |buf| {
     773            7 :                     buf.put_u8(1); // copy_is_binary
     774            7 :                     buf.put_i16(0); // numAttributes
     775            7 :                 });
     776            7 :             }
     777              : 
     778          662 :             BeMessage::CopyOutResponse => {
     779          662 :                 buf.put_u8(b'H');
     780          662 :                 write_body(buf, |buf| {
     781          662 :                     buf.put_u8(0); // copy_is_binary
     782          662 :                     buf.put_i16(0); // numAttributes
     783          662 :                 });
     784          662 :             }
     785              : 
     786         7488 :             BeMessage::CopyBothResponse => {
     787         7488 :                 buf.put_u8(b'W');
     788         7488 :                 write_body(buf, |buf| {
     789         7488 :                     // doesn't matter, used only for replication
     790         7488 :                     buf.put_u8(0); // copy_is_binary
     791         7488 :                     buf.put_i16(0); // numAttributes
     792         7488 :                 });
     793         7488 :             }
     794              : 
     795         1056 :             BeMessage::DataRow(vals) => {
     796         1056 :                 buf.put_u8(b'D');
     797         1056 :                 write_body(buf, |buf| {
     798         1056 :                     buf.put_u16(vals.len() as u16); // num of cols
     799         3687 :                     for val_opt in vals.iter() {
     800         3687 :                         if let Some(val) = val_opt {
     801         2913 :                             buf.put_u32(val.len() as u32);
     802         2913 :                             buf.put_slice(val);
     803         2913 :                         } else {
     804          774 :                             buf.put_i32(-1);
     805          774 :                         }
     806              :                     }
     807         1056 :                 });
     808         1056 :             }
     809              : 
     810              :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     811              :             // First byte of each field represents type of this field. Set just enough fields
     812              :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     813              :             // message text.
     814          253 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     815          253 :                 // 'E' signalizes ErrorResponse messages
     816          253 :                 buf.put_u8(b'E');
     817          253 :                 write_body(buf, |buf| {
     818          253 :                     buf.put_u8(b'S'); // severity
     819          253 :                     buf.put_slice(b"ERROR\0");
     820          253 : 
     821          253 :                     buf.put_u8(b'C'); // SQLSTATE error code
     822          253 :                     buf.put_slice(&terminate_code(
     823          253 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     824          253 :                     ));
     825          253 : 
     826          253 :                     buf.put_u8(b'M'); // the message
     827          253 :                     write_cstr(error_msg, buf)?;
     828              : 
     829          253 :                     buf.put_u8(0); // terminator
     830          253 :                     Ok(())
     831          253 :                 })?;
     832              :             }
     833              : 
     834              :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     835              :             // message but continue listening for ReadyForQuery or ErrorResponse"
     836            6 :             BeMessage::NoticeResponse(error_msg) => {
     837            6 :                 // For all the errors set Severity to Error and error code to
     838            6 :                 // 'internal error'.
     839            6 : 
     840            6 :                 // 'N' signalizes NoticeResponse messages
     841            6 :                 buf.put_u8(b'N');
     842            6 :                 write_body(buf, |buf| {
     843            6 :                     buf.put_u8(b'S'); // severity
     844            6 :                     buf.put_slice(b"NOTICE\0");
     845            6 : 
     846            6 :                     buf.put_u8(b'C'); // SQLSTATE error code
     847            6 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     848            6 : 
     849            6 :                     buf.put_u8(b'M'); // the message
     850            6 :                     write_cstr(error_msg.as_bytes(), buf)?;
     851              : 
     852            6 :                     buf.put_u8(0); // terminator
     853            6 :                     Ok(())
     854            6 :                 })?;
     855              :             }
     856              : 
     857          662 :             BeMessage::NoData => {
     858          662 :                 buf.put_u8(b'n');
     859          662 :                 write_body(buf, |_| {});
     860          662 :             }
     861              : 
     862         6898 :             BeMessage::EncryptionResponse(should_negotiate) => {
     863         6898 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     864         6898 :                 buf.put_u8(response);
     865              :             }
     866              : 
     867        27098 :             BeMessage::ParameterStatus { name, value } => {
     868        27098 :                 buf.put_u8(b'S');
     869        27098 :                 write_body(buf, |buf| {
     870        27098 :                     write_cstr(name, buf)?;
     871        27098 :                     write_cstr(value, buf)
     872        27098 :                 })?;
     873              :             }
     874              : 
     875          662 :             BeMessage::ParameterDescription => {
     876          662 :                 buf.put_u8(b't');
     877          662 :                 write_body(buf, |buf| {
     878          662 :                     // we don't support params, so always 0
     879          662 :                     buf.put_i16(0);
     880          662 :                 });
     881          662 :             }
     882              : 
     883          662 :             BeMessage::ParseComplete => {
     884          662 :                 buf.put_u8(b'1');
     885          662 :                 write_body(buf, |_| {});
     886          662 :             }
     887              : 
     888        19814 :             BeMessage::ReadyForQuery => {
     889        19814 :                 buf.put_u8(b'Z');
     890        19814 :                 write_body(buf, |buf| {
     891        19814 :                     buf.put_u8(b'I');
     892        19814 :                 });
     893        19814 :             }
     894              : 
     895         1566 :             BeMessage::RowDescription(rows) => {
     896         1566 :                 buf.put_u8(b'T');
     897         1566 :                 write_body(buf, |buf| {
     898         1566 :                     buf.put_i16(rows.len() as i16); // # of fields
     899         4707 :                     for row in rows.iter() {
     900         4707 :                         write_cstr(row.name, buf)?;
     901         4707 :                         buf.put_i32(0); /* table oid */
     902         4707 :                         buf.put_i16(0); /* attnum */
     903         4707 :                         buf.put_u32(row.typoid);
     904         4707 :                         buf.put_i16(row.typlen);
     905         4707 :                         buf.put_i32(-1); /* typmod */
     906         4707 :                         buf.put_i16(0); /* format code */
     907              :                     }
     908         1566 :                     Ok(())
     909         1566 :                 })?;
     910              :             }
     911              : 
     912       740437 :             BeMessage::XLogData(body) => {
     913       740437 :                 buf.put_u8(b'd');
     914       740437 :                 write_body(buf, |buf| {
     915       740437 :                     buf.put_u8(b'w');
     916       740437 :                     buf.put_u64(body.wal_start);
     917       740437 :                     buf.put_u64(body.wal_end);
     918       740437 :                     buf.put_i64(body.timestamp);
     919       740437 :                     buf.put_slice(body.data);
     920       740437 :                 });
     921       740437 :             }
     922              : 
     923         1808 :             BeMessage::KeepAlive(req) => {
     924         1808 :                 buf.put_u8(b'd');
     925         1808 :                 write_body(buf, |buf| {
     926         1808 :                     buf.put_u8(b'k');
     927         1808 :                     buf.put_u64(req.wal_end);
     928         1808 :                     buf.put_i64(req.timestamp);
     929         1808 :                     buf.put_u8(u8::from(req.request_reply));
     930         1808 :                 });
     931         1808 :             }
     932              :         }
     933      8635761 :         Ok(())
     934      8635761 :     }
     935              : }
     936              : 
     937          259 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
     938          259 :     let mut terminated = [0; 6];
     939         1295 :     for (i, &elem) in code.iter().enumerate() {
     940         1295 :         terminated[i] = elem;
     941         1295 :     }
     942              : 
     943          259 :     terminated
     944          259 : }
     945              : 
     946              : #[cfg(test)]
     947              : mod tests {
     948              :     use super::*;
     949              : 
     950            1 :     #[test]
     951            1 :     fn test_startup_message_params_options_escaped() {
     952            4 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
     953            4 :             params
     954            4 :                 .options_escaped()
     955            4 :                 .expect("options are None")
     956            4 :                 .collect()
     957            4 :         }
     958            1 : 
     959            4 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
     960              : 
     961            1 :         let params = StartupMessageParams::new([]);
     962            1 :         assert!(matches!(params.options_escaped(), None));
     963              : 
     964            1 :         let params = make_params("");
     965            1 :         assert!(split_options(&params).is_empty());
     966              : 
     967            1 :         let params = make_params("foo");
     968            1 :         assert_eq!(split_options(&params), ["foo"]);
     969              : 
     970            1 :         let params = make_params(" foo  bar ");
     971            1 :         assert_eq!(split_options(&params), ["foo", "bar"]);
     972              : 
     973            1 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
     974            1 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
     975            1 :     }
     976              : }
        

Generated by: LCOV version 2.1-beta