LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connection.rs (source / functions) Coverage Total Hit
Test: 6df3fc19ec669bcfbbf9aba41d1338898d24eaa0.info Lines: 0.0 % 181 0
Test Date: 2025-03-12 18:28:53 Functions: 0.0 % 23 0

            Line data    Source code
       1              : use std::collections::{HashMap, VecDeque};
       2              : use std::future::Future;
       3              : use std::pin::Pin;
       4              : use std::task::{Context, Poll};
       5              : 
       6              : use bytes::BytesMut;
       7              : use fallible_iterator::FallibleIterator;
       8              : use futures_util::{Sink, Stream, ready};
       9              : use log::{info, trace};
      10              : use postgres_protocol2::message::backend::Message;
      11              : use postgres_protocol2::message::frontend;
      12              : use tokio::io::{AsyncRead, AsyncWrite};
      13              : use tokio::sync::mpsc;
      14              : use tokio_util::codec::Framed;
      15              : use tokio_util::sync::PollSender;
      16              : 
      17              : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
      18              : use crate::error::DbError;
      19              : use crate::maybe_tls_stream::MaybeTlsStream;
      20              : use crate::{AsyncMessage, Error, Notification};
      21              : 
      22              : pub enum RequestMessages {
      23              :     Single(FrontendMessage),
      24              : }
      25              : 
      26              : pub struct Request {
      27              :     pub messages: RequestMessages,
      28              :     pub sender: mpsc::Sender<BackendMessages>,
      29              : }
      30              : 
      31              : pub struct Response {
      32              :     sender: PollSender<BackendMessages>,
      33              : }
      34              : 
      35              : #[derive(PartialEq, Debug)]
      36              : enum State {
      37              :     Active,
      38              :     Closing,
      39              : }
      40              : 
      41              : enum WriteReady {
      42              :     Terminating,
      43              :     WaitingOnRead,
      44              : }
      45              : 
      46              : /// A connection to a PostgreSQL database.
      47              : ///
      48              : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
      49              : /// server, and should generally be spawned off onto an executor to run in the background.
      50              : ///
      51              : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
      52              : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
      53              : #[must_use = "futures do nothing unless polled"]
      54              : pub struct Connection<S, T> {
      55              :     /// HACK: we need this in the Neon Proxy.
      56              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      57              :     /// HACK: we need this in the Neon Proxy to forward params.
      58              :     pub parameters: HashMap<String, String>,
      59              :     receiver: mpsc::UnboundedReceiver<Request>,
      60              :     pending_responses: VecDeque<BackendMessage>,
      61              :     responses: VecDeque<Response>,
      62              :     state: State,
      63              : }
      64              : 
      65              : impl<S, T> Connection<S, T>
      66              : where
      67              :     S: AsyncRead + AsyncWrite + Unpin,
      68              :     T: AsyncRead + AsyncWrite + Unpin,
      69              : {
      70            0 :     pub(crate) fn new(
      71            0 :         stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      72            0 :         pending_responses: VecDeque<BackendMessage>,
      73            0 :         parameters: HashMap<String, String>,
      74            0 :         receiver: mpsc::UnboundedReceiver<Request>,
      75            0 :     ) -> Connection<S, T> {
      76            0 :         Connection {
      77            0 :             stream,
      78            0 :             parameters,
      79            0 :             receiver,
      80            0 :             pending_responses,
      81            0 :             responses: VecDeque::new(),
      82            0 :             state: State::Active,
      83            0 :         }
      84            0 :     }
      85              : 
      86            0 :     fn poll_response(
      87            0 :         &mut self,
      88            0 :         cx: &mut Context<'_>,
      89            0 :     ) -> Poll<Option<Result<BackendMessage, Error>>> {
      90            0 :         if let Some(message) = self.pending_responses.pop_front() {
      91            0 :             trace!("retrying pending response");
      92            0 :             return Poll::Ready(Some(Ok(message)));
      93            0 :         }
      94            0 : 
      95            0 :         Pin::new(&mut self.stream)
      96            0 :             .poll_next(cx)
      97            0 :             .map(|o| o.map(|r| r.map_err(Error::io)))
      98            0 :     }
      99              : 
     100              :     /// Read and process messages from the connection to postgres.
     101              :     /// client <- postgres
     102            0 :     fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
     103              :         loop {
     104            0 :             let message = match self.poll_response(cx)? {
     105            0 :                 Poll::Ready(Some(message)) => message,
     106            0 :                 Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
     107              :                 Poll::Pending => {
     108            0 :                     trace!("poll_read: waiting on response");
     109            0 :                     return Poll::Pending;
     110              :                 }
     111              :             };
     112              : 
     113            0 :             let (mut messages, request_complete) = match message {
     114            0 :                 BackendMessage::Async(Message::NoticeResponse(body)) => {
     115            0 :                     let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
     116            0 :                     return Poll::Ready(Ok(AsyncMessage::Notice(error)));
     117              :                 }
     118            0 :                 BackendMessage::Async(Message::NotificationResponse(body)) => {
     119            0 :                     let notification = Notification {
     120            0 :                         process_id: body.process_id(),
     121            0 :                         channel: body.channel().map_err(Error::parse)?.to_string(),
     122            0 :                         payload: body.message().map_err(Error::parse)?.to_string(),
     123            0 :                     };
     124            0 :                     return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
     125              :                 }
     126            0 :                 BackendMessage::Async(Message::ParameterStatus(body)) => {
     127            0 :                     self.parameters.insert(
     128            0 :                         body.name().map_err(Error::parse)?.to_string(),
     129            0 :                         body.value().map_err(Error::parse)?.to_string(),
     130            0 :                     );
     131            0 :                     continue;
     132              :                 }
     133            0 :                 BackendMessage::Async(_) => unreachable!(),
     134              :                 BackendMessage::Normal {
     135            0 :                     messages,
     136            0 :                     request_complete,
     137            0 :                 } => (messages, request_complete),
     138              :             };
     139              : 
     140            0 :             let mut response = match self.responses.pop_front() {
     141            0 :                 Some(response) => response,
     142            0 :                 None => match messages.next().map_err(Error::parse)? {
     143            0 :                     Some(Message::ErrorResponse(error)) => {
     144            0 :                         return Poll::Ready(Err(Error::db(error)));
     145              :                     }
     146            0 :                     _ => return Poll::Ready(Err(Error::unexpected_message())),
     147              :                 },
     148              :             };
     149              : 
     150            0 :             match response.sender.poll_reserve(cx) {
     151              :                 Poll::Ready(Ok(())) => {
     152            0 :                     let _ = response.sender.send_item(messages);
     153            0 :                     if !request_complete {
     154            0 :                         self.responses.push_front(response);
     155            0 :                     }
     156              :                 }
     157              :                 Poll::Ready(Err(_)) => {
     158              :                     // we need to keep paging through the rest of the messages even if the receiver's hung up
     159            0 :                     if !request_complete {
     160            0 :                         self.responses.push_front(response);
     161            0 :                     }
     162              :                 }
     163              :                 Poll::Pending => {
     164            0 :                     self.responses.push_front(response);
     165            0 :                     self.pending_responses.push_back(BackendMessage::Normal {
     166            0 :                         messages,
     167            0 :                         request_complete,
     168            0 :                     });
     169            0 :                     trace!("poll_read: waiting on sender");
     170            0 :                     return Poll::Pending;
     171              :                 }
     172              :             }
     173              :         }
     174            0 :     }
     175              : 
     176              :     /// Fetch the next client request and enqueue the response sender.
     177            0 :     fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
     178            0 :         if self.receiver.is_closed() {
     179            0 :             return Poll::Ready(None);
     180            0 :         }
     181            0 : 
     182            0 :         match self.receiver.poll_recv(cx) {
     183            0 :             Poll::Ready(Some(request)) => {
     184            0 :                 trace!("polled new request");
     185            0 :                 self.responses.push_back(Response {
     186            0 :                     sender: PollSender::new(request.sender),
     187            0 :                 });
     188            0 :                 Poll::Ready(Some(request.messages))
     189              :             }
     190            0 :             Poll::Ready(None) => Poll::Ready(None),
     191            0 :             Poll::Pending => Poll::Pending,
     192              :         }
     193            0 :     }
     194              : 
     195              :     /// Process client requests and write them to the postgres connection, flushing if necessary.
     196              :     /// client -> postgres
     197            0 :     fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
     198              :         loop {
     199            0 :             if Pin::new(&mut self.stream)
     200            0 :                 .poll_ready(cx)
     201            0 :                 .map_err(Error::io)?
     202            0 :                 .is_pending()
     203              :             {
     204            0 :                 trace!("poll_write: waiting on socket");
     205              : 
     206              :                 // poll_ready is self-flushing.
     207            0 :                 return Poll::Pending;
     208            0 :             }
     209            0 : 
     210            0 :             match self.poll_request(cx) {
     211              :                 // send the message to postgres
     212            0 :                 Poll::Ready(Some(RequestMessages::Single(request))) => {
     213            0 :                     Pin::new(&mut self.stream)
     214            0 :                         .start_send(request)
     215            0 :                         .map_err(Error::io)?;
     216              :                 }
     217              :                 // No more messages from the client, and no more responses to wait for.
     218              :                 // Send a terminate message to postgres
     219            0 :                 Poll::Ready(None) if self.responses.is_empty() => {
     220            0 :                     trace!("poll_write: at eof, terminating");
     221            0 :                     let mut request = BytesMut::new();
     222            0 :                     frontend::terminate(&mut request);
     223            0 :                     let request = FrontendMessage::Raw(request.freeze());
     224            0 : 
     225            0 :                     Pin::new(&mut self.stream)
     226            0 :                         .start_send(request)
     227            0 :                         .map_err(Error::io)?;
     228              : 
     229            0 :                     trace!("poll_write: sent eof, closing");
     230            0 :                     trace!("poll_write: done");
     231            0 :                     return Poll::Ready(Ok(WriteReady::Terminating));
     232              :                 }
     233              :                 // No more messages from the client, but there are still some responses to wait for.
     234              :                 Poll::Ready(None) => {
     235            0 :                     trace!(
     236            0 :                         "poll_write: at eof, pending responses {}",
     237            0 :                         self.responses.len()
     238              :                     );
     239            0 :                     ready!(self.poll_flush(cx))?;
     240            0 :                     return Poll::Ready(Ok(WriteReady::WaitingOnRead));
     241              :                 }
     242              :                 // Still waiting for a message from the client.
     243              :                 Poll::Pending => {
     244            0 :                     trace!("poll_write: waiting on request");
     245            0 :                     ready!(self.poll_flush(cx))?;
     246            0 :                     return Poll::Pending;
     247              :                 }
     248              :             }
     249              :         }
     250            0 :     }
     251              : 
     252            0 :     fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     253            0 :         match Pin::new(&mut self.stream)
     254            0 :             .poll_flush(cx)
     255            0 :             .map_err(Error::io)?
     256              :         {
     257              :             Poll::Ready(()) => {
     258            0 :                 trace!("poll_flush: flushed");
     259            0 :                 Poll::Ready(Ok(()))
     260              :             }
     261              :             Poll::Pending => {
     262            0 :                 trace!("poll_flush: waiting on socket");
     263            0 :                 Poll::Pending
     264              :             }
     265              :         }
     266            0 :     }
     267              : 
     268            0 :     fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     269            0 :         match Pin::new(&mut self.stream)
     270            0 :             .poll_close(cx)
     271            0 :             .map_err(Error::io)?
     272              :         {
     273              :             Poll::Ready(()) => {
     274            0 :                 trace!("poll_shutdown: complete");
     275            0 :                 Poll::Ready(Ok(()))
     276              :             }
     277              :             Poll::Pending => {
     278            0 :                 trace!("poll_shutdown: waiting on socket");
     279            0 :                 Poll::Pending
     280              :             }
     281              :         }
     282            0 :     }
     283              : 
     284              :     /// Returns the value of a runtime parameter for this connection.
     285            0 :     pub fn parameter(&self, name: &str) -> Option<&str> {
     286            0 :         self.parameters.get(name).map(|s| &**s)
     287            0 :     }
     288              : 
     289              :     /// Polls for asynchronous messages from the server.
     290              :     ///
     291              :     /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
     292              :     /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
     293            0 :     pub fn poll_message(
     294            0 :         &mut self,
     295            0 :         cx: &mut Context<'_>,
     296            0 :     ) -> Poll<Option<Result<AsyncMessage, Error>>> {
     297            0 :         if self.state != State::Closing {
     298              :             // if the state is still active, try read from and write to postgres.
     299            0 :             let message = self.poll_read(cx)?;
     300            0 :             let closing = self.poll_write(cx)?;
     301            0 :             if let Poll::Ready(WriteReady::Terminating) = closing {
     302            0 :                 self.state = State::Closing;
     303            0 :             }
     304              : 
     305            0 :             if let Poll::Ready(message) = message {
     306            0 :                 return Poll::Ready(Some(Ok(message)));
     307            0 :             }
     308            0 : 
     309            0 :             // poll_read returned Pending.
     310            0 :             // poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
     311            0 :             // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
     312            0 :             if self.state != State::Closing {
     313            0 :                 return Poll::Pending;
     314            0 :             }
     315            0 :         }
     316              : 
     317            0 :         match self.poll_shutdown(cx) {
     318            0 :             Poll::Ready(Ok(())) => Poll::Ready(None),
     319            0 :             Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
     320            0 :             Poll::Pending => Poll::Pending,
     321              :         }
     322            0 :     }
     323              : }
     324              : 
     325              : impl<S, T> Future for Connection<S, T>
     326              : where
     327              :     S: AsyncRead + AsyncWrite + Unpin,
     328              :     T: AsyncRead + AsyncWrite + Unpin,
     329              : {
     330              :     type Output = Result<(), Error>;
     331              : 
     332            0 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     333            0 :         while let Some(message) = ready!(self.poll_message(cx)?) {
     334            0 :             if let AsyncMessage::Notice(notice) = message {
     335            0 :                 info!("{}: {}", notice.severity(), notice.message());
     336            0 :             }
     337              :         }
     338            0 :         Poll::Ready(Ok(()))
     339            0 :     }
     340              : }
        

Generated by: LCOV version 2.1-beta