LCOV - code coverage report
Current view: top level - libs/proxy/postgres-protocol2/src/message - frontend.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 26.1 % 218 57
Test Date: 2025-07-16 12:29:03 Functions: 18.3 % 93 17

            Line data    Source code
       1              : //! Frontend message serialization.
       2              : #![allow(missing_docs)]
       3              : 
       4              : use std::error::Error;
       5              : use std::{io, marker};
       6              : 
       7              : use byteorder::{BigEndian, ByteOrder};
       8              : use bytes::{Buf, BufMut, BytesMut};
       9              : 
      10              : use crate::{FromUsize, IsNull, Oid, write_nullable};
      11              : 
      12              : #[inline]
      13           69 : fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
      14           69 : where
      15           69 :     F: FnOnce(&mut BytesMut) -> Result<(), E>,
      16           69 :     E: From<io::Error>,
      17              : {
      18           69 :     let base = buf.len();
      19           69 :     buf.extend_from_slice(&[0; 4]);
      20              : 
      21           69 :     f(buf)?;
      22              : 
      23           69 :     let size = i32::from_usize(buf.len() - base)?;
      24           69 :     BigEndian::write_i32(&mut buf[base..], size);
      25           69 :     Ok(())
      26           69 : }
      27              : 
      28              : #[derive(Debug)]
      29              : pub enum BindError {
      30              :     Conversion(Box<dyn Error + marker::Sync + Send>),
      31              :     Serialization(io::Error),
      32              : }
      33              : 
      34              : impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
      35              :     #[inline]
      36            0 :     fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
      37            0 :         BindError::Conversion(e)
      38            0 :     }
      39              : }
      40              : 
      41              : impl From<io::Error> for BindError {
      42              :     #[inline]
      43            0 :     fn from(e: io::Error) -> BindError {
      44            0 :         BindError::Serialization(e)
      45            0 :     }
      46              : }
      47              : 
      48              : #[inline]
      49            0 : pub fn bind<I, J, F, T, K>(
      50            0 :     portal: &str,
      51            0 :     statement: &str,
      52            0 :     formats: I,
      53            0 :     values: J,
      54            0 :     mut serializer: F,
      55            0 :     result_formats: K,
      56            0 :     buf: &mut BytesMut,
      57            0 : ) -> Result<(), BindError>
      58            0 : where
      59            0 :     I: IntoIterator<Item = i16>,
      60            0 :     J: IntoIterator<Item = T>,
      61            0 :     F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
      62            0 :     K: IntoIterator<Item = i16>,
      63              : {
      64            0 :     buf.put_u8(b'B');
      65              : 
      66            0 :     write_body(buf, |buf| {
      67            0 :         write_cstr(portal.as_bytes(), buf)?;
      68            0 :         write_cstr(statement.as_bytes(), buf)?;
      69            0 :         write_counted(
      70            0 :             formats,
      71            0 :             |f, buf| {
      72            0 :                 buf.put_i16(f);
      73            0 :                 Ok::<_, io::Error>(())
      74            0 :             },
      75            0 :             buf,
      76            0 :         )?;
      77            0 :         write_counted(
      78            0 :             values,
      79            0 :             |v, buf| write_nullable(|buf| serializer(v, buf), buf),
      80            0 :             buf,
      81            0 :         )?;
      82            0 :         write_counted(
      83            0 :             result_formats,
      84            0 :             |f, buf| {
      85            0 :                 buf.put_i16(f);
      86            0 :                 Ok::<_, io::Error>(())
      87            0 :             },
      88            0 :             buf,
      89            0 :         )?;
      90              : 
      91            0 :         Ok(())
      92            0 :     })
      93            0 : }
      94              : 
      95              : #[inline]
      96            0 : fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
      97            0 : where
      98            0 :     I: IntoIterator<Item = T>,
      99            0 :     F: FnMut(T, &mut BytesMut) -> Result<(), E>,
     100            0 :     E: From<io::Error>,
     101              : {
     102            0 :     let base = buf.len();
     103            0 :     buf.extend_from_slice(&[0; 2]);
     104            0 :     let mut count = 0;
     105            0 :     for item in items {
     106            0 :         serializer(item, buf)?;
     107            0 :         count += 1;
     108              :     }
     109            0 :     let count = i16::from_usize(count)?;
     110            0 :     BigEndian::write_i16(&mut buf[base..], count);
     111              : 
     112            0 :     Ok(())
     113            0 : }
     114              : 
     115              : #[inline]
     116            0 : pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
     117            0 :     write_body(buf, |buf| {
     118            0 :         buf.put_i32(80_877_102);
     119            0 :         buf.put_i32(process_id);
     120            0 :         buf.put_i32(secret_key);
     121            0 :         Ok::<_, io::Error>(())
     122            0 :     })
     123            0 :     .unwrap();
     124            0 : }
     125              : 
     126              : #[inline]
     127            0 : pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
     128            0 :     buf.put_u8(b'C');
     129            0 :     write_body(buf, |buf| {
     130            0 :         buf.put_u8(variant);
     131            0 :         write_cstr(name.as_bytes(), buf)
     132            0 :     })
     133            0 : }
     134              : 
     135              : pub struct CopyData<T> {
     136              :     buf: T,
     137              :     len: i32,
     138              : }
     139              : 
     140              : impl<T> CopyData<T>
     141              : where
     142              :     T: Buf,
     143              : {
     144            0 :     pub fn new(buf: T) -> io::Result<CopyData<T>> {
     145            0 :         let len = buf
     146            0 :             .remaining()
     147            0 :             .checked_add(4)
     148            0 :             .and_then(|l| i32::try_from(l).ok())
     149            0 :             .ok_or_else(|| {
     150            0 :                 io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
     151            0 :             })?;
     152              : 
     153            0 :         Ok(CopyData { buf, len })
     154            0 :     }
     155              : 
     156            0 :     pub fn write(self, out: &mut BytesMut) {
     157            0 :         out.put_u8(b'd');
     158            0 :         out.put_i32(self.len);
     159            0 :         out.put(self.buf);
     160            0 :     }
     161              : }
     162              : 
     163              : #[inline]
     164            0 : pub fn copy_done(buf: &mut BytesMut) {
     165            0 :     buf.put_u8(b'c');
     166            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     167            0 : }
     168              : 
     169              : #[inline]
     170            0 : pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
     171            0 :     buf.put_u8(b'f');
     172            0 :     write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
     173            0 : }
     174              : 
     175              : #[inline]
     176            0 : pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
     177            0 :     buf.put_u8(b'D');
     178            0 :     write_body(buf, |buf| {
     179            0 :         buf.put_u8(variant);
     180            0 :         write_cstr(name.as_bytes(), buf)
     181            0 :     })
     182            0 : }
     183              : 
     184              : #[inline]
     185            0 : pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
     186            0 :     buf.put_u8(b'E');
     187            0 :     write_body(buf, |buf| {
     188            0 :         write_cstr(portal.as_bytes(), buf)?;
     189            0 :         buf.put_i32(max_rows);
     190            0 :         Ok(())
     191            0 :     })
     192            0 : }
     193              : 
     194              : #[inline]
     195            0 : pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
     196            0 : where
     197            0 :     I: IntoIterator<Item = Oid>,
     198              : {
     199            0 :     buf.put_u8(b'P');
     200            0 :     write_body(buf, |buf| {
     201            0 :         write_cstr(name.as_bytes(), buf)?;
     202            0 :         write_cstr(query.as_bytes(), buf)?;
     203            0 :         write_counted(
     204            0 :             param_types,
     205            0 :             |t, buf| {
     206            0 :                 buf.put_u32(t);
     207            0 :                 Ok::<_, io::Error>(())
     208            0 :             },
     209            0 :             buf,
     210            0 :         )?;
     211            0 :         Ok(())
     212            0 :     })
     213            0 : }
     214              : 
     215              : #[inline]
     216            2 : pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     217            2 :     buf.put_u8(b'p');
     218            2 :     write_body(buf, |buf| write_cstr(password, buf))
     219            0 : }
     220              : 
     221              : #[inline]
     222            0 : pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
     223            0 :     buf.put_u8(b'Q');
     224            0 :     write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
     225            0 : }
     226              : 
     227              : #[inline]
     228           14 : pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     229           14 :     buf.put_u8(b'p');
     230           14 :     write_body(buf, |buf| {
     231           14 :         write_cstr(mechanism.as_bytes(), buf)?;
     232           14 :         let len = i32::from_usize(data.len())?;
     233           14 :         buf.put_i32(len);
     234           14 :         buf.put_slice(data);
     235           14 :         Ok(())
     236           14 :     })
     237            0 : }
     238              : 
     239              : #[inline]
     240           11 : pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     241           11 :     buf.put_u8(b'p');
     242           11 :     write_body(buf, |buf| {
     243           11 :         buf.put_slice(data);
     244           11 :         Ok(())
     245           11 :     })
     246            0 : }
     247              : 
     248              : #[inline]
     249           20 : pub fn ssl_request(buf: &mut BytesMut) {
     250           20 :     write_body(buf, |buf| {
     251           20 :         buf.put_i32(80_877_103);
     252           20 :         Ok::<_, io::Error>(())
     253           20 :     })
     254           20 :     .unwrap();
     255            0 : }
     256              : 
     257              : #[inline]
     258           22 : pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
     259           22 :     write_body(buf, |buf| {
     260              :         // postgres protocol version 3.0(196608) in bigger-endian
     261           22 :         buf.put_i32(0x00_03_00_00);
     262           22 :         buf.put_slice(&parameters.params);
     263           22 :         buf.put_u8(0);
     264           22 :         Ok(())
     265           22 :     })
     266            0 : }
     267              : 
     268              : #[derive(Debug, Clone, Default, PartialEq, Eq)]
     269              : pub struct StartupMessageParams {
     270              :     pub params: BytesMut,
     271              : }
     272              : 
     273              : impl StartupMessageParams {
     274              :     /// Set parameter's value by its name.
     275           31 :     pub fn insert(&mut self, name: &str, value: &str) {
     276           31 :         if name.contains('\0') || value.contains('\0') {
     277            0 :             panic!("startup parameter name or value contained a null")
     278           31 :         }
     279           31 :         self.params.put_slice(name.as_bytes());
     280           31 :         self.params.put_u8(0);
     281           31 :         self.params.put_slice(value.as_bytes());
     282           31 :         self.params.put_u8(0);
     283           31 :     }
     284              : }
     285              : 
     286              : #[inline]
     287            0 : pub fn sync(buf: &mut BytesMut) {
     288            0 :     buf.put_u8(b'S');
     289            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     290            0 : }
     291              : 
     292              : #[inline]
     293            0 : pub fn flush(buf: &mut BytesMut) {
     294            0 :     buf.put_u8(b'H');
     295            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     296            0 : }
     297              : 
     298              : #[inline]
     299            0 : pub fn terminate(buf: &mut BytesMut) {
     300            0 :     buf.put_u8(b'X');
     301            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     302            0 : }
     303              : 
     304              : #[inline]
     305           16 : fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
     306           16 :     if s.contains(&0) {
     307            0 :         return Err(io::Error::new(
     308            0 :             io::ErrorKind::InvalidInput,
     309            0 :             "string contains embedded null",
     310            0 :         ));
     311           16 :     }
     312           16 :     buf.put_slice(s);
     313           16 :     buf.put_u8(0);
     314           16 :     Ok(())
     315           16 : }
        

Generated by: LCOV version 2.1-beta