LCOV - code coverage report
Current view: top level - libs/proxy/tokio-postgres2/src - connection.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 0.0 % 181 0
Test Date: 2025-02-20 13:11:02 Functions: 0.0 % 23 0

            Line data    Source code
       1              : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
       2              : use crate::error::DbError;
       3              : use crate::maybe_tls_stream::MaybeTlsStream;
       4              : use crate::{AsyncMessage, Error, Notification};
       5              : use bytes::BytesMut;
       6              : use fallible_iterator::FallibleIterator;
       7              : use futures_util::{ready, Sink, Stream};
       8              : use log::{info, trace};
       9              : use postgres_protocol2::message::backend::Message;
      10              : use postgres_protocol2::message::frontend;
      11              : use std::collections::{HashMap, VecDeque};
      12              : use std::future::Future;
      13              : use std::pin::Pin;
      14              : use std::task::{Context, Poll};
      15              : use tokio::io::{AsyncRead, AsyncWrite};
      16              : use tokio::sync::mpsc;
      17              : use tokio_util::codec::Framed;
      18              : use tokio_util::sync::PollSender;
      19              : 
      20              : pub enum RequestMessages {
      21              :     Single(FrontendMessage),
      22              : }
      23              : 
      24              : pub struct Request {
      25              :     pub messages: RequestMessages,
      26              :     pub sender: mpsc::Sender<BackendMessages>,
      27              : }
      28              : 
      29              : pub struct Response {
      30              :     sender: PollSender<BackendMessages>,
      31              : }
      32              : 
      33              : #[derive(PartialEq, Debug)]
      34              : enum State {
      35              :     Active,
      36              :     Closing,
      37              : }
      38              : 
      39              : enum WriteReady {
      40              :     Terminating,
      41              :     WaitingOnRead,
      42              : }
      43              : 
      44              : /// A connection to a PostgreSQL database.
      45              : ///
      46              : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
      47              : /// server, and should generally be spawned off onto an executor to run in the background.
      48              : ///
      49              : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
      50              : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
      51              : #[must_use = "futures do nothing unless polled"]
      52              : pub struct Connection<S, T> {
      53              :     /// HACK: we need this in the Neon Proxy.
      54              :     pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      55              :     /// HACK: we need this in the Neon Proxy to forward params.
      56              :     pub parameters: HashMap<String, String>,
      57              :     receiver: mpsc::UnboundedReceiver<Request>,
      58              :     pending_responses: VecDeque<BackendMessage>,
      59              :     responses: VecDeque<Response>,
      60              :     state: State,
      61              : }
      62              : 
      63              : impl<S, T> Connection<S, T>
      64              : where
      65              :     S: AsyncRead + AsyncWrite + Unpin,
      66              :     T: AsyncRead + AsyncWrite + Unpin,
      67              : {
      68            0 :     pub(crate) fn new(
      69            0 :         stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
      70            0 :         pending_responses: VecDeque<BackendMessage>,
      71            0 :         parameters: HashMap<String, String>,
      72            0 :         receiver: mpsc::UnboundedReceiver<Request>,
      73            0 :     ) -> Connection<S, T> {
      74            0 :         Connection {
      75            0 :             stream,
      76            0 :             parameters,
      77            0 :             receiver,
      78            0 :             pending_responses,
      79            0 :             responses: VecDeque::new(),
      80            0 :             state: State::Active,
      81            0 :         }
      82            0 :     }
      83              : 
      84            0 :     fn poll_response(
      85            0 :         &mut self,
      86            0 :         cx: &mut Context<'_>,
      87            0 :     ) -> Poll<Option<Result<BackendMessage, Error>>> {
      88            0 :         if let Some(message) = self.pending_responses.pop_front() {
      89            0 :             trace!("retrying pending response");
      90            0 :             return Poll::Ready(Some(Ok(message)));
      91            0 :         }
      92            0 : 
      93            0 :         Pin::new(&mut self.stream)
      94            0 :             .poll_next(cx)
      95            0 :             .map(|o| o.map(|r| r.map_err(Error::io)))
      96            0 :     }
      97              : 
      98              :     /// Read and process messages from the connection to postgres.
      99              :     /// client <- postgres
     100            0 :     fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
     101              :         loop {
     102            0 :             let message = match self.poll_response(cx)? {
     103            0 :                 Poll::Ready(Some(message)) => message,
     104            0 :                 Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
     105              :                 Poll::Pending => {
     106            0 :                     trace!("poll_read: waiting on response");
     107            0 :                     return Poll::Pending;
     108              :                 }
     109              :             };
     110              : 
     111            0 :             let (mut messages, request_complete) = match message {
     112            0 :                 BackendMessage::Async(Message::NoticeResponse(body)) => {
     113            0 :                     let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
     114            0 :                     return Poll::Ready(Ok(AsyncMessage::Notice(error)));
     115              :                 }
     116            0 :                 BackendMessage::Async(Message::NotificationResponse(body)) => {
     117            0 :                     let notification = Notification {
     118            0 :                         process_id: body.process_id(),
     119            0 :                         channel: body.channel().map_err(Error::parse)?.to_string(),
     120            0 :                         payload: body.message().map_err(Error::parse)?.to_string(),
     121            0 :                     };
     122            0 :                     return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
     123              :                 }
     124            0 :                 BackendMessage::Async(Message::ParameterStatus(body)) => {
     125            0 :                     self.parameters.insert(
     126            0 :                         body.name().map_err(Error::parse)?.to_string(),
     127            0 :                         body.value().map_err(Error::parse)?.to_string(),
     128            0 :                     );
     129            0 :                     continue;
     130              :                 }
     131            0 :                 BackendMessage::Async(_) => unreachable!(),
     132              :                 BackendMessage::Normal {
     133            0 :                     messages,
     134            0 :                     request_complete,
     135            0 :                 } => (messages, request_complete),
     136              :             };
     137              : 
     138            0 :             let mut response = match self.responses.pop_front() {
     139            0 :                 Some(response) => response,
     140            0 :                 None => match messages.next().map_err(Error::parse)? {
     141            0 :                     Some(Message::ErrorResponse(error)) => {
     142            0 :                         return Poll::Ready(Err(Error::db(error)))
     143              :                     }
     144            0 :                     _ => return Poll::Ready(Err(Error::unexpected_message())),
     145              :                 },
     146              :             };
     147              : 
     148            0 :             match response.sender.poll_reserve(cx) {
     149              :                 Poll::Ready(Ok(())) => {
     150            0 :                     let _ = response.sender.send_item(messages);
     151            0 :                     if !request_complete {
     152            0 :                         self.responses.push_front(response);
     153            0 :                     }
     154              :                 }
     155              :                 Poll::Ready(Err(_)) => {
     156              :                     // we need to keep paging through the rest of the messages even if the receiver's hung up
     157            0 :                     if !request_complete {
     158            0 :                         self.responses.push_front(response);
     159            0 :                     }
     160              :                 }
     161              :                 Poll::Pending => {
     162            0 :                     self.responses.push_front(response);
     163            0 :                     self.pending_responses.push_back(BackendMessage::Normal {
     164            0 :                         messages,
     165            0 :                         request_complete,
     166            0 :                     });
     167            0 :                     trace!("poll_read: waiting on sender");
     168            0 :                     return Poll::Pending;
     169              :                 }
     170              :             }
     171              :         }
     172            0 :     }
     173              : 
     174              :     /// Fetch the next client request and enqueue the response sender.
     175            0 :     fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
     176            0 :         if self.receiver.is_closed() {
     177            0 :             return Poll::Ready(None);
     178            0 :         }
     179            0 : 
     180            0 :         match self.receiver.poll_recv(cx) {
     181            0 :             Poll::Ready(Some(request)) => {
     182            0 :                 trace!("polled new request");
     183            0 :                 self.responses.push_back(Response {
     184            0 :                     sender: PollSender::new(request.sender),
     185            0 :                 });
     186            0 :                 Poll::Ready(Some(request.messages))
     187              :             }
     188            0 :             Poll::Ready(None) => Poll::Ready(None),
     189            0 :             Poll::Pending => Poll::Pending,
     190              :         }
     191            0 :     }
     192              : 
     193              :     /// Process client requests and write them to the postgres connection, flushing if necessary.
     194              :     /// client -> postgres
     195            0 :     fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
     196              :         loop {
     197            0 :             if Pin::new(&mut self.stream)
     198            0 :                 .poll_ready(cx)
     199            0 :                 .map_err(Error::io)?
     200            0 :                 .is_pending()
     201              :             {
     202            0 :                 trace!("poll_write: waiting on socket");
     203              : 
     204              :                 // poll_ready is self-flushing.
     205            0 :                 return Poll::Pending;
     206            0 :             }
     207            0 : 
     208            0 :             match self.poll_request(cx) {
     209              :                 // send the message to postgres
     210            0 :                 Poll::Ready(Some(RequestMessages::Single(request))) => {
     211            0 :                     Pin::new(&mut self.stream)
     212            0 :                         .start_send(request)
     213            0 :                         .map_err(Error::io)?;
     214              :                 }
     215              :                 // No more messages from the client, and no more responses to wait for.
     216              :                 // Send a terminate message to postgres
     217            0 :                 Poll::Ready(None) if self.responses.is_empty() => {
     218            0 :                     trace!("poll_write: at eof, terminating");
     219            0 :                     let mut request = BytesMut::new();
     220            0 :                     frontend::terminate(&mut request);
     221            0 :                     let request = FrontendMessage::Raw(request.freeze());
     222            0 : 
     223            0 :                     Pin::new(&mut self.stream)
     224            0 :                         .start_send(request)
     225            0 :                         .map_err(Error::io)?;
     226              : 
     227            0 :                     trace!("poll_write: sent eof, closing");
     228            0 :                     trace!("poll_write: done");
     229            0 :                     return Poll::Ready(Ok(WriteReady::Terminating));
     230              :                 }
     231              :                 // No more messages from the client, but there are still some responses to wait for.
     232              :                 Poll::Ready(None) => {
     233            0 :                     trace!(
     234            0 :                         "poll_write: at eof, pending responses {}",
     235            0 :                         self.responses.len()
     236              :                     );
     237            0 :                     ready!(self.poll_flush(cx))?;
     238            0 :                     return Poll::Ready(Ok(WriteReady::WaitingOnRead));
     239              :                 }
     240              :                 // Still waiting for a message from the client.
     241              :                 Poll::Pending => {
     242            0 :                     trace!("poll_write: waiting on request");
     243            0 :                     ready!(self.poll_flush(cx))?;
     244            0 :                     return Poll::Pending;
     245              :                 }
     246              :             }
     247              :         }
     248            0 :     }
     249              : 
     250            0 :     fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     251            0 :         match Pin::new(&mut self.stream)
     252            0 :             .poll_flush(cx)
     253            0 :             .map_err(Error::io)?
     254              :         {
     255              :             Poll::Ready(()) => {
     256            0 :                 trace!("poll_flush: flushed");
     257            0 :                 Poll::Ready(Ok(()))
     258              :             }
     259              :             Poll::Pending => {
     260            0 :                 trace!("poll_flush: waiting on socket");
     261            0 :                 Poll::Pending
     262              :             }
     263              :         }
     264            0 :     }
     265              : 
     266            0 :     fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     267            0 :         match Pin::new(&mut self.stream)
     268            0 :             .poll_close(cx)
     269            0 :             .map_err(Error::io)?
     270              :         {
     271              :             Poll::Ready(()) => {
     272            0 :                 trace!("poll_shutdown: complete");
     273            0 :                 Poll::Ready(Ok(()))
     274              :             }
     275              :             Poll::Pending => {
     276            0 :                 trace!("poll_shutdown: waiting on socket");
     277            0 :                 Poll::Pending
     278              :             }
     279              :         }
     280            0 :     }
     281              : 
     282              :     /// Returns the value of a runtime parameter for this connection.
     283            0 :     pub fn parameter(&self, name: &str) -> Option<&str> {
     284            0 :         self.parameters.get(name).map(|s| &**s)
     285            0 :     }
     286              : 
     287              :     /// Polls for asynchronous messages from the server.
     288              :     ///
     289              :     /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
     290              :     /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
     291            0 :     pub fn poll_message(
     292            0 :         &mut self,
     293            0 :         cx: &mut Context<'_>,
     294            0 :     ) -> Poll<Option<Result<AsyncMessage, Error>>> {
     295            0 :         if self.state != State::Closing {
     296              :             // if the state is still active, try read from and write to postgres.
     297            0 :             let message = self.poll_read(cx)?;
     298            0 :             let closing = self.poll_write(cx)?;
     299            0 :             if let Poll::Ready(WriteReady::Terminating) = closing {
     300            0 :                 self.state = State::Closing;
     301            0 :             }
     302              : 
     303            0 :             if let Poll::Ready(message) = message {
     304            0 :                 return Poll::Ready(Some(Ok(message)));
     305            0 :             }
     306            0 : 
     307            0 :             // poll_read returned Pending.
     308            0 :             // poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
     309            0 :             // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
     310            0 :             if self.state != State::Closing {
     311            0 :                 return Poll::Pending;
     312            0 :             }
     313            0 :         }
     314              : 
     315            0 :         match self.poll_shutdown(cx) {
     316            0 :             Poll::Ready(Ok(())) => Poll::Ready(None),
     317            0 :             Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
     318            0 :             Poll::Pending => Poll::Pending,
     319              :         }
     320            0 :     }
     321              : }
     322              : 
     323              : impl<S, T> Future for Connection<S, T>
     324              : where
     325              :     S: AsyncRead + AsyncWrite + Unpin,
     326              :     T: AsyncRead + AsyncWrite + Unpin,
     327              : {
     328              :     type Output = Result<(), Error>;
     329              : 
     330            0 :     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
     331            0 :         while let Some(message) = ready!(self.poll_message(cx)?) {
     332            0 :             if let AsyncMessage::Notice(notice) = message {
     333            0 :                 info!("{}: {}", notice.severity(), notice.message());
     334            0 :             }
     335              :         }
     336            0 :         Poll::Ready(Ok(()))
     337            0 :     }
     338              : }
        

Generated by: LCOV version 2.1-beta