Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use anyhow::Context;
8 : use bytes::Bytes;
9 : use serde::{Deserialize, Serialize};
10 : use std::io::ErrorKind;
11 : use std::net::SocketAddr;
12 : use std::os::fd::AsRawFd;
13 : use std::os::fd::RawFd;
14 : use std::pin::Pin;
15 : use std::sync::Arc;
16 : use std::task::{ready, Poll};
17 : use std::{fmt, io};
18 : use std::{future::Future, str::FromStr};
19 : use tokio::io::{AsyncRead, AsyncWrite};
20 : use tokio_rustls::TlsAcceptor;
21 : use tokio_util::sync::CancellationToken;
22 : use tracing::{debug, error, info, trace, warn};
23 :
24 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
25 : use pq_proto::{
26 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
27 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
28 : };
29 :
30 : /// An error, occurred during query processing:
31 : /// either during the connection ([`ConnectionError`]) or before/after it.
32 : #[derive(thiserror::Error, Debug)]
33 : pub enum QueryError {
34 : /// The connection was lost while processing the query.
35 : #[error(transparent)]
36 : Disconnected(#[from] ConnectionError),
37 : /// We were instructed to shutdown while processing the query
38 : #[error("Shutting down")]
39 : Shutdown,
40 : /// Query handler indicated that client should reconnect
41 : #[error("Server requested reconnect")]
42 : Reconnect,
43 : /// Query named an entity that was not found
44 : #[error("Not found: {0}")]
45 : NotFound(std::borrow::Cow<'static, str>),
46 : /// Authentication failure
47 : #[error("Unauthorized: {0}")]
48 : Unauthorized(std::borrow::Cow<'static, str>),
49 : #[error("Simulated Connection Error")]
50 : SimulatedConnectionError,
51 : /// Some other error
52 : #[error(transparent)]
53 : Other(#[from] anyhow::Error),
54 : }
55 :
56 : impl From<io::Error> for QueryError {
57 0 : fn from(e: io::Error) -> Self {
58 0 : Self::Disconnected(ConnectionError::Io(e))
59 0 : }
60 : }
61 :
62 : impl QueryError {
63 0 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
64 0 : match self {
65 0 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
66 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
67 0 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
68 0 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
69 : }
70 0 : }
71 : }
72 :
73 : /// Returns true if the given error is a normal consequence of a network issue,
74 : /// or the client closing the connection.
75 : ///
76 : /// These errors can happen during normal operations,
77 : /// and don't indicate a bug in our code.
78 0 : pub fn is_expected_io_error(e: &io::Error) -> bool {
79 : use io::ErrorKind::*;
80 0 : matches!(
81 0 : e.kind(),
82 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
83 : )
84 0 : }
85 :
86 : pub trait Handler<IO> {
87 : /// Handle single query.
88 : /// postgres_backend will issue ReadyForQuery after calling this (this
89 : /// might be not what we want after CopyData streaming, but currently we don't
90 : /// care). It will also flush out the output buffer.
91 : fn process_query(
92 : &mut self,
93 : pgb: &mut PostgresBackend<IO>,
94 : query_string: &str,
95 : ) -> impl Future<Output = Result<(), QueryError>>;
96 :
97 : /// Called on startup packet receival, allows to process params.
98 : ///
99 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
100 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
101 : /// to override whole init logic in implementations.
102 2 : fn startup(
103 2 : &mut self,
104 2 : _pgb: &mut PostgresBackend<IO>,
105 2 : _sm: &FeStartupPacket,
106 2 : ) -> Result<(), QueryError> {
107 2 : Ok(())
108 2 : }
109 :
110 : /// Check auth jwt
111 0 : fn check_auth_jwt(
112 0 : &mut self,
113 0 : _pgb: &mut PostgresBackend<IO>,
114 0 : _jwt_response: &[u8],
115 0 : ) -> Result<(), QueryError> {
116 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
117 0 : }
118 : }
119 :
120 : /// PostgresBackend protocol state.
121 : /// XXX: The order of the constructors matters.
122 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
123 : pub enum ProtoState {
124 : /// Nothing happened yet.
125 : Initialization,
126 : /// Encryption handshake is done; waiting for encrypted Startup message.
127 : Encrypted,
128 : /// Waiting for password (auth token).
129 : Authentication,
130 : /// Performed handshake and auth, ReadyForQuery is issued.
131 : Established,
132 : Closed,
133 : }
134 :
135 : #[derive(Clone, Copy)]
136 : pub enum ProcessMsgResult {
137 : Continue,
138 : Break,
139 : }
140 :
141 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
142 : pub enum MaybeTlsStream<IO> {
143 : Unencrypted(IO),
144 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
145 : }
146 :
147 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
148 5 : fn poll_write(
149 5 : self: Pin<&mut Self>,
150 5 : cx: &mut std::task::Context<'_>,
151 5 : buf: &[u8],
152 5 : ) -> Poll<io::Result<usize>> {
153 5 : match self.get_mut() {
154 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
155 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
156 : }
157 5 : }
158 5 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
159 5 : match self.get_mut() {
160 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
161 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
162 : }
163 5 : }
164 0 : fn poll_shutdown(
165 0 : self: Pin<&mut Self>,
166 0 : cx: &mut std::task::Context<'_>,
167 0 : ) -> Poll<io::Result<()>> {
168 0 : match self.get_mut() {
169 0 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
170 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
171 : }
172 0 : }
173 : }
174 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
175 12 : fn poll_read(
176 12 : self: Pin<&mut Self>,
177 12 : cx: &mut std::task::Context<'_>,
178 12 : buf: &mut tokio::io::ReadBuf<'_>,
179 12 : ) -> Poll<io::Result<()>> {
180 12 : match self.get_mut() {
181 7 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
182 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
183 : }
184 12 : }
185 : }
186 :
187 0 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
188 : pub enum AuthType {
189 : Trust,
190 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
191 : NeonJWT,
192 : }
193 :
194 : impl FromStr for AuthType {
195 : type Err = anyhow::Error;
196 :
197 0 : fn from_str(s: &str) -> Result<Self, Self::Err> {
198 0 : match s {
199 0 : "Trust" => Ok(Self::Trust),
200 0 : "NeonJWT" => Ok(Self::NeonJWT),
201 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
202 : }
203 0 : }
204 : }
205 :
206 : impl fmt::Display for AuthType {
207 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 0 : f.write_str(match self {
209 0 : AuthType::Trust => "Trust",
210 0 : AuthType::NeonJWT => "NeonJWT",
211 : })
212 0 : }
213 : }
214 :
215 : /// Either full duplex Framed or write only half; the latter is left in
216 : /// PostgresBackend after call to `split`. In principle we could always store a
217 : /// pair of splitted handles, but that would force to to pay splitting price
218 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
219 : enum MaybeWriteOnly<IO> {
220 : Full(Framed<MaybeTlsStream<IO>>),
221 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
222 : Broken, // temporary value palmed off during the split
223 : }
224 :
225 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
226 3 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
227 3 : match self {
228 3 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
229 : MaybeWriteOnly::WriteOnly(_) => {
230 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
231 : }
232 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
233 : }
234 3 : }
235 :
236 4 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
237 4 : match self {
238 4 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
239 : MaybeWriteOnly::WriteOnly(_) => {
240 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
241 : }
242 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
243 : }
244 2 : }
245 :
246 19 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
247 19 : match self {
248 19 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
249 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
250 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
251 : }
252 19 : }
253 :
254 5 : async fn flush(&mut self) -> io::Result<()> {
255 5 : match self {
256 5 : MaybeWriteOnly::Full(framed) => framed.flush().await,
257 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
258 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
259 : }
260 5 : }
261 :
262 : /// Cancellation safe as long as the underlying IO is cancellation safe.
263 0 : async fn shutdown(&mut self) -> io::Result<()> {
264 0 : match self {
265 0 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
266 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
267 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
268 : }
269 0 : }
270 : }
271 :
272 : pub struct PostgresBackend<IO> {
273 : pub socket_fd: RawFd,
274 : framed: MaybeWriteOnly<IO>,
275 :
276 : pub state: ProtoState,
277 :
278 : auth_type: AuthType,
279 :
280 : peer_addr: SocketAddr,
281 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
282 : }
283 :
284 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
285 :
286 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
287 2 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
288 2 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
289 2 : std::str::from_utf8(without_null).map_err(|e| e.into())
290 2 : }
291 :
292 : impl PostgresBackend<tokio::net::TcpStream> {
293 2 : pub fn new(
294 2 : socket: tokio::net::TcpStream,
295 2 : auth_type: AuthType,
296 2 : tls_config: Option<Arc<rustls::ServerConfig>>,
297 2 : ) -> io::Result<Self> {
298 2 : let peer_addr = socket.peer_addr()?;
299 2 : let socket_fd = socket.as_raw_fd();
300 2 : let stream = MaybeTlsStream::Unencrypted(socket);
301 2 :
302 2 : Ok(Self {
303 2 : socket_fd,
304 2 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
305 2 : state: ProtoState::Initialization,
306 2 : auth_type,
307 2 : tls_config,
308 2 : peer_addr,
309 2 : })
310 2 : }
311 : }
312 :
313 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
314 0 : pub fn new_from_io(
315 0 : socket_fd: RawFd,
316 0 : socket: IO,
317 0 : peer_addr: SocketAddr,
318 0 : auth_type: AuthType,
319 0 : tls_config: Option<Arc<rustls::ServerConfig>>,
320 0 : ) -> io::Result<Self> {
321 0 : let stream = MaybeTlsStream::Unencrypted(socket);
322 0 :
323 0 : Ok(Self {
324 0 : socket_fd,
325 0 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
326 0 : state: ProtoState::Initialization,
327 0 : auth_type,
328 0 : tls_config,
329 0 : peer_addr,
330 0 : })
331 0 : }
332 :
333 0 : pub fn get_peer_addr(&self) -> &SocketAddr {
334 0 : &self.peer_addr
335 0 : }
336 :
337 : /// Read full message or return None if connection is cleanly closed with no
338 : /// unprocessed data.
339 4 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
340 4 : if let ProtoState::Closed = self.state {
341 0 : Ok(None)
342 : } else {
343 4 : match self.framed.read_message().await {
344 2 : Ok(m) => {
345 2 : trace!("read msg {:?}", m);
346 2 : Ok(m)
347 : }
348 0 : Err(e) => {
349 0 : // remember not to try to read anymore
350 0 : self.state = ProtoState::Closed;
351 0 : Err(e)
352 : }
353 : }
354 : }
355 2 : }
356 :
357 : /// Write message into internal output buffer, doesn't flush it. Technically
358 : /// error type can be only ProtocolError here (if, unlikely, serialization
359 : /// fails), but callers typically wrap it anyway.
360 19 : pub fn write_message_noflush(
361 19 : &mut self,
362 19 : message: &BeMessage<'_>,
363 19 : ) -> Result<&mut Self, ConnectionError> {
364 19 : self.framed.write_message_noflush(message)?;
365 19 : trace!("wrote msg {:?}", message);
366 19 : Ok(self)
367 19 : }
368 :
369 : /// Flush output buffer into the socket.
370 5 : pub async fn flush(&mut self) -> io::Result<()> {
371 5 : self.framed.flush().await
372 5 : }
373 :
374 : /// Polling version of `flush()`, saves the caller need to pin.
375 0 : pub fn poll_flush(
376 0 : &mut self,
377 0 : cx: &mut std::task::Context<'_>,
378 0 : ) -> Poll<Result<(), std::io::Error>> {
379 0 : let flush_fut = std::pin::pin!(self.flush());
380 0 : flush_fut.poll(cx)
381 0 : }
382 :
383 : /// Write message into internal output buffer and flush it to the stream.
384 3 : pub async fn write_message(
385 3 : &mut self,
386 3 : message: &BeMessage<'_>,
387 3 : ) -> Result<&mut Self, ConnectionError> {
388 3 : self.write_message_noflush(message)?;
389 3 : self.flush().await?;
390 3 : Ok(self)
391 3 : }
392 :
393 : /// Returns an AsyncWrite implementation that wraps all the data written
394 : /// to it in CopyData messages, and writes them to the connection
395 : ///
396 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
397 0 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
398 0 : CopyDataWriter { pgb: self }
399 0 : }
400 :
401 : /// Wrapper for run_message_loop() that shuts down socket when we are done
402 2 : pub async fn run(
403 2 : mut self,
404 2 : handler: &mut impl Handler<IO>,
405 2 : cancel: &CancellationToken,
406 2 : ) -> Result<(), QueryError> {
407 2 : let ret = self.run_message_loop(handler, cancel).await;
408 :
409 0 : tokio::select! {
410 0 : _ = cancel.cancelled() => {
411 0 : // do nothing; we most likely got already stopped by shutdown and will log it next.
412 0 : }
413 0 : _ = self.framed.shutdown() => {
414 0 : // socket might be already closed, e.g. if previously received error,
415 0 : // so ignore result.
416 0 : },
417 : }
418 :
419 0 : match ret {
420 0 : Ok(()) => Ok(()),
421 : Err(QueryError::Shutdown) => {
422 0 : info!("Stopped due to shutdown");
423 0 : Ok(())
424 : }
425 : Err(QueryError::Reconnect) => {
426 : // Dropping out of this loop implicitly disconnects
427 0 : info!("Stopped due to handler reconnect request");
428 0 : Ok(())
429 : }
430 0 : Err(QueryError::Disconnected(e)) => {
431 0 : info!("Disconnected ({e:#})");
432 : // Disconnection is not an error: we just use it that way internally to drop
433 : // out of loops.
434 0 : Ok(())
435 : }
436 0 : e => e,
437 : }
438 0 : }
439 :
440 2 : async fn run_message_loop(
441 2 : &mut self,
442 2 : handler: &mut impl Handler<IO>,
443 2 : cancel: &CancellationToken,
444 2 : ) -> Result<(), QueryError> {
445 2 : trace!("postgres backend to {:?} started", self.peer_addr);
446 :
447 2 : tokio::select!(
448 : biased;
449 :
450 2 : _ = cancel.cancelled() => {
451 : // We were requested to shut down.
452 0 : tracing::info!("shutdown request received during handshake");
453 0 : return Err(QueryError::Shutdown)
454 : },
455 :
456 2 : handshake_r = self.handshake(handler) => {
457 2 : handshake_r?;
458 : }
459 : );
460 :
461 : // Authentication completed
462 2 : let mut query_string = Bytes::new();
463 4 : while let Some(msg) = tokio::select!(
464 : biased;
465 4 : _ = cancel.cancelled() => {
466 : // We were requested to shut down.
467 0 : tracing::info!("shutdown request received in run_message_loop");
468 0 : return Err(QueryError::Shutdown)
469 : },
470 4 : msg = self.read_message() => { msg },
471 0 : )? {
472 2 : trace!("got message {:?}", msg);
473 :
474 2 : let result = self.process_message(handler, msg, &mut query_string).await;
475 2 : tokio::select!(
476 : biased;
477 2 : _ = cancel.cancelled() => {
478 : // We were requested to shut down.
479 0 : tracing::info!("shutdown request received during response flush");
480 :
481 : // If we exited process_message with a shutdown error, there may be
482 : // some valid response content on in our transmit buffer: permit sending
483 : // this within a short timeout. This is a best effort thing so we don't
484 : // care about the result.
485 0 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
486 0 :
487 0 : return Err(QueryError::Shutdown)
488 : },
489 2 : flush_r = self.flush() => {
490 2 : flush_r?;
491 : }
492 : );
493 :
494 2 : match result? {
495 : ProcessMsgResult::Continue => {
496 2 : continue;
497 : }
498 0 : ProcessMsgResult::Break => break,
499 : }
500 : }
501 :
502 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
503 0 : Ok(())
504 0 : }
505 :
506 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
507 1 : async fn tls_upgrade(
508 1 : src: MaybeTlsStream<IO>,
509 1 : tls_config: Arc<rustls::ServerConfig>,
510 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
511 1 : match src {
512 1 : MaybeTlsStream::Unencrypted(s) => {
513 1 : let acceptor = TlsAcceptor::from(tls_config);
514 1 : let tls_stream = acceptor.accept(s).await?;
515 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
516 : }
517 : MaybeTlsStream::Tls(_) => {
518 0 : anyhow::bail!("TLS already started");
519 : }
520 : }
521 1 : }
522 :
523 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
524 1 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
525 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
526 1 : MaybeWriteOnly::Full(framed) => {
527 1 : let tls_config = self
528 1 : .tls_config
529 1 : .as_ref()
530 1 : .context("start_tls called without conf")?
531 1 : .clone();
532 1 : let tls_framed = framed
533 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
534 1 : .await?;
535 : // push back ready TLS stream
536 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
537 1 : Ok(())
538 : }
539 : MaybeWriteOnly::WriteOnly(_) => {
540 0 : anyhow::bail!("TLS upgrade attempt in split state")
541 : }
542 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
543 : }
544 1 : }
545 :
546 : /// Split off owned read part from which messages can be read in different
547 : /// task/thread.
548 0 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
549 0 : // temporary replace stream with fake to cook split one, Indiana Jones style
550 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
551 0 : MaybeWriteOnly::Full(framed) => {
552 0 : let (reader, writer) = framed.split();
553 0 : self.framed = MaybeWriteOnly::WriteOnly(writer);
554 0 : Ok(PostgresBackendReader {
555 0 : reader,
556 0 : closed: false,
557 0 : })
558 : }
559 : MaybeWriteOnly::WriteOnly(_) => {
560 0 : anyhow::bail!("PostgresBackend is already split")
561 : }
562 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
563 : }
564 0 : }
565 :
566 : /// Join read part back.
567 0 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
568 0 : // temporary replace stream with fake to cook joined one, Indiana Jones style
569 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
570 : MaybeWriteOnly::Full(_) => {
571 0 : anyhow::bail!("PostgresBackend is not split")
572 : }
573 0 : MaybeWriteOnly::WriteOnly(writer) => {
574 0 : let joined = Framed::unsplit(reader.reader, writer);
575 0 : self.framed = MaybeWriteOnly::Full(joined);
576 0 : // if reader encountered connection error, do not attempt reading anymore
577 0 : if reader.closed {
578 0 : self.state = ProtoState::Closed;
579 0 : }
580 0 : Ok(())
581 : }
582 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
583 : }
584 0 : }
585 :
586 : /// Perform handshake with the client, transitioning to Established.
587 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
588 2 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
589 5 : while self.state < ProtoState::Authentication {
590 3 : match self.framed.read_startup_message().await? {
591 3 : Some(msg) => {
592 3 : self.process_startup_message(handler, msg).await?;
593 : }
594 : None => {
595 0 : trace!(
596 0 : "postgres backend to {:?} received EOF during handshake",
597 : self.peer_addr
598 : );
599 0 : self.state = ProtoState::Closed;
600 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
601 0 : ProtocolError::Protocol("EOF during handshake".to_string()),
602 0 : )));
603 : }
604 : }
605 : }
606 :
607 : // Perform auth, if needed.
608 2 : if self.state == ProtoState::Authentication {
609 0 : match self.framed.read_message().await? {
610 0 : Some(FeMessage::PasswordMessage(m)) => {
611 0 : assert!(self.auth_type == AuthType::NeonJWT);
612 :
613 0 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
614 :
615 0 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
616 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
617 0 : &short_error(&e),
618 0 : Some(e.pg_error_code()),
619 0 : ))?;
620 0 : return Err(e);
621 0 : }
622 0 :
623 0 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
624 0 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
625 0 : .write_message(&BeMessage::ReadyForQuery)
626 0 : .await?;
627 0 : self.state = ProtoState::Established;
628 : }
629 0 : Some(m) => {
630 0 : return Err(QueryError::Other(anyhow::anyhow!(
631 0 : "Unexpected message {:?} while waiting for handshake",
632 0 : m
633 0 : )));
634 : }
635 : None => {
636 0 : trace!(
637 0 : "postgres backend to {:?} received EOF during auth",
638 : self.peer_addr
639 : );
640 0 : self.state = ProtoState::Closed;
641 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
642 0 : ProtocolError::Protocol("EOF during auth".to_string()),
643 0 : )));
644 : }
645 : }
646 2 : }
647 :
648 2 : Ok(())
649 2 : }
650 :
651 : /// Process startup packet:
652 : /// - transition to Established if auth type is trust
653 : /// - transition to Authentication if auth type is NeonJWT.
654 : /// - or perform TLS handshake -- then need to call this again to receive
655 : /// actual startup packet.
656 3 : async fn process_startup_message(
657 3 : &mut self,
658 3 : handler: &mut impl Handler<IO>,
659 3 : msg: FeStartupPacket,
660 3 : ) -> Result<(), QueryError> {
661 3 : assert!(self.state < ProtoState::Authentication);
662 3 : let have_tls = self.tls_config.is_some();
663 3 : match msg {
664 1 : FeStartupPacket::SslRequest { direct } => {
665 1 : debug!("SSL requested");
666 :
667 1 : if !direct {
668 1 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
669 1 : .await?;
670 0 : } else if !have_tls {
671 0 : return Err(QueryError::Other(anyhow::anyhow!(
672 0 : "direct SSL negotiation but no TLS support"
673 0 : )));
674 0 : }
675 :
676 1 : if have_tls {
677 1 : self.start_tls().await?;
678 1 : self.state = ProtoState::Encrypted;
679 0 : }
680 : }
681 : FeStartupPacket::GssEncRequest => {
682 0 : debug!("GSS requested");
683 0 : self.write_message(&BeMessage::EncryptionResponse(false))
684 0 : .await?;
685 : }
686 : FeStartupPacket::StartupMessage { .. } => {
687 2 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
688 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
689 0 : .await?;
690 0 : return Err(QueryError::Other(anyhow::anyhow!(
691 0 : "client did not connect with TLS"
692 0 : )));
693 2 : }
694 2 :
695 2 : // NB: startup() may change self.auth_type -- we are using that in proxy code
696 2 : // to bypass auth for new users.
697 2 : handler.startup(self, &msg)?;
698 :
699 2 : match self.auth_type {
700 : AuthType::Trust => {
701 2 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
702 2 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
703 2 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
704 : // The async python driver requires a valid server_version
705 2 : .write_message_noflush(&BeMessage::server_version("14.1"))?
706 2 : .write_message(&BeMessage::ReadyForQuery)
707 2 : .await?;
708 2 : self.state = ProtoState::Established;
709 : }
710 : AuthType::NeonJWT => {
711 0 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
712 0 : .await?;
713 0 : self.state = ProtoState::Authentication;
714 : }
715 : }
716 : }
717 : FeStartupPacket::CancelRequest { .. } => {
718 0 : return Err(QueryError::Other(anyhow::anyhow!(
719 0 : "Unexpected CancelRequest message during handshake"
720 0 : )));
721 : }
722 : }
723 3 : Ok(())
724 3 : }
725 :
726 : // Proto looks like this:
727 : // FeMessage::Query("pagestream_v2{FeMessage::CopyData(PagesetreamFeMessage::GetPage(..))}")
728 :
729 2 : async fn process_message(
730 2 : &mut self,
731 2 : handler: &mut impl Handler<IO>,
732 2 : msg: FeMessage,
733 2 : unnamed_query_string: &mut Bytes,
734 2 : ) -> Result<ProcessMsgResult, QueryError> {
735 2 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
736 2 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
737 2 : assert!(self.state == ProtoState::Established);
738 :
739 2 : match msg {
740 2 : FeMessage::Query(body) => {
741 : // remove null terminator
742 2 : let query_string = cstr_to_str(&body)?;
743 :
744 2 : trace!("got query {query_string:?}");
745 2 : if let Err(e) = handler.process_query(self, query_string).await {
746 0 : match e {
747 0 : QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
748 : QueryError::SimulatedConnectionError => {
749 0 : return Err(QueryError::SimulatedConnectionError)
750 : }
751 0 : err @ QueryError::Reconnect => {
752 0 : // Instruct the client to reconnect, stop processing messages
753 0 : // from this libpq connection and, finally, disconnect from the
754 0 : // server side (returning an Err achieves the later).
755 0 : //
756 0 : // Note the flushing is done by the caller.
757 0 : let reconnect_error = short_error(&err);
758 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
759 0 : &reconnect_error,
760 0 : Some(err.pg_error_code()),
761 0 : ))?;
762 :
763 0 : return Err(err);
764 : }
765 0 : e => {
766 0 : log_query_error(query_string, &e);
767 0 : let short_error = short_error(&e);
768 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
769 0 : &short_error,
770 0 : Some(e.pg_error_code()),
771 0 : ))?;
772 : }
773 : }
774 2 : }
775 2 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
776 : }
777 :
778 0 : FeMessage::Parse(m) => {
779 0 : *unnamed_query_string = m.query_string;
780 0 : self.write_message_noflush(&BeMessage::ParseComplete)?;
781 : }
782 :
783 : FeMessage::Describe(_) => {
784 0 : self.write_message_noflush(&BeMessage::ParameterDescription)?
785 0 : .write_message_noflush(&BeMessage::NoData)?;
786 : }
787 :
788 : FeMessage::Bind(_) => {
789 0 : self.write_message_noflush(&BeMessage::BindComplete)?;
790 : }
791 :
792 : FeMessage::Close(_) => {
793 0 : self.write_message_noflush(&BeMessage::CloseComplete)?;
794 : }
795 :
796 : FeMessage::Execute(_) => {
797 0 : let query_string = cstr_to_str(unnamed_query_string)?;
798 0 : trace!("got execute {query_string:?}");
799 0 : if let Err(e) = handler.process_query(self, query_string).await {
800 0 : log_query_error(query_string, &e);
801 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
802 0 : &e.to_string(),
803 0 : Some(e.pg_error_code()),
804 0 : ))?;
805 0 : }
806 : // NOTE there is no ReadyForQuery message. This handler is used
807 : // for basebackup and it uses CopyOut which doesn't require
808 : // ReadyForQuery message and backend just switches back to
809 : // processing mode after sending CopyDone or ErrorResponse.
810 : }
811 :
812 : FeMessage::Sync => {
813 0 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
814 : }
815 :
816 : FeMessage::Terminate => {
817 0 : return Ok(ProcessMsgResult::Break);
818 : }
819 :
820 : // We prefer explicit pattern matching to wildcards, because
821 : // this helps us spot the places where new variants are missing
822 : FeMessage::CopyData(_)
823 : | FeMessage::CopyDone
824 : | FeMessage::CopyFail
825 : | FeMessage::PasswordMessage(_) => {
826 0 : return Err(QueryError::Other(anyhow::anyhow!(
827 0 : "unexpected message type: {msg:?}",
828 0 : )));
829 : }
830 : }
831 :
832 2 : Ok(ProcessMsgResult::Continue)
833 2 : }
834 :
835 : /// - Log as info/error result of handling COPY stream and send back
836 : /// ErrorResponse if that makes sense.
837 : /// - Shutdown the stream if we got Terminate.
838 : /// - Then close the connection because we don't handle exiting from COPY
839 : /// stream normally.
840 0 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
841 : use CopyStreamHandlerEnd::*;
842 :
843 0 : let expected_end = match &end {
844 0 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF | Cancelled => true,
845 0 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
846 0 : if is_expected_io_error(io_error) =>
847 0 : {
848 0 : true
849 : }
850 0 : _ => false,
851 : };
852 0 : if expected_end {
853 0 : info!("terminated: {:#}", end);
854 : } else {
855 0 : error!("terminated: {:?}", end);
856 : }
857 :
858 : // Note: no current usages ever send this
859 0 : if let CopyDone = &end {
860 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
861 0 : error!("failed to send CopyDone: {}", e);
862 0 : }
863 0 : }
864 :
865 0 : let err_to_send_and_errcode = match &end {
866 0 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
867 0 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
868 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
869 : // PG walsender; evidently and per my docs reading client should
870 : // finish it with CopyDone). It is not a problem to recover from it
871 : // finishing the stream in both directions like we do, but note that
872 : // sync rust-postgres client (which we don't use anymore) hangs if
873 : // socket is not closed here.
874 : // https://github.com/sfackler/rust-postgres/issues/755
875 : // https://github.com/neondatabase/neon/issues/935
876 : //
877 : // Currently, the version of tokio_postgres replication patch we use
878 : // sends this when it closes the stream (e.g. pageserver decided to
879 : // switch conn to another safekeeper and client gets dropped).
880 : // Moreover, seems like 'connection' task errors with 'unexpected
881 : // message from server' when it receives ErrorResponse (anything but
882 : // CopyData/CopyDone) back.
883 0 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
884 :
885 : // When cancelled, send no response: we must not risk blocking on sending that response
886 0 : Cancelled => None,
887 0 : _ => None,
888 : };
889 0 : if let Some((err, errcode)) = err_to_send_and_errcode {
890 0 : if let Err(ee) = self
891 0 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
892 0 : .await
893 : {
894 0 : error!("failed to send ErrorResponse: {}", ee);
895 0 : }
896 0 : }
897 :
898 : // Proper COPY stream finishing to continue using the connection is not
899 : // implemented at the server side (we don't need it so far). To prevent
900 : // further usages of the connection, close it.
901 0 : self.framed.shutdown().await.ok();
902 0 : self.state = ProtoState::Closed;
903 0 : }
904 : }
905 :
906 : pub struct PostgresBackendReader<IO> {
907 : reader: FramedReader<MaybeTlsStream<IO>>,
908 : closed: bool, // true if received error closing the connection
909 : }
910 :
911 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
912 : /// Read full message or return None if connection is cleanly closed with no
913 : /// unprocessed data.
914 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
915 0 : match self.reader.read_message().await {
916 0 : Ok(m) => {
917 0 : trace!("read msg {:?}", m);
918 0 : Ok(m)
919 : }
920 0 : Err(e) => {
921 0 : self.closed = true;
922 0 : Err(e)
923 : }
924 : }
925 0 : }
926 :
927 : /// Get CopyData contents of the next message in COPY stream or error
928 : /// closing it. The error type is wider than actual errors which can happen
929 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
930 : /// current callers.
931 0 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
932 0 : match self.read_message().await? {
933 0 : Some(msg) => match msg {
934 0 : FeMessage::CopyData(m) => Ok(m),
935 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
936 0 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
937 0 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
938 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
939 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
940 0 : ))),
941 : },
942 0 : None => Err(CopyStreamHandlerEnd::EOF),
943 : }
944 0 : }
945 : }
946 :
947 : ///
948 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
949 : /// messages.
950 : ///
951 : pub struct CopyDataWriter<'a, IO> {
952 : pgb: &'a mut PostgresBackend<IO>,
953 : }
954 :
955 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'_, IO> {
956 0 : fn poll_write(
957 0 : self: Pin<&mut Self>,
958 0 : cx: &mut std::task::Context<'_>,
959 0 : buf: &[u8],
960 0 : ) -> Poll<Result<usize, std::io::Error>> {
961 0 : let this = self.get_mut();
962 :
963 : // It's not strictly required to flush between each message, but makes it easier
964 : // to view in wireshark, and usually the messages that the callers write are
965 : // decently-sized anyway.
966 0 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
967 0 : return Poll::Ready(Err(err));
968 0 : }
969 0 :
970 0 : // CopyData
971 0 : // XXX: if the input is large, we should split it into multiple messages.
972 0 : // Not sure what the threshold should be, but the ultimate hard limit is that
973 0 : // the length cannot exceed u32.
974 0 : this.pgb
975 0 : .write_message_noflush(&BeMessage::CopyData(buf))
976 0 : // write_message only writes to the buffer, so it can fail iff the
977 0 : // message is invaid, but CopyData can't be invalid.
978 0 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
979 :
980 0 : Poll::Ready(Ok(buf.len()))
981 0 : }
982 :
983 0 : fn poll_flush(
984 0 : self: Pin<&mut Self>,
985 0 : cx: &mut std::task::Context<'_>,
986 0 : ) -> Poll<Result<(), std::io::Error>> {
987 0 : let this = self.get_mut();
988 0 : this.pgb.poll_flush(cx)
989 0 : }
990 :
991 0 : fn poll_shutdown(
992 0 : self: Pin<&mut Self>,
993 0 : cx: &mut std::task::Context<'_>,
994 0 : ) -> Poll<Result<(), std::io::Error>> {
995 0 : let this = self.get_mut();
996 0 : this.pgb.poll_flush(cx)
997 0 : }
998 : }
999 :
1000 0 : pub fn short_error(e: &QueryError) -> String {
1001 0 : match e {
1002 0 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
1003 0 : QueryError::Reconnect => "reconnect".to_string(),
1004 0 : QueryError::Shutdown => "shutdown".to_string(),
1005 0 : QueryError::NotFound(_) => "not found".to_string(),
1006 0 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
1007 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
1008 0 : QueryError::Other(e) => format!("{e:#}"),
1009 : }
1010 0 : }
1011 :
1012 0 : fn log_query_error(query: &str, e: &QueryError) {
1013 : // If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
1014 0 : match e {
1015 0 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1016 0 : if is_expected_io_error(io_error) {
1017 0 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1018 : } else {
1019 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1020 : }
1021 : }
1022 0 : QueryError::Disconnected(other_connection_error) => {
1023 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
1024 : }
1025 : QueryError::SimulatedConnectionError => {
1026 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1027 : }
1028 : QueryError::Reconnect => {
1029 0 : info!("query handler for '{query}' requested client to reconnect")
1030 : }
1031 : QueryError::Shutdown => {
1032 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1033 : }
1034 0 : QueryError::NotFound(reason) => {
1035 0 : info!("query handler for '{query}' entity not found: {reason}")
1036 : }
1037 0 : QueryError::Unauthorized(e) => {
1038 0 : warn!("query handler for '{query}' failed with authentication error: {e}");
1039 : }
1040 0 : QueryError::Other(e) => {
1041 0 : error!("query handler for '{query}' failed: {e:?}");
1042 : }
1043 : }
1044 0 : }
1045 :
1046 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1047 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1048 : #[derive(thiserror::Error, Debug)]
1049 : pub enum CopyStreamHandlerEnd {
1050 : /// Handler initiates the end of streaming.
1051 : #[error("{0}")]
1052 : ServerInitiated(String),
1053 : #[error("received CopyDone")]
1054 : CopyDone,
1055 : #[error("received CopyFail")]
1056 : CopyFail,
1057 : #[error("received Terminate")]
1058 : Terminate,
1059 : #[error("EOF on COPY stream")]
1060 : EOF,
1061 : /// The connection was lost
1062 : #[error("connection error: {0}")]
1063 : Disconnected(#[from] ConnectionError),
1064 : #[error("Shutdown")]
1065 : Cancelled,
1066 : /// Some other error
1067 : #[error(transparent)]
1068 : Other(#[from] anyhow::Error),
1069 : }
|