LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - prepare.rs (source / functions) Coverage Total Hit
Test: 4f58e98c51285c7fa348e0b410c88a10caf68ad2.info Lines: 0.0 % 154 0
Test Date: 2025-01-07 20:58:07 Functions: 0.0 % 20 0

            Line data    Source code
       1              : use crate::client::InnerClient;
       2              : use crate::codec::FrontendMessage;
       3              : use crate::connection::RequestMessages;
       4              : use crate::error::SqlState;
       5              : use crate::types::{Field, Kind, Oid, Type};
       6              : use crate::{query, slice_iter};
       7              : use crate::{Column, Error, Statement};
       8              : use bytes::Bytes;
       9              : use fallible_iterator::FallibleIterator;
      10              : use futures_util::{pin_mut, TryStreamExt};
      11              : use log::debug;
      12              : use postgres_protocol2::message::backend::Message;
      13              : use postgres_protocol2::message::frontend;
      14              : use std::future::Future;
      15              : use std::pin::Pin;
      16              : use std::sync::atomic::{AtomicUsize, Ordering};
      17              : use std::sync::Arc;
      18              : 
      19              : pub(crate) const TYPEINFO_QUERY: &str = "\
      20              : SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
      21              : FROM pg_catalog.pg_type t
      22              : LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
      23              : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
      24              : WHERE t.oid = $1
      25              : ";
      26              : 
      27              : // Range types weren't added until Postgres 9.2, so pg_range may not exist
      28              : const TYPEINFO_FALLBACK_QUERY: &str = "\
      29              : SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
      30              : FROM pg_catalog.pg_type t
      31              : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
      32              : WHERE t.oid = $1
      33              : ";
      34              : 
      35              : const TYPEINFO_ENUM_QUERY: &str = "\
      36              : SELECT enumlabel
      37              : FROM pg_catalog.pg_enum
      38              : WHERE enumtypid = $1
      39              : ORDER BY enumsortorder
      40              : ";
      41              : 
      42              : // Postgres 9.0 didn't have enumsortorder
      43              : const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
      44              : SELECT enumlabel
      45              : FROM pg_catalog.pg_enum
      46              : WHERE enumtypid = $1
      47              : ORDER BY oid
      48              : ";
      49              : 
      50              : pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\
      51              : SELECT attname, atttypid
      52              : FROM pg_catalog.pg_attribute
      53              : WHERE attrelid = $1
      54              : AND NOT attisdropped
      55              : AND attnum > 0
      56              : ORDER BY attnum
      57              : ";
      58              : 
      59              : static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
      60              : 
      61            0 : pub async fn prepare(
      62            0 :     client: &Arc<InnerClient>,
      63            0 :     query: &str,
      64            0 :     types: &[Type],
      65            0 : ) -> Result<Statement, Error> {
      66            0 :     let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
      67            0 :     let buf = encode(client, &name, query, types)?;
      68            0 :     let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
      69              : 
      70            0 :     match responses.next().await? {
      71            0 :         Message::ParseComplete => {}
      72            0 :         _ => return Err(Error::unexpected_message()),
      73              :     }
      74              : 
      75            0 :     let parameter_description = match responses.next().await? {
      76            0 :         Message::ParameterDescription(body) => body,
      77            0 :         _ => return Err(Error::unexpected_message()),
      78              :     };
      79              : 
      80            0 :     let row_description = match responses.next().await? {
      81            0 :         Message::RowDescription(body) => Some(body),
      82            0 :         Message::NoData => None,
      83            0 :         _ => return Err(Error::unexpected_message()),
      84              :     };
      85              : 
      86            0 :     let mut parameters = vec![];
      87            0 :     let mut it = parameter_description.parameters();
      88            0 :     while let Some(oid) = it.next().map_err(Error::parse)? {
      89            0 :         let type_ = get_type(client, oid).await?;
      90            0 :         parameters.push(type_);
      91              :     }
      92              : 
      93            0 :     let mut columns = vec![];
      94            0 :     if let Some(row_description) = row_description {
      95            0 :         let mut it = row_description.fields();
      96            0 :         while let Some(field) = it.next().map_err(Error::parse)? {
      97            0 :             let type_ = get_type(client, field.type_oid()).await?;
      98            0 :             let column = Column::new(field.name().to_string(), type_, field);
      99            0 :             columns.push(column);
     100              :         }
     101            0 :     }
     102              : 
     103            0 :     Ok(Statement::new(client, name, parameters, columns))
     104            0 : }
     105              : 
     106            0 : fn prepare_rec<'a>(
     107            0 :     client: &'a Arc<InnerClient>,
     108            0 :     query: &'a str,
     109            0 :     types: &'a [Type],
     110            0 : ) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
     111            0 :     Box::pin(prepare(client, query, types))
     112            0 : }
     113              : 
     114            0 : fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
     115            0 :     if types.is_empty() {
     116            0 :         debug!("preparing query {}: {}", name, query);
     117              :     } else {
     118            0 :         debug!("preparing query {} with types {:?}: {}", name, types, query);
     119              :     }
     120              : 
     121            0 :     client.with_buf(|buf| {
     122            0 :         frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
     123            0 :         frontend::describe(b'S', name, buf).map_err(Error::encode)?;
     124            0 :         frontend::sync(buf);
     125            0 :         Ok(buf.split().freeze())
     126            0 :     })
     127            0 : }
     128              : 
     129            0 : pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
     130            0 :     if let Some(type_) = Type::from_oid(oid) {
     131            0 :         return Ok(type_);
     132            0 :     }
     133              : 
     134            0 :     if let Some(type_) = client.type_(oid) {
     135            0 :         return Ok(type_);
     136            0 :     }
     137              : 
     138            0 :     let stmt = typeinfo_statement(client).await?;
     139              : 
     140            0 :     let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
     141            0 :     pin_mut!(rows);
     142              : 
     143            0 :     let row = match rows.try_next().await? {
     144            0 :         Some(row) => row,
     145            0 :         None => return Err(Error::unexpected_message()),
     146              :     };
     147              : 
     148            0 :     let name: String = row.try_get(0)?;
     149            0 :     let type_: i8 = row.try_get(1)?;
     150            0 :     let elem_oid: Oid = row.try_get(2)?;
     151            0 :     let rngsubtype: Option<Oid> = row.try_get(3)?;
     152            0 :     let basetype: Oid = row.try_get(4)?;
     153            0 :     let schema: String = row.try_get(5)?;
     154            0 :     let relid: Oid = row.try_get(6)?;
     155              : 
     156            0 :     let kind = if type_ == b'e' as i8 {
     157            0 :         let variants = get_enum_variants(client, oid).await?;
     158            0 :         Kind::Enum(variants)
     159            0 :     } else if type_ == b'p' as i8 {
     160            0 :         Kind::Pseudo
     161            0 :     } else if basetype != 0 {
     162            0 :         let type_ = get_type_rec(client, basetype).await?;
     163            0 :         Kind::Domain(type_)
     164            0 :     } else if elem_oid != 0 {
     165            0 :         let type_ = get_type_rec(client, elem_oid).await?;
     166            0 :         Kind::Array(type_)
     167            0 :     } else if relid != 0 {
     168            0 :         let fields = get_composite_fields(client, relid).await?;
     169            0 :         Kind::Composite(fields)
     170            0 :     } else if let Some(rngsubtype) = rngsubtype {
     171            0 :         let type_ = get_type_rec(client, rngsubtype).await?;
     172            0 :         Kind::Range(type_)
     173              :     } else {
     174            0 :         Kind::Simple
     175              :     };
     176              : 
     177            0 :     let type_ = Type::new(name, oid, kind, schema);
     178            0 :     client.set_type(oid, &type_);
     179            0 : 
     180            0 :     Ok(type_)
     181            0 : }
     182              : 
     183            0 : fn get_type_rec<'a>(
     184            0 :     client: &'a Arc<InnerClient>,
     185            0 :     oid: Oid,
     186            0 : ) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
     187            0 :     Box::pin(get_type(client, oid))
     188            0 : }
     189              : 
     190            0 : async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
     191            0 :     if let Some(stmt) = client.typeinfo() {
     192            0 :         return Ok(stmt);
     193            0 :     }
     194              : 
     195            0 :     let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
     196            0 :         Ok(stmt) => stmt,
     197            0 :         Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
     198            0 :             prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
     199              :         }
     200            0 :         Err(e) => return Err(e),
     201              :     };
     202              : 
     203            0 :     client.set_typeinfo(&stmt);
     204            0 :     Ok(stmt)
     205            0 : }
     206              : 
     207            0 : async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
     208            0 :     let stmt = typeinfo_enum_statement(client).await?;
     209              : 
     210            0 :     query::query(client, stmt, slice_iter(&[&oid]))
     211            0 :         .await?
     212            0 :         .and_then(|row| async move { row.try_get(0) })
     213            0 :         .try_collect()
     214            0 :         .await
     215            0 : }
     216              : 
     217            0 : async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
     218            0 :     if let Some(stmt) = client.typeinfo_enum() {
     219            0 :         return Ok(stmt);
     220            0 :     }
     221              : 
     222            0 :     let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
     223            0 :         Ok(stmt) => stmt,
     224            0 :         Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
     225            0 :             prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
     226              :         }
     227            0 :         Err(e) => return Err(e),
     228              :     };
     229              : 
     230            0 :     client.set_typeinfo_enum(&stmt);
     231            0 :     Ok(stmt)
     232            0 : }
     233              : 
     234            0 : async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
     235            0 :     let stmt = typeinfo_composite_statement(client).await?;
     236              : 
     237            0 :     let rows = query::query(client, stmt, slice_iter(&[&oid]))
     238            0 :         .await?
     239            0 :         .try_collect::<Vec<_>>()
     240            0 :         .await?;
     241              : 
     242            0 :     let mut fields = vec![];
     243            0 :     for row in rows {
     244            0 :         let name = row.try_get(0)?;
     245            0 :         let oid = row.try_get(1)?;
     246            0 :         let type_ = get_type_rec(client, oid).await?;
     247            0 :         fields.push(Field::new(name, type_));
     248              :     }
     249              : 
     250            0 :     Ok(fields)
     251            0 : }
     252              : 
     253            0 : async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
     254            0 :     if let Some(stmt) = client.typeinfo_composite() {
     255            0 :         return Ok(stmt);
     256            0 :     }
     257              : 
     258            0 :     let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
     259              : 
     260            0 :     client.set_typeinfo_composite(&stmt);
     261            0 :     Ok(stmt)
     262            0 : }
        

Generated by: LCOV version 2.1-beta