LCOV - code coverage report
Current view: top level - libs/proxy/postgres-protocol2/src/message - frontend.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 29.3 % 222 65
Test Date: 2025-02-20 13:11:02 Functions: 16.8 % 101 17

            Line data    Source code
       1              : //! Frontend message serialization.
       2              : #![allow(missing_docs)]
       3              : 
       4              : use byteorder::{BigEndian, ByteOrder};
       5              : use bytes::{Buf, BufMut, BytesMut};
       6              : use std::error::Error;
       7              : use std::io;
       8              : use std::marker;
       9              : 
      10              : use crate::{write_nullable, FromUsize, IsNull, Oid};
      11              : 
      12              : #[inline]
      13           69 : fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
      14           69 : where
      15           69 :     F: FnOnce(&mut BytesMut) -> Result<(), E>,
      16           69 :     E: From<io::Error>,
      17           69 : {
      18           69 :     let base = buf.len();
      19           69 :     buf.extend_from_slice(&[0; 4]);
      20           69 : 
      21           69 :     f(buf)?;
      22              : 
      23           69 :     let size = i32::from_usize(buf.len() - base)?;
      24           69 :     BigEndian::write_i32(&mut buf[base..], size);
      25           69 :     Ok(())
      26           69 : }
      27              : 
      28              : pub enum BindError {
      29              :     Conversion(Box<dyn Error + marker::Sync + Send>),
      30              :     Serialization(io::Error),
      31              : }
      32              : 
      33              : impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
      34              :     #[inline]
      35            0 :     fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
      36            0 :         BindError::Conversion(e)
      37            0 :     }
      38              : }
      39              : 
      40              : impl From<io::Error> for BindError {
      41              :     #[inline]
      42            0 :     fn from(e: io::Error) -> BindError {
      43            0 :         BindError::Serialization(e)
      44            0 :     }
      45              : }
      46              : 
      47              : #[inline]
      48            0 : pub fn bind<I, J, F, T, K>(
      49            0 :     portal: &str,
      50            0 :     statement: &str,
      51            0 :     formats: I,
      52            0 :     values: J,
      53            0 :     mut serializer: F,
      54            0 :     result_formats: K,
      55            0 :     buf: &mut BytesMut,
      56            0 : ) -> Result<(), BindError>
      57            0 : where
      58            0 :     I: IntoIterator<Item = i16>,
      59            0 :     J: IntoIterator<Item = T>,
      60            0 :     F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
      61            0 :     K: IntoIterator<Item = i16>,
      62            0 : {
      63            0 :     buf.put_u8(b'B');
      64            0 : 
      65            0 :     write_body(buf, |buf| {
      66            0 :         write_cstr(portal.as_bytes(), buf)?;
      67            0 :         write_cstr(statement.as_bytes(), buf)?;
      68            0 :         write_counted(
      69            0 :             formats,
      70            0 :             |f, buf| {
      71            0 :                 buf.put_i16(f);
      72            0 :                 Ok::<_, io::Error>(())
      73            0 :             },
      74            0 :             buf,
      75            0 :         )?;
      76            0 :         write_counted(
      77            0 :             values,
      78            0 :             |v, buf| write_nullable(|buf| serializer(v, buf), buf),
      79            0 :             buf,
      80            0 :         )?;
      81            0 :         write_counted(
      82            0 :             result_formats,
      83            0 :             |f, buf| {
      84            0 :                 buf.put_i16(f);
      85            0 :                 Ok::<_, io::Error>(())
      86            0 :             },
      87            0 :             buf,
      88            0 :         )?;
      89              : 
      90            0 :         Ok(())
      91            0 :     })
      92            0 : }
      93              : 
      94              : #[inline]
      95            0 : fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
      96            0 : where
      97            0 :     I: IntoIterator<Item = T>,
      98            0 :     F: FnMut(T, &mut BytesMut) -> Result<(), E>,
      99            0 :     E: From<io::Error>,
     100            0 : {
     101            0 :     let base = buf.len();
     102            0 :     buf.extend_from_slice(&[0; 2]);
     103            0 :     let mut count = 0;
     104            0 :     for item in items {
     105            0 :         serializer(item, buf)?;
     106            0 :         count += 1;
     107              :     }
     108            0 :     let count = i16::from_usize(count)?;
     109            0 :     BigEndian::write_i16(&mut buf[base..], count);
     110            0 : 
     111            0 :     Ok(())
     112            0 : }
     113              : 
     114              : #[inline]
     115            0 : pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
     116            0 :     write_body(buf, |buf| {
     117            0 :         buf.put_i32(80_877_102);
     118            0 :         buf.put_i32(process_id);
     119            0 :         buf.put_i32(secret_key);
     120            0 :         Ok::<_, io::Error>(())
     121            0 :     })
     122            0 :     .unwrap();
     123            0 : }
     124              : 
     125              : #[inline]
     126            0 : pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
     127            0 :     buf.put_u8(b'C');
     128            0 :     write_body(buf, |buf| {
     129            0 :         buf.put_u8(variant);
     130            0 :         write_cstr(name.as_bytes(), buf)
     131            0 :     })
     132            0 : }
     133              : 
     134              : pub struct CopyData<T> {
     135              :     buf: T,
     136              :     len: i32,
     137              : }
     138              : 
     139              : impl<T> CopyData<T>
     140              : where
     141              :     T: Buf,
     142              : {
     143            0 :     pub fn new(buf: T) -> io::Result<CopyData<T>> {
     144            0 :         let len = buf
     145            0 :             .remaining()
     146            0 :             .checked_add(4)
     147            0 :             .and_then(|l| i32::try_from(l).ok())
     148            0 :             .ok_or_else(|| {
     149            0 :                 io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
     150            0 :             })?;
     151              : 
     152            0 :         Ok(CopyData { buf, len })
     153            0 :     }
     154              : 
     155            0 :     pub fn write(self, out: &mut BytesMut) {
     156            0 :         out.put_u8(b'd');
     157            0 :         out.put_i32(self.len);
     158            0 :         out.put(self.buf);
     159            0 :     }
     160              : }
     161              : 
     162              : #[inline]
     163            0 : pub fn copy_done(buf: &mut BytesMut) {
     164            0 :     buf.put_u8(b'c');
     165            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     166            0 : }
     167              : 
     168              : #[inline]
     169            0 : pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
     170            0 :     buf.put_u8(b'f');
     171            0 :     write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
     172            0 : }
     173              : 
     174              : #[inline]
     175            0 : pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
     176            0 :     buf.put_u8(b'D');
     177            0 :     write_body(buf, |buf| {
     178            0 :         buf.put_u8(variant);
     179            0 :         write_cstr(name.as_bytes(), buf)
     180            0 :     })
     181            0 : }
     182              : 
     183              : #[inline]
     184            0 : pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
     185            0 :     buf.put_u8(b'E');
     186            0 :     write_body(buf, |buf| {
     187            0 :         write_cstr(portal.as_bytes(), buf)?;
     188            0 :         buf.put_i32(max_rows);
     189            0 :         Ok(())
     190            0 :     })
     191            0 : }
     192              : 
     193              : #[inline]
     194            0 : pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
     195            0 : where
     196            0 :     I: IntoIterator<Item = Oid>,
     197            0 : {
     198            0 :     buf.put_u8(b'P');
     199            0 :     write_body(buf, |buf| {
     200            0 :         write_cstr(name.as_bytes(), buf)?;
     201            0 :         write_cstr(query.as_bytes(), buf)?;
     202            0 :         write_counted(
     203            0 :             param_types,
     204            0 :             |t, buf| {
     205            0 :                 buf.put_u32(t);
     206            0 :                 Ok::<_, io::Error>(())
     207            0 :             },
     208            0 :             buf,
     209            0 :         )?;
     210            0 :         Ok(())
     211            0 :     })
     212            0 : }
     213              : 
     214              : #[inline]
     215            2 : pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     216            2 :     buf.put_u8(b'p');
     217            2 :     write_body(buf, |buf| write_cstr(password, buf))
     218            2 : }
     219              : 
     220              : #[inline]
     221            0 : pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
     222            0 :     buf.put_u8(b'Q');
     223            0 :     write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
     224            0 : }
     225              : 
     226              : #[inline]
     227           14 : pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     228           14 :     buf.put_u8(b'p');
     229           14 :     write_body(buf, |buf| {
     230           14 :         write_cstr(mechanism.as_bytes(), buf)?;
     231           14 :         let len = i32::from_usize(data.len())?;
     232           14 :         buf.put_i32(len);
     233           14 :         buf.put_slice(data);
     234           14 :         Ok(())
     235           14 :     })
     236           14 : }
     237              : 
     238              : #[inline]
     239           11 : pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
     240           11 :     buf.put_u8(b'p');
     241           11 :     write_body(buf, |buf| {
     242           11 :         buf.put_slice(data);
     243           11 :         Ok(())
     244           11 :     })
     245           11 : }
     246              : 
     247              : #[inline]
     248           20 : pub fn ssl_request(buf: &mut BytesMut) {
     249           20 :     write_body(buf, |buf| {
     250           20 :         buf.put_i32(80_877_103);
     251           20 :         Ok::<_, io::Error>(())
     252           20 :     })
     253           20 :     .unwrap();
     254           20 : }
     255              : 
     256              : #[inline]
     257           22 : pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
     258           22 :     write_body(buf, |buf| {
     259           22 :         // postgres protocol version 3.0(196608) in bigger-endian
     260           22 :         buf.put_i32(0x00_03_00_00);
     261           22 :         buf.put_slice(&parameters.params);
     262           22 :         buf.put_u8(0);
     263           22 :         Ok(())
     264           22 :     })
     265           22 : }
     266              : 
     267              : #[derive(Debug, Clone, Default, PartialEq, Eq)]
     268              : pub struct StartupMessageParams {
     269              :     pub params: BytesMut,
     270              : }
     271              : 
     272              : impl StartupMessageParams {
     273              :     /// Set parameter's value by its name.
     274           31 :     pub fn insert(&mut self, name: &str, value: &str) {
     275           31 :         if name.contains('\0') || value.contains('\0') {
     276            0 :             panic!("startup parameter name or value contained a null")
     277           31 :         }
     278           31 :         self.params.put_slice(name.as_bytes());
     279           31 :         self.params.put_u8(0);
     280           31 :         self.params.put_slice(value.as_bytes());
     281           31 :         self.params.put_u8(0);
     282           31 :     }
     283              : }
     284              : 
     285              : #[inline]
     286            0 : pub fn sync(buf: &mut BytesMut) {
     287            0 :     buf.put_u8(b'S');
     288            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     289            0 : }
     290              : 
     291              : #[inline]
     292            0 : pub fn terminate(buf: &mut BytesMut) {
     293            0 :     buf.put_u8(b'X');
     294            0 :     write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
     295            0 : }
     296              : 
     297              : #[inline]
     298           16 : fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
     299           16 :     if s.contains(&0) {
     300            0 :         return Err(io::Error::new(
     301            0 :             io::ErrorKind::InvalidInput,
     302            0 :             "string contains embedded null",
     303            0 :         ));
     304           16 :     }
     305           16 :     buf.put_slice(s);
     306           16 :     buf.put_u8(0);
     307           16 :     Ok(())
     308           16 : }
        

Generated by: LCOV version 2.1-beta