Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use std::future::Future;
8 : use std::net::SocketAddr;
9 : use std::os::fd::{AsRawFd, RawFd};
10 : use std::pin::Pin;
11 : use std::str::FromStr;
12 : use std::sync::Arc;
13 : use std::task::{Poll, ready};
14 : use std::{fmt, io};
15 :
16 : use anyhow::Context;
17 : use bytes::Bytes;
18 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
19 : use pq_proto::{
20 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
21 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
22 : };
23 : use serde::{Deserialize, Serialize};
24 : use tokio::io::{AsyncRead, AsyncWrite};
25 : use tokio_rustls::TlsAcceptor;
26 : use tokio_util::sync::CancellationToken;
27 : use tracing::{debug, error, info, trace, warn};
28 :
29 : /// An error, occurred during query processing:
30 : /// either during the connection ([`ConnectionError`]) or before/after it.
31 : #[derive(thiserror::Error, Debug)]
32 : pub enum QueryError {
33 : /// The connection was lost while processing the query.
34 : #[error(transparent)]
35 : Disconnected(#[from] ConnectionError),
36 : /// We were instructed to shutdown while processing the query
37 : #[error("Shutting down")]
38 : Shutdown,
39 : /// Query handler indicated that client should reconnect
40 : #[error("Server requested reconnect")]
41 : Reconnect,
42 : /// Query named an entity that was not found
43 : #[error("Not found: {0}")]
44 : NotFound(std::borrow::Cow<'static, str>),
45 : /// Authentication failure
46 : #[error("Unauthorized: {0}")]
47 : Unauthorized(std::borrow::Cow<'static, str>),
48 : #[error("Simulated Connection Error")]
49 : SimulatedConnectionError,
50 : /// Some other error
51 : #[error(transparent)]
52 : Other(#[from] anyhow::Error),
53 : }
54 :
55 : impl From<io::Error> for QueryError {
56 0 : fn from(e: io::Error) -> Self {
57 0 : Self::Disconnected(ConnectionError::Io(e))
58 0 : }
59 : }
60 :
61 : impl QueryError {
62 0 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
63 0 : match self {
64 0 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
65 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
66 0 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
67 0 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
68 : }
69 0 : }
70 : }
71 :
72 : /// Returns true if the given error is a normal consequence of a network issue,
73 : /// or the client closing the connection.
74 : ///
75 : /// These errors can happen during normal operations,
76 : /// and don't indicate a bug in our code.
77 0 : pub fn is_expected_io_error(e: &io::Error) -> bool {
78 : use io::ErrorKind::*;
79 0 : matches!(
80 0 : e.kind(),
81 : HostUnreachable
82 : | NetworkUnreachable
83 : | BrokenPipe
84 : | ConnectionRefused
85 : | ConnectionAborted
86 : | ConnectionReset
87 : | TimedOut,
88 : )
89 0 : }
90 :
91 : pub trait Handler<IO> {
92 : /// Handle single query.
93 : /// postgres_backend will issue ReadyForQuery after calling this (this
94 : /// might be not what we want after CopyData streaming, but currently we don't
95 : /// care). It will also flush out the output buffer.
96 : fn process_query(
97 : &mut self,
98 : pgb: &mut PostgresBackend<IO>,
99 : query_string: &str,
100 : ) -> impl Future<Output = Result<(), QueryError>>;
101 :
102 : /// Called on startup packet receival, allows to process params.
103 : ///
104 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
105 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
106 : /// to override whole init logic in implementations.
107 2 : fn startup(
108 2 : &mut self,
109 2 : _pgb: &mut PostgresBackend<IO>,
110 2 : _sm: &FeStartupPacket,
111 2 : ) -> Result<(), QueryError> {
112 2 : Ok(())
113 0 : }
114 :
115 : /// Check auth jwt
116 0 : fn check_auth_jwt(
117 0 : &mut self,
118 0 : _pgb: &mut PostgresBackend<IO>,
119 0 : _jwt_response: &[u8],
120 0 : ) -> Result<(), QueryError> {
121 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
122 0 : }
123 : }
124 :
125 : /// PostgresBackend protocol state.
126 : /// XXX: The order of the constructors matters.
127 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
128 : pub enum ProtoState {
129 : /// Nothing happened yet.
130 : Initialization,
131 : /// Encryption handshake is done; waiting for encrypted Startup message.
132 : Encrypted,
133 : /// Waiting for password (auth token).
134 : Authentication,
135 : /// Performed handshake and auth, ReadyForQuery is issued.
136 : Established,
137 : Closed,
138 : }
139 :
140 : #[derive(Clone, Copy)]
141 : pub enum ProcessMsgResult {
142 : Continue,
143 : Break,
144 : }
145 :
146 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
147 : pub enum MaybeTlsStream<IO> {
148 : Unencrypted(IO),
149 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
150 : }
151 :
152 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
153 5 : fn poll_write(
154 5 : self: Pin<&mut Self>,
155 5 : cx: &mut std::task::Context<'_>,
156 5 : buf: &[u8],
157 5 : ) -> Poll<io::Result<usize>> {
158 5 : match self.get_mut() {
159 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
160 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
161 : }
162 0 : }
163 5 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
164 5 : match self.get_mut() {
165 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
166 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
167 : }
168 0 : }
169 0 : fn poll_shutdown(
170 0 : self: Pin<&mut Self>,
171 0 : cx: &mut std::task::Context<'_>,
172 0 : ) -> Poll<io::Result<()>> {
173 0 : match self.get_mut() {
174 0 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
175 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
176 : }
177 0 : }
178 : }
179 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
180 12 : fn poll_read(
181 12 : self: Pin<&mut Self>,
182 12 : cx: &mut std::task::Context<'_>,
183 12 : buf: &mut tokio::io::ReadBuf<'_>,
184 12 : ) -> Poll<io::Result<()>> {
185 12 : match self.get_mut() {
186 7 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
187 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
188 : }
189 0 : }
190 : }
191 :
192 0 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
193 : pub enum AuthType {
194 : Trust,
195 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
196 : NeonJWT,
197 : // Similar to above but uses Hadron JWT. Hadron JWTs are slightly different in that:
198 : // 1. Decoding keys are loaded from PEM-encoded X509 certificates instead of plain key files.
199 : // 2. Signature algorithm is RSA-based (may change in the future).
200 : HadronJWT,
201 : }
202 :
203 : impl FromStr for AuthType {
204 : type Err = anyhow::Error;
205 :
206 0 : fn from_str(s: &str) -> Result<Self, Self::Err> {
207 0 : match s {
208 0 : "Trust" => Ok(Self::Trust),
209 0 : "NeonJWT" => Ok(Self::NeonJWT),
210 0 : "HadronJWT" => Ok(Self::HadronJWT),
211 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
212 : }
213 0 : }
214 : }
215 :
216 : impl fmt::Display for AuthType {
217 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 0 : f.write_str(match self {
219 0 : AuthType::Trust => "Trust",
220 0 : AuthType::NeonJWT => "NeonJWT",
221 0 : AuthType::HadronJWT => "HadronJWT",
222 : })
223 0 : }
224 : }
225 :
226 : /// Either full duplex Framed or write only half; the latter is left in
227 : /// PostgresBackend after call to `split`. In principle we could always store a
228 : /// pair of splitted handles, but that would force to to pay splitting price
229 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
230 : enum MaybeWriteOnly<IO> {
231 : Full(Framed<MaybeTlsStream<IO>>),
232 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
233 : Broken, // temporary value palmed off during the split
234 : }
235 :
236 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
237 3 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
238 3 : match self {
239 3 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
240 : MaybeWriteOnly::WriteOnly(_) => {
241 0 : Err(io::Error::other("reading from write only half").into())
242 : }
243 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
244 : }
245 0 : }
246 :
247 4 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
248 4 : match self {
249 4 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
250 : MaybeWriteOnly::WriteOnly(_) => {
251 0 : Err(io::Error::other("reading from write only half").into())
252 : }
253 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
254 : }
255 0 : }
256 :
257 19 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
258 19 : match self {
259 19 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
260 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
261 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
262 : }
263 0 : }
264 :
265 5 : async fn flush(&mut self) -> io::Result<()> {
266 5 : match self {
267 5 : MaybeWriteOnly::Full(framed) => framed.flush().await,
268 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
269 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
270 : }
271 0 : }
272 :
273 : /// Cancellation safe as long as the underlying IO is cancellation safe.
274 0 : async fn shutdown(&mut self) -> io::Result<()> {
275 0 : match self {
276 0 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
277 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
278 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
279 : }
280 0 : }
281 : }
282 :
283 : pub struct PostgresBackend<IO> {
284 : pub socket_fd: RawFd,
285 : framed: MaybeWriteOnly<IO>,
286 :
287 : pub state: ProtoState,
288 :
289 : auth_type: AuthType,
290 :
291 : peer_addr: SocketAddr,
292 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
293 : }
294 :
295 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
296 :
297 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
298 2 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
299 2 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
300 2 : std::str::from_utf8(without_null).map_err(|e| e.into())
301 2 : }
302 :
303 : impl PostgresBackend<tokio::net::TcpStream> {
304 2 : pub fn new(
305 2 : socket: tokio::net::TcpStream,
306 2 : auth_type: AuthType,
307 2 : tls_config: Option<Arc<rustls::ServerConfig>>,
308 2 : ) -> io::Result<Self> {
309 2 : let peer_addr = socket.peer_addr()?;
310 2 : let socket_fd = socket.as_raw_fd();
311 2 : let stream = MaybeTlsStream::Unencrypted(socket);
312 :
313 2 : Ok(Self {
314 2 : socket_fd,
315 2 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
316 2 : state: ProtoState::Initialization,
317 2 : auth_type,
318 2 : tls_config,
319 2 : peer_addr,
320 2 : })
321 2 : }
322 : }
323 :
324 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
325 0 : pub fn new_from_io(
326 0 : socket_fd: RawFd,
327 0 : socket: IO,
328 0 : peer_addr: SocketAddr,
329 0 : auth_type: AuthType,
330 0 : tls_config: Option<Arc<rustls::ServerConfig>>,
331 0 : ) -> io::Result<Self> {
332 0 : let stream = MaybeTlsStream::Unencrypted(socket);
333 :
334 0 : Ok(Self {
335 0 : socket_fd,
336 0 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
337 0 : state: ProtoState::Initialization,
338 0 : auth_type,
339 0 : tls_config,
340 0 : peer_addr,
341 0 : })
342 0 : }
343 :
344 0 : pub fn get_peer_addr(&self) -> &SocketAddr {
345 0 : &self.peer_addr
346 0 : }
347 :
348 : /// Read full message or return None if connection is cleanly closed with no
349 : /// unprocessed data.
350 4 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
351 4 : if let ProtoState::Closed = self.state {
352 0 : Ok(None)
353 : } else {
354 4 : match self.framed.read_message().await {
355 2 : Ok(m) => {
356 2 : trace!("read msg {:?}", m);
357 2 : Ok(m)
358 : }
359 0 : Err(e) => {
360 : // remember not to try to read anymore
361 0 : self.state = ProtoState::Closed;
362 0 : Err(e)
363 : }
364 : }
365 : }
366 0 : }
367 :
368 : /// Write message into internal output buffer, doesn't flush it. Technically
369 : /// error type can be only ProtocolError here (if, unlikely, serialization
370 : /// fails), but callers typically wrap it anyway.
371 19 : pub fn write_message_noflush(
372 19 : &mut self,
373 19 : message: &BeMessage<'_>,
374 19 : ) -> Result<&mut Self, ConnectionError> {
375 19 : self.framed.write_message_noflush(message)?;
376 19 : trace!("wrote msg {:?}", message);
377 19 : Ok(self)
378 0 : }
379 :
380 : /// Flush output buffer into the socket.
381 5 : pub async fn flush(&mut self) -> io::Result<()> {
382 5 : self.framed.flush().await
383 0 : }
384 :
385 : /// Polling version of `flush()`, saves the caller need to pin.
386 0 : pub fn poll_flush(
387 0 : &mut self,
388 0 : cx: &mut std::task::Context<'_>,
389 0 : ) -> Poll<Result<(), std::io::Error>> {
390 0 : let flush_fut = std::pin::pin!(self.flush());
391 0 : flush_fut.poll(cx)
392 0 : }
393 :
394 : /// Write message into internal output buffer and flush it to the stream.
395 3 : pub async fn write_message(
396 3 : &mut self,
397 3 : message: &BeMessage<'_>,
398 3 : ) -> Result<&mut Self, ConnectionError> {
399 3 : self.write_message_noflush(message)?;
400 3 : self.flush().await?;
401 3 : Ok(self)
402 0 : }
403 :
404 : /// Returns an AsyncWrite implementation that wraps all the data written
405 : /// to it in CopyData messages, and writes them to the connection
406 : ///
407 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
408 0 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
409 0 : CopyDataWriter { pgb: self }
410 0 : }
411 :
412 : /// Wrapper for run_message_loop() that shuts down socket when we are done
413 2 : pub async fn run(
414 2 : mut self,
415 2 : handler: &mut impl Handler<IO>,
416 2 : cancel: &CancellationToken,
417 2 : ) -> Result<(), QueryError> {
418 2 : let ret = self.run_message_loop(handler, cancel).await;
419 :
420 0 : tokio::select! {
421 0 : _ = cancel.cancelled() => {
422 0 : // do nothing; we most likely got already stopped by shutdown and will log it next.
423 0 : }
424 0 : _ = self.framed.shutdown() => {
425 0 : // socket might be already closed, e.g. if previously received error,
426 0 : // so ignore result.
427 0 : },
428 : }
429 :
430 0 : match ret {
431 0 : Ok(()) => Ok(()),
432 : Err(QueryError::Shutdown) => {
433 0 : info!("Stopped due to shutdown");
434 0 : Ok(())
435 : }
436 : Err(QueryError::Reconnect) => {
437 : // Dropping out of this loop implicitly disconnects
438 0 : info!("Stopped due to handler reconnect request");
439 0 : Ok(())
440 : }
441 0 : Err(QueryError::Disconnected(e)) => {
442 0 : info!("Disconnected ({e:#})");
443 : // Disconnection is not an error: we just use it that way internally to drop
444 : // out of loops.
445 0 : Ok(())
446 : }
447 0 : e => e,
448 : }
449 0 : }
450 :
451 2 : async fn run_message_loop(
452 2 : &mut self,
453 2 : handler: &mut impl Handler<IO>,
454 2 : cancel: &CancellationToken,
455 2 : ) -> Result<(), QueryError> {
456 2 : trace!("postgres backend to {:?} started", self.peer_addr);
457 :
458 2 : tokio::select!(
459 : biased;
460 :
461 2 : _ = cancel.cancelled() => {
462 : // We were requested to shut down.
463 0 : tracing::info!("shutdown request received during handshake");
464 0 : return Err(QueryError::Shutdown)
465 : },
466 :
467 2 : handshake_r = self.handshake(handler) => {
468 2 : handshake_r?;
469 : }
470 : );
471 :
472 : // Authentication completed
473 2 : let mut query_string = Bytes::new();
474 4 : while let Some(msg) = tokio::select!(
475 : biased;
476 4 : _ = cancel.cancelled() => {
477 : // We were requested to shut down.
478 0 : tracing::info!("shutdown request received in run_message_loop");
479 0 : return Err(QueryError::Shutdown)
480 : },
481 4 : msg = self.read_message() => { msg },
482 0 : )? {
483 2 : trace!("got message {:?}", msg);
484 :
485 2 : let result = self.process_message(handler, msg, &mut query_string).await;
486 2 : tokio::select!(
487 : biased;
488 2 : _ = cancel.cancelled() => {
489 : // We were requested to shut down.
490 0 : tracing::info!("shutdown request received during response flush");
491 :
492 : // If we exited process_message with a shutdown error, there may be
493 : // some valid response content on in our transmit buffer: permit sending
494 : // this within a short timeout. This is a best effort thing so we don't
495 : // care about the result.
496 0 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
497 :
498 0 : return Err(QueryError::Shutdown)
499 : },
500 2 : flush_r = self.flush() => {
501 2 : flush_r?;
502 : }
503 : );
504 :
505 2 : match result? {
506 : ProcessMsgResult::Continue => {
507 2 : continue;
508 : }
509 0 : ProcessMsgResult::Break => break,
510 : }
511 : }
512 :
513 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
514 0 : Ok(())
515 0 : }
516 :
517 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
518 1 : async fn tls_upgrade(
519 1 : src: MaybeTlsStream<IO>,
520 1 : tls_config: Arc<rustls::ServerConfig>,
521 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
522 1 : match src {
523 1 : MaybeTlsStream::Unencrypted(s) => {
524 1 : let acceptor = TlsAcceptor::from(tls_config);
525 1 : let tls_stream = acceptor.accept(s).await?;
526 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
527 : }
528 : MaybeTlsStream::Tls(_) => {
529 0 : anyhow::bail!("TLS already started");
530 : }
531 : }
532 0 : }
533 :
534 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
535 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
536 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
537 1 : MaybeWriteOnly::Full(framed) => {
538 1 : let tls_config = self
539 1 : .tls_config
540 1 : .as_ref()
541 1 : .context("start_tls called without conf")?
542 1 : .clone();
543 1 : let tls_framed = framed
544 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
545 1 : .await?;
546 : // push back ready TLS stream
547 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
548 1 : Ok(())
549 : }
550 : MaybeWriteOnly::WriteOnly(_) => {
551 0 : anyhow::bail!("TLS upgrade attempt in split state")
552 : }
553 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
554 : }
555 0 : }
556 :
557 : /// Split off owned read part from which messages can be read in different
558 : /// task/thread.
559 0 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
560 : // temporary replace stream with fake to cook split one, Indiana Jones style
561 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
562 0 : MaybeWriteOnly::Full(framed) => {
563 0 : let (reader, writer) = framed.split();
564 0 : self.framed = MaybeWriteOnly::WriteOnly(writer);
565 0 : Ok(PostgresBackendReader {
566 0 : reader,
567 0 : closed: false,
568 0 : })
569 : }
570 : MaybeWriteOnly::WriteOnly(_) => {
571 0 : anyhow::bail!("PostgresBackend is already split")
572 : }
573 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
574 : }
575 0 : }
576 :
577 : /// Join read part back.
578 0 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
579 : // temporary replace stream with fake to cook joined one, Indiana Jones style
580 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
581 : MaybeWriteOnly::Full(_) => {
582 0 : anyhow::bail!("PostgresBackend is not split")
583 : }
584 0 : MaybeWriteOnly::WriteOnly(writer) => {
585 0 : let joined = Framed::unsplit(reader.reader, writer);
586 0 : self.framed = MaybeWriteOnly::Full(joined);
587 : // if reader encountered connection error, do not attempt reading anymore
588 0 : if reader.closed {
589 0 : self.state = ProtoState::Closed;
590 0 : }
591 0 : Ok(())
592 : }
593 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
594 : }
595 0 : }
596 :
597 : /// Perform handshake with the client, transitioning to Established.
598 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
599 2 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
600 5 : while self.state < ProtoState::Authentication {
601 3 : match self.framed.read_startup_message().await? {
602 3 : Some(msg) => {
603 3 : self.process_startup_message(handler, msg).await?;
604 : }
605 : None => {
606 0 : trace!(
607 0 : "postgres backend to {:?} received EOF during handshake",
608 : self.peer_addr
609 : );
610 0 : self.state = ProtoState::Closed;
611 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
612 0 : ProtocolError::Protocol("EOF during handshake".to_string()),
613 0 : )));
614 : }
615 : }
616 : }
617 :
618 : // Perform auth, if needed.
619 2 : if self.state == ProtoState::Authentication {
620 0 : match self.framed.read_message().await? {
621 0 : Some(FeMessage::PasswordMessage(m)) => {
622 0 : assert!(matches!(
623 0 : self.auth_type,
624 : AuthType::NeonJWT | AuthType::HadronJWT
625 : ));
626 :
627 0 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
628 :
629 0 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
630 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
631 0 : &short_error(&e),
632 0 : Some(e.pg_error_code()),
633 0 : ))?;
634 0 : return Err(e);
635 0 : }
636 :
637 0 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
638 0 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
639 0 : .write_message(&BeMessage::ReadyForQuery)
640 0 : .await?;
641 0 : self.state = ProtoState::Established;
642 : }
643 0 : Some(m) => {
644 0 : return Err(QueryError::Other(anyhow::anyhow!(
645 0 : "Unexpected message {:?} while waiting for handshake",
646 0 : m
647 0 : )));
648 : }
649 : None => {
650 0 : trace!(
651 0 : "postgres backend to {:?} received EOF during auth",
652 : self.peer_addr
653 : );
654 0 : self.state = ProtoState::Closed;
655 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
656 0 : ProtocolError::Protocol("EOF during auth".to_string()),
657 0 : )));
658 : }
659 : }
660 0 : }
661 :
662 2 : Ok(())
663 0 : }
664 :
665 : /// Process startup packet:
666 : /// - transition to Established if auth type is trust
667 : /// - transition to Authentication if auth type is NeonJWT.
668 : /// - or perform TLS handshake -- then need to call this again to receive
669 : /// actual startup packet.
670 3 : async fn process_startup_message(
671 3 : &mut self,
672 3 : handler: &mut impl Handler<IO>,
673 3 : msg: FeStartupPacket,
674 3 : ) -> Result<(), QueryError> {
675 3 : assert!(self.state < ProtoState::Authentication);
676 3 : let have_tls = self.tls_config.is_some();
677 3 : match msg {
678 1 : FeStartupPacket::SslRequest { direct } => {
679 1 : debug!("SSL requested");
680 :
681 1 : if !direct {
682 1 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
683 1 : .await?;
684 0 : } else if !have_tls {
685 0 : return Err(QueryError::Other(anyhow::anyhow!(
686 0 : "direct SSL negotiation but no TLS support"
687 0 : )));
688 0 : }
689 :
690 1 : if have_tls {
691 1 : self.start_tls().await?;
692 1 : self.state = ProtoState::Encrypted;
693 0 : }
694 : }
695 : FeStartupPacket::GssEncRequest => {
696 0 : debug!("GSS requested");
697 0 : self.write_message(&BeMessage::EncryptionResponse(false))
698 0 : .await?;
699 : }
700 : FeStartupPacket::StartupMessage { .. } => {
701 2 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
702 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
703 0 : .await?;
704 0 : return Err(QueryError::Other(anyhow::anyhow!(
705 0 : "client did not connect with TLS"
706 0 : )));
707 0 : }
708 :
709 : // NB: startup() may change self.auth_type -- we are using that in proxy code
710 : // to bypass auth for new users.
711 2 : handler.startup(self, &msg)?;
712 :
713 2 : match self.auth_type {
714 : AuthType::Trust => {
715 2 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
716 2 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
717 2 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
718 : // The async python driver requires a valid server_version
719 2 : .write_message_noflush(&BeMessage::server_version("14.1"))?
720 2 : .write_message(&BeMessage::ReadyForQuery)
721 2 : .await?;
722 2 : self.state = ProtoState::Established;
723 : }
724 : AuthType::NeonJWT | AuthType::HadronJWT => {
725 0 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
726 0 : .await?;
727 0 : self.state = ProtoState::Authentication;
728 : }
729 : }
730 : }
731 : FeStartupPacket::CancelRequest { .. } => {
732 0 : return Err(QueryError::Other(anyhow::anyhow!(
733 0 : "Unexpected CancelRequest message during handshake"
734 0 : )));
735 : }
736 : }
737 3 : Ok(())
738 0 : }
739 :
740 : // Proto looks like this:
741 : // FeMessage::Query("pagestream_v2{FeMessage::CopyData(PagesetreamFeMessage::GetPage(..))}")
742 :
743 2 : async fn process_message(
744 2 : &mut self,
745 2 : handler: &mut impl Handler<IO>,
746 2 : msg: FeMessage,
747 2 : unnamed_query_string: &mut Bytes,
748 2 : ) -> Result<ProcessMsgResult, QueryError> {
749 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
750 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
751 2 : assert!(self.state == ProtoState::Established);
752 :
753 2 : match msg {
754 2 : FeMessage::Query(body) => {
755 : // remove null terminator
756 2 : let query_string = cstr_to_str(&body)?;
757 :
758 2 : trace!("got query {query_string:?}");
759 2 : if let Err(e) = handler.process_query(self, query_string).await {
760 0 : match e {
761 0 : err @ QueryError::Shutdown => {
762 : // Notify postgres of the connection shutdown at the libpq
763 : // protocol level. This avoids postgres having to tell apart
764 : // from an idle connection and a stale one, which is bug prone.
765 0 : let shutdown_error = short_error(&err);
766 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
767 0 : &shutdown_error,
768 0 : Some(err.pg_error_code()),
769 0 : ))?;
770 :
771 0 : return Ok(ProcessMsgResult::Break);
772 : }
773 : QueryError::SimulatedConnectionError => {
774 0 : return Err(QueryError::SimulatedConnectionError);
775 : }
776 0 : err @ QueryError::Reconnect => {
777 : // Instruct the client to reconnect, stop processing messages
778 : // from this libpq connection and, finally, disconnect from the
779 : // server side (returning an Err achieves the later).
780 : //
781 : // Note the flushing is done by the caller.
782 0 : let reconnect_error = short_error(&err);
783 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
784 0 : &reconnect_error,
785 0 : Some(err.pg_error_code()),
786 0 : ))?;
787 :
788 0 : return Err(err);
789 : }
790 0 : e => {
791 0 : log_query_error(query_string, &e);
792 0 : let short_error = short_error(&e);
793 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
794 0 : &short_error,
795 0 : Some(e.pg_error_code()),
796 0 : ))?;
797 : }
798 : }
799 0 : }
800 2 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
801 : }
802 :
803 0 : FeMessage::Parse(m) => {
804 0 : *unnamed_query_string = m.query_string;
805 0 : self.write_message_noflush(&BeMessage::ParseComplete)?;
806 : }
807 :
808 : FeMessage::Describe(_) => {
809 0 : self.write_message_noflush(&BeMessage::ParameterDescription)?
810 0 : .write_message_noflush(&BeMessage::NoData)?;
811 : }
812 :
813 : FeMessage::Bind(_) => {
814 0 : self.write_message_noflush(&BeMessage::BindComplete)?;
815 : }
816 :
817 : FeMessage::Close(_) => {
818 0 : self.write_message_noflush(&BeMessage::CloseComplete)?;
819 : }
820 :
821 : FeMessage::Execute(_) => {
822 0 : let query_string = cstr_to_str(unnamed_query_string)?;
823 0 : trace!("got execute {query_string:?}");
824 0 : if let Err(e) = handler.process_query(self, query_string).await {
825 0 : log_query_error(query_string, &e);
826 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
827 0 : &e.to_string(),
828 0 : Some(e.pg_error_code()),
829 0 : ))?;
830 0 : }
831 : // NOTE there is no ReadyForQuery message. This handler is used
832 : // for basebackup and it uses CopyOut which doesn't require
833 : // ReadyForQuery message and backend just switches back to
834 : // processing mode after sending CopyDone or ErrorResponse.
835 : }
836 :
837 : FeMessage::Sync => {
838 0 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
839 : }
840 :
841 : FeMessage::Terminate => {
842 0 : return Ok(ProcessMsgResult::Break);
843 : }
844 :
845 : // We prefer explicit pattern matching to wildcards, because
846 : // this helps us spot the places where new variants are missing
847 : FeMessage::CopyData(_)
848 : | FeMessage::CopyDone
849 : | FeMessage::CopyFail
850 : | FeMessage::PasswordMessage(_) => {
851 0 : return Err(QueryError::Other(anyhow::anyhow!(
852 0 : "unexpected message type: {msg:?}",
853 0 : )));
854 : }
855 : }
856 :
857 2 : Ok(ProcessMsgResult::Continue)
858 0 : }
859 :
860 : /// - Log as info/error result of handling COPY stream and send back
861 : /// ErrorResponse if that makes sense.
862 : /// - Shutdown the stream if we got Terminate.
863 : /// - Then close the connection because we don't handle exiting from COPY
864 : /// stream normally.
865 0 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
866 : use CopyStreamHandlerEnd::*;
867 :
868 0 : let expected_end = match &end {
869 0 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF | Cancelled => true,
870 : // The timeline doesn't exist and we have been requested to not auto-create it.
871 : // Compute requests for timelines that haven't been created yet
872 : // might reach us before the storcon request to create those timelines.
873 0 : TimelineNoCreate => true,
874 0 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
875 0 : if is_expected_io_error(io_error) =>
876 : {
877 0 : true
878 : }
879 0 : _ => false,
880 : };
881 0 : if expected_end {
882 0 : info!("terminated: {:#}", end);
883 : } else {
884 0 : error!("terminated: {:?}", end);
885 : }
886 :
887 : // Note: no current usages ever send this
888 0 : if let CopyDone = &end {
889 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
890 0 : error!("failed to send CopyDone: {}", e);
891 0 : }
892 0 : }
893 :
894 0 : let err_to_send_and_errcode = match &end {
895 0 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
896 0 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
897 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
898 : // PG walsender; evidently and per my docs reading client should
899 : // finish it with CopyDone). It is not a problem to recover from it
900 : // finishing the stream in both directions like we do, but note that
901 : // sync rust-postgres client (which we don't use anymore) hangs if
902 : // socket is not closed here.
903 : // https://github.com/sfackler/rust-postgres/issues/755
904 : // https://github.com/neondatabase/neon/issues/935
905 : //
906 : // Currently, the version of tokio_postgres replication patch we use
907 : // sends this when it closes the stream (e.g. pageserver decided to
908 : // switch conn to another safekeeper and client gets dropped).
909 : // Moreover, seems like 'connection' task errors with 'unexpected
910 : // message from server' when it receives ErrorResponse (anything but
911 : // CopyData/CopyDone) back.
912 0 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
913 :
914 : // When cancelled, send no response: we must not risk blocking on sending that response
915 0 : Cancelled => None,
916 0 : _ => None,
917 : };
918 0 : if let Some((err, errcode)) = err_to_send_and_errcode {
919 0 : if let Err(ee) = self
920 0 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
921 0 : .await
922 : {
923 0 : error!("failed to send ErrorResponse: {}", ee);
924 0 : }
925 0 : }
926 :
927 : // Proper COPY stream finishing to continue using the connection is not
928 : // implemented at the server side (we don't need it so far). To prevent
929 : // further usages of the connection, close it.
930 0 : self.framed.shutdown().await.ok();
931 0 : self.state = ProtoState::Closed;
932 0 : }
933 : }
934 :
935 : pub struct PostgresBackendReader<IO> {
936 : reader: FramedReader<MaybeTlsStream<IO>>,
937 : closed: bool, // true if received error closing the connection
938 : }
939 :
940 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
941 : /// Read full message or return None if connection is cleanly closed with no
942 : /// unprocessed data.
943 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
944 0 : match self.reader.read_message().await {
945 0 : Ok(m) => {
946 0 : trace!("read msg {:?}", m);
947 0 : Ok(m)
948 : }
949 0 : Err(e) => {
950 0 : self.closed = true;
951 0 : Err(e)
952 : }
953 : }
954 0 : }
955 :
956 : /// Get CopyData contents of the next message in COPY stream or error
957 : /// closing it. The error type is wider than actual errors which can happen
958 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
959 : /// current callers.
960 0 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
961 0 : match self.read_message().await? {
962 0 : Some(msg) => match msg {
963 0 : FeMessage::CopyData(m) => Ok(m),
964 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
965 0 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
966 0 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
967 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
968 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {msg:?}")),
969 0 : ))),
970 : },
971 0 : None => Err(CopyStreamHandlerEnd::EOF),
972 : }
973 0 : }
974 : }
975 :
976 : ///
977 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
978 : /// messages.
979 : ///
980 : pub struct CopyDataWriter<'a, IO> {
981 : pgb: &'a mut PostgresBackend<IO>,
982 : }
983 :
984 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'_, IO> {
985 0 : fn poll_write(
986 0 : self: Pin<&mut Self>,
987 0 : cx: &mut std::task::Context<'_>,
988 0 : buf: &[u8],
989 0 : ) -> Poll<Result<usize, std::io::Error>> {
990 0 : let this = self.get_mut();
991 :
992 : // It's not strictly required to flush between each message, but makes it easier
993 : // to view in wireshark, and usually the messages that the callers write are
994 : // decently-sized anyway.
995 0 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
996 0 : return Poll::Ready(Err(err));
997 0 : }
998 :
999 : // CopyData
1000 : // XXX: if the input is large, we should split it into multiple messages.
1001 : // Not sure what the threshold should be, but the ultimate hard limit is that
1002 : // the length cannot exceed u32.
1003 0 : this.pgb
1004 0 : .write_message_noflush(&BeMessage::CopyData(buf))
1005 : // write_message only writes to the buffer, so it can fail iff the
1006 : // message is invaid, but CopyData can't be invalid.
1007 0 : .map_err(|_| io::Error::other("failed to serialize CopyData"))?;
1008 :
1009 0 : Poll::Ready(Ok(buf.len()))
1010 0 : }
1011 :
1012 0 : fn poll_flush(
1013 0 : self: Pin<&mut Self>,
1014 0 : cx: &mut std::task::Context<'_>,
1015 0 : ) -> Poll<Result<(), std::io::Error>> {
1016 0 : let this = self.get_mut();
1017 0 : this.pgb.poll_flush(cx)
1018 0 : }
1019 :
1020 0 : fn poll_shutdown(
1021 0 : self: Pin<&mut Self>,
1022 0 : cx: &mut std::task::Context<'_>,
1023 0 : ) -> Poll<Result<(), std::io::Error>> {
1024 0 : let this = self.get_mut();
1025 0 : this.pgb.poll_flush(cx)
1026 0 : }
1027 : }
1028 :
1029 0 : pub fn short_error(e: &QueryError) -> String {
1030 0 : match e {
1031 0 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
1032 0 : QueryError::Reconnect => "reconnect".to_string(),
1033 0 : QueryError::Shutdown => "shutdown".to_string(),
1034 0 : QueryError::NotFound(_) => "not found".to_string(),
1035 0 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
1036 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
1037 0 : QueryError::Other(e) => format!("{e:#}"),
1038 : }
1039 0 : }
1040 :
1041 0 : fn log_query_error(query: &str, e: &QueryError) {
1042 : // If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
1043 0 : match e {
1044 0 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1045 0 : if is_expected_io_error(io_error) {
1046 0 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1047 : } else {
1048 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1049 : }
1050 : }
1051 0 : QueryError::Disconnected(other_connection_error) => {
1052 0 : error!(
1053 0 : "query handler for '{query}' failed with connection error: {other_connection_error:?}"
1054 : )
1055 : }
1056 : QueryError::SimulatedConnectionError => {
1057 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1058 : }
1059 : QueryError::Reconnect => {
1060 0 : info!("query handler for '{query}' requested client to reconnect")
1061 : }
1062 : QueryError::Shutdown => {
1063 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1064 : }
1065 0 : QueryError::NotFound(reason) => {
1066 0 : info!("query handler for '{query}' entity not found: {reason}")
1067 : }
1068 0 : QueryError::Unauthorized(e) => {
1069 0 : warn!("query handler for '{query}' failed with authentication error: {e}");
1070 : }
1071 0 : QueryError::Other(e) => {
1072 0 : error!("query handler for '{query}' failed: {e:?}");
1073 : }
1074 : }
1075 0 : }
1076 :
1077 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1078 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1079 : #[derive(thiserror::Error, Debug)]
1080 : pub enum CopyStreamHandlerEnd {
1081 : /// Handler initiates the end of streaming.
1082 : #[error("{0}")]
1083 : ServerInitiated(String),
1084 : #[error("received CopyDone")]
1085 : CopyDone,
1086 : #[error("received CopyFail")]
1087 : CopyFail,
1088 : #[error("received Terminate")]
1089 : Terminate,
1090 : #[error("EOF on COPY stream")]
1091 : EOF,
1092 : #[error("timeline not found, and allow_timeline_creation is false")]
1093 : TimelineNoCreate,
1094 : /// The connection was lost
1095 : #[error("connection error: {0}")]
1096 : Disconnected(#[from] ConnectionError),
1097 : #[error("Shutdown")]
1098 : Cancelled,
1099 : /// Some other error
1100 : #[error(transparent)]
1101 : Other(#[from] anyhow::Error),
1102 : }
|