Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use anyhow::Context;
8 : use bytes::Bytes;
9 : use serde::{Deserialize, Serialize};
10 : use std::io::ErrorKind;
11 : use std::net::SocketAddr;
12 : use std::pin::Pin;
13 : use std::sync::Arc;
14 : use std::task::{ready, Poll};
15 : use std::{fmt, io};
16 : use std::{future::Future, str::FromStr};
17 : use tokio::io::{AsyncRead, AsyncWrite};
18 : use tokio_rustls::TlsAcceptor;
19 : use tokio_util::sync::CancellationToken;
20 : use tracing::{debug, error, info, trace, warn};
21 :
22 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
23 : use pq_proto::{
24 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
25 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
26 : };
27 :
28 : /// An error, occurred during query processing:
29 : /// either during the connection ([`ConnectionError`]) or before/after it.
30 0 : #[derive(thiserror::Error, Debug)]
31 : pub enum QueryError {
32 : /// The connection was lost while processing the query.
33 : #[error(transparent)]
34 : Disconnected(#[from] ConnectionError),
35 : /// We were instructed to shutdown while processing the query
36 : #[error("Shutting down")]
37 : Shutdown,
38 : /// Query handler indicated that client should reconnect
39 : #[error("Server requested reconnect")]
40 : Reconnect,
41 : /// Query named an entity that was not found
42 : #[error("Not found: {0}")]
43 : NotFound(std::borrow::Cow<'static, str>),
44 : /// Authentication failure
45 : #[error("Unauthorized: {0}")]
46 : Unauthorized(std::borrow::Cow<'static, str>),
47 : #[error("Simulated Connection Error")]
48 : SimulatedConnectionError,
49 : /// Some other error
50 : #[error(transparent)]
51 : Other(#[from] anyhow::Error),
52 : }
53 :
54 : impl From<io::Error> for QueryError {
55 0 : fn from(e: io::Error) -> Self {
56 0 : Self::Disconnected(ConnectionError::Io(e))
57 0 : }
58 : }
59 :
60 : impl QueryError {
61 0 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
62 0 : match self {
63 0 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
64 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
65 0 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
66 0 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
67 : }
68 0 : }
69 : }
70 :
71 : /// Returns true if the given error is a normal consequence of a network issue,
72 : /// or the client closing the connection.
73 : ///
74 : /// These errors can happen during normal operations,
75 : /// and don't indicate a bug in our code.
76 0 : pub fn is_expected_io_error(e: &io::Error) -> bool {
77 : use io::ErrorKind::*;
78 0 : matches!(
79 0 : e.kind(),
80 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
81 : )
82 0 : }
83 :
84 : pub trait Handler<IO> {
85 : /// Handle single query.
86 : /// postgres_backend will issue ReadyForQuery after calling this (this
87 : /// might be not what we want after CopyData streaming, but currently we don't
88 : /// care). It will also flush out the output buffer.
89 : fn process_query(
90 : &mut self,
91 : pgb: &mut PostgresBackend<IO>,
92 : query_string: &str,
93 : ) -> impl Future<Output = Result<(), QueryError>>;
94 :
95 : /// Called on startup packet receival, allows to process params.
96 : ///
97 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
98 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
99 : /// to override whole init logic in implementations.
100 2 : fn startup(
101 2 : &mut self,
102 2 : _pgb: &mut PostgresBackend<IO>,
103 2 : _sm: &FeStartupPacket,
104 2 : ) -> Result<(), QueryError> {
105 2 : Ok(())
106 2 : }
107 :
108 : /// Check auth jwt
109 0 : fn check_auth_jwt(
110 0 : &mut self,
111 0 : _pgb: &mut PostgresBackend<IO>,
112 0 : _jwt_response: &[u8],
113 0 : ) -> Result<(), QueryError> {
114 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
115 0 : }
116 : }
117 :
118 : /// PostgresBackend protocol state.
119 : /// XXX: The order of the constructors matters.
120 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
121 : pub enum ProtoState {
122 : /// Nothing happened yet.
123 : Initialization,
124 : /// Encryption handshake is done; waiting for encrypted Startup message.
125 : Encrypted,
126 : /// Waiting for password (auth token).
127 : Authentication,
128 : /// Performed handshake and auth, ReadyForQuery is issued.
129 : Established,
130 : Closed,
131 : }
132 :
133 : #[derive(Clone, Copy)]
134 : pub enum ProcessMsgResult {
135 : Continue,
136 : Break,
137 : }
138 :
139 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
140 : pub enum MaybeTlsStream<IO> {
141 : Unencrypted(IO),
142 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
143 : }
144 :
145 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
146 5 : fn poll_write(
147 5 : self: Pin<&mut Self>,
148 5 : cx: &mut std::task::Context<'_>,
149 5 : buf: &[u8],
150 5 : ) -> Poll<io::Result<usize>> {
151 5 : match self.get_mut() {
152 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
153 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
154 : }
155 5 : }
156 5 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
157 5 : match self.get_mut() {
158 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
159 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
160 : }
161 5 : }
162 0 : fn poll_shutdown(
163 0 : self: Pin<&mut Self>,
164 0 : cx: &mut std::task::Context<'_>,
165 0 : ) -> Poll<io::Result<()>> {
166 0 : match self.get_mut() {
167 0 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
168 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
169 : }
170 0 : }
171 : }
172 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
173 12 : fn poll_read(
174 12 : self: Pin<&mut Self>,
175 12 : cx: &mut std::task::Context<'_>,
176 12 : buf: &mut tokio::io::ReadBuf<'_>,
177 12 : ) -> Poll<io::Result<()>> {
178 12 : match self.get_mut() {
179 7 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
180 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
181 : }
182 12 : }
183 : }
184 :
185 0 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
186 : pub enum AuthType {
187 : Trust,
188 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
189 : NeonJWT,
190 : }
191 :
192 : impl FromStr for AuthType {
193 : type Err = anyhow::Error;
194 :
195 0 : fn from_str(s: &str) -> Result<Self, Self::Err> {
196 0 : match s {
197 0 : "Trust" => Ok(Self::Trust),
198 0 : "NeonJWT" => Ok(Self::NeonJWT),
199 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
200 : }
201 0 : }
202 : }
203 :
204 : impl fmt::Display for AuthType {
205 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 0 : f.write_str(match self {
207 0 : AuthType::Trust => "Trust",
208 0 : AuthType::NeonJWT => "NeonJWT",
209 : })
210 0 : }
211 : }
212 :
213 : /// Either full duplex Framed or write only half; the latter is left in
214 : /// PostgresBackend after call to `split`. In principle we could always store a
215 : /// pair of splitted handles, but that would force to to pay splitting price
216 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
217 : enum MaybeWriteOnly<IO> {
218 : Full(Framed<MaybeTlsStream<IO>>),
219 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
220 : Broken, // temporary value palmed off during the split
221 : }
222 :
223 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
224 3 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
225 3 : match self {
226 3 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
227 : MaybeWriteOnly::WriteOnly(_) => {
228 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
229 : }
230 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
231 : }
232 3 : }
233 :
234 4 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
235 4 : match self {
236 4 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
237 : MaybeWriteOnly::WriteOnly(_) => {
238 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
239 : }
240 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
241 : }
242 2 : }
243 :
244 19 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
245 19 : match self {
246 19 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
247 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
248 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
249 : }
250 19 : }
251 :
252 5 : async fn flush(&mut self) -> io::Result<()> {
253 5 : match self {
254 5 : MaybeWriteOnly::Full(framed) => framed.flush().await,
255 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
256 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
257 : }
258 5 : }
259 :
260 : /// Cancellation safe as long as the underlying IO is cancellation safe.
261 0 : async fn shutdown(&mut self) -> io::Result<()> {
262 0 : match self {
263 0 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
264 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
265 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
266 : }
267 0 : }
268 : }
269 :
270 : pub struct PostgresBackend<IO> {
271 : framed: MaybeWriteOnly<IO>,
272 :
273 : pub state: ProtoState,
274 :
275 : auth_type: AuthType,
276 :
277 : peer_addr: SocketAddr,
278 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
279 : }
280 :
281 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
282 :
283 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
284 2 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
285 2 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
286 2 : std::str::from_utf8(without_null).map_err(|e| e.into())
287 2 : }
288 :
289 : impl PostgresBackend<tokio::net::TcpStream> {
290 2 : pub fn new(
291 2 : socket: tokio::net::TcpStream,
292 2 : auth_type: AuthType,
293 2 : tls_config: Option<Arc<rustls::ServerConfig>>,
294 2 : ) -> io::Result<Self> {
295 2 : let peer_addr = socket.peer_addr()?;
296 2 : let stream = MaybeTlsStream::Unencrypted(socket);
297 2 :
298 2 : Ok(Self {
299 2 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
300 2 : state: ProtoState::Initialization,
301 2 : auth_type,
302 2 : tls_config,
303 2 : peer_addr,
304 2 : })
305 2 : }
306 : }
307 :
308 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
309 0 : pub fn new_from_io(
310 0 : socket: IO,
311 0 : peer_addr: SocketAddr,
312 0 : auth_type: AuthType,
313 0 : tls_config: Option<Arc<rustls::ServerConfig>>,
314 0 : ) -> io::Result<Self> {
315 0 : let stream = MaybeTlsStream::Unencrypted(socket);
316 0 :
317 0 : Ok(Self {
318 0 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
319 0 : state: ProtoState::Initialization,
320 0 : auth_type,
321 0 : tls_config,
322 0 : peer_addr,
323 0 : })
324 0 : }
325 :
326 0 : pub fn get_peer_addr(&self) -> &SocketAddr {
327 0 : &self.peer_addr
328 0 : }
329 :
330 : /// Read full message or return None if connection is cleanly closed with no
331 : /// unprocessed data.
332 4 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
333 4 : if let ProtoState::Closed = self.state {
334 0 : Ok(None)
335 : } else {
336 4 : match self.framed.read_message().await {
337 2 : Ok(m) => {
338 2 : trace!("read msg {:?}", m);
339 2 : Ok(m)
340 : }
341 0 : Err(e) => {
342 0 : // remember not to try to read anymore
343 0 : self.state = ProtoState::Closed;
344 0 : Err(e)
345 : }
346 : }
347 : }
348 2 : }
349 :
350 : /// Write message into internal output buffer, doesn't flush it. Technically
351 : /// error type can be only ProtocolError here (if, unlikely, serialization
352 : /// fails), but callers typically wrap it anyway.
353 19 : pub fn write_message_noflush(
354 19 : &mut self,
355 19 : message: &BeMessage<'_>,
356 19 : ) -> Result<&mut Self, ConnectionError> {
357 19 : self.framed.write_message_noflush(message)?;
358 19 : trace!("wrote msg {:?}", message);
359 19 : Ok(self)
360 19 : }
361 :
362 : /// Flush output buffer into the socket.
363 5 : pub async fn flush(&mut self) -> io::Result<()> {
364 5 : self.framed.flush().await
365 5 : }
366 :
367 : /// Polling version of `flush()`, saves the caller need to pin.
368 0 : pub fn poll_flush(
369 0 : &mut self,
370 0 : cx: &mut std::task::Context<'_>,
371 0 : ) -> Poll<Result<(), std::io::Error>> {
372 0 : let flush_fut = std::pin::pin!(self.flush());
373 0 : flush_fut.poll(cx)
374 0 : }
375 :
376 : /// Write message into internal output buffer and flush it to the stream.
377 3 : pub async fn write_message(
378 3 : &mut self,
379 3 : message: &BeMessage<'_>,
380 3 : ) -> Result<&mut Self, ConnectionError> {
381 3 : self.write_message_noflush(message)?;
382 3 : self.flush().await?;
383 3 : Ok(self)
384 3 : }
385 :
386 : /// Returns an AsyncWrite implementation that wraps all the data written
387 : /// to it in CopyData messages, and writes them to the connection
388 : ///
389 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
390 0 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
391 0 : CopyDataWriter { pgb: self }
392 0 : }
393 :
394 : /// Wrapper for run_message_loop() that shuts down socket when we are done
395 2 : pub async fn run(
396 2 : mut self,
397 2 : handler: &mut impl Handler<IO>,
398 2 : cancel: &CancellationToken,
399 2 : ) -> Result<(), QueryError> {
400 7 : let ret = self.run_message_loop(handler, cancel).await;
401 :
402 0 : tokio::select! {
403 0 : _ = cancel.cancelled() => {
404 0 : // do nothing; we most likely got already stopped by shutdown and will log it next.
405 0 : }
406 0 : _ = self.framed.shutdown() => {
407 0 : // socket might be already closed, e.g. if previously received error,
408 0 : // so ignore result.
409 0 : },
410 : }
411 :
412 0 : match ret {
413 0 : Ok(()) => Ok(()),
414 : Err(QueryError::Shutdown) => {
415 0 : info!("Stopped due to shutdown");
416 0 : Ok(())
417 : }
418 : Err(QueryError::Reconnect) => {
419 : // Dropping out of this loop implicitly disconnects
420 0 : info!("Stopped due to handler reconnect request");
421 0 : Ok(())
422 : }
423 0 : Err(QueryError::Disconnected(e)) => {
424 0 : info!("Disconnected ({e:#})");
425 : // Disconnection is not an error: we just use it that way internally to drop
426 : // out of loops.
427 0 : Ok(())
428 : }
429 0 : e => e,
430 : }
431 0 : }
432 :
433 2 : async fn run_message_loop(
434 2 : &mut self,
435 2 : handler: &mut impl Handler<IO>,
436 2 : cancel: &CancellationToken,
437 2 : ) -> Result<(), QueryError> {
438 2 : trace!("postgres backend to {:?} started", self.peer_addr);
439 :
440 2 : tokio::select!(
441 : biased;
442 :
443 2 : _ = cancel.cancelled() => {
444 : // We were requested to shut down.
445 0 : tracing::info!("shutdown request received during handshake");
446 0 : return Err(QueryError::Shutdown)
447 : },
448 :
449 2 : handshake_r = self.handshake(handler) => {
450 2 : handshake_r?;
451 : }
452 : );
453 :
454 : // Authentication completed
455 2 : let mut query_string = Bytes::new();
456 4 : while let Some(msg) = tokio::select!(
457 : biased;
458 4 : _ = cancel.cancelled() => {
459 : // We were requested to shut down.
460 0 : tracing::info!("shutdown request received in run_message_loop");
461 0 : return Err(QueryError::Shutdown)
462 : },
463 4 : msg = self.read_message() => { msg },
464 0 : )? {
465 2 : trace!("got message {:?}", msg);
466 :
467 2 : let result = self.process_message(handler, msg, &mut query_string).await;
468 2 : tokio::select!(
469 : biased;
470 2 : _ = cancel.cancelled() => {
471 : // We were requested to shut down.
472 0 : tracing::info!("shutdown request received during response flush");
473 :
474 : // If we exited process_message with a shutdown error, there may be
475 : // some valid response content on in our transmit buffer: permit sending
476 : // this within a short timeout. This is a best effort thing so we don't
477 : // care about the result.
478 0 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
479 0 :
480 0 : return Err(QueryError::Shutdown)
481 : },
482 2 : flush_r = self.flush() => {
483 2 : flush_r?;
484 : }
485 : );
486 :
487 2 : match result? {
488 : ProcessMsgResult::Continue => {
489 2 : continue;
490 : }
491 0 : ProcessMsgResult::Break => break,
492 : }
493 : }
494 :
495 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
496 0 : Ok(())
497 0 : }
498 :
499 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
500 1 : async fn tls_upgrade(
501 1 : src: MaybeTlsStream<IO>,
502 1 : tls_config: Arc<rustls::ServerConfig>,
503 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
504 1 : match src {
505 1 : MaybeTlsStream::Unencrypted(s) => {
506 1 : let acceptor = TlsAcceptor::from(tls_config);
507 2 : let tls_stream = acceptor.accept(s).await?;
508 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
509 : }
510 : MaybeTlsStream::Tls(_) => {
511 0 : anyhow::bail!("TLS already started");
512 : }
513 : }
514 1 : }
515 :
516 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
517 1 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
518 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
519 1 : MaybeWriteOnly::Full(framed) => {
520 1 : let tls_config = self
521 1 : .tls_config
522 1 : .as_ref()
523 1 : .context("start_tls called without conf")?
524 1 : .clone();
525 1 : let tls_framed = framed
526 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
527 2 : .await?;
528 : // push back ready TLS stream
529 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
530 1 : Ok(())
531 : }
532 : MaybeWriteOnly::WriteOnly(_) => {
533 0 : anyhow::bail!("TLS upgrade attempt in split state")
534 : }
535 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
536 : }
537 1 : }
538 :
539 : /// Split off owned read part from which messages can be read in different
540 : /// task/thread.
541 0 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
542 0 : // temporary replace stream with fake to cook split one, Indiana Jones style
543 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
544 0 : MaybeWriteOnly::Full(framed) => {
545 0 : let (reader, writer) = framed.split();
546 0 : self.framed = MaybeWriteOnly::WriteOnly(writer);
547 0 : Ok(PostgresBackendReader {
548 0 : reader,
549 0 : closed: false,
550 0 : })
551 : }
552 : MaybeWriteOnly::WriteOnly(_) => {
553 0 : anyhow::bail!("PostgresBackend is already split")
554 : }
555 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
556 : }
557 0 : }
558 :
559 : /// Join read part back.
560 0 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
561 0 : // temporary replace stream with fake to cook joined one, Indiana Jones style
562 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
563 : MaybeWriteOnly::Full(_) => {
564 0 : anyhow::bail!("PostgresBackend is not split")
565 : }
566 0 : MaybeWriteOnly::WriteOnly(writer) => {
567 0 : let joined = Framed::unsplit(reader.reader, writer);
568 0 : self.framed = MaybeWriteOnly::Full(joined);
569 0 : // if reader encountered connection error, do not attempt reading anymore
570 0 : if reader.closed {
571 0 : self.state = ProtoState::Closed;
572 0 : }
573 0 : Ok(())
574 : }
575 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
576 : }
577 0 : }
578 :
579 : /// Perform handshake with the client, transitioning to Established.
580 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
581 2 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
582 5 : while self.state < ProtoState::Authentication {
583 3 : match self.framed.read_startup_message().await? {
584 3 : Some(msg) => {
585 3 : self.process_startup_message(handler, msg).await?;
586 : }
587 : None => {
588 0 : trace!(
589 0 : "postgres backend to {:?} received EOF during handshake",
590 : self.peer_addr
591 : );
592 0 : self.state = ProtoState::Closed;
593 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
594 0 : ProtocolError::Protocol("EOF during handshake".to_string()),
595 0 : )));
596 : }
597 : }
598 : }
599 :
600 : // Perform auth, if needed.
601 2 : if self.state == ProtoState::Authentication {
602 0 : match self.framed.read_message().await? {
603 0 : Some(FeMessage::PasswordMessage(m)) => {
604 0 : assert!(self.auth_type == AuthType::NeonJWT);
605 :
606 0 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
607 :
608 0 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
609 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
610 0 : &short_error(&e),
611 0 : Some(e.pg_error_code()),
612 0 : ))?;
613 0 : return Err(e);
614 0 : }
615 0 :
616 0 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
617 0 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
618 0 : .write_message(&BeMessage::ReadyForQuery)
619 0 : .await?;
620 0 : self.state = ProtoState::Established;
621 : }
622 0 : Some(m) => {
623 0 : return Err(QueryError::Other(anyhow::anyhow!(
624 0 : "Unexpected message {:?} while waiting for handshake",
625 0 : m
626 0 : )));
627 : }
628 : None => {
629 0 : trace!(
630 0 : "postgres backend to {:?} received EOF during auth",
631 : self.peer_addr
632 : );
633 0 : self.state = ProtoState::Closed;
634 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
635 0 : ProtocolError::Protocol("EOF during auth".to_string()),
636 0 : )));
637 : }
638 : }
639 2 : }
640 :
641 2 : Ok(())
642 2 : }
643 :
644 : /// Process startup packet:
645 : /// - transition to Established if auth type is trust
646 : /// - transition to Authentication if auth type is NeonJWT.
647 : /// - or perform TLS handshake -- then need to call this again to receive
648 : /// actual startup packet.
649 3 : async fn process_startup_message(
650 3 : &mut self,
651 3 : handler: &mut impl Handler<IO>,
652 3 : msg: FeStartupPacket,
653 3 : ) -> Result<(), QueryError> {
654 3 : assert!(self.state < ProtoState::Authentication);
655 3 : let have_tls = self.tls_config.is_some();
656 3 : match msg {
657 1 : FeStartupPacket::SslRequest { direct } => {
658 1 : debug!("SSL requested");
659 :
660 1 : if !direct {
661 1 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
662 0 : .await?;
663 0 : } else if !have_tls {
664 0 : return Err(QueryError::Other(anyhow::anyhow!(
665 0 : "direct SSL negotiation but no TLS support"
666 0 : )));
667 0 : }
668 :
669 1 : if have_tls {
670 2 : self.start_tls().await?;
671 1 : self.state = ProtoState::Encrypted;
672 0 : }
673 : }
674 : FeStartupPacket::GssEncRequest => {
675 0 : debug!("GSS requested");
676 0 : self.write_message(&BeMessage::EncryptionResponse(false))
677 0 : .await?;
678 : }
679 : FeStartupPacket::StartupMessage { .. } => {
680 2 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
681 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
682 0 : .await?;
683 0 : return Err(QueryError::Other(anyhow::anyhow!(
684 0 : "client did not connect with TLS"
685 0 : )));
686 2 : }
687 2 :
688 2 : // NB: startup() may change self.auth_type -- we are using that in proxy code
689 2 : // to bypass auth for new users.
690 2 : handler.startup(self, &msg)?;
691 :
692 2 : match self.auth_type {
693 : AuthType::Trust => {
694 2 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
695 2 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
696 2 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
697 : // The async python driver requires a valid server_version
698 2 : .write_message_noflush(&BeMessage::server_version("14.1"))?
699 2 : .write_message(&BeMessage::ReadyForQuery)
700 0 : .await?;
701 2 : self.state = ProtoState::Established;
702 : }
703 : AuthType::NeonJWT => {
704 0 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
705 0 : .await?;
706 0 : self.state = ProtoState::Authentication;
707 : }
708 : }
709 : }
710 : FeStartupPacket::CancelRequest { .. } => {
711 0 : return Err(QueryError::Other(anyhow::anyhow!(
712 0 : "Unexpected CancelRequest message during handshake"
713 0 : )));
714 : }
715 : }
716 3 : Ok(())
717 3 : }
718 :
719 2 : async fn process_message(
720 2 : &mut self,
721 2 : handler: &mut impl Handler<IO>,
722 2 : msg: FeMessage,
723 2 : unnamed_query_string: &mut Bytes,
724 2 : ) -> Result<ProcessMsgResult, QueryError> {
725 2 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
726 2 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
727 2 : assert!(self.state == ProtoState::Established);
728 :
729 2 : match msg {
730 2 : FeMessage::Query(body) => {
731 : // remove null terminator
732 2 : let query_string = cstr_to_str(&body)?;
733 :
734 2 : trace!("got query {query_string:?}");
735 2 : if let Err(e) = handler.process_query(self, query_string).await {
736 0 : match e {
737 0 : QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
738 : QueryError::SimulatedConnectionError => {
739 0 : return Err(QueryError::SimulatedConnectionError)
740 : }
741 0 : err @ QueryError::Reconnect => {
742 0 : // Instruct the client to reconnect, stop processing messages
743 0 : // from this libpq connection and, finally, disconnect from the
744 0 : // server side (returning an Err achieves the later).
745 0 : //
746 0 : // Note the flushing is done by the caller.
747 0 : let reconnect_error = short_error(&err);
748 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
749 0 : &reconnect_error,
750 0 : Some(err.pg_error_code()),
751 0 : ))?;
752 :
753 0 : return Err(err);
754 : }
755 0 : e => {
756 0 : log_query_error(query_string, &e);
757 0 : let short_error = short_error(&e);
758 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
759 0 : &short_error,
760 0 : Some(e.pg_error_code()),
761 0 : ))?;
762 : }
763 : }
764 2 : }
765 2 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
766 : }
767 :
768 0 : FeMessage::Parse(m) => {
769 0 : *unnamed_query_string = m.query_string;
770 0 : self.write_message_noflush(&BeMessage::ParseComplete)?;
771 : }
772 :
773 : FeMessage::Describe(_) => {
774 0 : self.write_message_noflush(&BeMessage::ParameterDescription)?
775 0 : .write_message_noflush(&BeMessage::NoData)?;
776 : }
777 :
778 : FeMessage::Bind(_) => {
779 0 : self.write_message_noflush(&BeMessage::BindComplete)?;
780 : }
781 :
782 : FeMessage::Close(_) => {
783 0 : self.write_message_noflush(&BeMessage::CloseComplete)?;
784 : }
785 :
786 : FeMessage::Execute(_) => {
787 0 : let query_string = cstr_to_str(unnamed_query_string)?;
788 0 : trace!("got execute {query_string:?}");
789 0 : if let Err(e) = handler.process_query(self, query_string).await {
790 0 : log_query_error(query_string, &e);
791 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
792 0 : &e.to_string(),
793 0 : Some(e.pg_error_code()),
794 0 : ))?;
795 0 : }
796 : // NOTE there is no ReadyForQuery message. This handler is used
797 : // for basebackup and it uses CopyOut which doesn't require
798 : // ReadyForQuery message and backend just switches back to
799 : // processing mode after sending CopyDone or ErrorResponse.
800 : }
801 :
802 : FeMessage::Sync => {
803 0 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
804 : }
805 :
806 : FeMessage::Terminate => {
807 0 : return Ok(ProcessMsgResult::Break);
808 : }
809 :
810 : // We prefer explicit pattern matching to wildcards, because
811 : // this helps us spot the places where new variants are missing
812 : FeMessage::CopyData(_)
813 : | FeMessage::CopyDone
814 : | FeMessage::CopyFail
815 : | FeMessage::PasswordMessage(_) => {
816 0 : return Err(QueryError::Other(anyhow::anyhow!(
817 0 : "unexpected message type: {msg:?}",
818 0 : )));
819 : }
820 : }
821 :
822 2 : Ok(ProcessMsgResult::Continue)
823 2 : }
824 :
825 : /// - Log as info/error result of handling COPY stream and send back
826 : /// ErrorResponse if that makes sense.
827 : /// - Shutdown the stream if we got Terminate.
828 : /// - Then close the connection because we don't handle exiting from COPY
829 : /// stream normally.
830 0 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
831 : use CopyStreamHandlerEnd::*;
832 :
833 0 : let expected_end = match &end {
834 0 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
835 0 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
836 0 : if is_expected_io_error(io_error) =>
837 0 : {
838 0 : true
839 : }
840 0 : _ => false,
841 : };
842 0 : if expected_end {
843 0 : info!("terminated: {:#}", end);
844 : } else {
845 0 : error!("terminated: {:?}", end);
846 : }
847 :
848 : // Note: no current usages ever send this
849 0 : if let CopyDone = &end {
850 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
851 0 : error!("failed to send CopyDone: {}", e);
852 0 : }
853 0 : }
854 :
855 0 : let err_to_send_and_errcode = match &end {
856 0 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
857 0 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
858 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
859 : // PG walsender; evidently and per my docs reading client should
860 : // finish it with CopyDone). It is not a problem to recover from it
861 : // finishing the stream in both directions like we do, but note that
862 : // sync rust-postgres client (which we don't use anymore) hangs if
863 : // socket is not closed here.
864 : // https://github.com/sfackler/rust-postgres/issues/755
865 : // https://github.com/neondatabase/neon/issues/935
866 : //
867 : // Currently, the version of tokio_postgres replication patch we use
868 : // sends this when it closes the stream (e.g. pageserver decided to
869 : // switch conn to another safekeeper and client gets dropped).
870 : // Moreover, seems like 'connection' task errors with 'unexpected
871 : // message from server' when it receives ErrorResponse (anything but
872 : // CopyData/CopyDone) back.
873 0 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
874 0 : _ => None,
875 : };
876 0 : if let Some((err, errcode)) = err_to_send_and_errcode {
877 0 : if let Err(ee) = self
878 0 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
879 0 : .await
880 : {
881 0 : error!("failed to send ErrorResponse: {}", ee);
882 0 : }
883 0 : }
884 :
885 : // Proper COPY stream finishing to continue using the connection is not
886 : // implemented at the server side (we don't need it so far). To prevent
887 : // further usages of the connection, close it.
888 0 : self.framed.shutdown().await.ok();
889 0 : self.state = ProtoState::Closed;
890 0 : }
891 : }
892 :
893 : pub struct PostgresBackendReader<IO> {
894 : reader: FramedReader<MaybeTlsStream<IO>>,
895 : closed: bool, // true if received error closing the connection
896 : }
897 :
898 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
899 : /// Read full message or return None if connection is cleanly closed with no
900 : /// unprocessed data.
901 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
902 0 : match self.reader.read_message().await {
903 0 : Ok(m) => {
904 0 : trace!("read msg {:?}", m);
905 0 : Ok(m)
906 : }
907 0 : Err(e) => {
908 0 : self.closed = true;
909 0 : Err(e)
910 : }
911 : }
912 0 : }
913 :
914 : /// Get CopyData contents of the next message in COPY stream or error
915 : /// closing it. The error type is wider than actual errors which can happen
916 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
917 : /// current callers.
918 0 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
919 0 : match self.read_message().await? {
920 0 : Some(msg) => match msg {
921 0 : FeMessage::CopyData(m) => Ok(m),
922 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
923 0 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
924 0 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
925 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
926 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
927 0 : ))),
928 : },
929 0 : None => Err(CopyStreamHandlerEnd::EOF),
930 : }
931 0 : }
932 : }
933 :
934 : ///
935 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
936 : /// messages.
937 : ///
938 : pub struct CopyDataWriter<'a, IO> {
939 : pgb: &'a mut PostgresBackend<IO>,
940 : }
941 :
942 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'_, IO> {
943 0 : fn poll_write(
944 0 : self: Pin<&mut Self>,
945 0 : cx: &mut std::task::Context<'_>,
946 0 : buf: &[u8],
947 0 : ) -> Poll<Result<usize, std::io::Error>> {
948 0 : let this = self.get_mut();
949 :
950 : // It's not strictly required to flush between each message, but makes it easier
951 : // to view in wireshark, and usually the messages that the callers write are
952 : // decently-sized anyway.
953 0 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
954 0 : return Poll::Ready(Err(err));
955 0 : }
956 0 :
957 0 : // CopyData
958 0 : // XXX: if the input is large, we should split it into multiple messages.
959 0 : // Not sure what the threshold should be, but the ultimate hard limit is that
960 0 : // the length cannot exceed u32.
961 0 : this.pgb
962 0 : .write_message_noflush(&BeMessage::CopyData(buf))
963 0 : // write_message only writes to the buffer, so it can fail iff the
964 0 : // message is invaid, but CopyData can't be invalid.
965 0 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
966 :
967 0 : Poll::Ready(Ok(buf.len()))
968 0 : }
969 :
970 0 : fn poll_flush(
971 0 : self: Pin<&mut Self>,
972 0 : cx: &mut std::task::Context<'_>,
973 0 : ) -> Poll<Result<(), std::io::Error>> {
974 0 : let this = self.get_mut();
975 0 : this.pgb.poll_flush(cx)
976 0 : }
977 :
978 0 : fn poll_shutdown(
979 0 : self: Pin<&mut Self>,
980 0 : cx: &mut std::task::Context<'_>,
981 0 : ) -> Poll<Result<(), std::io::Error>> {
982 0 : let this = self.get_mut();
983 0 : this.pgb.poll_flush(cx)
984 0 : }
985 : }
986 :
987 0 : pub fn short_error(e: &QueryError) -> String {
988 0 : match e {
989 0 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
990 0 : QueryError::Reconnect => "reconnect".to_string(),
991 0 : QueryError::Shutdown => "shutdown".to_string(),
992 0 : QueryError::NotFound(_) => "not found".to_string(),
993 0 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
994 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
995 0 : QueryError::Other(e) => format!("{e:#}"),
996 : }
997 0 : }
998 :
999 0 : fn log_query_error(query: &str, e: &QueryError) {
1000 : // If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
1001 0 : match e {
1002 0 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1003 0 : if is_expected_io_error(io_error) {
1004 0 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1005 : } else {
1006 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1007 : }
1008 : }
1009 0 : QueryError::Disconnected(other_connection_error) => {
1010 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
1011 : }
1012 : QueryError::SimulatedConnectionError => {
1013 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1014 : }
1015 : QueryError::Reconnect => {
1016 0 : info!("query handler for '{query}' requested client to reconnect")
1017 : }
1018 : QueryError::Shutdown => {
1019 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1020 : }
1021 0 : QueryError::NotFound(reason) => {
1022 0 : info!("query handler for '{query}' entity not found: {reason}")
1023 : }
1024 0 : QueryError::Unauthorized(e) => {
1025 0 : warn!("query handler for '{query}' failed with authentication error: {e}");
1026 : }
1027 0 : QueryError::Other(e) => {
1028 0 : error!("query handler for '{query}' failed: {e:?}");
1029 : }
1030 : }
1031 0 : }
1032 :
1033 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1034 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1035 0 : #[derive(thiserror::Error, Debug)]
1036 : pub enum CopyStreamHandlerEnd {
1037 : /// Handler initiates the end of streaming.
1038 : #[error("{0}")]
1039 : ServerInitiated(String),
1040 : #[error("received CopyDone")]
1041 : CopyDone,
1042 : #[error("received CopyFail")]
1043 : CopyFail,
1044 : #[error("received Terminate")]
1045 : Terminate,
1046 : #[error("EOF on COPY stream")]
1047 : EOF,
1048 : /// The connection was lost
1049 : #[error("connection error: {0}")]
1050 : Disconnected(#[from] ConnectionError),
1051 : /// Some other error
1052 : #[error(transparent)]
1053 : Other(#[from] anyhow::Error),
1054 : }
|