LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - prepare.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 0.0 % 168 0
Test Date: 2025-07-16 12:29:03 Functions: 0.0 % 18 0

            Line data    Source code
       1              : use bytes::BytesMut;
       2              : use fallible_iterator::FallibleIterator;
       3              : use postgres_protocol2::IsNull;
       4              : use postgres_protocol2::message::backend::{Message, RowDescriptionBody};
       5              : use postgres_protocol2::message::frontend;
       6              : use postgres_protocol2::types::oid_to_sql;
       7              : use postgres_types2::Format;
       8              : 
       9              : use crate::client::{CachedTypeInfo, PartialQuery, Responses};
      10              : use crate::types::{Kind, Oid, Type};
      11              : use crate::{Column, Error, Row, Statement};
      12              : 
      13              : pub(crate) const TYPEINFO_QUERY: &str = "\
      14              : SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
      15              : FROM pg_catalog.pg_type t
      16              : LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
      17              : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
      18              : WHERE t.oid = $1
      19              : ";
      20              : 
      21              : /// we need to make sure we close this prepared statement.
      22              : struct CloseStmt<'a, 'b> {
      23              :     client: Option<&'a mut PartialQuery<'b>>,
      24              :     name: &'static str,
      25              : }
      26              : 
      27              : impl<'a> CloseStmt<'a, '_> {
      28            0 :     fn close(mut self) -> Result<&'a mut Responses, Error> {
      29            0 :         let client = self.client.take().unwrap();
      30            0 :         client.send_with_flush(|buf| {
      31            0 :             frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
      32            0 :             Ok(())
      33            0 :         })
      34            0 :     }
      35              : }
      36              : 
      37              : impl Drop for CloseStmt<'_, '_> {
      38            0 :     fn drop(&mut self) {
      39            0 :         if let Some(client) = self.client.take() {
      40            0 :             let _ = client.send_with_flush(|buf| {
      41            0 :                 frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
      42            0 :                 Ok(())
      43            0 :             });
      44            0 :         }
      45            0 :     }
      46              : }
      47              : 
      48            0 : async fn prepare_typecheck(
      49            0 :     client: &mut PartialQuery<'_>,
      50            0 :     name: &'static str,
      51            0 :     query: &str,
      52            0 : ) -> Result<Statement, Error> {
      53            0 :     let responses = client.send_with_flush(|buf| {
      54            0 :         frontend::parse(name, query, [], buf).map_err(Error::encode)?;
      55            0 :         frontend::describe(b'S', name, buf).map_err(Error::encode)?;
      56            0 :         Ok(())
      57            0 :     })?;
      58              : 
      59            0 :     match responses.next().await? {
      60            0 :         Message::ParseComplete => {}
      61            0 :         _ => return Err(Error::unexpected_message()),
      62              :     }
      63              : 
      64            0 :     match responses.next().await? {
      65            0 :         Message::ParameterDescription(_) => {}
      66            0 :         _ => return Err(Error::unexpected_message()),
      67              :     };
      68              : 
      69            0 :     let row_description = match responses.next().await? {
      70            0 :         Message::RowDescription(body) => Some(body),
      71            0 :         Message::NoData => None,
      72            0 :         _ => return Err(Error::unexpected_message()),
      73              :     };
      74              : 
      75            0 :     let mut columns = vec![];
      76            0 :     if let Some(row_description) = row_description {
      77            0 :         let mut it = row_description.fields();
      78            0 :         while let Some(field) = it.next().map_err(Error::parse)? {
      79            0 :             let type_ = Type::from_oid(field.type_oid()).ok_or_else(Error::unexpected_message)?;
      80            0 :             let column = Column::new(field.name().to_string(), type_, field);
      81            0 :             columns.push(column);
      82              :         }
      83            0 :     }
      84              : 
      85            0 :     Ok(Statement::new(name, columns))
      86            0 : }
      87              : 
      88            0 : fn try_from_cache(typecache: &CachedTypeInfo, oid: Oid) -> Option<Type> {
      89            0 :     if let Some(type_) = Type::from_oid(oid) {
      90            0 :         return Some(type_);
      91            0 :     }
      92              : 
      93            0 :     if let Some(type_) = typecache.types.get(&oid) {
      94            0 :         return Some(type_.clone());
      95            0 :     };
      96              : 
      97            0 :     None
      98            0 : }
      99              : 
     100            0 : pub async fn parse_row_description(
     101            0 :     client: &mut PartialQuery<'_>,
     102            0 :     typecache: &mut CachedTypeInfo,
     103            0 :     row_description: Option<RowDescriptionBody>,
     104            0 : ) -> Result<Vec<Column>, Error> {
     105            0 :     let mut columns = vec![];
     106              : 
     107            0 :     if let Some(row_description) = row_description {
     108            0 :         let mut it = row_description.fields();
     109            0 :         while let Some(field) = it.next().map_err(Error::parse)? {
     110            0 :             let type_ = try_from_cache(typecache, field.type_oid()).unwrap_or(Type::UNKNOWN);
     111            0 :             let column = Column::new(field.name().to_string(), type_, field);
     112            0 :             columns.push(column);
     113            0 :         }
     114            0 :     }
     115              : 
     116            0 :     let all_known = columns.iter().all(|c| c.type_ != Type::UNKNOWN);
     117            0 :     if all_known {
     118              :         // all known, return early.
     119            0 :         return Ok(columns);
     120            0 :     }
     121              : 
     122            0 :     let typeinfo = "neon_proxy_typeinfo";
     123              : 
     124              :     // make sure to close the typeinfo statement before exiting.
     125            0 :     let mut guard = CloseStmt {
     126            0 :         name: typeinfo,
     127            0 :         client: None,
     128            0 :     };
     129            0 :     let client = guard.client.insert(client);
     130              : 
     131              :     // get the typeinfo statement.
     132            0 :     let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY).await?;
     133              : 
     134            0 :     for column in &mut columns {
     135            0 :         column.type_ = get_type(client, typecache, &stmt, column.type_oid()).await?;
     136              :     }
     137              : 
     138              :     // cancel the close guard.
     139            0 :     let responses = guard.close()?;
     140              : 
     141            0 :     match responses.next().await? {
     142            0 :         Message::CloseComplete => {}
     143            0 :         _ => return Err(Error::unexpected_message()),
     144              :     }
     145              : 
     146            0 :     Ok(columns)
     147            0 : }
     148              : 
     149            0 : async fn get_type(
     150            0 :     client: &mut PartialQuery<'_>,
     151            0 :     typecache: &mut CachedTypeInfo,
     152            0 :     stmt: &Statement,
     153            0 :     mut oid: Oid,
     154            0 : ) -> Result<Type, Error> {
     155            0 :     let mut stack = vec![];
     156            0 :     let mut type_ = loop {
     157            0 :         if let Some(type_) = try_from_cache(typecache, oid) {
     158            0 :             break type_;
     159            0 :         }
     160              : 
     161            0 :         let row = exec(client, stmt, oid).await?;
     162            0 :         if stack.len() > 8 {
     163            0 :             return Err(Error::unexpected_message());
     164            0 :         }
     165              : 
     166            0 :         let name: String = row.try_get(0)?;
     167            0 :         let type_: i8 = row.try_get(1)?;
     168            0 :         let elem_oid: Oid = row.try_get(2)?;
     169            0 :         let rngsubtype: Option<Oid> = row.try_get(3)?;
     170            0 :         let basetype: Oid = row.try_get(4)?;
     171            0 :         let schema: String = row.try_get(5)?;
     172            0 :         let relid: Oid = row.try_get(6)?;
     173              : 
     174            0 :         let kind = if type_ == b'e' as i8 {
     175            0 :             Kind::Enum
     176            0 :         } else if type_ == b'p' as i8 {
     177            0 :             Kind::Pseudo
     178            0 :         } else if basetype != 0 {
     179            0 :             Kind::Domain(basetype)
     180            0 :         } else if elem_oid != 0 {
     181            0 :             stack.push((name, oid, schema));
     182            0 :             oid = elem_oid;
     183            0 :             continue;
     184            0 :         } else if relid != 0 {
     185            0 :             Kind::Composite(relid)
     186            0 :         } else if let Some(rngsubtype) = rngsubtype {
     187            0 :             Kind::Range(rngsubtype)
     188              :         } else {
     189            0 :             Kind::Simple
     190              :         };
     191              : 
     192            0 :         let type_ = Type::new(name, oid, kind, schema);
     193            0 :         typecache.types.insert(oid, type_.clone());
     194            0 :         break type_;
     195              :     };
     196              : 
     197            0 :     while let Some((name, oid, schema)) = stack.pop() {
     198            0 :         type_ = Type::new(name, oid, Kind::Array(type_), schema);
     199            0 :         typecache.types.insert(oid, type_.clone());
     200            0 :     }
     201              : 
     202            0 :     Ok(type_)
     203            0 : }
     204              : 
     205              : /// exec the typeinfo statement returning one row.
     206            0 : async fn exec(
     207            0 :     client: &mut PartialQuery<'_>,
     208            0 :     statement: &Statement,
     209            0 :     param: Oid,
     210            0 : ) -> Result<Row, Error> {
     211            0 :     let responses = client.send_with_flush(|buf| {
     212            0 :         encode_bind(statement, param, "", buf);
     213            0 :         frontend::execute("", 0, buf).map_err(Error::encode)?;
     214            0 :         Ok(())
     215            0 :     })?;
     216              : 
     217            0 :     match responses.next().await? {
     218            0 :         Message::BindComplete => {}
     219            0 :         _ => return Err(Error::unexpected_message()),
     220              :     }
     221              : 
     222            0 :     let row = match responses.next().await? {
     223            0 :         Message::DataRow(body) => Row::new(statement.clone(), body, Format::Binary)?,
     224            0 :         _ => return Err(Error::unexpected_message()),
     225              :     };
     226              : 
     227            0 :     match responses.next().await? {
     228            0 :         Message::CommandComplete(_) => {}
     229            0 :         _ => return Err(Error::unexpected_message()),
     230              :     };
     231              : 
     232            0 :     Ok(row)
     233            0 : }
     234              : 
     235            0 : fn encode_bind(statement: &Statement, param: Oid, portal: &str, buf: &mut BytesMut) {
     236            0 :     frontend::bind(
     237            0 :         portal,
     238            0 :         statement.name(),
     239            0 :         [Format::Binary as i16],
     240            0 :         [param],
     241            0 :         |param, buf| {
     242            0 :             oid_to_sql(param, buf);
     243            0 :             Ok(IsNull::No)
     244            0 :         },
     245            0 :         [Format::Binary as i16],
     246            0 :         buf,
     247              :     )
     248            0 :     .unwrap();
     249            0 : }
        

Generated by: LCOV version 2.1-beta