LCOV - code coverage report
Current view: top level - libs/postgres_backend/src - lib.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 83.3 % 520 433
Test Date: 2023-09-06 10:18:01 Functions: 49.4 % 322 159

            Line data    Source code
       1              : //! Server-side asynchronous Postgres connection, as limited as we need.
       2              : //! To use, create PostgresBackend and run() it, passing the Handler
       3              : //! implementation determining how to process the queries. Currently its API
       4              : //! is rather narrow, but we can extend it once required.
       5              : use anyhow::Context;
       6              : use bytes::Bytes;
       7              : use futures::pin_mut;
       8              : use serde::{Deserialize, Serialize};
       9              : use std::io::ErrorKind;
      10              : use std::net::SocketAddr;
      11              : use std::pin::Pin;
      12              : use std::sync::Arc;
      13              : use std::task::{ready, Poll};
      14              : use std::{fmt, io};
      15              : use std::{future::Future, str::FromStr};
      16              : use tokio::io::{AsyncRead, AsyncWrite};
      17              : use tokio_rustls::TlsAcceptor;
      18              : use tracing::{debug, error, info, trace};
      19              : 
      20              : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
      21              : use pq_proto::{
      22              :     BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
      23              :     SQLSTATE_SUCCESSFUL_COMPLETION,
      24              : };
      25              : 
      26              : /// An error, occurred during query processing:
      27              : /// either during the connection ([`ConnectionError`]) or before/after it.
      28          502 : #[derive(thiserror::Error, Debug)]
      29              : pub enum QueryError {
      30              :     /// The connection was lost while processing the query.
      31              :     #[error(transparent)]
      32              :     Disconnected(#[from] ConnectionError),
      33              :     /// Some other error
      34              :     #[error(transparent)]
      35              :     Other(#[from] anyhow::Error),
      36              : }
      37              : 
      38              : impl From<io::Error> for QueryError {
      39          327 :     fn from(e: io::Error) -> Self {
      40          327 :         Self::Disconnected(ConnectionError::Io(e))
      41          327 :     }
      42              : }
      43              : 
      44              : impl QueryError {
      45           91 :     pub fn pg_error_code(&self) -> &'static [u8; 5] {
      46           91 :         match self {
      47            6 :             Self::Disconnected(_) => b"08006",         // connection failure
      48           85 :             Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
      49              :         }
      50           91 :     }
      51              : }
      52              : 
      53              : /// Returns true if the given error is a normal consequence of a network issue,
      54              : /// or the client closing the connection. These errors can happen during normal
      55              : /// operations, and don't indicate a bug in our code.
      56          804 : pub fn is_expected_io_error(e: &io::Error) -> bool {
      57              :     use io::ErrorKind::*;
      58            0 :     matches!(
      59          804 :         e.kind(),
      60              :         BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
      61              :     )
      62          804 : }
      63              : 
      64              : #[async_trait::async_trait]
      65              : pub trait Handler<IO> {
      66              :     /// Handle single query.
      67              :     /// postgres_backend will issue ReadyForQuery after calling this (this
      68              :     /// might be not what we want after CopyData streaming, but currently we don't
      69              :     /// care). It will also flush out the output buffer.
      70              :     async fn process_query(
      71              :         &mut self,
      72              :         pgb: &mut PostgresBackend<IO>,
      73              :         query_string: &str,
      74              :     ) -> Result<(), QueryError>;
      75              : 
      76              :     /// Called on startup packet receival, allows to process params.
      77              :     ///
      78              :     /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
      79              :     /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
      80              :     /// to override whole init logic in implementations.
      81            5 :     fn startup(
      82            5 :         &mut self,
      83            5 :         _pgb: &mut PostgresBackend<IO>,
      84            5 :         _sm: &FeStartupPacket,
      85            5 :     ) -> Result<(), QueryError> {
      86            5 :         Ok(())
      87            5 :     }
      88              : 
      89              :     /// Check auth jwt
      90            0 :     fn check_auth_jwt(
      91            0 :         &mut self,
      92            0 :         _pgb: &mut PostgresBackend<IO>,
      93            0 :         _jwt_response: &[u8],
      94            0 :     ) -> Result<(), QueryError> {
      95            0 :         Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
      96            0 :     }
      97              : }
      98              : 
      99              : /// PostgresBackend protocol state.
     100              : /// XXX: The order of the constructors matters.
     101        40899 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
     102              : pub enum ProtoState {
     103              :     /// Nothing happened yet.
     104              :     Initialization,
     105              :     /// Encryption handshake is done; waiting for encrypted Startup message.
     106              :     Encrypted,
     107              :     /// Waiting for password (auth token).
     108              :     Authentication,
     109              :     /// Performed handshake and auth, ReadyForQuery is issued.
     110              :     Established,
     111              :     Closed,
     112              : }
     113              : 
     114            0 : #[derive(Clone, Copy)]
     115              : pub enum ProcessMsgResult {
     116              :     Continue,
     117              :     Break,
     118              : }
     119              : 
     120              : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
     121              : pub enum MaybeTlsStream<IO> {
     122              :     Unencrypted(IO),
     123              :     Tls(Box<tokio_rustls::server::TlsStream<IO>>),
     124              : }
     125              : 
     126              : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
     127      8640428 :     fn poll_write(
     128      8640428 :         self: Pin<&mut Self>,
     129      8640428 :         cx: &mut std::task::Context<'_>,
     130      8640428 :         buf: &[u8],
     131      8640428 :     ) -> Poll<io::Result<usize>> {
     132      8640428 :         match self.get_mut() {
     133      8640426 :             Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
     134            2 :             Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
     135              :         }
     136      8640428 :     }
     137      8654170 :     fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
     138      8654170 :         match self.get_mut() {
     139      8654167 :             Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
     140            3 :             Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
     141              :         }
     142      8654170 :     }
     143         8294 :     fn poll_shutdown(
     144         8294 :         self: Pin<&mut Self>,
     145         8294 :         cx: &mut std::task::Context<'_>,
     146         8294 :     ) -> Poll<io::Result<()>> {
     147         8294 :         match self.get_mut() {
     148         8294 :             Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
     149            0 :             Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
     150              :         }
     151         8294 :     }
     152              : }
     153              : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
     154     18994316 :     fn poll_read(
     155     18994316 :         self: Pin<&mut Self>,
     156     18994316 :         cx: &mut std::task::Context<'_>,
     157     18994316 :         buf: &mut tokio::io::ReadBuf<'_>,
     158     18994316 :     ) -> Poll<io::Result<()>> {
     159     18994316 :         match self.get_mut() {
     160     18994311 :             Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
     161            5 :             Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
     162              :         }
     163     18994316 :     }
     164              : }
     165              : 
     166        39718 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
     167              : pub enum AuthType {
     168              :     Trust,
     169              :     // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
     170              :     NeonJWT,
     171              : }
     172              : 
     173              : impl FromStr for AuthType {
     174              :     type Err = anyhow::Error;
     175              : 
     176         1888 :     fn from_str(s: &str) -> Result<Self, Self::Err> {
     177         1888 :         match s {
     178         1888 :             "Trust" => Ok(Self::Trust),
     179           36 :             "NeonJWT" => Ok(Self::NeonJWT),
     180            0 :             _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
     181              :         }
     182         1888 :     }
     183              : }
     184              : 
     185              : impl fmt::Display for AuthType {
     186              :     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
     187         3038 :         f.write_str(match self {
     188         2984 :             AuthType::Trust => "Trust",
     189           54 :             AuthType::NeonJWT => "NeonJWT",
     190              :         })
     191         3038 :     }
     192              : }
     193              : 
     194              : /// Either full duplex Framed or write only half; the latter is left in
     195              : /// PostgresBackend after call to `split`. In principle we could always store a
     196              : /// pair of splitted handles, but that would force to to pay splitting price
     197              : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
     198              : enum MaybeWriteOnly<IO> {
     199              :     Full(Framed<MaybeTlsStream<IO>>),
     200              :     WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
     201              :     Broken, // temporary value palmed off during the split
     202              : }
     203              : 
     204              : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
     205        15920 :     async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
     206        15920 :         match self {
     207        15920 :             MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
     208              :             MaybeWriteOnly::WriteOnly(_) => {
     209            0 :                 Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
     210              :             }
     211            0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     212              :         }
     213        15920 :     }
     214              : 
     215      4653304 :     async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     216      4653304 :         match self {
     217      4653304 :             MaybeWriteOnly::Full(framed) => framed.read_message().await,
     218              :             MaybeWriteOnly::WriteOnly(_) => {
     219            0 :                 Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
     220              :             }
     221            0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     222              :         }
     223      4653177 :     }
     224              : 
     225      8635171 :     fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
     226      8635171 :         match self {
     227      5688444 :             MaybeWriteOnly::Full(framed) => framed.write_message(msg),
     228      2946727 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
     229            0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     230              :         }
     231      8635171 :     }
     232              : 
     233      8654851 :     async fn flush(&mut self) -> io::Result<()> {
     234      8654851 :         match self {
     235      5708124 :             MaybeWriteOnly::Full(framed) => framed.flush().await,
     236      2946727 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
     237            0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     238              :         }
     239      8646209 :     }
     240              : 
     241         8620 :     async fn shutdown(&mut self) -> io::Result<()> {
     242         8620 :         match self {
     243         8620 :             MaybeWriteOnly::Full(framed) => framed.shutdown().await,
     244            0 :             MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
     245            0 :             MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
     246              :         }
     247         8620 :     }
     248              : }
     249              : 
     250              : pub struct PostgresBackend<IO> {
     251              :     framed: MaybeWriteOnly<IO>,
     252              : 
     253              :     pub state: ProtoState,
     254              : 
     255              :     auth_type: AuthType,
     256              : 
     257              :     peer_addr: SocketAddr,
     258              :     pub tls_config: Option<Arc<rustls::ServerConfig>>,
     259              : }
     260              : 
     261              : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
     262              : 
     263            0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
     264            0 :     let mut query_string = query_string.to_vec();
     265            0 :     if let Some(ch) = query_string.last() {
     266            0 :         if *ch == 0 {
     267            0 :             query_string.pop();
     268            0 :         }
     269            0 :     }
     270            0 :     query_string
     271            0 : }
     272              : 
     273              : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
     274         9834 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
     275         9834 :     let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
     276         9834 :     std::str::from_utf8(without_null).map_err(|e| e.into())
     277         9834 : }
     278              : 
     279              : impl PostgresBackend<tokio::net::TcpStream> {
     280            5 :     pub fn new(
     281            5 :         socket: tokio::net::TcpStream,
     282            5 :         auth_type: AuthType,
     283            5 :         tls_config: Option<Arc<rustls::ServerConfig>>,
     284            5 :     ) -> io::Result<Self> {
     285            5 :         let peer_addr = socket.peer_addr()?;
     286            5 :         let stream = MaybeTlsStream::Unencrypted(socket);
     287            5 : 
     288            5 :         Ok(Self {
     289            5 :             framed: MaybeWriteOnly::Full(Framed::new(stream)),
     290            5 :             state: ProtoState::Initialization,
     291            5 :             auth_type,
     292            5 :             tls_config,
     293            5 :             peer_addr,
     294            5 :         })
     295            5 :     }
     296              : }
     297              : 
     298              : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
     299         9054 :     pub fn new_from_io(
     300         9054 :         socket: IO,
     301         9054 :         peer_addr: SocketAddr,
     302         9054 :         auth_type: AuthType,
     303         9054 :         tls_config: Option<Arc<rustls::ServerConfig>>,
     304         9054 :     ) -> io::Result<Self> {
     305         9054 :         let stream = MaybeTlsStream::Unencrypted(socket);
     306         9054 : 
     307         9054 :         Ok(Self {
     308         9054 :             framed: MaybeWriteOnly::Full(Framed::new(stream)),
     309         9054 :             state: ProtoState::Initialization,
     310         9054 :             auth_type,
     311         9054 :             tls_config,
     312         9054 :             peer_addr,
     313         9054 :         })
     314         9054 :     }
     315              : 
     316         2935 :     pub fn get_peer_addr(&self) -> &SocketAddr {
     317         2935 :         &self.peer_addr
     318         2935 :     }
     319              : 
     320              :     /// Read full message or return None if connection is cleanly closed with no
     321              :     /// unprocessed data.
     322      4653256 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     323      4653153 :         if let ProtoState::Closed = self.state {
     324           65 :             Ok(None)
     325              :         } else {
     326      4653088 :             match self.framed.read_message().await {
     327      4652923 :                 Ok(m) => {
     328            0 :                     trace!("read msg {:?}", m);
     329      4652923 :                     Ok(m)
     330              :                 }
     331           38 :                 Err(e) => {
     332           38 :                     // remember not to try to read anymore
     333           38 :                     self.state = ProtoState::Closed;
     334           38 :                     Err(e)
     335              :                 }
     336              :             }
     337              :         }
     338      4653026 :     }
     339              : 
     340              :     /// Write message into internal output buffer, doesn't flush it. Technically
     341              :     /// error type can be only ProtocolError here (if, unlikely, serialization
     342              :     /// fails), but callers typically wrap it anyway.
     343              :     pub fn write_message_noflush(
     344              :         &mut self,
     345              :         message: &BeMessage<'_>,
     346              :     ) -> Result<&mut Self, ConnectionError> {
     347      8635171 :         self.framed.write_message_noflush(message)?;
     348      8635171 :         trace!("wrote msg {:?}", message);
     349      8635171 :         Ok(self)
     350      8635171 :     }
     351              : 
     352              :     /// Flush output buffer into the socket.
     353      8654851 :     pub async fn flush(&mut self) -> io::Result<()> {
     354      8654851 :         self.framed.flush().await
     355      8646209 :     }
     356              : 
     357              :     /// Polling version of `flush()`, saves the caller need to pin.
     358      1054034 :     pub fn poll_flush(
     359      1054034 :         &mut self,
     360      1054034 :         cx: &mut std::task::Context<'_>,
     361      1054034 :     ) -> Poll<Result<(), std::io::Error>> {
     362      1054034 :         let flush_fut = self.flush();
     363      1054034 :         pin_mut!(flush_fut);
     364      1054034 :         flush_fut.poll(cx)
     365      1054034 :     }
     366              : 
     367              :     /// Write message into internal output buffer and flush it to the stream.
     368      2965945 :     pub async fn write_message(
     369      2965945 :         &mut self,
     370      2965945 :         message: &BeMessage<'_>,
     371      2965945 :     ) -> Result<&mut Self, ConnectionError> {
     372      2965945 :         self.write_message_noflush(message)?;
     373      2965945 :         self.flush().await?;
     374      2965925 :         Ok(self)
     375      2965931 :     }
     376              : 
     377              :     /// Returns an AsyncWrite implementation that wraps all the data written
     378              :     /// to it in CopyData messages, and writes them to the connection
     379              :     ///
     380              :     /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
     381          662 :     pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
     382          662 :         CopyDataWriter { pgb: self }
     383          662 :     }
     384              : 
     385              :     /// Wrapper for run_message_loop() that shuts down socket when we are done
     386         9059 :     pub async fn run<F, S>(
     387         9059 :         mut self,
     388         9059 :         handler: &mut impl Handler<IO>,
     389         9059 :         shutdown_watcher: F,
     390         9059 :     ) -> Result<(), QueryError>
     391         9059 :     where
     392         9059 :         F: Fn() -> S,
     393         9059 :         S: Future,
     394         9059 :     {
     395     20363098 :         let ret = self.run_message_loop(handler, shutdown_watcher).await;
     396              :         // socket might be already closed, e.g. if previously received error,
     397              :         // so ignore result.
     398         8620 :         self.framed.shutdown().await.ok();
     399         8620 :         ret
     400         8620 :     }
     401              : 
     402         9059 :     async fn run_message_loop<F, S>(
     403         9059 :         &mut self,
     404         9059 :         handler: &mut impl Handler<IO>,
     405         9059 :         shutdown_watcher: F,
     406         9059 :     ) -> Result<(), QueryError>
     407         9059 :     where
     408         9059 :         F: Fn() -> S,
     409         9059 :         S: Future,
     410         9059 :     {
     411            0 :         trace!("postgres backend to {:?} started", self.peer_addr);
     412              : 
     413         9059 :         tokio::select!(
     414              :             biased;
     415              : 
     416              :             _ = shutdown_watcher() => {
     417              :                 // We were requested to shut down.
     418            0 :                 tracing::info!("shutdown request received during handshake");
     419              :                 return Ok(())
     420              :             },
     421              : 
     422         9059 :             result = self.handshake(handler) => {
     423              :                 // Handshake complete.
     424              :                 result?;
     425              :                 if self.state == ProtoState::Closed {
     426              :                     return Ok(()); // EOF during handshake
     427              :                 }
     428              :             }
     429              :         );
     430              : 
     431              :         // Authentication completed
     432         9053 :         let mut query_string = Bytes::new();
     433        22765 :         while let Some(msg) = tokio::select!(
     434              :             biased;
     435              :             _ = shutdown_watcher() => {
     436              :                 // We were requested to shut down.
     437           96 :                 tracing::info!("shutdown request received in run_message_loop");
     438              :                 Ok(None)
     439              :             },
     440        22667 :             msg = self.read_message() => { msg },
     441           33 :         )? {
     442            0 :             trace!("got message {:?}", msg);
     443              : 
     444     20334774 :             let result = self.process_message(handler, msg, &mut query_string).await;
     445        15694 :             self.flush().await?;
     446        15368 :             match result? {
     447              :                 ProcessMsgResult::Continue => {
     448        13712 :                     self.flush().await?;
     449        13712 :                     continue;
     450              :                 }
     451         1521 :                 ProcessMsgResult::Break => break,
     452              :             }
     453              :         }
     454              : 
     455            0 :         trace!("postgres backend to {:?} exited", self.peer_addr);
     456         8120 :         Ok(())
     457         8620 :     }
     458              : 
     459              :     /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
     460            1 :     async fn tls_upgrade(
     461            1 :         src: MaybeTlsStream<IO>,
     462            1 :         tls_config: Arc<rustls::ServerConfig>,
     463            1 :     ) -> anyhow::Result<MaybeTlsStream<IO>> {
     464            1 :         match src {
     465            1 :             MaybeTlsStream::Unencrypted(s) => {
     466            1 :                 let acceptor = TlsAcceptor::from(tls_config);
     467            2 :                 let tls_stream = acceptor.accept(s).await?;
     468            1 :                 Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
     469              :             }
     470              :             MaybeTlsStream::Tls(_) => {
     471            0 :                 anyhow::bail!("TLS already started");
     472              :             }
     473              :         }
     474            1 :     }
     475              : 
     476            1 :     async fn start_tls(&mut self) -> anyhow::Result<()> {
     477            1 :         // temporary replace stream with fake to cook TLS one, Indiana Jones style
     478            1 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     479            1 :             MaybeWriteOnly::Full(framed) => {
     480            1 :                 let tls_config = self
     481            1 :                     .tls_config
     482            1 :                     .as_ref()
     483            1 :                     .context("start_tls called without conf")?
     484            1 :                     .clone();
     485            1 :                 let tls_framed = framed
     486            1 :                     .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
     487            2 :                     .await?;
     488              :                 // push back ready TLS stream
     489            1 :                 self.framed = MaybeWriteOnly::Full(tls_framed);
     490            1 :                 Ok(())
     491              :             }
     492              :             MaybeWriteOnly::WriteOnly(_) => {
     493            0 :                 anyhow::bail!("TLS upgrade attempt in split state")
     494              :             }
     495            0 :             MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
     496              :         }
     497            1 :     }
     498              : 
     499              :     /// Split off owned read part from which messages can be read in different
     500              :     /// task/thread.
     501         2935 :     pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
     502         2935 :         // temporary replace stream with fake to cook split one, Indiana Jones style
     503         2935 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     504         2935 :             MaybeWriteOnly::Full(framed) => {
     505         2935 :                 let (reader, writer) = framed.split();
     506         2935 :                 self.framed = MaybeWriteOnly::WriteOnly(writer);
     507         2935 :                 Ok(PostgresBackendReader {
     508         2935 :                     reader,
     509         2935 :                     closed: false,
     510         2935 :                 })
     511              :             }
     512              :             MaybeWriteOnly::WriteOnly(_) => {
     513            0 :                 anyhow::bail!("PostgresBackend is already split")
     514              :             }
     515            0 :             MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
     516              :         }
     517         2935 :     }
     518              : 
     519              :     /// Join read part back.
     520         2565 :     pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
     521         2565 :         // temporary replace stream with fake to cook joined one, Indiana Jones style
     522         2565 :         match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
     523              :             MaybeWriteOnly::Full(_) => {
     524            0 :                 anyhow::bail!("PostgresBackend is not split")
     525              :             }
     526         2565 :             MaybeWriteOnly::WriteOnly(writer) => {
     527         2565 :                 let joined = Framed::unsplit(reader.reader, writer);
     528         2565 :                 self.framed = MaybeWriteOnly::Full(joined);
     529         2565 :                 // if reader encountered connection error, do not attempt reading anymore
     530         2565 :                 if reader.closed {
     531          303 :                     self.state = ProtoState::Closed;
     532         2262 :                 }
     533         2565 :                 Ok(())
     534              :             }
     535            0 :             MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
     536              :         }
     537         2565 :     }
     538              : 
     539              :     /// Perform handshake with the client, transitioning to Established.
     540              :     /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
     541         9059 :     async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
     542        24979 :         while self.state < ProtoState::Authentication {
     543        15920 :             match self.framed.read_startup_message().await? {
     544        15920 :                 Some(msg) => {
     545        15920 :                     self.process_startup_message(handler, msg).await?;
     546              :                 }
     547              :                 None => {
     548            0 :                     trace!(
     549            0 :                         "postgres backend to {:?} received EOF during handshake",
     550            0 :                         self.peer_addr
     551            0 :                     );
     552            0 :                     self.state = ProtoState::Closed;
     553            0 :                     return Ok(());
     554              :                 }
     555              :             }
     556              :         }
     557              : 
     558              :         // Perform auth, if needed.
     559         9059 :         if self.state == ProtoState::Authentication {
     560          216 :             match self.framed.read_message().await? {
     561          211 :                 Some(FeMessage::PasswordMessage(m)) => {
     562          211 :                     assert!(self.auth_type == AuthType::NeonJWT);
     563              : 
     564          211 :                     let (_, jwt_response) = m.split_last().context("protocol violation")?;
     565              : 
     566          211 :                     if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
     567            1 :                         self.write_message_noflush(&BeMessage::ErrorResponse(
     568            1 :                             &e.to_string(),
     569            1 :                             Some(e.pg_error_code()),
     570            1 :                         ))?;
     571            1 :                         return Err(e);
     572          210 :                     }
     573          210 : 
     574          210 :                     self.write_message_noflush(&BeMessage::AuthenticationOk)?
     575          210 :                         .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
     576          210 :                         .write_message(&BeMessage::ReadyForQuery)
     577            0 :                         .await?;
     578          210 :                     self.state = ProtoState::Established;
     579              :                 }
     580            0 :                 Some(m) => {
     581            0 :                     return Err(QueryError::Other(anyhow::anyhow!(
     582            0 :                         "Unexpected message {:?} while waiting for handshake",
     583            0 :                         m
     584            0 :                     )));
     585              :                 }
     586              :                 None => {
     587            0 :                     trace!(
     588            0 :                         "postgres backend to {:?} received EOF during auth",
     589            0 :                         self.peer_addr
     590            0 :                     );
     591            5 :                     self.state = ProtoState::Closed;
     592            5 :                     return Ok(());
     593              :                 }
     594              :             }
     595         8843 :         }
     596              : 
     597         9053 :         Ok(())
     598         9059 :     }
     599              : 
     600              :     /// Process startup packet:
     601              :     /// - transition to Established if auth type is trust
     602              :     /// - transition to Authentication if auth type is NeonJWT.
     603              :     /// - or perform TLS handshake -- then need to call this again to receive
     604              :     ///   actual startup packet.
     605        15920 :     async fn process_startup_message(
     606        15920 :         &mut self,
     607        15920 :         handler: &mut impl Handler<IO>,
     608        15920 :         msg: FeStartupPacket,
     609        15920 :     ) -> Result<(), QueryError> {
     610        15920 :         assert!(self.state < ProtoState::Authentication);
     611        15920 :         let have_tls = self.tls_config.is_some();
     612        15920 :         match msg {
     613              :             FeStartupPacket::SslRequest => {
     614            0 :                 debug!("SSL requested");
     615              : 
     616         6861 :                 self.write_message(&BeMessage::EncryptionResponse(have_tls))
     617            0 :                     .await?;
     618              : 
     619         6861 :                 if have_tls {
     620            2 :                     self.start_tls().await?;
     621            1 :                     self.state = ProtoState::Encrypted;
     622         6860 :                 }
     623              :             }
     624              :             FeStartupPacket::GssEncRequest => {
     625            0 :                 debug!("GSS requested");
     626            0 :                 self.write_message(&BeMessage::EncryptionResponse(false))
     627            0 :                     .await?;
     628              :             }
     629              :             FeStartupPacket::StartupMessage { .. } => {
     630         9059 :                 if have_tls && !matches!(self.state, ProtoState::Encrypted) {
     631            0 :                     self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
     632            0 :                         .await?;
     633            0 :                     return Err(QueryError::Other(anyhow::anyhow!(
     634            0 :                         "client did not connect with TLS"
     635            0 :                     )));
     636         9059 :                 }
     637         9059 : 
     638         9059 :                 // NB: startup() may change self.auth_type -- we are using that in proxy code
     639         9059 :                 // to bypass auth for new users.
     640         9059 :                 handler.startup(self, &msg)?;
     641              : 
     642         9059 :                 match self.auth_type {
     643              :                     AuthType::Trust => {
     644         8843 :                         self.write_message_noflush(&BeMessage::AuthenticationOk)?
     645         8843 :                             .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
     646         8843 :                             .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
     647              :                             // The async python driver requires a valid server_version
     648         8843 :                             .write_message_noflush(&BeMessage::server_version("14.1"))?
     649         8843 :                             .write_message(&BeMessage::ReadyForQuery)
     650            0 :                             .await?;
     651         8843 :                         self.state = ProtoState::Established;
     652              :                     }
     653              :                     AuthType::NeonJWT => {
     654          216 :                         self.write_message(&BeMessage::AuthenticationCleartextPassword)
     655            0 :                             .await?;
     656          216 :                         self.state = ProtoState::Authentication;
     657              :                     }
     658              :                 }
     659              :             }
     660              :             FeStartupPacket::CancelRequest { .. } => {
     661            0 :                 return Err(QueryError::Other(anyhow::anyhow!(
     662            0 :                     "Unexpected CancelRequest message during handshake"
     663            0 :                 )));
     664              :             }
     665              :         }
     666        15920 :         Ok(())
     667        15920 :     }
     668              : 
     669        16131 :     async fn process_message(
     670        16131 :         &mut self,
     671        16131 :         handler: &mut impl Handler<IO>,
     672        16131 :         msg: FeMessage,
     673        16131 :         unnamed_query_string: &mut Bytes,
     674        16131 :     ) -> Result<ProcessMsgResult, QueryError> {
     675        16131 :         // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
     676        16131 :         // TODO: change that to proper top-level match of protocol state with separate message handling for each state
     677        16131 :         assert!(self.state == ProtoState::Established);
     678              : 
     679        16131 :         match msg {
     680         9172 :             FeMessage::Query(body) => {
     681              :                 // remove null terminator
     682         9172 :                 let query_string = cstr_to_str(&body)?;
     683              : 
     684            0 :                 trace!("got query {query_string:?}");
     685     20320920 :                 if let Err(e) = handler.process_query(self, query_string).await {
     686           79 :                     log_query_error(query_string, &e);
     687           79 :                     let short_error = short_error(&e);
     688           79 :                     self.write_message_noflush(&BeMessage::ErrorResponse(
     689           79 :                         &short_error,
     690           79 :                         Some(e.pg_error_code()),
     691           79 :                     ))?;
     692         8656 :                 }
     693         8735 :                 self.write_message_noflush(&BeMessage::ReadyForQuery)?;
     694              :             }
     695              : 
     696          662 :             FeMessage::Parse(m) => {
     697          662 :                 *unnamed_query_string = m.query_string;
     698          662 :                 self.write_message_noflush(&BeMessage::ParseComplete)?;
     699              :             }
     700              : 
     701              :             FeMessage::Describe(_) => {
     702          662 :                 self.write_message_noflush(&BeMessage::ParameterDescription)?
     703          662 :                     .write_message_noflush(&BeMessage::NoData)?;
     704              :             }
     705              : 
     706              :             FeMessage::Bind(_) => {
     707          662 :                 self.write_message_noflush(&BeMessage::BindComplete)?;
     708              :             }
     709              : 
     710              :             FeMessage::Close(_) => {
     711          661 :                 self.write_message_noflush(&BeMessage::CloseComplete)?;
     712              :             }
     713              : 
     714              :             FeMessage::Execute(_) => {
     715          662 :                 let query_string = cstr_to_str(unnamed_query_string)?;
     716            0 :                 trace!("got execute {query_string:?}");
     717        13854 :                 if let Err(e) = handler.process_query(self, query_string).await {
     718            9 :                     log_query_error(query_string, &e);
     719            9 :                     self.write_message_noflush(&BeMessage::ErrorResponse(
     720            9 :                         &e.to_string(),
     721            9 :                         Some(e.pg_error_code()),
     722            9 :                     ))?;
     723          653 :                 }
     724              :                 // NOTE there is no ReadyForQuery message. This handler is used
     725              :                 // for basebackup and it uses CopyOut which doesn't require
     726              :                 // ReadyForQuery message and backend just switches back to
     727              :                 // processing mode after sending CopyDone or ErrorResponse.
     728              :             }
     729              : 
     730              :             FeMessage::Sync => {
     731         1994 :                 self.write_message_noflush(&BeMessage::ReadyForQuery)?;
     732              :             }
     733              : 
     734              :             FeMessage::Terminate => {
     735         1521 :                 return Ok(ProcessMsgResult::Break);
     736              :             }
     737              : 
     738              :             // We prefer explicit pattern matching to wildcards, because
     739              :             // this helps us spot the places where new variants are missing
     740              :             FeMessage::CopyData(_)
     741              :             | FeMessage::CopyDone
     742              :             | FeMessage::CopyFail
     743              :             | FeMessage::PasswordMessage(_) => {
     744          135 :                 return Err(QueryError::Other(anyhow::anyhow!(
     745          135 :                     "unexpected message type: {msg:?}",
     746          135 :                 )));
     747              :             }
     748              :         }
     749              : 
     750        14038 :         Ok(ProcessMsgResult::Continue)
     751        15694 :     }
     752              : 
     753              :     /// Log as info/error result of handling COPY stream and send back
     754              :     /// ErrorResponse if that makes sense. Shutdown the stream if we got
     755              :     /// Terminate. TODO: transition into waiting for Sync msg if we initiate the
     756              :     /// close.
     757         2565 :     pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
     758              :         use CopyStreamHandlerEnd::*;
     759              : 
     760         2565 :         let expected_end = match &end {
     761         2255 :             ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
     762          309 :             CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
     763          309 :                 if is_expected_io_error(io_error) =>
     764          309 :             {
     765          309 :                 true
     766              :             }
     767            1 :             _ => false,
     768              :         };
     769         2565 :         if expected_end {
     770         2564 :             info!("terminated: {:#}", end);
     771              :         } else {
     772            1 :             error!("terminated: {:?}", end);
     773              :         }
     774              : 
     775              :         // Note: no current usages ever send this
     776         2565 :         if let CopyDone = &end {
     777            0 :             if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
     778            0 :                 error!("failed to send CopyDone: {}", e);
     779            0 :             }
     780         2565 :         }
     781              : 
     782         2565 :         if let Terminate = &end {
     783           65 :             self.state = ProtoState::Closed;
     784         2500 :         }
     785              : 
     786         2565 :         let err_to_send_and_errcode = match &end {
     787          133 :             ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
     788            1 :             Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
     789              :             // Note: CopyFail in duplex copy is somewhat unexpected (at least to
     790              :             // PG walsender; evidently and per my docs reading client should
     791              :             // finish it with CopyDone). It is not a problem to recover from it
     792              :             // finishing the stream in both directions like we do, but note that
     793              :             // sync rust-postgres client (which we don't use anymore) hangs if
     794              :             // socket is not closed here.
     795              :             // https://github.com/sfackler/rust-postgres/issues/755
     796              :             // https://github.com/neondatabase/neon/issues/935
     797              :             //
     798              :             // Currently, the version of tokio_postgres replication patch we use
     799              :             // sends this when it closes the stream (e.g. pageserver decided to
     800              :             // switch conn to another safekeeper and client gets dropped).
     801              :             // Moreover, seems like 'connection' task errors with 'unexpected
     802              :             // message from server' when it receives ErrorResponse (anything but
     803              :             // CopyData/CopyDone) back.
     804           19 :             CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
     805         2412 :             _ => None,
     806              :         };
     807         2565 :         if let Some((err, errcode)) = err_to_send_and_errcode {
     808          153 :             if let Err(ee) = self
     809          153 :                 .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
     810            0 :                 .await
     811              :             {
     812            0 :                 error!("failed to send ErrorResponse: {}", ee);
     813          153 :             }
     814         2412 :         }
     815         2565 :     }
     816              : }
     817              : 
     818              : pub struct PostgresBackendReader<IO> {
     819              :     reader: FramedReader<MaybeTlsStream<IO>>,
     820              :     closed: bool, // true if received error closing the connection
     821              : }
     822              : 
     823              : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
     824              :     /// Read full message or return None if connection is cleanly closed with no
     825              :     /// unprocessed data.
     826      3387810 :     pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
     827      7312479 :         match self.reader.read_message().await {
     828      3386998 :             Ok(m) => {
     829            0 :                 trace!("read msg {:?}", m);
     830      3386998 :                 Ok(m)
     831              :             }
     832          303 :             Err(e) => {
     833          303 :                 self.closed = true;
     834          303 :                 Err(e)
     835              :             }
     836              :         }
     837      3387301 :     }
     838              : 
     839              :     /// Get CopyData contents of the next message in COPY stream or error
     840              :     /// closing it. The error type is wider than actual errors which can happen
     841              :     /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
     842              :     /// current callers.
     843      3387810 :     pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
     844      7312479 :         match self.read_message().await? {
     845      3384960 :             Some(msg) => match msg {
     846      3384876 :                 FeMessage::CopyData(m) => Ok(m),
     847            0 :                 FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
     848           19 :                 FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
     849           65 :                 FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
     850            0 :                 _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
     851            0 :                     ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
     852            0 :                 ))),
     853              :             },
     854         2038 :             None => Err(CopyStreamHandlerEnd::EOF),
     855              :         }
     856      3387301 :     }
     857              : }
     858              : 
     859              : ///
     860              : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
     861              : /// messages.
     862              : ///
     863              : 
     864              : pub struct CopyDataWriter<'a, IO> {
     865              :     pgb: &'a mut PostgresBackend<IO>,
     866              : }
     867              : 
     868              : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
     869      1017389 :     fn poll_write(
     870      1017389 :         self: Pin<&mut Self>,
     871      1017389 :         cx: &mut std::task::Context<'_>,
     872      1017389 :         buf: &[u8],
     873      1017389 :     ) -> Poll<Result<usize, std::io::Error>> {
     874      1017389 :         let this = self.get_mut();
     875              : 
     876              :         // It's not strictly required to flush between each message, but makes it easier
     877              :         // to view in wireshark, and usually the messages that the callers write are
     878              :         // decently-sized anyway.
     879      1017389 :         if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
     880            0 :             return Poll::Ready(Err(err));
     881      1008955 :         }
     882      1008955 : 
     883      1008955 :         // CopyData
     884      1008955 :         // XXX: if the input is large, we should split it into multiple messages.
     885      1008955 :         // Not sure what the threshold should be, but the ultimate hard limit is that
     886      1008955 :         // the length cannot exceed u32.
     887      1008955 :         this.pgb
     888      1008955 :             .write_message_noflush(&BeMessage::CopyData(buf))
     889      1008955 :             // write_message only writes to the buffer, so it can fail iff the
     890      1008955 :             // message is invaid, but CopyData can't be invalid.
     891      1008955 :             .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
     892              : 
     893      1008955 :         Poll::Ready(Ok(buf.len()))
     894      1017389 :     }
     895              : 
     896        36440 :     fn poll_flush(
     897        36440 :         self: Pin<&mut Self>,
     898        36440 :         cx: &mut std::task::Context<'_>,
     899        36440 :     ) -> Poll<Result<(), std::io::Error>> {
     900        36440 :         let this = self.get_mut();
     901        36440 :         this.pgb.poll_flush(cx)
     902        36440 :     }
     903              : 
     904          205 :     fn poll_shutdown(
     905          205 :         self: Pin<&mut Self>,
     906          205 :         cx: &mut std::task::Context<'_>,
     907          205 :     ) -> Poll<Result<(), std::io::Error>> {
     908          205 :         let this = self.get_mut();
     909          205 :         this.pgb.poll_flush(cx)
     910          205 :     }
     911              : }
     912              : 
     913           79 : pub fn short_error(e: &QueryError) -> String {
     914           79 :     match e {
     915            6 :         QueryError::Disconnected(connection_error) => connection_error.to_string(),
     916           73 :         QueryError::Other(e) => format!("{e:#}"),
     917              :     }
     918           79 : }
     919              : 
     920              : fn log_query_error(query: &str, e: &QueryError) {
     921            6 :     match e {
     922            6 :         QueryError::Disconnected(ConnectionError::Io(io_error)) => {
     923            6 :             if is_expected_io_error(io_error) {
     924            6 :                 info!("query handler for '{query}' failed with expected io error: {io_error}");
     925              :             } else {
     926            0 :                 error!("query handler for '{query}' failed with io error: {io_error}");
     927              :             }
     928              :         }
     929            0 :         QueryError::Disconnected(other_connection_error) => {
     930            0 :             error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
     931              :         }
     932           82 :         QueryError::Other(e) => {
     933           82 :             error!("query handler for '{query}' failed: {e:?}");
     934              :         }
     935              :     }
     936           88 : }
     937              : 
     938              : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
     939              : /// This is not always a real error, but it allows to use ? and thiserror impls.
     940         2717 : #[derive(thiserror::Error, Debug)]
     941              : pub enum CopyStreamHandlerEnd {
     942              :     /// Handler initiates the end of streaming.
     943              :     #[error("{0}")]
     944              :     ServerInitiated(String),
     945              :     #[error("received CopyDone")]
     946              :     CopyDone,
     947              :     #[error("received CopyFail")]
     948              :     CopyFail,
     949              :     #[error("received Terminate")]
     950              :     Terminate,
     951              :     #[error("EOF on COPY stream")]
     952              :     EOF,
     953              :     /// The connection was lost
     954              :     #[error("connection error: {0}")]
     955              :     Disconnected(#[from] ConnectionError),
     956              :     /// Some other error
     957              :     #[error(transparent)]
     958              :     Other(#[from] anyhow::Error),
     959              : }
        

Generated by: LCOV version 2.1-beta