LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - client.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 188 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 40 0

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::fmt;
       3              : use std::net::IpAddr;
       4              : use std::task::{Context, Poll};
       5              : use std::time::Duration;
       6              : 
       7              : use bytes::BytesMut;
       8              : use fallible_iterator::FallibleIterator;
       9              : use futures_util::{TryStreamExt, future, ready};
      10              : use postgres_protocol2::message::backend::Message;
      11              : use postgres_protocol2::message::frontend;
      12              : use serde::{Deserialize, Serialize};
      13              : use tokio::sync::mpsc;
      14              : 
      15              : use crate::cancel_token::RawCancelToken;
      16              : use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
      17              : use crate::config::{Host, SslMode};
      18              : use crate::connection::gc_bytesmut;
      19              : use crate::query::RowStream;
      20              : use crate::simple_query::SimpleQueryStream;
      21              : use crate::types::{Oid, Type};
      22              : use crate::{
      23              :     CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
      24              :     query, simple_query,
      25              : };
      26              : 
      27              : pub struct Responses {
      28              :     /// new messages from conn
      29              :     receiver: mpsc::Receiver<BackendMessages>,
      30              :     /// current batch of messages
      31              :     cur: BackendMessages,
      32              :     /// number of total queries sent.
      33              :     waiting: usize,
      34              :     /// number of ReadyForQuery messages received.
      35              :     received: usize,
      36              : }
      37              : 
      38              : impl Responses {
      39            0 :     pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
      40              :         loop {
      41              :             // get the next saved message
      42            0 :             if let Some(message) = self.cur.next().map_err(Error::parse)? {
      43            0 :                 let received = self.received;
      44              : 
      45              :                 // increase the query head if this is the last message.
      46            0 :                 if let Message::ReadyForQuery(_) = message {
      47            0 :                     self.received += 1;
      48            0 :                 }
      49              : 
      50              :                 // check if the client has skipped this query.
      51            0 :                 if received + 1 < self.waiting {
      52              :                     // grab the next message.
      53            0 :                     continue;
      54            0 :                 }
      55              : 
      56              :                 // convenience: turn the error messaage into a proper error.
      57            0 :                 let res = match message {
      58            0 :                     Message::ErrorResponse(body) => Err(Error::db(body)),
      59            0 :                     message => Ok(message),
      60              :                 };
      61            0 :                 return Poll::Ready(res);
      62            0 :             }
      63              : 
      64              :             // get the next batch of messages.
      65            0 :             match ready!(self.receiver.poll_recv(cx)) {
      66            0 :                 Some(messages) => self.cur = messages,
      67            0 :                 None => return Poll::Ready(Err(Error::closed())),
      68              :             }
      69              :         }
      70            0 :     }
      71              : 
      72            0 :     pub async fn next(&mut self) -> Result<Message, Error> {
      73            0 :         future::poll_fn(|cx| self.poll_next(cx)).await
      74            0 :     }
      75              : }
      76              : 
      77              : /// A cache of type info and prepared statements for fetching type info
      78              : /// (corresponding to the queries in the [crate::prepare] module).
      79              : #[derive(Default)]
      80              : pub(crate) struct CachedTypeInfo {
      81              :     /// Cache of types already looked up.
      82              :     pub(crate) types: HashMap<Oid, Type>,
      83              : }
      84              : 
      85              : pub struct InnerClient {
      86              :     sender: mpsc::UnboundedSender<FrontendMessage>,
      87              :     responses: Responses,
      88              : 
      89              :     /// A buffer to use when writing out postgres commands.
      90              :     buffer: BytesMut,
      91              : }
      92              : 
      93              : impl InnerClient {
      94            0 :     pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
      95            0 :         self.responses.waiting += 1;
      96            0 :         Ok(PartialQuery(Some(self)))
      97            0 :     }
      98              : 
      99            0 :     pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
     100            0 :         self.responses.waiting += 1;
     101              : 
     102            0 :         self.buffer.clear();
     103              :         // simple queries do not need sync.
     104            0 :         frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
     105            0 :         let buf = self.buffer.split();
     106            0 :         self.send_message(FrontendMessage::Raw(buf))
     107            0 :     }
     108              : 
     109            0 :     fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
     110            0 :         self.sender.send(messages).map_err(|_| Error::closed())?;
     111            0 :         Ok(&mut self.responses)
     112            0 :     }
     113              : }
     114              : 
     115              : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
     116              : 
     117              : impl Drop for PartialQuery<'_> {
     118            0 :     fn drop(&mut self) {
     119            0 :         if let Some(client) = self.0.take() {
     120            0 :             client.buffer.clear();
     121            0 :             frontend::sync(&mut client.buffer);
     122            0 :             let buf = client.buffer.split();
     123            0 :             let _ = client.send_message(FrontendMessage::Raw(buf));
     124            0 :         }
     125            0 :     }
     126              : }
     127              : 
     128              : impl<'a> PartialQuery<'a> {
     129            0 :     pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
     130            0 :     where
     131            0 :         F: FnOnce(&mut BytesMut) -> Result<(), Error>,
     132              :     {
     133            0 :         let client = self.0.as_deref_mut().unwrap();
     134              : 
     135            0 :         client.buffer.clear();
     136            0 :         f(&mut client.buffer)?;
     137            0 :         frontend::flush(&mut client.buffer);
     138            0 :         let buf = client.buffer.split();
     139            0 :         client.send_message(FrontendMessage::Raw(buf))
     140            0 :     }
     141              : 
     142            0 :     pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
     143            0 :     where
     144            0 :         F: FnOnce(&mut BytesMut) -> Result<(), Error>,
     145              :     {
     146            0 :         let client = self.0.as_deref_mut().unwrap();
     147              : 
     148            0 :         client.buffer.clear();
     149            0 :         f(&mut client.buffer)?;
     150            0 :         frontend::sync(&mut client.buffer);
     151            0 :         let buf = client.buffer.split();
     152            0 :         let _ = client.send_message(FrontendMessage::Raw(buf));
     153              : 
     154            0 :         Ok(&mut self.0.take().unwrap().responses)
     155            0 :     }
     156              : }
     157              : 
     158            0 : #[derive(Clone, Serialize, Deserialize)]
     159              : pub struct SocketConfig {
     160              :     pub host_addr: Option<IpAddr>,
     161              :     pub host: Host,
     162              :     pub port: u16,
     163              :     pub connect_timeout: Option<Duration>,
     164              : }
     165              : 
     166              : /// An asynchronous PostgreSQL client.
     167              : ///
     168              : /// The client is one half of what is returned when a connection is established. Users interact with the database
     169              : /// through this client object.
     170              : pub struct Client {
     171              :     inner: InnerClient,
     172              :     cached_typeinfo: CachedTypeInfo,
     173              : 
     174              :     socket_config: SocketConfig,
     175              :     ssl_mode: SslMode,
     176              :     process_id: i32,
     177              :     secret_key: i32,
     178              : }
     179              : 
     180              : impl Client {
     181            0 :     pub(crate) fn new(
     182            0 :         sender: mpsc::UnboundedSender<FrontendMessage>,
     183            0 :         receiver: mpsc::Receiver<BackendMessages>,
     184            0 :         socket_config: SocketConfig,
     185            0 :         ssl_mode: SslMode,
     186            0 :         process_id: i32,
     187            0 :         secret_key: i32,
     188            0 :         write_buf: BytesMut,
     189            0 :     ) -> Client {
     190            0 :         Client {
     191            0 :             inner: InnerClient {
     192            0 :                 sender,
     193            0 :                 responses: Responses {
     194            0 :                     receiver,
     195            0 :                     cur: BackendMessages::empty(),
     196            0 :                     waiting: 0,
     197            0 :                     received: 0,
     198            0 :                 },
     199            0 :                 buffer: write_buf,
     200            0 :             },
     201            0 :             cached_typeinfo: Default::default(),
     202            0 : 
     203            0 :             socket_config,
     204            0 :             ssl_mode,
     205            0 :             process_id,
     206            0 :             secret_key,
     207            0 :         }
     208            0 :     }
     209              : 
     210              :     /// Returns process_id.
     211            0 :     pub fn get_process_id(&self) -> i32 {
     212            0 :         self.process_id
     213            0 :     }
     214              : 
     215            0 :     pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
     216            0 :         &mut self.inner
     217            0 :     }
     218              : 
     219            0 :     pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
     220            0 :         let (tx, rx) = mpsc::unbounded_channel();
     221              : 
     222            0 :         let notices = RecordNotices { sender: tx, limit };
     223            0 :         self.inner
     224            0 :             .sender
     225            0 :             .send(FrontendMessage::RecordNotices(notices))
     226            0 :             .ok();
     227              : 
     228            0 :         rx
     229            0 :     }
     230              : 
     231              :     /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
     232              :     /// to save a roundtrip
     233            0 :     pub async fn query_raw_txt<S, I>(
     234            0 :         &mut self,
     235            0 :         statement: &str,
     236            0 :         params: I,
     237            0 :     ) -> Result<RowStream<'_>, Error>
     238            0 :     where
     239            0 :         S: AsRef<str>,
     240            0 :         I: IntoIterator<Item = Option<S>>,
     241            0 :         I::IntoIter: ExactSizeIterator,
     242            0 :     {
     243            0 :         query::query_txt(
     244            0 :             &mut self.inner,
     245            0 :             &mut self.cached_typeinfo,
     246            0 :             statement,
     247            0 :             params,
     248            0 :         )
     249            0 :         .await
     250            0 :     }
     251              : 
     252              :     /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
     253              :     ///
     254              :     /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
     255              :     /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
     256              :     /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
     257              :     /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
     258              :     /// or a row of data. This preserves the framing between the separate statements in the request.
     259              :     ///
     260              :     /// # Warning
     261              :     ///
     262              :     /// Prepared statements should be use for any query which contains user-specified data, as they provided the
     263              :     /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
     264              :     /// them to this method!
     265            0 :     pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
     266            0 :         self.simple_query_raw(query).await?.try_collect().await
     267            0 :     }
     268              : 
     269            0 :     pub(crate) async fn simple_query_raw(
     270            0 :         &mut self,
     271            0 :         query: &str,
     272            0 :     ) -> Result<SimpleQueryStream<'_>, Error> {
     273            0 :         simple_query::simple_query(self.inner_mut(), query).await
     274            0 :     }
     275              : 
     276              :     /// Executes a sequence of SQL statements using the simple query protocol.
     277              :     ///
     278              :     /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
     279              :     /// point. This is intended for use when, for example, initializing a database schema.
     280              :     ///
     281              :     /// # Warning
     282              :     ///
     283              :     /// Prepared statements should be use for any query which contains user-specified data, as they provided the
     284              :     /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
     285              :     /// them to this method!
     286            0 :     pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
     287            0 :         simple_query::batch_execute(self.inner_mut(), query).await
     288            0 :     }
     289              : 
     290              :     /// Similar to `discard_all`, but it does not clear any query plans
     291              :     ///
     292              :     /// This runs in the background, so it can be executed without `await`ing.
     293            0 :     pub fn reset_session_background(&mut self) -> Result<(), Error> {
     294              :         // "CLOSE ALL": closes any cursors
     295              :         // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
     296              :         // "RESET ALL": resets any GUCs back to their session defaults.
     297              :         // "DEALLOCATE ALL": deallocates any prepared statements
     298              :         // "UNLISTEN *": stops listening on all channels
     299              :         // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
     300              :         // "DISCARD TEMP;": drops all temporary tables
     301              :         // "DISCARD SEQUENCES;": deallocates all cached sequence state
     302              : 
     303            0 :         let _responses = self.inner_mut().send_simple_query(
     304            0 :             "ROLLBACK;
     305            0 :             CLOSE ALL;
     306            0 :             SET SESSION AUTHORIZATION DEFAULT;
     307            0 :             RESET ALL;
     308            0 :             DEALLOCATE ALL;
     309            0 :             UNLISTEN *;
     310            0 :             SELECT pg_advisory_unlock_all();
     311            0 :             DISCARD TEMP;
     312            0 :             DISCARD SEQUENCES;",
     313            0 :         )?;
     314              : 
     315              :         // Clean up memory usage.
     316            0 :         gc_bytesmut(&mut self.inner_mut().buffer);
     317              : 
     318            0 :         Ok(())
     319            0 :     }
     320              : 
     321              :     /// Begins a new database transaction.
     322              :     ///
     323              :     /// The transaction will roll back by default - use the `commit` method to commit it.
     324            0 :     pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
     325              :         struct RollbackIfNotDone<'me> {
     326              :             client: &'me mut Client,
     327              :             done: bool,
     328              :         }
     329              : 
     330              :         impl Drop for RollbackIfNotDone<'_> {
     331            0 :             fn drop(&mut self) {
     332            0 :                 if self.done {
     333            0 :                     return;
     334            0 :                 }
     335              : 
     336            0 :                 let _ = self.client.inner.send_simple_query("ROLLBACK");
     337            0 :             }
     338              :         }
     339              : 
     340              :         // This is done, as `Future` created by this method can be dropped after
     341              :         // `RequestMessages` is synchronously send to the `Connection` by
     342              :         // `batch_execute()`, but before `Responses` is asynchronously polled to
     343              :         // completion. In that case `Transaction` won't be created and thus
     344              :         // won't be rolled back.
     345              :         {
     346            0 :             let mut cleaner = RollbackIfNotDone {
     347            0 :                 client: self,
     348            0 :                 done: false,
     349            0 :             };
     350            0 :             cleaner.client.batch_execute("BEGIN").await?;
     351            0 :             cleaner.done = true;
     352              :         }
     353              : 
     354            0 :         Ok(Transaction::new(self))
     355            0 :     }
     356              : 
     357              :     /// Returns a builder for a transaction with custom settings.
     358              :     ///
     359              :     /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
     360              :     /// attributes.
     361            0 :     pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
     362            0 :         TransactionBuilder::new(self)
     363            0 :     }
     364              : 
     365              :     /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
     366              :     /// connection associated with this client.
     367            0 :     pub fn cancel_token(&self) -> CancelToken {
     368            0 :         CancelToken {
     369            0 :             socket_config: self.socket_config.clone(),
     370            0 :             raw: RawCancelToken {
     371            0 :                 ssl_mode: self.ssl_mode,
     372            0 :                 process_id: self.process_id,
     373            0 :                 secret_key: self.secret_key,
     374            0 :             },
     375            0 :         }
     376            0 :     }
     377              : 
     378              :     /// Determines if the connection to the server has already closed.
     379              :     ///
     380              :     /// In that case, all future queries will fail.
     381            0 :     pub fn is_closed(&self) -> bool {
     382            0 :         self.inner.sender.is_closed()
     383            0 :     }
     384              : }
     385              : 
     386              : impl fmt::Debug for Client {
     387            0 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     388            0 :         f.debug_struct("Client").finish()
     389            0 :     }
     390              : }
        

Generated by: LCOV version 2.1-beta