LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connection.rs (source / functions) Coverage Total Hit
Test: c8f8d331b83562868d9054d9e0e68f866772aeaa.info Lines: 0.0 % 142 0
Test Date: 2025-07-26 17:20:05 Functions: 0.0 % 28 0

            Line data    Source code
       1              : use std::future::Future;
       2              : use std::pin::Pin;
       3              : use std::task::{Context, Poll};
       4              : 
       5              : use bytes::BytesMut;
       6              : use fallible_iterator::FallibleIterator;
       7              : use futures_util::{Sink, StreamExt, ready};
       8              : use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
       9              : use postgres_protocol2::message::frontend;
      10              : use tokio::io::{AsyncRead, AsyncWrite};
      11              : use tokio::sync::mpsc;
      12              : use tokio_util::codec::Framed;
      13              : use tokio_util::sync::PollSender;
      14              : use tracing::trace;
      15              : 
      16              : use crate::Error;
      17              : use crate::codec::{
      18              :     BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
      19              : };
      20              : use crate::maybe_tls_stream::MaybeTlsStream;
      21              : 
      22              : #[derive(PartialEq, Debug)]
      23              : enum State {
      24              :     Active,
      25              :     Closing,
      26              : }
      27              : 
      28              : /// A connection to a PostgreSQL database.
      29              : ///
      30              : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
      31              : /// server, and should generally be spawned off onto an executor to run in the background.
      32              : ///
      33              : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
      34              : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
      35              : #[must_use = "futures do nothing unless polled"]
      36              : pub struct Connection<S, T> {
      37              :     stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      38              : 
      39              :     sender: PollSender<BackendMessages>,
      40              :     receiver: mpsc::UnboundedReceiver<FrontendMessage>,
      41              :     notices: Option<RecordNotices>,
      42              : 
      43              :     pending_response: Option<BackendMessages>,
      44              :     state: State,
      45              : }
      46              : 
      47              : pub const INITIAL_CAPACITY: usize = 2 * 1024;
      48              : pub const GC_THRESHOLD: usize = 16 * 1024;
      49              : 
      50              : /// Gargabe collect the [`BytesMut`] if it has too much spare capacity.
      51            0 : pub fn gc_bytesmut(buf: &mut BytesMut) {
      52              :     // We use a different mode to shrink the buf when above the threshold.
      53              :     // When above the threshold, we only re-allocate when the buf has 2x spare capacity.
      54            0 :     let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len());
      55              : 
      56              :     // `try_reclaim` tries to get the capacity from any shared `BytesMut`s,
      57              :     // before then comparing the length against the capacity.
      58            0 :     if buf.try_reclaim(reclaim) {
      59            0 :         let capacity = usize::max(buf.len(), INITIAL_CAPACITY);
      60            0 : 
      61            0 :         // Allocate a new `BytesMut` so that we deallocate the old version.
      62            0 :         let mut new = BytesMut::with_capacity(capacity);
      63            0 :         new.extend_from_slice(buf);
      64            0 :         *buf = new;
      65            0 :     }
      66            0 : }
      67              : 
      68              : pub enum Never {}
      69              : 
      70              : impl<S, T> Connection<S, T>
      71              : where
      72              :     S: AsyncRead + AsyncWrite + Unpin,
      73              :     T: AsyncRead + AsyncWrite + Unpin,
      74              : {
      75            0 :     pub(crate) fn new(
      76            0 :         stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      77            0 :         sender: mpsc::Sender<BackendMessages>,
      78            0 :         receiver: mpsc::UnboundedReceiver<FrontendMessage>,
      79            0 :     ) -> Connection<S, T> {
      80            0 :         Connection {
      81            0 :             stream,
      82            0 :             sender: PollSender::new(sender),
      83            0 :             receiver,
      84            0 :             notices: None,
      85            0 :             pending_response: None,
      86            0 :             state: State::Active,
      87            0 :         }
      88            0 :     }
      89              : 
      90              :     /// Read and process messages from the connection to postgres.
      91              :     /// client <- postgres
      92            0 :     fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
      93              :         loop {
      94            0 :             let messages = match self.pending_response.take() {
      95            0 :                 Some(messages) => messages,
      96              :                 None => {
      97            0 :                     let message = match self.stream.poll_next_unpin(cx) {
      98            0 :                         Poll::Pending => return Poll::Pending,
      99            0 :                         Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
     100            0 :                         Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
     101            0 :                         Poll::Ready(Some(Ok(message))) => message,
     102              :                     };
     103              : 
     104            0 :                     match message {
     105            0 :                         BackendMessage::Async(Message::NoticeResponse(body)) => {
     106            0 :                             self.handle_notice(body)?;
     107            0 :                             continue;
     108              :                         }
     109            0 :                         BackendMessage::Async(_) => continue,
     110            0 :                         BackendMessage::Normal { messages, ready } => {
     111              :                             // if we read a ReadyForQuery from postgres, let's try GC the read buffer.
     112            0 :                             if ready {
     113            0 :                                 gc_bytesmut(self.stream.read_buffer_mut());
     114            0 :                             }
     115              : 
     116            0 :                             messages
     117              :                         }
     118              :                     }
     119              :                 }
     120              :             };
     121              : 
     122            0 :             match self.sender.poll_reserve(cx) {
     123            0 :                 Poll::Ready(Ok(())) => {
     124            0 :                     let _ = self.sender.send_item(messages);
     125            0 :                 }
     126              :                 Poll::Ready(Err(_)) => {
     127            0 :                     return Poll::Ready(Err(Error::closed()));
     128              :                 }
     129              :                 Poll::Pending => {
     130            0 :                     self.pending_response = Some(messages);
     131            0 :                     trace!("poll_read: waiting on sender");
     132            0 :                     return Poll::Pending;
     133              :                 }
     134              :             }
     135              :         }
     136            0 :     }
     137              : 
     138            0 :     fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
     139            0 :         let Some(notices) = &mut self.notices else {
     140            0 :             return Ok(());
     141              :         };
     142              : 
     143            0 :         let mut fields = body.fields();
     144            0 :         while let Some(field) = fields.next().map_err(Error::parse)? {
     145              :             // loop until we find the message field
     146            0 :             if field.type_() == b'M' {
     147              :                 // if the message field is within the limit, send it.
     148            0 :                 if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
     149            0 :                     match notices.sender.send(field.value().into()) {
     150              :                         // set the new limit.
     151            0 :                         Ok(()) => notices.limit = new_limit,
     152              :                         // closed.
     153            0 :                         Err(_) => self.notices = None,
     154              :                     }
     155            0 :                 }
     156            0 :                 break;
     157            0 :             }
     158              :         }
     159              : 
     160            0 :         Ok(())
     161            0 :     }
     162              : 
     163              :     /// Fetch the next client request and enqueue the response sender.
     164            0 :     fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
     165            0 :         if self.receiver.is_closed() {
     166            0 :             return Poll::Ready(None);
     167            0 :         }
     168              : 
     169            0 :         match self.receiver.poll_recv(cx) {
     170            0 :             Poll::Ready(Some(request)) => {
     171            0 :                 trace!("polled new request");
     172            0 :                 Poll::Ready(Some(request))
     173              :             }
     174            0 :             Poll::Ready(None) => Poll::Ready(None),
     175            0 :             Poll::Pending => Poll::Pending,
     176              :         }
     177            0 :     }
     178              : 
     179              :     /// Process client requests and write them to the postgres connection, flushing if necessary.
     180              :     /// client -> postgres
     181            0 :     fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     182              :         loop {
     183            0 :             if Pin::new(&mut self.stream)
     184            0 :                 .poll_ready(cx)
     185            0 :                 .map_err(Error::io)?
     186            0 :                 .is_pending()
     187              :             {
     188            0 :                 trace!("poll_write: waiting on socket");
     189              : 
     190              :                 // poll_ready is self-flushing.
     191            0 :                 return Poll::Pending;
     192            0 :             }
     193              : 
     194            0 :             match self.poll_request(cx) {
     195              :                 // send the message to postgres
     196            0 :                 Poll::Ready(Some(FrontendMessage::Raw(request))) => {
     197            0 :                     Pin::new(&mut self.stream)
     198            0 :                         .start_send(request)
     199            0 :                         .map_err(Error::io)?;
     200              :                 }
     201            0 :                 Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
     202            0 :                     self.notices = Some(notices)
     203              :                 }
     204              :                 // No more messages from the client, and no more responses to wait for.
     205              :                 // Send a terminate message to postgres
     206              :                 Poll::Ready(None) => {
     207            0 :                     trace!("poll_write: at eof, terminating");
     208            0 :                     frontend::terminate(self.stream.write_buffer_mut());
     209              : 
     210            0 :                     trace!("poll_write: sent eof, closing");
     211            0 :                     trace!("poll_write: done");
     212            0 :                     return Poll::Ready(Ok(()));
     213              :                 }
     214              :                 // Still waiting for a message from the client.
     215              :                 Poll::Pending => {
     216            0 :                     trace!("poll_write: waiting on request");
     217            0 :                     ready!(self.poll_flush(cx))?;
     218            0 :                     return Poll::Pending;
     219              :                 }
     220              :             }
     221              :         }
     222            0 :     }
     223              : 
     224            0 :     fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     225            0 :         match Pin::new(&mut self.stream)
     226            0 :             .poll_flush(cx)
     227            0 :             .map_err(Error::io)?
     228              :         {
     229              :             Poll::Ready(()) => {
     230            0 :                 trace!("poll_flush: flushed");
     231              : 
     232              :                 // Since our codec prefers to share the buffer with the `Client`,
     233              :                 // if we don't release our share, then the `Client` would have to re-alloc
     234              :                 // the buffer when they next use it.
     235            0 :                 debug_assert!(self.stream.write_buffer().is_empty());
     236            0 :                 *self.stream.write_buffer_mut() = BytesMut::new();
     237              : 
     238            0 :                 Poll::Ready(Ok(()))
     239              :             }
     240              :             Poll::Pending => {
     241            0 :                 trace!("poll_flush: waiting on socket");
     242            0 :                 Poll::Pending
     243              :             }
     244              :         }
     245            0 :     }
     246              : 
     247            0 :     fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     248            0 :         match Pin::new(&mut self.stream)
     249            0 :             .poll_close(cx)
     250            0 :             .map_err(Error::io)?
     251              :         {
     252              :             Poll::Ready(()) => {
     253            0 :                 trace!("poll_shutdown: complete");
     254            0 :                 Poll::Ready(Ok(()))
     255              :             }
     256              :             Poll::Pending => {
     257            0 :                 trace!("poll_shutdown: waiting on socket");
     258            0 :                 Poll::Pending
     259              :             }
     260              :         }
     261            0 :     }
     262              : 
     263            0 :     fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
     264            0 :         if self.state != State::Closing {
     265              :             // if the state is still active, try read from and write to postgres.
     266            0 :             let Poll::Pending = self.poll_read(cx)?;
     267            0 :             if self.poll_write(cx)?.is_ready() {
     268            0 :                 self.state = State::Closing;
     269            0 :             }
     270              : 
     271              :             // poll_read returned Pending.
     272              :             // poll_write returned Pending or Ready(()).
     273              :             // if poll_write returned Ready(()), then we are waiting to read more data from postgres.
     274            0 :             if self.state != State::Closing {
     275            0 :                 return Poll::Pending;
     276            0 :             }
     277            0 :         }
     278              : 
     279            0 :         match self.poll_shutdown(cx) {
     280            0 :             Poll::Ready(Ok(())) => Poll::Ready(None),
     281            0 :             Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
     282            0 :             Poll::Pending => Poll::Pending,
     283              :         }
     284            0 :     }
     285              : }
     286              : 
     287              : impl<S, T> Future for Connection<S, T>
     288              : where
     289              :     S: AsyncRead + AsyncWrite + Unpin,
     290              :     T: AsyncRead + AsyncWrite + Unpin,
     291              : {
     292              :     type Output = Result<(), Error>;
     293              : 
     294            0 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     295            0 :         match self.poll_message(cx)? {
     296            0 :             Poll::Ready(None) => Poll::Ready(Ok(())),
     297            0 :             Poll::Pending => Poll::Pending,
     298              :         }
     299            0 :     }
     300              : }
        

Generated by: LCOV version 2.1-beta