LCOV - code coverage report
Current view: top level - proxy/src - pqproto.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 78.1 % 352 275
Test Date: 2025-07-16 12:29:03 Functions: 72.0 % 100 72

            Line data    Source code
       1              : //! Postgres protocol codec
       2              : //!
       3              : //! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
       4              : 
       5              : use std::fmt;
       6              : use std::io::{self, Cursor};
       7              : 
       8              : use bytes::{Buf, BufMut};
       9              : use itertools::Itertools;
      10              : use rand::distributions::{Distribution, Standard};
      11              : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
      12              : use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
      13              : 
      14              : pub type ErrorCode = [u8; 5];
      15              : 
      16              : pub const FE_PASSWORD_MESSAGE: u8 = b'p';
      17              : 
      18              : pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
      19              : 
      20              : /// The protocol version number.
      21              : ///
      22              : /// The most significant 16 bits are the major version number (3 for the protocol described here).
      23              : /// The least significant 16 bits are the minor version number (0 for the protocol described here).
      24              : /// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
      25              : #[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
      26              : #[repr(C)]
      27              : pub struct ProtocolVersion {
      28              :     major: big_endian::U16,
      29              :     minor: big_endian::U16,
      30              : }
      31              : 
      32              : impl ProtocolVersion {
      33            3 :     pub const fn new(major: u16, minor: u16) -> Self {
      34            3 :         Self {
      35            3 :             major: big_endian::U16::new(major),
      36            3 :             minor: big_endian::U16::new(minor),
      37            3 :         }
      38            3 :     }
      39            1 :     pub const fn minor(self) -> u16 {
      40            1 :         self.minor.get()
      41            1 :     }
      42           24 :     pub const fn major(self) -> u16 {
      43           24 :         self.major.get()
      44           24 :     }
      45              : }
      46              : 
      47              : impl fmt::Debug for ProtocolVersion {
      48            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
      49            0 :         f.debug_list()
      50            0 :             .entry(&self.major())
      51            0 :             .entry(&self.minor())
      52            0 :             .finish()
      53            0 :     }
      54              : }
      55              : 
      56              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
      57              : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
      58              : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
      59              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
      60              : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
      61              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
      62              : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
      63              : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
      64              : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
      65              : 
      66              : /// This first reads the startup message header, is 8 bytes.
      67              : /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
      68              : ///
      69              : /// The length value is inclusive of the header. For example,
      70              : /// an empty message will always have length 8.
      71              : #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
      72              : #[repr(C)]
      73              : struct StartupHeader {
      74              :     len: big_endian::U32,
      75              :     version: ProtocolVersion,
      76              : }
      77              : 
      78              : /// read the type from the stream using zerocopy.
      79              : ///
      80              : /// not cancel safe.
      81              : macro_rules! read {
      82              :     ($s:expr => $t:ty) => {{
      83              :         // cannot be implemented as a function due to lack of const-generic-expr
      84              :         let mut buf = [0; size_of::<$t>()];
      85              :         $s.read_exact(&mut buf).await?;
      86              :         let res: $t = zerocopy::transmute!(buf);
      87              :         res
      88              :     }};
      89              : }
      90              : 
      91              : /// Returns true if TLS is supported.
      92              : ///
      93              : /// This is not cancel safe.
      94            0 : pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
      95            0 : where
      96            0 :     S: AsyncRead + AsyncWrite + Unpin,
      97            0 : {
      98            0 :     let payload = StartupHeader {
      99            0 :         len: 8.into(),
     100            0 :         version: NEGOTIATE_SSL_CODE,
     101            0 :     };
     102            0 :     stream.write_all(payload.as_bytes()).await?;
     103            0 :     stream.flush().await?;
     104              : 
     105              :     // we expect back either `S` or `N` as a single byte.
     106            0 :     let mut res = *b"0";
     107            0 :     stream.read_exact(&mut res).await?;
     108              : 
     109            0 :     debug_assert!(
     110            0 :         res == *b"S" || res == *b"N",
     111            0 :         "unexpected SSL negotiation response: {}",
     112            0 :         char::from(res[0]),
     113              :     );
     114              : 
     115              :     // S for SSL.
     116            0 :     Ok(res == *b"S")
     117            0 : }
     118              : 
     119           46 : pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
     120           46 : where
     121           46 :     S: AsyncRead + Unpin,
     122           46 : {
     123           46 :     let header = read!(stream => StartupHeader);
     124              : 
     125              :     // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
     126              :     // First byte indicates standard SSL handshake message
     127              :     // (It can't be a Postgres startup length because in network byte order
     128              :     // that would be a startup packet hundreds of megabytes long)
     129           46 :     if header.as_bytes()[0] == 0x16 {
     130              :         return Ok(FeStartupPacket::SslRequest {
     131              :             // The bytes we read for the header are actually part of a TLS ClientHello.
     132              :             // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
     133              :             // In practice though, I see no world where a ClientHello is less than 8 bytes
     134              :             // since it includes ephemeral keys etc.
     135            1 :             direct: Some(zerocopy::transmute!(header)),
     136              :         });
     137           45 :     }
     138              : 
     139           45 :     let Some(len) = (header.len.get() as usize).checked_sub(8) else {
     140            0 :         return Err(io::Error::other(format!(
     141            0 :             "invalid startup message length {}, must be at least 8.",
     142            0 :             header.len,
     143            0 :         )));
     144              :     };
     145              : 
     146              :     // TODO: add a histogram for startup packet lengths
     147           45 :     if len > MAX_STARTUP_PACKET_LENGTH {
     148            1 :         tracing::warn!("large startup message detected: {len} bytes");
     149            1 :         return Err(io::Error::other(format!(
     150            1 :             "invalid startup message length {len}"
     151            1 :         )));
     152           44 :     }
     153              : 
     154           23 :     match header.version {
     155              :         // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
     156              :         CANCEL_REQUEST_CODE => {
     157            0 :             if len != 8 {
     158            0 :                 return Err(io::Error::other(
     159            0 :                     "CancelRequest message is malformed, backend PID / secret key missing",
     160            0 :                 ));
     161            0 :             }
     162              : 
     163              :             Ok(FeStartupPacket::CancelRequest(
     164            0 :                 read!(stream => CancelKeyData),
     165              :             ))
     166              :         }
     167              :         // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
     168              :         NEGOTIATE_SSL_CODE => {
     169              :             // Requested upgrade to SSL (aka TLS)
     170           21 :             Ok(FeStartupPacket::SslRequest { direct: None })
     171              :         }
     172              :         NEGOTIATE_GSS_CODE => {
     173              :             // Requested upgrade to GSSAPI
     174            0 :             Ok(FeStartupPacket::GssEncRequest)
     175              :         }
     176           23 :         version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
     177            0 :             format!("Unrecognized request code {version:?}"),
     178            0 :         )),
     179              :         // StartupMessage
     180           23 :         version => {
     181              :             // The protocol version number is followed by one or more pairs of parameter name and value strings.
     182              :             // A zero byte is required as a terminator after the last name/value pair.
     183              :             // Parameters can appear in any order. user is required, others are optional.
     184              : 
     185           23 :             let mut buf = vec![0; len];
     186           23 :             stream.read_exact(&mut buf).await?;
     187              : 
     188           23 :             if buf.pop() != Some(b'\0') {
     189            0 :                 return Err(io::Error::other(
     190            0 :                     "StartupMessage params: missing null terminator",
     191            0 :                 ));
     192           23 :             }
     193              : 
     194              :             // TODO: Don't do this.
     195              :             // There's no guarantee that these messages are utf8,
     196              :             // but they usually happen to be simple ascii.
     197           23 :             let params = String::from_utf8(buf)
     198           23 :                 .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
     199              : 
     200           23 :             Ok(FeStartupPacket::StartupMessage {
     201           23 :                 version,
     202           23 :                 params: StartupMessageParams { params },
     203           23 :             })
     204              :         }
     205              :     }
     206           46 : }
     207              : 
     208              : /// Read a raw postgres packet, which will respect the max length requested.
     209              : ///
     210              : /// This returns the message tag, as well as the message body. The message
     211              : /// body is written into `buf`, and it is otherwise completely overwritten.
     212              : ///
     213              : /// This is not cancel safe.
     214           29 : pub async fn read_message<'a, S>(
     215           29 :     stream: &mut S,
     216           29 :     buf: &'a mut Vec<u8>,
     217           29 :     max: u32,
     218           29 : ) -> io::Result<(u8, &'a mut [u8])>
     219           29 : where
     220           29 :     S: AsyncRead + Unpin,
     221           29 : {
     222              :     /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
     223              :     /// The first byte is a message tag, and the next 4 bytes is a big-endian length.
     224              :     ///
     225              :     /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
     226              :     /// an empty message will always have length 4.
     227              :     #[derive(Clone, Copy, FromBytes)]
     228              :     #[repr(C)]
     229              :     struct Header {
     230              :         tag: u8,
     231              :         len: big_endian::U32,
     232              :     }
     233              : 
     234           29 :     let header = read!(stream => Header);
     235              : 
     236              :     // as described above, the length must be at least 4.
     237           28 :     let Some(len) = header.len.get().checked_sub(4) else {
     238            0 :         return Err(io::Error::other(format!(
     239            0 :             "invalid startup message length {}, must be at least 4.",
     240            0 :             header.len,
     241            0 :         )));
     242              :     };
     243              : 
     244              :     // TODO: add a histogram for message lengths
     245              : 
     246              :     // check if the message exceeds our desired max.
     247           28 :     if len > max {
     248            1 :         tracing::warn!("large postgres message detected: {len} bytes");
     249            1 :         return Err(io::Error::other(format!("invalid message length {len}")));
     250           27 :     }
     251              : 
     252              :     // read in our entire message.
     253           27 :     buf.resize(len as usize, 0);
     254           27 :     stream.read_exact(buf).await?;
     255              : 
     256           27 :     Ok((header.tag, buf))
     257           29 : }
     258              : 
     259              : pub struct WriteBuf(Cursor<Vec<u8>>);
     260              : 
     261              : impl Buf for WriteBuf {
     262              :     #[inline]
     263          117 :     fn remaining(&self) -> usize {
     264          117 :         self.0.remaining()
     265          117 :     }
     266              : 
     267              :     #[inline]
     268           55 :     fn chunk(&self) -> &[u8] {
     269           55 :         self.0.chunk()
     270           55 :     }
     271              : 
     272              :     #[inline]
     273           55 :     fn advance(&mut self, cnt: usize) {
     274           55 :         self.0.advance(cnt);
     275           55 :     }
     276              : }
     277              : 
     278              : impl WriteBuf {
     279           45 :     pub const fn new() -> Self {
     280           45 :         Self(Cursor::new(Vec::new()))
     281           45 :     }
     282              : 
     283              :     /// Use a heuristic to determine if we should shrink the write buffer.
     284              :     #[inline]
     285           57 :     fn should_shrink(&self) -> bool {
     286           57 :         let n = self.0.position() as usize;
     287           57 :         let len = self.0.get_ref().len();
     288              : 
     289              :         // the unused space at the front of our buffer is 2x the size of our filled portion.
     290           57 :         n + n > len
     291           57 :     }
     292              : 
     293              :     /// Shrink the write buffer so that subsequent writes have more spare capacity.
     294              :     #[cold]
     295            1 :     fn shrink(&mut self) {
     296            1 :         let n = self.0.position() as usize;
     297            1 :         let buf = self.0.get_mut();
     298              : 
     299              :         // buf repr:
     300              :         // [----unused------|-----filled-----|-----uninit-----]
     301              :         //                  ^ n              ^ buf.len()      ^ buf.capacity()
     302            1 :         let filled = n..buf.len();
     303            1 :         let filled_len = filled.len();
     304            1 :         buf.copy_within(filled, 0);
     305            1 :         buf.truncate(filled_len);
     306            1 :         self.0.set_position(0);
     307            1 :     }
     308              : 
     309              :     /// clear the write buffer.
     310           62 :     pub fn reset(&mut self) {
     311           62 :         let buf = self.0.get_mut();
     312           62 :         buf.clear();
     313           62 :         self.0.set_position(0);
     314           62 :     }
     315              : 
     316              :     /// Write a raw message to the internal buffer.
     317              :     ///
     318              :     /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
     319              :     /// we calculate the length after the fact.
     320           57 :     pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
     321           57 :         if self.should_shrink() {
     322            0 :             self.shrink();
     323           57 :         }
     324              : 
     325           57 :         let buf = self.0.get_mut();
     326           57 :         buf.reserve(5 + size_hint);
     327              : 
     328           57 :         buf.push(tag);
     329           57 :         let start = buf.len();
     330           57 :         buf.extend_from_slice(&[0, 0, 0, 0]);
     331              : 
     332           57 :         f(buf);
     333              : 
     334           57 :         let end = buf.len();
     335           57 :         let len = (end - start) as u32;
     336           57 :         buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
     337           57 :     }
     338              : 
     339              :     /// Write an encryption response message.
     340           20 :     pub fn encryption(&mut self, m: u8) {
     341           20 :         self.0.get_mut().push(m);
     342           20 :     }
     343              : 
     344            1 :     pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
     345            1 :         self.shrink();
     346              : 
     347              :         // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
     348              :         // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
     349              :         // "SERROR\0CXXXXX\0M\0\0".len() == 17
     350            1 :         self.write_raw(17 + msg.len(), b'E', |buf| {
     351              :             // Severity: ERROR
     352            1 :             buf.put_slice(b"SERROR\0");
     353              : 
     354              :             // Code: error_code
     355            1 :             buf.put_u8(b'C');
     356            1 :             buf.put_slice(&error_code);
     357            1 :             buf.put_u8(0);
     358              : 
     359              :             // Message: msg
     360            1 :             buf.put_u8(b'M');
     361            1 :             buf.put_slice(msg.as_bytes());
     362            1 :             buf.put_u8(0);
     363              : 
     364              :             // End.
     365            1 :             buf.put_u8(0);
     366            1 :         });
     367            1 :     }
     368              : }
     369              : 
     370              : #[derive(Debug)]
     371              : pub enum FeStartupPacket {
     372              :     CancelRequest(CancelKeyData),
     373              :     SslRequest {
     374              :         direct: Option<[u8; 8]>,
     375              :     },
     376              :     GssEncRequest,
     377              :     StartupMessage {
     378              :         version: ProtocolVersion,
     379              :         params: StartupMessageParams,
     380              :     },
     381              : }
     382              : 
     383              : #[derive(Debug, Clone, Default)]
     384              : pub struct StartupMessageParams {
     385              :     pub params: String,
     386              : }
     387              : 
     388              : impl StartupMessageParams {
     389              :     /// Get parameter's value by its name.
     390           41 :     pub fn get(&self, name: &str) -> Option<&str> {
     391           60 :         self.iter().find_map(|(k, v)| (k == name).then_some(v))
     392           41 :     }
     393              : 
     394              :     /// Split command-line options according to PostgreSQL's logic,
     395              :     /// taking into account all escape sequences but leaving them as-is.
     396              :     /// [`None`] means that there's no `options` in [`Self`].
     397           27 :     pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
     398           27 :         self.get("options").map(Self::parse_options_raw)
     399           27 :     }
     400              : 
     401              :     /// Split command-line options according to PostgreSQL's logic,
     402              :     /// taking into account all escape sequences but leaving them as-is.
     403           34 :     pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
     404              :         // See `postgres: pg_split_opts`.
     405           34 :         let mut last_was_escape = false;
     406           34 :         input
     407          608 :             .split(move |c: char| {
     408              :                 // We split by non-escaped whitespace symbols.
     409          608 :                 let should_split = c.is_ascii_whitespace() && !last_was_escape;
     410          608 :                 last_was_escape = c == '\\' && !last_was_escape;
     411          608 :                 should_split
     412          608 :             })
     413           74 :             .filter(|s| !s.is_empty())
     414           34 :     }
     415              : 
     416              :     /// Iterate through key-value pairs in an arbitrary order.
     417           41 :     pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
     418           41 :         self.params.split_terminator('\0').tuples()
     419           41 :     }
     420              : 
     421              :     // This function is mostly useful in tests.
     422              :     #[cfg(test)]
     423           13 :     pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
     424           13 :         let mut b = Self {
     425           13 :             params: String::new(),
     426           13 :         };
     427           36 :         for (k, v) in pairs {
     428           23 :             b.insert(k, v);
     429           23 :         }
     430           13 :         b
     431           13 :     }
     432              : 
     433              :     /// Set parameter's value by its name.
     434              :     /// name and value must not contain a \0 byte
     435           23 :     pub fn insert(&mut self, name: &str, value: &str) {
     436           23 :         self.params.reserve(name.len() + value.len() + 2);
     437           23 :         self.params.push_str(name);
     438           23 :         self.params.push('\0');
     439           23 :         self.params.push_str(value);
     440           23 :         self.params.push('\0');
     441           23 :     }
     442              : }
     443              : 
     444              : /// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
     445              : /// opaque bytes.
     446              : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
     447              : pub struct CancelKeyData(pub big_endian::U64);
     448              : 
     449            1 : pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
     450            1 :     CancelKeyData(big_endian::U64::new(id))
     451            1 : }
     452              : 
     453              : impl fmt::Display for CancelKeyData {
     454            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     455            0 :         let id = self.0;
     456            0 :         f.debug_tuple("CancelKeyData")
     457            0 :             .field(&format_args!("{id:x}"))
     458            0 :             .finish()
     459            0 :     }
     460              : }
     461              : impl Distribution<CancelKeyData> for Standard {
     462            0 :     fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
     463            0 :         id_to_cancel_key(rng.r#gen())
     464            0 :     }
     465              : }
     466              : 
     467              : pub enum BeMessage<'a> {
     468              :     AuthenticationOk,
     469              :     AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
     470              :     AuthenticationCleartextPassword,
     471              :     BackendKeyData(CancelKeyData),
     472              :     ParameterStatus {
     473              :         name: &'a [u8],
     474              :         value: &'a [u8],
     475              :     },
     476              :     ReadyForQuery,
     477              :     NoticeResponse(&'a str),
     478              :     NegotiateProtocolVersion {
     479              :         version: ProtocolVersion,
     480              :         options: &'a [&'a str],
     481              :     },
     482              : }
     483              : 
     484              : #[derive(Debug)]
     485              : pub enum BeAuthenticationSaslMessage<'a> {
     486              :     Methods(&'a [&'a str]),
     487              :     Continue(&'a [u8]),
     488              :     Final(&'a [u8]),
     489              : }
     490              : 
     491              : impl BeMessage<'_> {
     492              :     /// Write the message into an internal buffer
     493           56 :     pub fn write_message(self, buf: &mut WriteBuf) {
     494           30 :         match self {
     495              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
     496              :             BeMessage::AuthenticationOk => {
     497           10 :                 buf.write_raw(1, b'R', |buf| buf.put_i32(0));
     498              :             }
     499              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
     500              :             BeMessage::AuthenticationCleartextPassword => {
     501            2 :                 buf.write_raw(1, b'R', |buf| buf.put_i32(3));
     502              :             }
     503              : 
     504              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
     505           13 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
     506           25 :                 let len: usize = methods.iter().map(|m| m.len() + 1).sum();
     507           13 :                 buf.write_raw(len + 2, b'R', |buf| {
     508           13 :                     buf.put_i32(10); // Specifies that SASL auth method is used.
     509           38 :                     for method in methods {
     510           25 :                         buf.put_slice(method.as_bytes());
     511           25 :                         buf.put_u8(0);
     512           25 :                     }
     513           13 :                     buf.put_u8(0); // zero terminator for the list
     514           13 :                 });
     515              :             }
     516              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
     517           11 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
     518           11 :                 buf.write_raw(extra.len() + 1, b'R', |buf| {
     519           11 :                     buf.put_i32(11); // Continue SASL auth.
     520           11 :                     buf.put_slice(extra);
     521           11 :                 });
     522              :             }
     523              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
     524            6 :             BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
     525            6 :                 buf.write_raw(extra.len() + 1, b'R', |buf| {
     526            6 :                     buf.put_i32(12); // Send final SASL message.
     527            6 :                     buf.put_slice(extra);
     528            6 :                 });
     529              :             }
     530              : 
     531              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
     532            0 :             BeMessage::BackendKeyData(key_data) => {
     533            0 :                 buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
     534              :             }
     535              : 
     536              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
     537              :             // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
     538            0 :             BeMessage::NoticeResponse(msg) => {
     539              :                 // 'N' signalizes NoticeResponse messages
     540            0 :                 buf.write_raw(18 + msg.len(), b'N', |buf| {
     541              :                     // Severity: NOTICE
     542            0 :                     buf.put_slice(b"SNOTICE\0");
     543              : 
     544              :                     // Code: XX000 (ignored for notice, but still required)
     545            0 :                     buf.put_slice(b"CXX000\0");
     546              : 
     547              :                     // Message: msg
     548            0 :                     buf.put_u8(b'M');
     549            0 :                     buf.put_slice(msg.as_bytes());
     550            0 :                     buf.put_u8(0);
     551              : 
     552              :                     // End notice.
     553            0 :                     buf.put_u8(0);
     554            0 :                 });
     555              :             }
     556              : 
     557              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
     558            7 :             BeMessage::ParameterStatus { name, value } => {
     559            7 :                 buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
     560            7 :                     buf.put_slice(name.as_bytes());
     561            7 :                     buf.put_u8(0);
     562            7 :                     buf.put_slice(value.as_bytes());
     563            7 :                     buf.put_u8(0);
     564            7 :                 });
     565              :             }
     566              : 
     567              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
     568              :             BeMessage::ReadyForQuery => {
     569            7 :                 buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
     570              :             }
     571              : 
     572              :             // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
     573            0 :             BeMessage::NegotiateProtocolVersion { version, options } => {
     574            0 :                 let len: usize = options.iter().map(|o| o.len() + 1).sum();
     575            0 :                 buf.write_raw(8 + len, b'v', |buf| {
     576            0 :                     buf.put_slice(version.as_bytes());
     577            0 :                     buf.put_u32(options.len() as u32);
     578            0 :                     for option in options {
     579            0 :                         buf.put_slice(option.as_bytes());
     580            0 :                         buf.put_u8(0);
     581            0 :                     }
     582            0 :                 });
     583              :             }
     584              :         }
     585           56 :     }
     586              : }
     587              : 
     588              : #[cfg(test)]
     589              : mod tests {
     590              :     use std::io::Cursor;
     591              : 
     592              :     use tokio::io::{AsyncWriteExt, duplex};
     593              :     use zerocopy::IntoBytes;
     594              : 
     595              :     use super::ProtocolVersion;
     596              :     use crate::pqproto::{FeStartupPacket, read_message, read_startup};
     597              : 
     598              :     #[tokio::test]
     599            1 :     async fn reject_large_startup() {
     600              :         // we're going to define a v3.0 startup message with far too many parameters.
     601            1 :         let mut payload = vec![];
     602              :         // 10001 + 8 bytes.
     603            1 :         payload.extend_from_slice(&10009_u32.to_be_bytes());
     604            1 :         payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
     605            1 :         payload.resize(10009, b'a');
     606              : 
     607            1 :         let (mut server, mut client) = duplex(128);
     608              :         #[rustfmt::skip]
     609            1 :         let (server, client) = tokio::join!(
     610            1 :             async move { read_startup(&mut server).await.unwrap_err() },
     611            1 :             async move { client.write_all(&payload).await.unwrap_err() },
     612              :         );
     613              : 
     614            1 :         assert_eq!(server.to_string(), "invalid startup message length 10001");
     615            1 :         assert_eq!(client.to_string(), "broken pipe");
     616            1 :     }
     617              : 
     618              :     #[tokio::test]
     619            1 :     async fn reject_large_password() {
     620              :         // we're going to define a password message that is far too long.
     621            1 :         let mut payload = vec![];
     622            1 :         payload.push(b'p');
     623            1 :         payload.extend_from_slice(&517_u32.to_be_bytes());
     624            1 :         payload.resize(518, b'a');
     625              : 
     626            1 :         let (mut server, mut client) = duplex(128);
     627              :         #[rustfmt::skip]
     628            1 :         let (server, client) = tokio::join!(
     629            1 :             async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
     630            1 :             async move { client.write_all(&payload).await.unwrap_err() },
     631              :         );
     632              : 
     633            1 :         assert_eq!(server.to_string(), "invalid message length 513");
     634            1 :         assert_eq!(client.to_string(), "broken pipe");
     635            1 :     }
     636              : 
     637              :     #[tokio::test]
     638            1 :     async fn read_startup_message() {
     639            1 :         let mut payload = vec![];
     640            1 :         payload.extend_from_slice(&17_u32.to_be_bytes());
     641            1 :         payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
     642            1 :         payload.extend_from_slice(b"abc\0def\0\0");
     643              : 
     644            1 :         let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
     645            1 :         let FeStartupPacket::StartupMessage { version, params } = startup else {
     646            0 :             panic!("unexpected startup message: {startup:?}");
     647              :         };
     648              : 
     649            1 :         assert_eq!(version.major(), 3);
     650            1 :         assert_eq!(version.minor(), 0);
     651            1 :         assert_eq!(params.params, "abc\0def\0");
     652            1 :     }
     653              : 
     654              :     #[tokio::test]
     655            1 :     async fn read_ssl_message() {
     656            1 :         let mut payload = vec![];
     657            1 :         payload.extend_from_slice(&8_u32.to_be_bytes());
     658            1 :         payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
     659              : 
     660            1 :         let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
     661            1 :         let FeStartupPacket::SslRequest { direct: None } = startup else {
     662            1 :             panic!("unexpected startup message: {startup:?}");
     663            1 :         };
     664            1 :     }
     665              : 
     666              :     #[tokio::test]
     667            1 :     async fn read_tls_message() {
     668              :         // sample client hello taken from <https://tls13.xargs.org/#client-hello>
     669            1 :         let client_hello = [
     670            1 :             0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
     671            1 :             0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
     672            1 :             0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
     673            1 :             0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
     674            1 :             0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
     675            1 :             0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
     676            1 :             0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
     677            1 :             0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
     678            1 :             0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
     679            1 :             0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
     680            1 :             0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
     681            1 :             0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
     682            1 :             0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
     683            1 :             0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
     684            1 :             0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
     685            1 :             0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
     686            1 :             0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
     687            1 :             0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
     688            1 :             0x54,
     689            1 :         ];
     690              : 
     691            1 :         let mut cursor = Cursor::new(&client_hello);
     692              : 
     693            1 :         let startup = read_startup(&mut cursor).await.unwrap();
     694              :         let FeStartupPacket::SslRequest {
     695            1 :             direct: Some(prefix),
     696            1 :         } = startup
     697              :         else {
     698            0 :             panic!("unexpected startup message: {startup:?}");
     699              :         };
     700              : 
     701              :         // check that no data is lost.
     702            1 :         assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
     703            1 :         assert_eq!(cursor.position(), 8);
     704            1 :     }
     705              : 
     706              :     #[tokio::test]
     707            1 :     async fn read_message_success() {
     708            1 :         let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
     709            1 :         let mut cursor = Cursor::new(&query);
     710              : 
     711            1 :         let mut buf = vec![];
     712            1 :         let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
     713            1 :         assert_eq!(tag, b'Q');
     714            1 :         assert_eq!(message, b"SELECT 1");
     715              : 
     716            1 :         let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
     717            1 :         assert_eq!(tag, b'Q');
     718            1 :         assert_eq!(message, b"SELECT 2");
     719            1 :     }
     720              : }
        

Generated by: LCOV version 2.1-beta