LCOV - differential code coverage report
Current view: top level - libs/postgres_backend/src - lib.rs (source / functions) Coverage Total Hit LBC UBC GBC CBC
Current: cd44433dd675caa99df17a61b18949c8387e2242.info Lines: 81.9 % 568 465 4 99 2 463
Current Date: 2024-01-09 02:06:09 Functions: 48.7 % 357 174 183 174
Baseline: 66c52a629a0f4a503e193045e0df4c77139e344b.info
Baseline Date: 2024-01-08 15:34:46

           TLA  Line data    Source code
       1                 : //! Server-side asynchronous Postgres connection, as limited as we need.
       2                 : //! To use, create PostgresBackend and run() it, passing the Handler
       3                 : //! implementation determining how to process the queries. Currently its API
       4                 : //! is rather narrow, but we can extend it once required.
       5                 : #![deny(unsafe_code)]
       6                 : #![deny(clippy::undocumented_unsafe_blocks)]
       7                 : use anyhow::Context;
       8                 : use bytes::Bytes;
       9                 : use futures::pin_mut;
      10                 : use serde::{Deserialize, Serialize};
      11                 : use std::io::ErrorKind;
      12                 : use std::net::SocketAddr;
      13                 : use std::pin::Pin;
      14                 : use std::sync::Arc;
      15                 : use std::task::{ready, Poll};
      16                 : use std::{fmt, io};
      17                 : use std::{future::Future, str::FromStr};
      18                 : use tokio::io::{AsyncRead, AsyncWrite};
      19                 : use tokio_rustls::TlsAcceptor;
      20                 : use tracing::{debug, error, info, trace, warn};
      21                 : 
      22                 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
      23                 : use pq_proto::{
      24                 :     BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
      25                 :     SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
      26                 : };
      27                 : 
      28                 : /// An error, occurred during query processing:
      29                 : /// either during the connection ([`ConnectionError`]) or before/after it.
      30 CBC          84 : #[derive(thiserror::Error, Debug)]
      31                 : pub enum QueryError {
      32                 :     /// The connection was lost while processing the query.
      33                 :     #[error(transparent)]
      34                 :     Disconnected(#[from] ConnectionError),
      35                 :     /// We were instructed to shutdown while processing the query
      36                 :     #[error("Shutting down")]
      37                 :     Shutdown,
      38                 :     /// Query handler indicated that client should reconnect
      39                 :     #[error("Server requested reconnect")]
      40                 :     Reconnect,
      41                 :     /// Query named an entity that was not found
      42                 :     #[error("Not found: {0}")]
      43                 :     NotFound(std::borrow::Cow<'static, str>),
      44                 :     /// Authentication failure
      45                 :     #[error("Unauthorized: {0}")]
      46                 :     Unauthorized(std::borrow::Cow<'static, str>),
      47                 :     #[error("Simulated Connection Error")]
      48                 :     SimulatedConnectionError,
      49                 :     /// Some other error
      50                 :     #[error(transparent)]
      51                 :     Other(#[from] anyhow::Error),
      52                 : }
      53                 : 
      54                 : impl From<io::Error> for QueryError {
      55             322 :     fn from(e: io::Error) -> Self {
      56             322 :         Self::Disconnected(ConnectionError::Io(e))
      57             322 :     }
      58                 : }
      59                 : 
      60                 : impl QueryError {
      61             567 :     pub fn pg_error_code(&self) -> &'static [u8; 5] {
      62             567 :         match self {
      63              38 :             Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
      64 UBC           0 :             Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
      65 CBC           5 :             Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
      66             524 :             Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
      67                 :         }
      68             567 :     }
      69                 : }
      70                 : 
      71                 : /// Returns true if the given error is a normal consequence of a network issue,
      72                 : /// or the client closing the connection. These errors can happen during normal
      73                 : /// operations, and don't indicate a bug in our code.
      74            1091 : pub fn is_expected_io_error(e: &io::Error) -> bool {
      75                 :     use io::ErrorKind::*;
      76 UBC           0 :     matches!(
      77 CBC        1091 :         e.kind(),
      78                 :         BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
      79                 :     )
      80            1091 : }
      81                 : 
      82                 : #[async_trait::async_trait]
      83                 : pub trait Handler<IO> {
      84                 :     /// Handle single query.
      85                 :     /// postgres_backend will issue ReadyForQuery after calling this (this
      86                 :     /// might be not what we want after CopyData streaming, but currently we don't
      87                 :     /// care). It will also flush out the output buffer.
      88                 :     async fn process_query(
      89                 :         &mut self,
      90                 :         pgb: &mut PostgresBackend<IO>,
      91                 :         query_string: &str,
      92                 :     ) -> Result<(), QueryError>;
      93                 : 
      94                 :     /// Called on startup packet receival, allows to process params.
      95                 :     ///
      96                 :     /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
      97                 :     /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
      98                 :     /// to override whole init logic in implementations.
      99               5 :     fn startup(
     100               5 :         &mut self,
     101               5 :         _pgb: &mut PostgresBackend<IO>,
     102               5 :         _sm: &FeStartupPacket,
     103               5 :     ) -> Result<(), QueryError> {
     104               5 :         Ok(())
     105               5 :     }
     106                 : 
     107                 :     /// Check auth jwt
     108 UBC           0 :     fn check_auth_jwt(
     109               0 :         &mut self,
     110               0 :         _pgb: &mut PostgresBackend<IO>,
     111               0 :         _jwt_response: &[u8],
     112               0 :     ) -> Result<(), QueryError> {
     113               0 :         Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
     114               0 :     }
     115                 : }
     116                 : 
     117                 : /// PostgresBackend protocol state.
     118                 : /// XXX: The order of the constructors matters.
     119 CBC       52265 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
     120                 : pub enum ProtoState {
     121                 :     /// Nothing happened yet.
     122                 :     Initialization,
     123                 :     /// Encryption handshake is done; waiting for encrypted Startup message.
     124                 :     Encrypted,
     125                 :     /// Waiting for password (auth token).
     126                 :     Authentication,
     127                 :     /// Performed handshake and auth, ReadyForQuery is issued.
     128                 :     Established,
     129                 :     Closed,
     130                 : }
     131                 : 
     132 UBC           0 : #[derive(Clone, Copy)]
     133                 : pub enum ProcessMsgResult {
     134                 :     Continue,
     135                 :     Break,
     136                 : }
     137                 : 
     138                 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
     139                 : pub enum MaybeTlsStream<IO> {
     140                 :     Unencrypted(IO),
     141                 :     Tls(Box<tokio_rustls::server::TlsStream<IO>>),
     142                 : }
     143                 : 
     144                 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
     145 CBC     6493844 :     fn poll_write(
     146         6493844 :         self: Pin<&mut Self>,
     147         6493844 :         cx: &mut std::task::Context<'_>,
     148         6493844 :         buf: &[u8],
     149         6493844 :     ) -> Poll<io::Result<usize>> {
     150         6493844 :         match self.get_mut() {
     151         6493842 :             Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
     152               2 :             Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
     153                 :         }
     154         6493844 :     }
     155         6508846 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
     156         6508846 :         match self.get_mut() {
     157         6508844 :             Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
     158               2 :             Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
     159                 :         }
     160         6508846 :     }
     161           10471 :     fn poll_shutdown(
     162           10471 :         self: Pin<&mut Self>,
     163           10471 :         cx: &mut std::task::Context<'_>,
     164           10471 :     ) -> Poll<io::Result<()>> {
     165           10471 :         match self.get_mut() {
     166           10471 :             Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
     167 UBC           0 :             Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
     168                 :         }
     169 CBC       10471 :     }
     170                 : }
     171                 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
     172        13787171 :     fn poll_read(
     173        13787171 :         self: Pin<&mut Self>,
     174        13787171 :         cx: &mut std::task::Context<'_>,
     175        13787171 :         buf: &mut tokio::io::ReadBuf<'_>,
     176        13787171 :     ) -> Poll<io::Result<()>> {
     177        13787171 :         match self.get_mut() {
     178        13787166 :             Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
     179               5 :             Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
     180                 :         }
     181        13787171 :     }
     182                 : }
     183                 : 
     184           59558 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
     185                 : pub enum AuthType {
     186                 :     Trust,
     187                 :     // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
     188                 :     NeonJWT,
     189                 : }
     190                 : 
     191                 : impl FromStr for AuthType {
     192                 :     type Err = anyhow::Error;
     193                 : 
     194            1812 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     195            1812 :         match s {
     196            1812 :             "Trust" => Ok(Self::Trust),
     197              44 :             "NeonJWT" => Ok(Self::NeonJWT),
     198 UBC           0 :             _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
     199                 :         }
     200 CBC        1812 :     }
     201                 : }
     202                 : 
     203                 : impl fmt::Display for AuthType {
     204            1812 :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     205            1812 :         f.write_str(match self {
     206            1768 :             AuthType::Trust => "Trust",
     207              44 :             AuthType::NeonJWT => "NeonJWT",
     208                 :         })
     209            1812 :     }
     210                 : }
     211                 : 
     212                 : /// Either full duplex Framed or write only half; the latter is left in
     213                 : /// PostgresBackend after call to `split`. In principle we could always store a
     214                 : /// pair of splitted handles, but that would force to to pay splitting price
     215                 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
     216                 : enum MaybeWriteOnly<IO> {
     217                 :     Full(Framed<MaybeTlsStream<IO>>),
     218                 :     WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
     219                 :     Broken, // temporary value palmed off during the split
     220                 : }
     221                 : 
     222                 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
     223           20528 :     async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
     224           20528 :         match self {
     225           20528 :             MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
     226                 :             MaybeWriteOnly::WriteOnly(_) => {
     227 UBC           0 :                 Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
     228                 :             }
     229               0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     230                 :         }
     231 CBC       20528 :     }
     232                 : 
     233         3758127 :     async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     234         3758127 :         match self {
     235         3758127 :             MaybeWriteOnly::Full(framed) => framed.read_message().await,
     236                 :             MaybeWriteOnly::WriteOnly(_) => {
     237 UBC           0 :                 Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
     238                 :             }
     239               0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     240                 :         }
     241 CBC     3757922 :     }
     242                 : 
     243         6507626 :     fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     244         6507626 :         match self {
     245         4638125 :             MaybeWriteOnly::Full(framed) => framed.write_message(msg),
     246         1869501 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
     247 UBC           0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     248                 :         }
     249 CBC     6507626 :     }
     250                 : 
     251         6506561 :     async fn flush(&mut self) -> io::Result<()> {
     252         6506561 :         match self {
     253         4637060 :             MaybeWriteOnly::Full(framed) => framed.flush().await,
     254         1869501 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
     255 UBC           0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     256                 :         }
     257 CBC     6498717 :     }
     258                 : 
     259                 :     /// Cancellation safe as long as the underlying IO is cancellation safe.
     260           10764 :     async fn shutdown(&mut self) -> io::Result<()> {
     261           10764 :         match self {
     262           10764 :             MaybeWriteOnly::Full(framed) => framed.shutdown().await,
     263 UBC           0 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
     264               0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     265                 :         }
     266 CBC       10764 :     }
     267                 : }
     268                 : 
     269                 : pub struct PostgresBackend<IO> {
     270                 :     framed: MaybeWriteOnly<IO>,
     271                 : 
     272                 :     pub state: ProtoState,
     273                 : 
     274                 :     auth_type: AuthType,
     275                 : 
     276                 :     peer_addr: SocketAddr,
     277                 :     pub tls_config: Option<Arc<rustls::ServerConfig>>,
     278                 : }
     279                 : 
     280                 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
     281                 : 
     282 UBC           0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
     283               0 :     let mut query_string = query_string.to_vec();
     284               0 :     if let Some(ch) = query_string.last() {
     285               0 :         if *ch == 0 {
     286               0 :             query_string.pop();
     287               0 :         }
     288               0 :     }
     289               0 :     query_string
     290               0 : }
     291                 : 
     292                 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
     293 CBC       11945 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
     294           11945 :     let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
     295           11945 :     std::str::from_utf8(without_null).map_err(|e| e.into())
     296           11945 : }
     297                 : 
     298                 : impl PostgresBackend<tokio::net::TcpStream> {
     299               5 :     pub fn new(
     300               5 :         socket: tokio::net::TcpStream,
     301               5 :         auth_type: AuthType,
     302               5 :         tls_config: Option<Arc<rustls::ServerConfig>>,
     303               5 :     ) -> io::Result<Self> {
     304               5 :         let peer_addr = socket.peer_addr()?;
     305               5 :         let stream = MaybeTlsStream::Unencrypted(socket);
     306               5 : 
     307               5 :         Ok(Self {
     308               5 :             framed: MaybeWriteOnly::Full(Framed::new(stream)),
     309               5 :             state: ProtoState::Initialization,
     310               5 :             auth_type,
     311               5 :             tls_config,
     312               5 :             peer_addr,
     313               5 :         })
     314               5 :     }
     315                 : }
     316                 : 
     317                 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
     318           11204 :     pub fn new_from_io(
     319           11204 :         socket: IO,
     320           11204 :         peer_addr: SocketAddr,
     321           11204 :         auth_type: AuthType,
     322           11204 :         tls_config: Option<Arc<rustls::ServerConfig>>,
     323           11204 :     ) -> io::Result<Self> {
     324           11204 :         let stream = MaybeTlsStream::Unencrypted(socket);
     325           11204 : 
     326           11204 :         Ok(Self {
     327           11204 :             framed: MaybeWriteOnly::Full(Framed::new(stream)),
     328           11204 :             state: ProtoState::Initialization,
     329           11204 :             auth_type,
     330           11204 :             tls_config,
     331           11204 :             peer_addr,
     332           11204 :         })
     333           11204 :     }
     334                 : 
     335            2479 :     pub fn get_peer_addr(&self) -> &SocketAddr {
     336            2479 :         &self.peer_addr
     337            2479 :     }
     338                 : 
     339                 :     /// Read full message or return None if connection is cleanly closed with no
     340                 :     /// unprocessed data.
     341         3757934 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     342         3757933 :         if let ProtoState::Closed = self.state {
     343              21 :             Ok(None)
     344                 :         } else {
     345         3757912 :             match self.framed.read_message().await {
     346         3757676 :                 Ok(m) => {
     347 UBC           0 :                     trace!("read msg {:?}", m);
     348 CBC     3757676 :                     Ok(m)
     349                 :                 }
     350              31 :                 Err(e) => {
     351              31 :                     // remember not to try to read anymore
     352              31 :                     self.state = ProtoState::Closed;
     353              31 :                     Err(e)
     354                 :                 }
     355                 :             }
     356                 :         }
     357         3757728 :     }
     358                 : 
     359                 :     /// Write message into internal output buffer, doesn't flush it. Technically
     360                 :     /// error type can be only ProtocolError here (if, unlikely, serialization
     361                 :     /// fails), but callers typically wrap it anyway.
     362         6507626 :     pub fn write_message_noflush(
     363         6507626 :         &mut self,
     364         6507626 :         message: &BeMessage<'_>,
     365         6507626 :     ) -> Result<&mut Self, ConnectionError> {
     366         6507626 :         self.framed.write_message_noflush(message)?;
     367         6507626 :         trace!("wrote msg {:?}", message);
     368         6507626 :         Ok(self)
     369         6507626 :     }
     370                 : 
     371                 :     /// Flush output buffer into the socket.
     372         6506561 :     pub async fn flush(&mut self) -> io::Result<()> {
     373         6506561 :         self.framed.flush().await
     374         6498717 :     }
     375                 : 
     376                 :     /// Polling version of `flush()`, saves the caller need to pin.
     377          947031 :     pub fn poll_flush(
     378          947031 :         &mut self,
     379          947031 :         cx: &mut std::task::Context<'_>,
     380          947031 :     ) -> Poll<Result<(), std::io::Error>> {
     381          947031 :         let flush_fut = self.flush();
     382          947031 :         pin_mut!(flush_fut);
     383          947031 :         flush_fut.poll(cx)
     384          947031 :     }
     385                 : 
     386                 :     /// Write message into internal output buffer and flush it to the stream.
     387         1892785 :     pub async fn write_message(
     388         1892785 :         &mut self,
     389         1892785 :         message: &BeMessage<'_>,
     390         1892785 :     ) -> Result<&mut Self, ConnectionError> {
     391         1892785 :         self.write_message_noflush(message)?;
     392         1892785 :         self.flush().await?;
     393         1892743 :         Ok(self)
     394         1892763 :     }
     395                 : 
     396                 :     /// Returns an AsyncWrite implementation that wraps all the data written
     397                 :     /// to it in CopyData messages, and writes them to the connection
     398                 :     ///
     399                 :     /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
     400             559 :     pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
     401             559 :         CopyDataWriter { pgb: self }
     402             559 :     }
     403                 : 
     404                 :     /// Wrapper for run_message_loop() that shuts down socket when we are done
     405           11209 :     pub async fn run<F, S>(
     406           11209 :         mut self,
     407           11209 :         handler: &mut impl Handler<IO>,
     408           11209 :         shutdown_watcher: F,
     409           11209 :     ) -> Result<(), QueryError>
     410           11209 :     where
     411           11209 :         F: Fn() -> S + Clone,
     412           11209 :         S: Future,
     413           11209 :     {
     414           11209 :         let ret = self
     415           11209 :             .run_message_loop(handler, shutdown_watcher.clone())
     416         9905962 :             .await;
     417                 : 
     418           10764 :         tokio::select! {
     419           10765 :             _ = shutdown_watcher() => {
     420           10765 :                 // do nothing; we most likely got already stopped by shutdown and will log it next.
     421           10765 :             }
     422           10765 :             _ = self.framed.shutdown() => {
     423           10765 :                 // socket might be already closed, e.g. if previously received error,
     424           10765 :                 // so ignore result.
     425           10765 :             },
     426           10765 :         }
     427                 : 
     428             378 :         match ret {
     429           10386 :             Ok(()) => Ok(()),
     430                 :             Err(QueryError::Shutdown) => {
     431 UBC           0 :                 info!("Stopped due to shutdown");
     432               0 :                 Ok(())
     433                 :             }
     434                 :             Err(QueryError::Reconnect) => {
     435                 :                 // Dropping out of this loop implicitly disconnects
     436               0 :                 info!("Stopped due to handler reconnect request");
     437               0 :                 Ok(())
     438                 :             }
     439 CBC         320 :             Err(QueryError::Disconnected(e)) => {
     440             320 :                 info!("Disconnected ({e:#})");
     441                 :                 // Disconnection is not an error: we just use it that way internally to drop
     442                 :                 // out of loops.
     443             320 :                 Ok(())
     444                 :             }
     445              58 :             e => e,
     446                 :         }
     447           10764 :     }
     448                 : 
     449           11209 :     async fn run_message_loop<F, S>(
     450           11209 :         &mut self,
     451           11209 :         handler: &mut impl Handler<IO>,
     452           11209 :         shutdown_watcher: F,
     453           11209 :     ) -> Result<(), QueryError>
     454           11209 :     where
     455           11209 :         F: Fn() -> S,
     456           11209 :         S: Future,
     457           11209 :     {
     458 UBC           0 :         trace!("postgres backend to {:?} started", self.peer_addr);
     459                 : 
     460 CBC       11209 :         tokio::select!(
     461                 :             biased;
     462                 : 
     463                 :             _ = shutdown_watcher() => {
     464                 :                 // We were requested to shut down.
     465 UBC           0 :                 tracing::info!("shutdown request received during handshake");
     466                 :                 return Err(QueryError::Shutdown)
     467                 :             },
     468                 : 
     469 CBC       11209 :             handshake_r = self.handshake(handler) => {
     470                 :                 handshake_r?;
     471                 :             }
     472                 :         );
     473                 : 
     474                 :         // Authentication completed
     475           11203 :         let mut query_string = Bytes::new();
     476           26093 :         while let Some(msg) = tokio::select!(
     477                 :             biased;
     478                 :             _ = shutdown_watcher() => {
     479                 :                 // We were requested to shut down.
     480 UBC           0 :                 tracing::info!("shutdown request received in run_message_loop");
     481                 :                 return Err(QueryError::Shutdown)
     482                 :             },
     483 CBC       26091 :             msg = self.read_message() => { msg },
     484              22 :         )? {
     485 UBC           0 :             trace!("got message {:?}", msg);
     486                 : 
     487 CBC     9872183 :             let result = self.process_message(handler, msg, &mut query_string).await;
     488           17148 :             tokio::select!(
     489                 :                 biased;
     490                 :                 _ = shutdown_watcher() => {
     491                 :                     // We were requested to shut down.
     492 UBC           0 :                     tracing::info!("shutdown request received during response flush");
     493                 : 
     494                 :                     // If we exited process_message with a shutdown error, there may be
     495                 :                     // some valid response content on in our transmit buffer: permit sending
     496                 :                     // this within a short timeout.  This is a best effort thing so we don't
     497                 :                     // care about the result.
     498                 :                     tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
     499                 : 
     500                 :                     return Err(QueryError::Shutdown)
     501                 :                 },
     502 CBC       17148 :                 flush_r = self.flush() => {
     503                 :                     flush_r?;
     504                 :                 }
     505                 :             );
     506                 : 
     507           16855 :             match result? {
     508                 :                 ProcessMsgResult::Continue => {
     509           14890 :                     continue;
     510                 :                 }
     511            1908 :                 ProcessMsgResult::Break => break,
     512                 :             }
     513                 :         }
     514                 : 
     515 UBC           0 :         trace!("postgres backend to {:?} exited", self.peer_addr);
     516 CBC       10386 :         Ok(())
     517           10764 :     }
     518                 : 
     519                 :     /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
     520               1 :     async fn tls_upgrade(
     521               1 :         src: MaybeTlsStream<IO>,
     522               1 :         tls_config: Arc<rustls::ServerConfig>,
     523               1 :     ) -> anyhow::Result<MaybeTlsStream<IO>> {
     524               1 :         match src {
     525               1 :             MaybeTlsStream::Unencrypted(s) => {
     526               1 :                 let acceptor = TlsAcceptor::from(tls_config);
     527               2 :                 let tls_stream = acceptor.accept(s).await?;
     528               1 :                 Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
     529                 :             }
     530                 :             MaybeTlsStream::Tls(_) => {
     531 UBC           0 :                 anyhow::bail!("TLS already started");
     532                 :             }
     533                 :         }
     534 CBC           1 :     }
     535                 : 
     536               1 :     async fn start_tls(&mut self) -> anyhow::Result<()> {
     537               1 :         // temporary replace stream with fake to cook TLS one, Indiana Jones style
     538               1 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     539               1 :             MaybeWriteOnly::Full(framed) => {
     540               1 :                 let tls_config = self
     541               1 :                     .tls_config
     542               1 :                     .as_ref()
     543               1 :                     .context("start_tls called without conf")?
     544               1 :                     .clone();
     545               1 :                 let tls_framed = framed
     546               1 :                     .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
     547               2 :                     .await?;
     548                 :                 // push back ready TLS stream
     549               1 :                 self.framed = MaybeWriteOnly::Full(tls_framed);
     550               1 :                 Ok(())
     551                 :             }
     552                 :             MaybeWriteOnly::WriteOnly(_) => {
     553 UBC           0 :                 anyhow::bail!("TLS upgrade attempt in split state")
     554                 :             }
     555               0 :             MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
     556                 :         }
     557 CBC           1 :     }
     558                 : 
     559                 :     /// Split off owned read part from which messages can be read in different
     560                 :     /// task/thread.
     561            2479 :     pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
     562            2479 :         // temporary replace stream with fake to cook split one, Indiana Jones style
     563            2479 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     564            2479 :             MaybeWriteOnly::Full(framed) => {
     565            2479 :                 let (reader, writer) = framed.split();
     566            2479 :                 self.framed = MaybeWriteOnly::WriteOnly(writer);
     567            2479 :                 Ok(PostgresBackendReader {
     568            2479 :                     reader,
     569            2479 :                     closed: false,
     570            2479 :                 })
     571                 :             }
     572                 :             MaybeWriteOnly::WriteOnly(_) => {
     573 UBC           0 :                 anyhow::bail!("PostgresBackend is already split")
     574                 :             }
     575               0 :             MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
     576                 :         }
     577 CBC        2479 :     }
     578                 : 
     579                 :     /// Join read part back.
     580            2089 :     pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
     581            2089 :         // temporary replace stream with fake to cook joined one, Indiana Jones style
     582            2089 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     583                 :             MaybeWriteOnly::Full(_) => {
     584 UBC           0 :                 anyhow::bail!("PostgresBackend is not split")
     585                 :             }
     586 CBC        2089 :             MaybeWriteOnly::WriteOnly(writer) => {
     587            2089 :                 let joined = Framed::unsplit(reader.reader, writer);
     588            2089 :                 self.framed = MaybeWriteOnly::Full(joined);
     589            2089 :                 // if reader encountered connection error, do not attempt reading anymore
     590            2089 :                 if reader.closed {
     591             215 :                     self.state = ProtoState::Closed;
     592            1874 :                 }
     593            2089 :                 Ok(())
     594                 :             }
     595 UBC           0 :             MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
     596                 :         }
     597 CBC        2089 :     }
     598                 : 
     599                 :     /// Perform handshake with the client, transitioning to Established.
     600                 :     /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
     601           11209 :     async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
     602           31737 :         while self.state < ProtoState::Authentication {
     603           20528 :             match self.framed.read_startup_message().await? {
     604           20528 :                 Some(msg) => {
     605           20528 :                     self.process_startup_message(handler, msg).await?;
     606                 :                 }
     607                 :                 None => {
     608 UBC           0 :                     trace!(
     609               0 :                         "postgres backend to {:?} received EOF during handshake",
     610               0 :                         self.peer_addr
     611               0 :                     );
     612 LBC         (1) :                     self.state = ProtoState::Closed;
     613             (1) :                     return Err(QueryError::Disconnected(ConnectionError::Protocol(
     614             (1) :                         ProtocolError::Protocol("EOF during handshake".to_string()),
     615             (1) :                     )));
     616                 :                 }
     617                 :             }
     618                 :         }
     619                 : 
     620                 :         // Perform auth, if needed.
     621 CBC       11209 :         if self.state == ProtoState::Authentication {
     622             215 :             match self.framed.read_message().await? {
     623             210 :                 Some(FeMessage::PasswordMessage(m)) => {
     624             210 :                     assert!(self.auth_type == AuthType::NeonJWT);
     625                 : 
     626             210 :                     let (_, jwt_response) = m.split_last().context("protocol violation")?;
     627                 : 
     628             210 :                     if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
     629               1 :                         self.write_message_noflush(&BeMessage::ErrorResponse(
     630               1 :                             &short_error(&e),
     631               1 :                             Some(e.pg_error_code()),
     632               1 :                         ))?;
     633               1 :                         return Err(e);
     634             209 :                     }
     635             209 : 
     636             209 :                     self.write_message_noflush(&BeMessage::AuthenticationOk)?
     637             209 :                         .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
     638             209 :                         .write_message(&BeMessage::ReadyForQuery)
     639 UBC           0 :                         .await?;
     640 CBC         209 :                     self.state = ProtoState::Established;
     641                 :                 }
     642 UBC           0 :                 Some(m) => {
     643               0 :                     return Err(QueryError::Other(anyhow::anyhow!(
     644               0 :                         "Unexpected message {:?} while waiting for handshake",
     645               0 :                         m
     646               0 :                     )));
     647                 :                 }
     648                 :                 None => {
     649               0 :                     trace!(
     650               0 :                         "postgres backend to {:?} received EOF during auth",
     651               0 :                         self.peer_addr
     652               0 :                     );
     653 CBC           5 :                     self.state = ProtoState::Closed;
     654               5 :                     return Err(QueryError::Disconnected(ConnectionError::Protocol(
     655               5 :                         ProtocolError::Protocol("EOF during auth".to_string()),
     656               5 :                     )));
     657                 :                 }
     658                 :             }
     659           10994 :         }
     660                 : 
     661           11203 :         Ok(())
     662           11209 :     }
     663                 : 
     664                 :     /// Process startup packet:
     665                 :     /// - transition to Established if auth type is trust
     666                 :     /// - transition to Authentication if auth type is NeonJWT.
     667                 :     /// - or perform TLS handshake -- then need to call this again to receive
     668                 :     ///   actual startup packet.
     669           20528 :     async fn process_startup_message(
     670           20528 :         &mut self,
     671           20528 :         handler: &mut impl Handler<IO>,
     672           20528 :         msg: FeStartupPacket,
     673           20528 :     ) -> Result<(), QueryError> {
     674           20528 :         assert!(self.state < ProtoState::Authentication);
     675           20528 :         let have_tls = self.tls_config.is_some();
     676           20528 :         match msg {
     677                 :             FeStartupPacket::SslRequest => {
     678 UBC           0 :                 debug!("SSL requested");
     679                 : 
     680 CBC        9319 :                 self.write_message(&BeMessage::EncryptionResponse(have_tls))
     681 UBC           0 :                     .await?;
     682                 : 
     683 CBC        9319 :                 if have_tls {
     684               2 :                     self.start_tls().await?;
     685               1 :                     self.state = ProtoState::Encrypted;
     686            9318 :                 }
     687                 :             }
     688                 :             FeStartupPacket::GssEncRequest => {
     689 UBC           0 :                 debug!("GSS requested");
     690               0 :                 self.write_message(&BeMessage::EncryptionResponse(false))
     691               0 :                     .await?;
     692                 :             }
     693                 :             FeStartupPacket::StartupMessage { .. } => {
     694 CBC       11209 :                 if have_tls && !matches!(self.state, ProtoState::Encrypted) {
     695 UBC           0 :                     self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
     696               0 :                         .await?;
     697               0 :                     return Err(QueryError::Other(anyhow::anyhow!(
     698               0 :                         "client did not connect with TLS"
     699               0 :                     )));
     700 CBC       11209 :                 }
     701           11209 : 
     702           11209 :                 // NB: startup() may change self.auth_type -- we are using that in proxy code
     703           11209 :                 // to bypass auth for new users.
     704           11209 :                 handler.startup(self, &msg)?;
     705                 : 
     706           11209 :                 match self.auth_type {
     707                 :                     AuthType::Trust => {
     708           10994 :                         self.write_message_noflush(&BeMessage::AuthenticationOk)?
     709           10994 :                             .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
     710           10994 :                             .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
     711                 :                             // The async python driver requires a valid server_version
     712           10994 :                             .write_message_noflush(&BeMessage::server_version("14.1"))?
     713           10994 :                             .write_message(&BeMessage::ReadyForQuery)
     714 UBC           0 :                             .await?;
     715 CBC       10994 :                         self.state = ProtoState::Established;
     716                 :                     }
     717                 :                     AuthType::NeonJWT => {
     718             215 :                         self.write_message(&BeMessage::AuthenticationCleartextPassword)
     719 UBC           0 :                             .await?;
     720 CBC         215 :                         self.state = ProtoState::Authentication;
     721                 :                     }
     722                 :                 }
     723                 :             }
     724                 :             FeStartupPacket::CancelRequest { .. } => {
     725 UBC           0 :                 return Err(QueryError::Other(anyhow::anyhow!(
     726               0 :                     "Unexpected CancelRequest message during handshake"
     727               0 :                 )));
     728                 :             }
     729                 :         }
     730 CBC       20528 :         Ok(())
     731           20528 :     }
     732                 : 
     733           17591 :     async fn process_message(
     734           17591 :         &mut self,
     735           17591 :         handler: &mut impl Handler<IO>,
     736           17591 :         msg: FeMessage,
     737           17591 :         unnamed_query_string: &mut Bytes,
     738           17591 :     ) -> Result<ProcessMsgResult, QueryError> {
     739           17591 :         // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
     740           17591 :         // TODO: change that to proper top-level match of protocol state with separate message handling for each state
     741           17591 :         assert!(self.state == ProtoState::Established);
     742                 : 
     743           17591 :         match msg {
     744           11393 :             FeMessage::Query(body) => {
     745                 :                 // remove null terminator
     746           11393 :                 let query_string = cstr_to_str(&body)?;
     747                 : 
     748 UBC           0 :                 trace!("got query {query_string:?}");
     749 CBC     9847474 :                 if let Err(e) = handler.process_query(self, query_string).await {
     750             745 :                     match e {
     751             181 :                         QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
     752                 :                         QueryError::SimulatedConnectionError => {
     753 GBC           7 :                             return Err(QueryError::SimulatedConnectionError)
     754                 :                         }
     755 CBC         557 :                         e => {
     756             557 :                             log_query_error(query_string, &e);
     757             557 :                             let short_error = short_error(&e);
     758             557 :                             self.write_message_noflush(&BeMessage::ErrorResponse(
     759             557 :                                 &short_error,
     760             557 :                                 Some(e.pg_error_code()),
     761             557 :                             ))?;
     762                 :                         }
     763                 :                     }
     764           10206 :                 }
     765           10762 :                 self.write_message_noflush(&BeMessage::ReadyForQuery)?;
     766                 :             }
     767                 : 
     768             552 :             FeMessage::Parse(m) => {
     769             552 :                 *unnamed_query_string = m.query_string;
     770             552 :                 self.write_message_noflush(&BeMessage::ParseComplete)?;
     771                 :             }
     772                 : 
     773                 :             FeMessage::Describe(_) => {
     774             552 :                 self.write_message_noflush(&BeMessage::ParameterDescription)?
     775             552 :                     .write_message_noflush(&BeMessage::NoData)?;
     776                 :             }
     777                 : 
     778                 :             FeMessage::Bind(_) => {
     779             552 :                 self.write_message_noflush(&BeMessage::BindComplete)?;
     780                 :             }
     781                 : 
     782                 :             FeMessage::Close(_) => {
     783             551 :                 self.write_message_noflush(&BeMessage::CloseComplete)?;
     784                 :             }
     785                 : 
     786                 :             FeMessage::Execute(_) => {
     787             552 :                 let query_string = cstr_to_str(unnamed_query_string)?;
     788 UBC           0 :                 trace!("got execute {query_string:?}");
     789 CBC       24709 :                 if let Err(e) = handler.process_query(self, query_string).await {
     790               8 :                     log_query_error(query_string, &e);
     791               8 :                     self.write_message_noflush(&BeMessage::ErrorResponse(
     792               8 :                         &e.to_string(),
     793               8 :                         Some(e.pg_error_code()),
     794               8 :                     ))?;
     795             544 :                 }
     796                 :                 // NOTE there is no ReadyForQuery message. This handler is used
     797                 :                 // for basebackup and it uses CopyOut which doesn't require
     798                 :                 // ReadyForQuery message and backend just switches back to
     799                 :                 // processing mode after sending CopyDone or ErrorResponse.
     800                 :             }
     801                 : 
     802                 :             FeMessage::Sync => {
     803            1662 :                 self.write_message_noflush(&BeMessage::ReadyForQuery)?;
     804                 :             }
     805                 : 
     806                 :             FeMessage::Terminate => {
     807            1727 :                 return Ok(ProcessMsgResult::Break);
     808                 :             }
     809                 : 
     810                 :             // We prefer explicit pattern matching to wildcards, because
     811                 :             // this helps us spot the places where new variants are missing
     812                 :             FeMessage::CopyData(_)
     813                 :             | FeMessage::CopyDone
     814                 :             | FeMessage::CopyFail
     815                 :             | FeMessage::PasswordMessage(_) => {
     816              50 :                 return Err(QueryError::Other(anyhow::anyhow!(
     817              50 :                     "unexpected message type: {msg:?}",
     818              50 :                 )));
     819                 :             }
     820                 :         }
     821                 : 
     822           15183 :         Ok(ProcessMsgResult::Continue)
     823           17148 :     }
     824                 : 
     825                 :     /// Log as info/error result of handling COPY stream and send back
     826                 :     /// ErrorResponse if that makes sense. Shutdown the stream if we got
     827                 :     /// Terminate. TODO: transition into waiting for Sync msg if we initiate the
     828                 :     /// close.
     829            2089 :     pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
     830                 :         use CopyStreamHandlerEnd::*;
     831                 : 
     832            2089 :         let expected_end = match &end {
     833            1854 :             ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
     834             234 :             CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
     835             234 :                 if is_expected_io_error(io_error) =>
     836             234 :             {
     837             234 :                 true
     838                 :             }
     839               1 :             _ => false,
     840                 :         };
     841            2089 :         if expected_end {
     842            2088 :             info!("terminated: {:#}", end);
     843                 :         } else {
     844               1 :             error!("terminated: {:?}", end);
     845                 :         }
     846                 : 
     847                 :         // Note: no current usages ever send this
     848            2089 :         if let CopyDone = &end {
     849 UBC           0 :             if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
     850               0 :                 error!("failed to send CopyDone: {}", e);
     851               0 :             }
     852 CBC        2089 :         }
     853                 : 
     854            2089 :         if let Terminate = &end {
     855              19 :             self.state = ProtoState::Closed;
     856            2070 :         }
     857                 : 
     858            2089 :         let err_to_send_and_errcode = match &end {
     859              48 :             ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
     860               1 :             Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
     861                 :             // Note: CopyFail in duplex copy is somewhat unexpected (at least to
     862                 :             // PG walsender; evidently and per my docs reading client should
     863                 :             // finish it with CopyDone). It is not a problem to recover from it
     864                 :             // finishing the stream in both directions like we do, but note that
     865                 :             // sync rust-postgres client (which we don't use anymore) hangs if
     866                 :             // socket is not closed here.
     867                 :             // https://github.com/sfackler/rust-postgres/issues/755
     868                 :             // https://github.com/neondatabase/neon/issues/935
     869                 :             //
     870                 :             // Currently, the version of tokio_postgres replication patch we use
     871                 :             // sends this when it closes the stream (e.g. pageserver decided to
     872                 :             // switch conn to another safekeeper and client gets dropped).
     873                 :             // Moreover, seems like 'connection' task errors with 'unexpected
     874                 :             // message from server' when it receives ErrorResponse (anything but
     875                 :             // CopyData/CopyDone) back.
     876              19 :             CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
     877            2021 :             _ => None,
     878                 :         };
     879            2089 :         if let Some((err, errcode)) = err_to_send_and_errcode {
     880              68 :             if let Err(ee) = self
     881              68 :                 .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
     882 GBC           1 :                 .await
     883                 :             {
     884 CBC           1 :                 error!("failed to send ErrorResponse: {}", ee);
     885              67 :             }
     886            2021 :         }
     887            2089 :     }
     888                 : }
     889                 : 
     890                 : pub struct PostgresBackendReader<IO> {
     891                 :     reader: FramedReader<MaybeTlsStream<IO>>,
     892                 :     closed: bool, // true if received error closing the connection
     893                 : }
     894                 : 
     895                 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
     896                 :     /// Read full message or return None if connection is cleanly closed with no
     897                 :     /// unprocessed data.
     898         2549745 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     899         4817654 :         match self.reader.read_message().await {
     900         2549072 :             Ok(m) => {
     901 UBC           0 :                 trace!("read msg {:?}", m);
     902 CBC     2549072 :                 Ok(m)
     903                 :             }
     904             215 :             Err(e) => {
     905             215 :                 self.closed = true;
     906             215 :                 Err(e)
     907                 :             }
     908                 :         }
     909         2549287 :     }
     910                 : 
     911                 :     /// Get CopyData contents of the next message in COPY stream or error
     912                 :     /// closing it. The error type is wider than actual errors which can happen
     913                 :     /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
     914                 :     /// current callers.
     915         2549745 :     pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
     916         4817654 :         match self.read_message().await? {
     917         2547304 :             Some(msg) => match msg {
     918         2547266 :                 FeMessage::CopyData(m) => Ok(m),
     919 UBC           0 :                 FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
     920 CBC          19 :                 FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
     921              19 :                 FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
     922 UBC           0 :                 _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
     923               0 :                     ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
     924               0 :                 ))),
     925                 :             },
     926 CBC        1768 :             None => Err(CopyStreamHandlerEnd::EOF),
     927                 :         }
     928         2549287 :     }
     929                 : }
     930                 : 
     931                 : ///
     932                 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
     933                 : /// messages.
     934                 : ///
     935                 : 
     936                 : pub struct CopyDataWriter<'a, IO> {
     937                 :     pgb: &'a mut PostgresBackend<IO>,
     938                 : }
     939                 : 
     940                 : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
     941          908485 :     fn poll_write(
     942          908485 :         self: Pin<&mut Self>,
     943          908485 :         cx: &mut std::task::Context<'_>,
     944          908485 :         buf: &[u8],
     945          908485 :     ) -> Poll<Result<usize, std::io::Error>> {
     946          908485 :         let this = self.get_mut();
     947                 : 
     948                 :         // It's not strictly required to flush between each message, but makes it easier
     949                 :         // to view in wireshark, and usually the messages that the callers write are
     950                 :         // decently-sized anyway.
     951          908485 :         if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
     952 UBC           0 :             return Poll::Ready(Err(err));
     953 CBC      900834 :         }
     954          900834 : 
     955          900834 :         // CopyData
     956          900834 :         // XXX: if the input is large, we should split it into multiple messages.
     957          900834 :         // Not sure what the threshold should be, but the ultimate hard limit is that
     958          900834 :         // the length cannot exceed u32.
     959          900834 :         this.pgb
     960          900834 :             .write_message_noflush(&BeMessage::CopyData(buf))
     961          900834 :             // write_message only writes to the buffer, so it can fail iff the
     962          900834 :             // message is invaid, but CopyData can't be invalid.
     963          900834 :             .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
     964                 : 
     965          900834 :         Poll::Ready(Ok(buf.len()))
     966          908485 :     }
     967                 : 
     968           38382 :     fn poll_flush(
     969           38382 :         self: Pin<&mut Self>,
     970           38382 :         cx: &mut std::task::Context<'_>,
     971           38382 :     ) -> Poll<Result<(), std::io::Error>> {
     972           38382 :         let this = self.get_mut();
     973           38382 :         this.pgb.poll_flush(cx)
     974           38382 :     }
     975                 : 
     976             164 :     fn poll_shutdown(
     977             164 :         self: Pin<&mut Self>,
     978             164 :         cx: &mut std::task::Context<'_>,
     979             164 :     ) -> Poll<Result<(), std::io::Error>> {
     980             164 :         let this = self.get_mut();
     981             164 :         this.pgb.poll_flush(cx)
     982             164 :     }
     983                 : }
     984                 : 
     985             557 : pub fn short_error(e: &QueryError) -> String {
     986             557 :     match e {
     987              38 :         QueryError::Disconnected(connection_error) => connection_error.to_string(),
     988 UBC           0 :         QueryError::Reconnect => "reconnect".to_string(),
     989               0 :         QueryError::Shutdown => "shutdown".to_string(),
     990               0 :         QueryError::NotFound(_) => "not found".to_string(),
     991 CBC           5 :         QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
     992 UBC           0 :         QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
     993 CBC         514 :         QueryError::Other(e) => format!("{e:#}"),
     994                 :     }
     995             557 : }
     996                 : 
     997             565 : fn log_query_error(query: &str, e: &QueryError) {
     998              38 :     match e {
     999              38 :         QueryError::Disconnected(ConnectionError::Io(io_error)) => {
    1000              38 :             if is_expected_io_error(io_error) {
    1001              38 :                 info!("query handler for '{query}' failed with expected io error: {io_error}");
    1002                 :             } else {
    1003 UBC           0 :                 error!("query handler for '{query}' failed with io error: {io_error}");
    1004                 :             }
    1005                 :         }
    1006               0 :         QueryError::Disconnected(other_connection_error) => {
    1007               0 :             error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
    1008                 :         }
    1009                 :         QueryError::SimulatedConnectionError => {
    1010               0 :             error!("query handler for query '{query}' failed due to a simulated connection error")
    1011                 :         }
    1012                 :         QueryError::Reconnect => {
    1013               0 :             info!("query handler for '{query}' requested client to reconnect")
    1014                 :         }
    1015                 :         QueryError::Shutdown => {
    1016               0 :             info!("query handler for '{query}' cancelled during tenant shutdown")
    1017                 :         }
    1018               0 :         QueryError::NotFound(reason) => {
    1019               0 :             info!("query handler for '{query}' entity not found: {reason}")
    1020                 :         }
    1021 CBC           4 :         QueryError::Unauthorized(e) => {
    1022               4 :             warn!("query handler for '{query}' failed with authentication error: {e}");
    1023                 :         }
    1024             523 :         QueryError::Other(e) => {
    1025             523 :             error!("query handler for '{query}' failed: {e:?}");
    1026                 :         }
    1027                 :     }
    1028             565 : }
    1029                 : 
    1030                 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
    1031                 : /// This is not always a real error, but it allows to use ? and thiserror impls.
    1032            2156 : #[derive(thiserror::Error, Debug)]
    1033                 : pub enum CopyStreamHandlerEnd {
    1034                 :     /// Handler initiates the end of streaming.
    1035                 :     #[error("{0}")]
    1036                 :     ServerInitiated(String),
    1037                 :     #[error("received CopyDone")]
    1038                 :     CopyDone,
    1039                 :     #[error("received CopyFail")]
    1040                 :     CopyFail,
    1041                 :     #[error("received Terminate")]
    1042                 :     Terminate,
    1043                 :     #[error("EOF on COPY stream")]
    1044                 :     EOF,
    1045                 :     /// The connection was lost
    1046                 :     #[error("connection error: {0}")]
    1047                 :     Disconnected(#[from] ConnectionError),
    1048                 :     /// Some other error
    1049                 :     #[error(transparent)]
    1050                 :     Other(#[from] anyhow::Error),
    1051                 : }
        

Generated by: LCOV version 2.1-beta