LCOV - code coverage report
Current view: top level - libs/pq_proto/src - lib.rs (source / functions) Coverage Total Hit
Test: 53437f7e869ac68c86c7d3e4c20964c0156f158c.info Lines: 52.6 % 595 313
Test Date: 2024-09-20 16:14:12 Functions: 45.5 % 121 55

            Line data    Source code
       1              : //! Postgres protocol messages serialization-deserialization. See
       2              : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
       3              : //! on message formats.
       4              : #![deny(clippy::undocumented_unsafe_blocks)]
       5              : 
       6              : pub mod framed;
       7              : 
       8              : use byteorder::{BigEndian, ReadBytesExt};
       9              : use bytes::{Buf, BufMut, Bytes, BytesMut};
      10              : use itertools::Itertools;
      11              : use serde::{Deserialize, Serialize};
      12              : use std::{borrow::Cow, fmt, io, str};
      13              : 
      14              : // re-export for use in utils pageserver_feedback.rs
      15              : pub use postgres_protocol::PG_EPOCH;
      16              : 
      17              : pub type Oid = u32;
      18              : pub type SystemId = u64;
      19              : 
      20              : pub const INT8_OID: Oid = 20;
      21              : pub const INT4_OID: Oid = 23;
      22              : pub const TEXT_OID: Oid = 25;
      23              : 
      24              : #[derive(Debug)]
      25              : pub enum FeMessage {
      26              :     // Simple query.
      27              :     Query(Bytes),
      28              :     // Extended query protocol.
      29              :     Parse(FeParseMessage),
      30              :     Describe(FeDescribeMessage),
      31              :     Bind(FeBindMessage),
      32              :     Execute(FeExecuteMessage),
      33              :     Close(FeCloseMessage),
      34              :     Sync,
      35              :     Terminate,
      36              :     CopyData(Bytes),
      37              :     CopyDone,
      38              :     CopyFail,
      39              :     PasswordMessage(Bytes),
      40              : }
      41              : 
      42              : #[derive(Clone, Copy, PartialEq, PartialOrd)]
      43              : pub struct ProtocolVersion(u32);
      44              : 
      45              : impl ProtocolVersion {
      46            0 :     pub const fn new(major: u16, minor: u16) -> Self {
      47            0 :         Self((major as u32) << 16 | minor as u32)
      48            0 :     }
      49            0 :     pub const fn minor(self) -> u16 {
      50            0 :         self.0 as u16
      51            0 :     }
      52           24 :     pub const fn major(self) -> u16 {
      53           24 :         (self.0 >> 16) as u16
      54           24 :     }
      55              : }
      56              : 
      57              : impl fmt::Debug for ProtocolVersion {
      58            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      59            0 :         f.debug_list()
      60            0 :             .entry(&self.major())
      61            0 :             .entry(&self.minor())
      62            0 :             .finish()
      63            0 :     }
      64              : }
      65              : 
      66              : #[derive(Debug)]
      67              : pub enum FeStartupPacket {
      68              :     CancelRequest(CancelKeyData),
      69              :     SslRequest {
      70              :         direct: bool,
      71              :     },
      72              :     GssEncRequest,
      73              :     StartupMessage {
      74              :         version: ProtocolVersion,
      75              :         params: StartupMessageParams,
      76              :     },
      77              : }
      78              : 
      79              : #[derive(Debug, Clone, Default)]
      80              : pub struct StartupMessageParamsBuilder {
      81              :     params: BytesMut,
      82              : }
      83              : 
      84              : impl StartupMessageParamsBuilder {
      85              :     /// Set parameter's value by its name.
      86              :     /// name and value must not contain a \0 byte
      87           31 :     pub fn insert(&mut self, name: &str, value: &str) {
      88           31 :         self.params.put(name.as_bytes());
      89           31 :         self.params.put(&b"\0"[..]);
      90           31 :         self.params.put(value.as_bytes());
      91           31 :         self.params.put(&b"\0"[..]);
      92           31 :     }
      93              : 
      94           23 :     pub fn freeze(self) -> StartupMessageParams {
      95           23 :         StartupMessageParams {
      96           23 :             params: self.params.freeze(),
      97           23 :         }
      98           23 :     }
      99              : }
     100              : 
     101              : #[derive(Debug, Clone, Default)]
     102              : pub struct StartupMessageParams {
     103              :     params: Bytes,
     104              : }
     105              : 
     106              : impl StartupMessageParams {
     107              :     /// Get parameter's value by its name.
     108           48 :     pub fn get(&self, name: &str) -> Option<&str> {
     109           64 :         self.iter().find_map(|(k, v)| (k == name).then_some(v))
     110           48 :     }
     111              : 
     112              :     /// Split command-line options according to PostgreSQL's logic,
     113              :     /// taking into account all escape sequences but leaving them as-is.
     114              :     /// [`None`] means that there's no `options` in [`Self`].
     115           30 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
     116           30 :         self.get("options").map(Self::parse_options_raw)
     117           30 :     }
     118              : 
     119              :     /// Split command-line options according to PostgreSQL's logic,
     120              :     /// applying all escape sequences (using owned strings as needed).
     121              :     /// [`None`] means that there's no `options` in [`Self`].
     122            5 :     pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
     123            5 :         self.get("options").map(Self::parse_options_escaped)
     124            5 :     }
     125              : 
     126              :     /// Split command-line options according to PostgreSQL's logic,
     127              :     /// taking into account all escape sequences but leaving them as-is.
     128           30 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
     129           30 :         // See `postgres: pg_split_opts`.
     130           30 :         let mut last_was_escape = false;
     131           30 :         input
     132          564 :             .split(move |c: char| {
     133              :                 // We split by non-escaped whitespace symbols.
     134          564 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
     135          564 :                 last_was_escape = c == '\\' && !last_was_escape;
     136          564 :                 should_split
     137          564 :             })
     138           76 :             .filter(|s| !s.is_empty())
     139           30 :     }
     140              : 
     141              :     /// Split command-line options according to PostgreSQL's logic,
     142              :     /// applying all escape sequences (using owned strings as needed).
     143            4 :     pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
     144            4 :         // See `postgres: pg_split_opts`.
     145            7 :         Self::parse_options_raw(input).map(|s| {
     146            7 :             let mut preserve_next_escape = false;
     147           17 :             let escape = |c| {
     148              :                 // We should remove '\\' unless it's preceded by '\\'.
     149           17 :                 let should_remove = c == '\\' && !preserve_next_escape;
     150           17 :                 preserve_next_escape = should_remove;
     151           17 :                 should_remove
     152           17 :             };
     153              : 
     154            7 :             match s.contains('\\') {
     155            3 :                 true => Cow::Owned(s.replace(escape, "")),
     156            4 :                 false => Cow::Borrowed(s),
     157              :             }
     158            7 :         })
     159            4 :     }
     160              : 
     161              :     /// Iterate through key-value pairs in an arbitrary order.
     162           55 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     163           55 :         let params =
     164           55 :             std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
     165           55 :         params.split_terminator('\0').tuples()
     166           55 :     }
     167              : 
     168              :     // This function is mostly useful in tests.
     169              :     #[doc(hidden)]
     170           23 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     171           23 :         let mut b = StartupMessageParamsBuilder::default();
     172           54 :         for (k, v) in pairs {
     173           31 :             b.insert(k, v)
     174              :         }
     175           23 :         b.freeze()
     176           23 :     }
     177              : }
     178              : 
     179            6 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
     180              : pub struct CancelKeyData {
     181              :     pub backend_pid: i32,
     182              :     pub cancel_key: i32,
     183              : }
     184              : 
     185              : impl fmt::Display for CancelKeyData {
     186            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     187            0 :         let hi = (self.backend_pid as u64) << 32;
     188            0 :         let lo = self.cancel_key as u64;
     189            0 :         let id = hi | lo;
     190            0 : 
     191            0 :         // This format is more compact and might work better for logs.
     192            0 :         f.debug_tuple("CancelKeyData")
     193            0 :             .field(&format_args!("{:x}", id))
     194            0 :             .finish()
     195            0 :     }
     196              : }
     197              : 
     198              : use rand::distributions::{Distribution, Standard};
     199              : impl Distribution<CancelKeyData> for Standard {
     200            1 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     201            1 :         CancelKeyData {
     202            1 :             backend_pid: rng.gen(),
     203            1 :             cancel_key: rng.gen(),
     204            1 :         }
     205            1 :     }
     206              : }
     207              : 
     208              : // We only support the simple case of Parse on unnamed prepared statement and
     209              : // no params
     210              : #[derive(Debug)]
     211              : pub struct FeParseMessage {
     212              :     pub query_string: Bytes,
     213              : }
     214              : 
     215              : #[derive(Debug)]
     216              : pub struct FeDescribeMessage {
     217              :     pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
     218              :                   // we only support unnamed prepared stmt or portal
     219              : }
     220              : 
     221              : // we only support unnamed prepared stmt and portal
     222              : #[derive(Debug)]
     223              : pub struct FeBindMessage;
     224              : 
     225              : // we only support unnamed prepared stmt or portal
     226              : #[derive(Debug)]
     227              : pub struct FeExecuteMessage {
     228              :     /// max # of rows
     229              :     pub maxrows: i32,
     230              : }
     231              : 
     232              : // we only support unnamed prepared stmt and portal
     233              : #[derive(Debug)]
     234              : pub struct FeCloseMessage;
     235              : 
     236              : /// An error occurred while parsing or serializing raw stream into Postgres
     237              : /// messages.
     238            0 : #[derive(thiserror::Error, Debug)]
     239              : pub enum ProtocolError {
     240              :     /// Invalid packet was received from the client (e.g. unexpected message
     241              :     /// type or broken len).
     242              :     #[error("Protocol error: {0}")]
     243              :     Protocol(String),
     244              :     /// Failed to parse or, (unlikely), serialize a protocol message.
     245              :     #[error("Message parse error: {0}")]
     246              :     BadMessage(String),
     247              : }
     248              : 
     249              : impl ProtocolError {
     250              :     /// Proxy stream.rs uses only io::Error; provide it.
     251            0 :     pub fn into_io_error(self) -> io::Error {
     252            0 :         io::Error::new(io::ErrorKind::Other, self.to_string())
     253            0 :     }
     254              : }
     255              : 
     256              : impl FeMessage {
     257              :     /// Read and parse one message from the `buf` input buffer. If there is at
     258              :     /// least one valid message, returns it, advancing `buf`; redundant copies
     259              :     /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
     260              :     /// directly into the `buf` (processed data is garbage collected after
     261              :     /// parsed message is dropped).
     262              :     ///
     263              :     /// Returns None if `buf` doesn't contain enough data for a single message.
     264              :     /// For efficiency, tries to reserve large enough space in `buf` for the
     265              :     /// next message in this case to save the repeated calls.
     266              :     ///
     267              :     /// Returns Error if message is malformed, the only possible ErrorKind is
     268              :     /// InvalidInput.
     269              :     //
     270              :     // Inspired by rust-postgres Message::parse.
     271           57 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
     272           57 :         // Every message contains message type byte and 4 bytes len; can't do
     273           57 :         // much without them.
     274           57 :         if buf.len() < 5 {
     275           30 :             let to_read = 5 - buf.len();
     276           30 :             buf.reserve(to_read);
     277           30 :             return Ok(None);
     278           27 :         }
     279           27 : 
     280           27 :         // We shouldn't advance `buf` as probably full message is not there yet,
     281           27 :         // so can't directly use Bytes::get_u32 etc.
     282           27 :         let tag = buf[0];
     283           27 :         let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
     284           27 :         if len < 4 {
     285            0 :             return Err(ProtocolError::Protocol(format!(
     286            0 :                 "invalid message length {}",
     287            0 :                 len
     288            0 :             )));
     289           27 :         }
     290           27 : 
     291           27 :         // length field includes itself, but not message type.
     292           27 :         let total_len = len as usize + 1;
     293           27 :         if buf.len() < total_len {
     294              :             // Don't have full message yet.
     295            0 :             let to_read = total_len - buf.len();
     296            0 :             buf.reserve(to_read);
     297            0 :             return Ok(None);
     298           27 :         }
     299           27 : 
     300           27 :         // got the message, advance buffer
     301           27 :         let mut msg = buf.split_to(total_len).freeze();
     302           27 :         msg.advance(5); // consume message type and len
     303           27 : 
     304           27 :         match tag {
     305            2 :             b'Q' => Ok(Some(FeMessage::Query(msg))),
     306            0 :             b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
     307            0 :             b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
     308            0 :             b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
     309            0 :             b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
     310            0 :             b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
     311            0 :             b'S' => Ok(Some(FeMessage::Sync)),
     312            0 :             b'X' => Ok(Some(FeMessage::Terminate)),
     313            0 :             b'd' => Ok(Some(FeMessage::CopyData(msg))),
     314            0 :             b'c' => Ok(Some(FeMessage::CopyDone)),
     315            0 :             b'f' => Ok(Some(FeMessage::CopyFail)),
     316           25 :             b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
     317            0 :             tag => Err(ProtocolError::Protocol(format!(
     318            0 :                 "unknown message tag: {tag},'{msg:?}'"
     319            0 :             ))),
     320              :         }
     321           57 :     }
     322              : }
     323              : 
     324              : impl FeStartupPacket {
     325              :     /// Read and parse startup message from the `buf` input buffer. It is
     326              :     /// different from [`FeMessage::parse`] because startup messages don't have
     327              :     /// message type byte; otherwise, its comments apply.
     328           91 :     pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
     329              :         /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
     330              :         const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
     331              :         const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
     332              :         /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
     333              :         const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
     334              :         /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
     335              :         const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
     336              :         /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
     337              :         const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
     338              : 
     339              :         // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
     340              :         // First byte indicates standard SSL handshake message
     341              :         // (It can't be a Postgres startup length because in network byte order
     342              :         // that would be a startup packet hundreds of megabytes long)
     343           91 :         if buf.first() == Some(&0x16) {
     344            0 :             return Ok(Some(FeStartupPacket::SslRequest { direct: true }));
     345           91 :         }
     346           91 : 
     347           91 :         // need at least 4 bytes with packet len
     348           91 :         if buf.len() < 4 {
     349           45 :             let to_read = 4 - buf.len();
     350           45 :             buf.reserve(to_read);
     351           45 :             return Ok(None);
     352           46 :         }
     353           46 : 
     354           46 :         // We shouldn't advance `buf` as probably full message is not there yet,
     355           46 :         // so can't directly use Bytes::get_u32 etc.
     356           46 :         let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
     357           46 :         // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
     358           46 :         // which is less readable
     359           46 :         #[allow(clippy::manual_range_contains)]
     360           46 :         if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
     361            1 :             return Err(ProtocolError::Protocol(format!(
     362            1 :                 "invalid startup packet message length {}",
     363            1 :                 len
     364            1 :             )));
     365           45 :         }
     366           45 : 
     367           45 :         if buf.len() < len {
     368              :             // Don't have full message yet.
     369            0 :             let to_read = len - buf.len();
     370            0 :             buf.reserve(to_read);
     371            0 :             return Ok(None);
     372           45 :         }
     373           45 : 
     374           45 :         // got the message, advance buffer
     375           45 :         let mut msg = buf.split_to(len).freeze();
     376           45 :         msg.advance(4); // consume len
     377           45 : 
     378           45 :         let request_code = ProtocolVersion(msg.get_u32());
     379              :         // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
     380           45 :         let message = match request_code {
     381              :             CANCEL_REQUEST_CODE => {
     382            0 :                 if msg.remaining() != 8 {
     383            0 :                     return Err(ProtocolError::BadMessage(
     384            0 :                         "CancelRequest message is malformed, backend PID / secret key missing"
     385            0 :                             .to_owned(),
     386            0 :                     ));
     387            0 :                 }
     388            0 :                 FeStartupPacket::CancelRequest(CancelKeyData {
     389            0 :                     backend_pid: msg.get_i32(),
     390            0 :                     cancel_key: msg.get_i32(),
     391            0 :                 })
     392              :             }
     393              :             NEGOTIATE_SSL_CODE => {
     394              :                 // Requested upgrade to SSL (aka TLS)
     395           21 :                 FeStartupPacket::SslRequest { direct: false }
     396              :             }
     397              :             NEGOTIATE_GSS_CODE => {
     398              :                 // Requested upgrade to GSSAPI
     399            0 :                 FeStartupPacket::GssEncRequest
     400              :             }
     401           24 :             version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
     402            0 :                 return Err(ProtocolError::Protocol(format!(
     403            0 :                     "Unrecognized request code {}",
     404            0 :                     version.minor()
     405            0 :                 )));
     406              :             }
     407              :             // TODO bail if protocol major_version is not 3?
     408           24 :             version => {
     409              :                 // StartupMessage
     410              : 
     411           24 :                 let s = str::from_utf8(&msg).map_err(|_e| {
     412            0 :                     ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
     413           24 :                 })?;
     414           24 :                 let s = s.strip_suffix('\0').ok_or_else(|| {
     415            0 :                     ProtocolError::Protocol(
     416            0 :                         "StartupMessage params: missing null terminator".to_string(),
     417            0 :                     )
     418           24 :                 })?;
     419              : 
     420           24 :                 FeStartupPacket::StartupMessage {
     421           24 :                     version,
     422           24 :                     params: StartupMessageParams {
     423           24 :                         params: msg.slice_ref(s.as_bytes()),
     424           24 :                     },
     425           24 :                 }
     426              :             }
     427              :         };
     428           45 :         Ok(Some(message))
     429           91 :     }
     430              : }
     431              : 
     432              : impl FeParseMessage {
     433            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     434              :         // FIXME: the rust-postgres driver uses a named prepared statement
     435              :         // for copy_out(). We're not prepared to handle that correctly. For
     436              :         // now, just ignore the statement name, assuming that the client never
     437              :         // uses more than one prepared statement at a time.
     438              : 
     439            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     440            0 :         let query_string = read_cstr(&mut buf)?;
     441            0 :         if buf.remaining() < 2 {
     442            0 :             return Err(ProtocolError::BadMessage(
     443            0 :                 "Parse message is malformed, nparams missing".to_string(),
     444            0 :             ));
     445            0 :         }
     446            0 :         let nparams = buf.get_i16();
     447            0 : 
     448            0 :         if nparams != 0 {
     449            0 :             return Err(ProtocolError::BadMessage(
     450            0 :                 "query params not implemented".to_string(),
     451            0 :             ));
     452            0 :         }
     453            0 : 
     454            0 :         Ok(FeMessage::Parse(FeParseMessage { query_string }))
     455            0 :     }
     456              : }
     457              : 
     458              : impl FeDescribeMessage {
     459            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     460            0 :         let kind = buf.get_u8();
     461            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     462              : 
     463              :         // FIXME: see FeParseMessage::parse
     464            0 :         if kind != b'S' {
     465            0 :             return Err(ProtocolError::BadMessage(
     466            0 :                 "only prepared statemement Describe is implemented".to_string(),
     467            0 :             ));
     468            0 :         }
     469            0 : 
     470            0 :         Ok(FeMessage::Describe(FeDescribeMessage { kind }))
     471            0 :     }
     472              : }
     473              : 
     474              : impl FeExecuteMessage {
     475            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     476            0 :         let portal_name = read_cstr(&mut buf)?;
     477            0 :         if buf.remaining() < 4 {
     478            0 :             return Err(ProtocolError::BadMessage(
     479            0 :                 "FeExecuteMessage message is malformed, maxrows missing".to_string(),
     480            0 :             ));
     481            0 :         }
     482            0 :         let maxrows = buf.get_i32();
     483            0 : 
     484            0 :         if !portal_name.is_empty() {
     485            0 :             return Err(ProtocolError::BadMessage(
     486            0 :                 "named portals not implemented".to_string(),
     487            0 :             ));
     488            0 :         }
     489            0 :         if maxrows != 0 {
     490            0 :             return Err(ProtocolError::BadMessage(
     491            0 :                 "row limit in Execute message not implemented".to_string(),
     492            0 :             ));
     493            0 :         }
     494            0 : 
     495            0 :         Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
     496            0 :     }
     497              : }
     498              : 
     499              : impl FeBindMessage {
     500            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     501            0 :         let portal_name = read_cstr(&mut buf)?;
     502            0 :         let _pstmt_name = read_cstr(&mut buf)?;
     503              : 
     504              :         // FIXME: see FeParseMessage::parse
     505            0 :         if !portal_name.is_empty() {
     506            0 :             return Err(ProtocolError::BadMessage(
     507            0 :                 "named portals not implemented".to_string(),
     508            0 :             ));
     509            0 :         }
     510            0 : 
     511            0 :         Ok(FeMessage::Bind(FeBindMessage))
     512            0 :     }
     513              : }
     514              : 
     515              : impl FeCloseMessage {
     516            0 :     fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
     517            0 :         let _kind = buf.get_u8();
     518            0 :         let _pstmt_or_portal_name = read_cstr(&mut buf)?;
     519              : 
     520              :         // FIXME: we do nothing with Close
     521            0 :         Ok(FeMessage::Close(FeCloseMessage))
     522            0 :     }
     523              : }
     524              : 
     525              : // Backend
     526              : 
     527              : #[derive(Debug)]
     528              : pub enum BeMessage<'a> {
     529              :     AuthenticationOk,
     530              :     AuthenticationMD5Password([u8; 4]),
     531              :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     532              :     AuthenticationCleartextPassword,
     533              :     BackendKeyData(CancelKeyData),
     534              :     BindComplete,
     535              :     CommandComplete(&'a [u8]),
     536              :     CopyData(&'a [u8]),
     537              :     CopyDone,
     538              :     CopyFail,
     539              :     CopyInResponse,
     540              :     CopyOutResponse,
     541              :     CopyBothResponse,
     542              :     CloseComplete,
     543              :     // None means column is NULL
     544              :     DataRow(&'a [Option<&'a [u8]>]),
     545              :     // None errcode means internal_error will be sent.
     546              :     ErrorResponse(&'a str, Option<&'a [u8; 5]>),
     547              :     /// Single byte - used in response to SSLRequest/GSSENCRequest.
     548              :     EncryptionResponse(bool),
     549              :     NoData,
     550              :     ParameterDescription,
     551              :     ParameterStatus {
     552              :         name: &'a [u8],
     553              :         value: &'a [u8],
     554              :     },
     555              :     ParseComplete,
     556              :     ReadyForQuery,
     557              :     RowDescription(&'a [RowDescriptor<'a>]),
     558              :     XLogData(XLogDataBody<'a>),
     559              :     NoticeResponse(&'a str),
     560              :     NegotiateProtocolVersion {
     561              :         version: ProtocolVersion,
     562              :         options: &'a [&'a str],
     563              :     },
     564              :     KeepAlive(WalSndKeepAlive),
     565              : }
     566              : 
     567              : /// Common shorthands.
     568              : impl<'a> BeMessage<'a> {
     569              :     /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
     570              :     /// This is a sensible default, given that:
     571              :     ///  * rust strings only support this encoding out of the box.
     572              :     ///  * tokio-postgres, postgres-jdbc (and probably more) mandate it.
     573              :     ///
     574              :     /// TODO: do we need to report `server_encoding` as well?
     575              :     pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
     576              :         name: b"client_encoding",
     577              :         value: b"UTF8",
     578              :     };
     579              : 
     580              :     pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
     581              :         name: b"integer_datetimes",
     582              :         value: b"on",
     583              :     };
     584              : 
     585              :     /// Build a [`BeMessage::ParameterStatus`] holding the server version.
     586            2 :     pub fn server_version(version: &'a str) -> Self {
     587            2 :         Self::ParameterStatus {
     588            2 :             name: b"server_version",
     589            2 :             value: version.as_bytes(),
     590            2 :         }
     591            2 :     }
     592              : }
     593              : 
     594              : #[derive(Debug)]
     595              : pub enum BeAuthenticationSaslMessage<'a> {
     596              :     Methods(&'a [&'a str]),
     597              :     Continue(&'a [u8]),
     598              :     Final(&'a [u8]),
     599              : }
     600              : 
     601              : #[derive(Debug)]
     602              : pub enum BeParameterStatusMessage<'a> {
     603              :     Encoding(&'a str),
     604              :     ServerVersion(&'a str),
     605              : }
     606              : 
     607              : // One row description in RowDescription packet.
     608              : #[derive(Debug)]
     609              : pub struct RowDescriptor<'a> {
     610              :     pub name: &'a [u8],
     611              :     pub tableoid: Oid,
     612              :     pub attnum: i16,
     613              :     pub typoid: Oid,
     614              :     pub typlen: i16,
     615              :     pub typmod: i32,
     616              :     pub formatcode: i16,
     617              : }
     618              : 
     619              : impl Default for RowDescriptor<'_> {
     620            0 :     fn default() -> RowDescriptor<'static> {
     621            0 :         RowDescriptor {
     622            0 :             name: b"",
     623            0 :             tableoid: 0,
     624            0 :             attnum: 0,
     625            0 :             typoid: 0,
     626            0 :             typlen: 0,
     627            0 :             typmod: 0,
     628            0 :             formatcode: 0,
     629            0 :         }
     630            0 :     }
     631              : }
     632              : 
     633              : impl RowDescriptor<'_> {
     634              :     /// Convenience function to create a RowDescriptor message for an int8 column
     635            0 :     pub const fn int8_col(name: &[u8]) -> RowDescriptor {
     636            0 :         RowDescriptor {
     637            0 :             name,
     638            0 :             tableoid: 0,
     639            0 :             attnum: 0,
     640            0 :             typoid: INT8_OID,
     641            0 :             typlen: 8,
     642            0 :             typmod: 0,
     643            0 :             formatcode: 0,
     644            0 :         }
     645            0 :     }
     646              : 
     647            2 :     pub const fn text_col(name: &[u8]) -> RowDescriptor {
     648            2 :         RowDescriptor {
     649            2 :             name,
     650            2 :             tableoid: 0,
     651            2 :             attnum: 0,
     652            2 :             typoid: TEXT_OID,
     653            2 :             typlen: -1,
     654            2 :             typmod: 0,
     655            2 :             formatcode: 0,
     656            2 :         }
     657            2 :     }
     658              : }
     659              : 
     660              : #[derive(Debug)]
     661              : pub struct XLogDataBody<'a> {
     662              :     pub wal_start: u64,
     663              :     pub wal_end: u64, // current end of WAL on the server
     664              :     pub timestamp: i64,
     665              :     pub data: &'a [u8],
     666              : }
     667              : 
     668              : #[derive(Debug)]
     669              : pub struct WalSndKeepAlive {
     670              :     pub wal_end: u64, // current end of WAL on the server
     671              :     pub timestamp: i64,
     672              :     pub request_reply: bool,
     673              : }
     674              : 
     675              : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
     676              : 
     677              : // single text column
     678              : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
     679              :     name: b"data",
     680              :     tableoid: 0,
     681              :     attnum: 0,
     682              :     typoid: TEXT_OID,
     683              :     typlen: -1,
     684              :     typmod: 0,
     685              :     formatcode: 0,
     686              : }]);
     687              : 
     688              : /// Call f() to write body of the message and prepend it with 4-byte len as
     689              : /// prescribed by the protocol.
     690           75 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
     691           75 :     let base = buf.len();
     692           75 :     buf.extend_from_slice(&[0; 4]);
     693           75 : 
     694           75 :     let res = f(buf);
     695           75 : 
     696           75 :     let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
     697           75 :     (&mut buf[base..]).put_slice(&size.to_be_bytes());
     698           75 : 
     699           75 :     res
     700           75 : }
     701              : 
     702              : /// Safe write of s into buf as cstring (String in the protocol).
     703           56 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
     704           56 :     let bytes = s.as_ref();
     705           56 :     if bytes.contains(&0) {
     706            0 :         return Err(ProtocolError::BadMessage(
     707            0 :             "string contains embedded null".to_owned(),
     708            0 :         ));
     709           56 :     }
     710           56 :     buf.put_slice(bytes);
     711           56 :     buf.put_u8(0);
     712           56 :     Ok(())
     713           56 : }
     714              : 
     715              : /// Read cstring from buf, advancing it.
     716           11 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
     717           11 :     let pos = buf
     718           11 :         .iter()
     719          156 :         .position(|x| *x == 0)
     720           11 :         .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
     721           11 :     let result = buf.split_to(pos);
     722           11 :     buf.advance(1); // drop the null terminator
     723           11 :     Ok(result)
     724           11 : }
     725              : 
     726              : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
     727              : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
     728              : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
     729              : 
     730              : impl<'a> BeMessage<'a> {
     731              :     /// Serialize `message` to the given `buf`.
     732              :     /// Apart from smart memory managemet, BytesMut is good here as msg len
     733              :     /// precedes its body and it is handy to write it down first and then fill
     734              :     /// the length. With Write we would have to either calc it manually or have
     735              :     /// one more buffer.
     736           96 :     pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
     737           96 :         match message {
     738           12 :             BeMessage::AuthenticationOk => {
     739           12 :                 buf.put_u8(b'R');
     740           12 :                 write_body(buf, |buf| {
     741           12 :                     buf.put_i32(0); // Specifies that the authentication was successful.
     742           12 :                 });
     743           12 :             }
     744              : 
     745            2 :             BeMessage::AuthenticationCleartextPassword => {
     746            2 :                 buf.put_u8(b'R');
     747            2 :                 write_body(buf, |buf| {
     748            2 :                     buf.put_i32(3); // Specifies that clear text password is required.
     749            2 :                 });
     750            2 :             }
     751              : 
     752            0 :             BeMessage::AuthenticationMD5Password(salt) => {
     753            0 :                 buf.put_u8(b'R');
     754            0 :                 write_body(buf, |buf| {
     755            0 :                     buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
     756            0 :                     buf.put_slice(&salt[..]);
     757            0 :                 });
     758            0 :             }
     759              : 
     760           30 :             BeMessage::AuthenticationSasl(msg) => {
     761           30 :                 buf.put_u8(b'R');
     762           30 :                 write_body(buf, |buf| {
     763              :                     use BeAuthenticationSaslMessage::*;
     764           30 :                     match msg {
     765           13 :                         Methods(methods) => {
     766           13 :                             buf.put_i32(10); // Specifies that SASL auth method is used.
     767           25 :                             for method in methods.iter() {
     768           25 :                                 write_cstr(method, buf)?;
     769              :                             }
     770           13 :                             buf.put_u8(0); // zero terminator for the list
     771              :                         }
     772           11 :                         Continue(extra) => {
     773           11 :                             buf.put_i32(11); // Continue SASL auth.
     774           11 :                             buf.put_slice(extra);
     775           11 :                         }
     776            6 :                         Final(extra) => {
     777            6 :                             buf.put_i32(12); // Send final SASL message.
     778            6 :                             buf.put_slice(extra);
     779            6 :                         }
     780              :                     }
     781           30 :                     Ok(())
     782           30 :                 })?;
     783              :             }
     784              : 
     785            0 :             BeMessage::BackendKeyData(key_data) => {
     786            0 :                 buf.put_u8(b'K');
     787            0 :                 write_body(buf, |buf| {
     788            0 :                     buf.put_i32(key_data.backend_pid);
     789            0 :                     buf.put_i32(key_data.cancel_key);
     790            0 :                 });
     791            0 :             }
     792              : 
     793            0 :             BeMessage::BindComplete => {
     794            0 :                 buf.put_u8(b'2');
     795            0 :                 write_body(buf, |_| {});
     796            0 :             }
     797              : 
     798            0 :             BeMessage::CloseComplete => {
     799            0 :                 buf.put_u8(b'3');
     800            0 :                 write_body(buf, |_| {});
     801            0 :             }
     802              : 
     803            2 :             BeMessage::CommandComplete(cmd) => {
     804            2 :                 buf.put_u8(b'C');
     805            2 :                 write_body(buf, |buf| write_cstr(cmd, buf))?;
     806              :             }
     807              : 
     808            0 :             BeMessage::CopyData(data) => {
     809            0 :                 buf.put_u8(b'd');
     810            0 :                 write_body(buf, |buf| {
     811            0 :                     buf.put_slice(data);
     812            0 :                 });
     813            0 :             }
     814              : 
     815            0 :             BeMessage::CopyDone => {
     816            0 :                 buf.put_u8(b'c');
     817            0 :                 write_body(buf, |_| {});
     818            0 :             }
     819              : 
     820            0 :             BeMessage::CopyFail => {
     821            0 :                 buf.put_u8(b'f');
     822            0 :                 write_body(buf, |_| {});
     823            0 :             }
     824              : 
     825            0 :             BeMessage::CopyInResponse => {
     826            0 :                 buf.put_u8(b'G');
     827            0 :                 write_body(buf, |buf| {
     828            0 :                     buf.put_u8(1); // copy_is_binary
     829            0 :                     buf.put_i16(0); // numAttributes
     830            0 :                 });
     831            0 :             }
     832              : 
     833            0 :             BeMessage::CopyOutResponse => {
     834            0 :                 buf.put_u8(b'H');
     835            0 :                 write_body(buf, |buf| {
     836            0 :                     buf.put_u8(0); // copy_is_binary
     837            0 :                     buf.put_i16(0); // numAttributes
     838            0 :                 });
     839            0 :             }
     840              : 
     841            0 :             BeMessage::CopyBothResponse => {
     842            0 :                 buf.put_u8(b'W');
     843            0 :                 write_body(buf, |buf| {
     844            0 :                     // doesn't matter, used only for replication
     845            0 :                     buf.put_u8(0); // copy_is_binary
     846            0 :                     buf.put_i16(0); // numAttributes
     847            0 :                 });
     848            0 :             }
     849              : 
     850            2 :             BeMessage::DataRow(vals) => {
     851            2 :                 buf.put_u8(b'D');
     852            2 :                 write_body(buf, |buf| {
     853            2 :                     buf.put_u16(vals.len() as u16); // num of cols
     854            2 :                     for val_opt in vals.iter() {
     855            2 :                         if let Some(val) = val_opt {
     856            2 :                             buf.put_u32(val.len() as u32);
     857            2 :                             buf.put_slice(val);
     858            2 :                         } else {
     859            0 :                             buf.put_i32(-1);
     860            0 :                         }
     861              :                     }
     862            2 :                 });
     863            2 :             }
     864              : 
     865              :             // ErrorResponse is a zero-terminated array of zero-terminated fields.
     866              :             // First byte of each field represents type of this field. Set just enough fields
     867              :             // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
     868              :             // message text.
     869            1 :             BeMessage::ErrorResponse(error_msg, pg_error_code) => {
     870            1 :                 // 'E' signalizes ErrorResponse messages
     871            1 :                 buf.put_u8(b'E');
     872            1 :                 write_body(buf, |buf| {
     873            1 :                     buf.put_u8(b'S'); // severity
     874            1 :                     buf.put_slice(b"ERROR\0");
     875            1 : 
     876            1 :                     buf.put_u8(b'C'); // SQLSTATE error code
     877            1 :                     buf.put_slice(&terminate_code(
     878            1 :                         pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
     879            1 :                     ));
     880            1 : 
     881            1 :                     buf.put_u8(b'M'); // the message
     882            1 :                     write_cstr(error_msg, buf)?;
     883              : 
     884            1 :                     buf.put_u8(0); // terminator
     885            1 :                     Ok(())
     886            1 :                 })?;
     887              :             }
     888              : 
     889              :             // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
     890              :             // message but continue listening for ReadyForQuery or ErrorResponse"
     891            0 :             BeMessage::NoticeResponse(error_msg) => {
     892            0 :                 // For all the errors set Severity to Error and error code to
     893            0 :                 // 'internal error'.
     894            0 : 
     895            0 :                 // 'N' signalizes NoticeResponse messages
     896            0 :                 buf.put_u8(b'N');
     897            0 :                 write_body(buf, |buf| {
     898            0 :                     buf.put_u8(b'S'); // severity
     899            0 :                     buf.put_slice(b"NOTICE\0");
     900            0 : 
     901            0 :                     buf.put_u8(b'C'); // SQLSTATE error code
     902            0 :                     buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
     903            0 : 
     904            0 :                     buf.put_u8(b'M'); // the message
     905            0 :                     write_cstr(error_msg.as_bytes(), buf)?;
     906              : 
     907            0 :                     buf.put_u8(0); // terminator
     908            0 :                     Ok(())
     909            0 :                 })?;
     910              :             }
     911              : 
     912            0 :             BeMessage::NoData => {
     913            0 :                 buf.put_u8(b'n');
     914            0 :                 write_body(buf, |_| {});
     915            0 :             }
     916              : 
     917           21 :             BeMessage::EncryptionResponse(should_negotiate) => {
     918           21 :                 let response = if *should_negotiate { b'S' } else { b'N' };
     919           21 :                 buf.put_u8(response);
     920              :             }
     921              : 
     922           13 :             BeMessage::ParameterStatus { name, value } => {
     923           13 :                 buf.put_u8(b'S');
     924           13 :                 write_body(buf, |buf| {
     925           13 :                     write_cstr(name, buf)?;
     926           13 :                     write_cstr(value, buf)
     927           13 :                 })?;
     928              :             }
     929              : 
     930            0 :             BeMessage::ParameterDescription => {
     931            0 :                 buf.put_u8(b't');
     932            0 :                 write_body(buf, |buf| {
     933            0 :                     // we don't support params, so always 0
     934            0 :                     buf.put_i16(0);
     935            0 :                 });
     936            0 :             }
     937              : 
     938            0 :             BeMessage::ParseComplete => {
     939            0 :                 buf.put_u8(b'1');
     940            0 :                 write_body(buf, |_| {});
     941            0 :             }
     942              : 
     943           11 :             BeMessage::ReadyForQuery => {
     944           11 :                 buf.put_u8(b'Z');
     945           11 :                 write_body(buf, |buf| {
     946           11 :                     buf.put_u8(b'I');
     947           11 :                 });
     948           11 :             }
     949              : 
     950            2 :             BeMessage::RowDescription(rows) => {
     951            2 :                 buf.put_u8(b'T');
     952            2 :                 write_body(buf, |buf| {
     953            2 :                     buf.put_i16(rows.len() as i16); // # of fields
     954            2 :                     for row in rows.iter() {
     955            2 :                         write_cstr(row.name, buf)?;
     956            2 :                         buf.put_i32(0); /* table oid */
     957            2 :                         buf.put_i16(0); /* attnum */
     958            2 :                         buf.put_u32(row.typoid);
     959            2 :                         buf.put_i16(row.typlen);
     960            2 :                         buf.put_i32(-1); /* typmod */
     961            2 :                         buf.put_i16(0); /* format code */
     962              :                     }
     963            2 :                     Ok(())
     964            2 :                 })?;
     965              :             }
     966              : 
     967            0 :             BeMessage::XLogData(body) => {
     968            0 :                 buf.put_u8(b'd');
     969            0 :                 write_body(buf, |buf| {
     970            0 :                     buf.put_u8(b'w');
     971            0 :                     buf.put_u64(body.wal_start);
     972            0 :                     buf.put_u64(body.wal_end);
     973            0 :                     buf.put_i64(body.timestamp);
     974            0 :                     buf.put_slice(body.data);
     975            0 :                 });
     976            0 :             }
     977              : 
     978            0 :             BeMessage::KeepAlive(req) => {
     979            0 :                 buf.put_u8(b'd');
     980            0 :                 write_body(buf, |buf| {
     981            0 :                     buf.put_u8(b'k');
     982            0 :                     buf.put_u64(req.wal_end);
     983            0 :                     buf.put_i64(req.timestamp);
     984            0 :                     buf.put_u8(u8::from(req.request_reply));
     985            0 :                 });
     986            0 :             }
     987              : 
     988            0 :             BeMessage::NegotiateProtocolVersion { version, options } => {
     989            0 :                 buf.put_u8(b'v');
     990            0 :                 write_body(buf, |buf| {
     991            0 :                     buf.put_u32(version.0);
     992            0 :                     buf.put_u32(options.len() as u32);
     993            0 :                     for option in options.iter() {
     994            0 :                         write_cstr(option, buf)?;
     995              :                     }
     996            0 :                     Ok(())
     997            0 :                 })?
     998              :             }
     999              :         }
    1000           96 :         Ok(())
    1001           96 :     }
    1002              : }
    1003              : 
    1004            1 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
    1005            1 :     let mut terminated = [0; 6];
    1006            5 :     for (i, &elem) in code.iter().enumerate() {
    1007            5 :         terminated[i] = elem;
    1008            5 :     }
    1009              : 
    1010            1 :     terminated
    1011            1 : }
    1012              : 
    1013              : #[cfg(test)]
    1014              : mod tests {
    1015              :     use super::*;
    1016              : 
    1017              :     #[test]
    1018            1 :     fn test_startup_message_params_options_escaped() {
    1019            4 :         fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
    1020            4 :             params
    1021            4 :                 .options_escaped()
    1022            4 :                 .expect("options are None")
    1023            4 :                 .collect()
    1024            4 :         }
    1025              : 
    1026            4 :         let make_params = |options| StartupMessageParams::new([("options", options)]);
    1027              : 
    1028            1 :         let params = StartupMessageParams::new([]);
    1029            1 :         assert!(params.options_escaped().is_none());
    1030              : 
    1031            1 :         let params = make_params("");
    1032            1 :         assert!(split_options(&params).is_empty());
    1033              : 
    1034            1 :         let params = make_params("foo");
    1035            1 :         assert_eq!(split_options(&params), ["foo"]);
    1036              : 
    1037            1 :         let params = make_params(" foo  bar ");
    1038            1 :         assert_eq!(split_options(&params), ["foo", "bar"]);
    1039              : 
    1040            1 :         let params = make_params("foo\\ bar \\ \\\\ baz\\  lol");
    1041            1 :         assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
    1042            1 :     }
    1043              : 
    1044              :     #[test]
    1045            1 :     fn parse_fe_startup_packet_regression() {
    1046            1 :         let data = [0, 0, 0, 7, 0, 0, 0, 0];
    1047            1 :         FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
    1048            1 :     }
    1049              : }
        

Generated by: LCOV version 2.1-beta