Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use std::future::Future;
8 : use std::net::SocketAddr;
9 : use std::os::fd::{AsRawFd, RawFd};
10 : use std::pin::Pin;
11 : use std::str::FromStr;
12 : use std::sync::Arc;
13 : use std::task::{Poll, ready};
14 : use std::{fmt, io};
15 :
16 : use anyhow::Context;
17 : use bytes::Bytes;
18 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
19 : use pq_proto::{
20 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
21 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
22 : };
23 : use serde::{Deserialize, Serialize};
24 : use tokio::io::{AsyncRead, AsyncWrite};
25 : use tokio_rustls::TlsAcceptor;
26 : use tokio_util::sync::CancellationToken;
27 : use tracing::{debug, error, info, trace, warn};
28 :
29 : /// An error, occurred during query processing:
30 : /// either during the connection ([`ConnectionError`]) or before/after it.
31 : #[derive(thiserror::Error, Debug)]
32 : pub enum QueryError {
33 : /// The connection was lost while processing the query.
34 : #[error(transparent)]
35 : Disconnected(#[from] ConnectionError),
36 : /// We were instructed to shutdown while processing the query
37 : #[error("Shutting down")]
38 : Shutdown,
39 : /// Query handler indicated that client should reconnect
40 : #[error("Server requested reconnect")]
41 : Reconnect,
42 : /// Query named an entity that was not found
43 : #[error("Not found: {0}")]
44 : NotFound(std::borrow::Cow<'static, str>),
45 : /// Authentication failure
46 : #[error("Unauthorized: {0}")]
47 : Unauthorized(std::borrow::Cow<'static, str>),
48 : #[error("Simulated Connection Error")]
49 : SimulatedConnectionError,
50 : /// Some other error
51 : #[error(transparent)]
52 : Other(#[from] anyhow::Error),
53 : }
54 :
55 : impl From<io::Error> for QueryError {
56 0 : fn from(e: io::Error) -> Self {
57 0 : Self::Disconnected(ConnectionError::Io(e))
58 0 : }
59 : }
60 :
61 : impl QueryError {
62 0 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
63 0 : match self {
64 0 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
65 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
66 0 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
67 0 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
68 : }
69 0 : }
70 : }
71 :
72 : /// Returns true if the given error is a normal consequence of a network issue,
73 : /// or the client closing the connection.
74 : ///
75 : /// These errors can happen during normal operations,
76 : /// and don't indicate a bug in our code.
77 0 : pub fn is_expected_io_error(e: &io::Error) -> bool {
78 : use io::ErrorKind::*;
79 0 : matches!(
80 0 : e.kind(),
81 : HostUnreachable
82 : | NetworkUnreachable
83 : | BrokenPipe
84 : | ConnectionRefused
85 : | ConnectionAborted
86 : | ConnectionReset
87 : | TimedOut,
88 : )
89 0 : }
90 :
91 : pub trait Handler<IO> {
92 : /// Handle single query.
93 : /// postgres_backend will issue ReadyForQuery after calling this (this
94 : /// might be not what we want after CopyData streaming, but currently we don't
95 : /// care). It will also flush out the output buffer.
96 : fn process_query(
97 : &mut self,
98 : pgb: &mut PostgresBackend<IO>,
99 : query_string: &str,
100 : ) -> impl Future<Output = Result<(), QueryError>>;
101 :
102 : /// Called on startup packet receival, allows to process params.
103 : ///
104 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
105 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
106 : /// to override whole init logic in implementations.
107 2 : fn startup(
108 2 : &mut self,
109 2 : _pgb: &mut PostgresBackend<IO>,
110 2 : _sm: &FeStartupPacket,
111 2 : ) -> Result<(), QueryError> {
112 2 : Ok(())
113 0 : }
114 :
115 : /// Check auth jwt
116 0 : fn check_auth_jwt(
117 0 : &mut self,
118 0 : _pgb: &mut PostgresBackend<IO>,
119 0 : _jwt_response: &[u8],
120 0 : ) -> Result<(), QueryError> {
121 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
122 0 : }
123 : }
124 :
125 : /// PostgresBackend protocol state.
126 : /// XXX: The order of the constructors matters.
127 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
128 : pub enum ProtoState {
129 : /// Nothing happened yet.
130 : Initialization,
131 : /// Encryption handshake is done; waiting for encrypted Startup message.
132 : Encrypted,
133 : /// Waiting for password (auth token).
134 : Authentication,
135 : /// Performed handshake and auth, ReadyForQuery is issued.
136 : Established,
137 : Closed,
138 : }
139 :
140 : #[derive(Clone, Copy)]
141 : pub enum ProcessMsgResult {
142 : Continue,
143 : Break,
144 : }
145 :
146 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
147 : pub enum MaybeTlsStream<IO> {
148 : Unencrypted(IO),
149 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
150 : }
151 :
152 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
153 5 : fn poll_write(
154 5 : self: Pin<&mut Self>,
155 5 : cx: &mut std::task::Context<'_>,
156 5 : buf: &[u8],
157 5 : ) -> Poll<io::Result<usize>> {
158 5 : match self.get_mut() {
159 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
160 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
161 : }
162 0 : }
163 5 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
164 5 : match self.get_mut() {
165 3 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
166 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
167 : }
168 0 : }
169 0 : fn poll_shutdown(
170 0 : self: Pin<&mut Self>,
171 0 : cx: &mut std::task::Context<'_>,
172 0 : ) -> Poll<io::Result<()>> {
173 0 : match self.get_mut() {
174 0 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
175 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
176 : }
177 0 : }
178 : }
179 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
180 12 : fn poll_read(
181 12 : self: Pin<&mut Self>,
182 12 : cx: &mut std::task::Context<'_>,
183 12 : buf: &mut tokio::io::ReadBuf<'_>,
184 12 : ) -> Poll<io::Result<()>> {
185 12 : match self.get_mut() {
186 7 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
187 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
188 : }
189 0 : }
190 : }
191 :
192 0 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
193 : pub enum AuthType {
194 : Trust,
195 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
196 : NeonJWT,
197 : }
198 :
199 : impl FromStr for AuthType {
200 : type Err = anyhow::Error;
201 :
202 0 : fn from_str(s: &str) -> Result<Self, Self::Err> {
203 0 : match s {
204 0 : "Trust" => Ok(Self::Trust),
205 0 : "NeonJWT" => Ok(Self::NeonJWT),
206 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
207 : }
208 0 : }
209 : }
210 :
211 : impl fmt::Display for AuthType {
212 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 0 : f.write_str(match self {
214 0 : AuthType::Trust => "Trust",
215 0 : AuthType::NeonJWT => "NeonJWT",
216 : })
217 0 : }
218 : }
219 :
220 : /// Either full duplex Framed or write only half; the latter is left in
221 : /// PostgresBackend after call to `split`. In principle we could always store a
222 : /// pair of splitted handles, but that would force to to pay splitting price
223 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
224 : enum MaybeWriteOnly<IO> {
225 : Full(Framed<MaybeTlsStream<IO>>),
226 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
227 : Broken, // temporary value palmed off during the split
228 : }
229 :
230 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
231 3 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
232 3 : match self {
233 3 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
234 : MaybeWriteOnly::WriteOnly(_) => {
235 0 : Err(io::Error::other("reading from write only half").into())
236 : }
237 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
238 : }
239 0 : }
240 :
241 4 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
242 4 : match self {
243 4 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
244 : MaybeWriteOnly::WriteOnly(_) => {
245 0 : Err(io::Error::other("reading from write only half").into())
246 : }
247 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
248 : }
249 0 : }
250 :
251 19 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
252 19 : match self {
253 19 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
254 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
255 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
256 : }
257 0 : }
258 :
259 5 : async fn flush(&mut self) -> io::Result<()> {
260 5 : match self {
261 5 : MaybeWriteOnly::Full(framed) => framed.flush().await,
262 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
263 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
264 : }
265 0 : }
266 :
267 : /// Cancellation safe as long as the underlying IO is cancellation safe.
268 0 : async fn shutdown(&mut self) -> io::Result<()> {
269 0 : match self {
270 0 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
271 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
272 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
273 : }
274 0 : }
275 : }
276 :
277 : pub struct PostgresBackend<IO> {
278 : pub socket_fd: RawFd,
279 : framed: MaybeWriteOnly<IO>,
280 :
281 : pub state: ProtoState,
282 :
283 : auth_type: AuthType,
284 :
285 : peer_addr: SocketAddr,
286 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
287 : }
288 :
289 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
290 :
291 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
292 2 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
293 2 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
294 2 : std::str::from_utf8(without_null).map_err(|e| e.into())
295 2 : }
296 :
297 : impl PostgresBackend<tokio::net::TcpStream> {
298 2 : pub fn new(
299 2 : socket: tokio::net::TcpStream,
300 2 : auth_type: AuthType,
301 2 : tls_config: Option<Arc<rustls::ServerConfig>>,
302 2 : ) -> io::Result<Self> {
303 2 : let peer_addr = socket.peer_addr()?;
304 2 : let socket_fd = socket.as_raw_fd();
305 2 : let stream = MaybeTlsStream::Unencrypted(socket);
306 :
307 2 : Ok(Self {
308 2 : socket_fd,
309 2 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
310 2 : state: ProtoState::Initialization,
311 2 : auth_type,
312 2 : tls_config,
313 2 : peer_addr,
314 2 : })
315 2 : }
316 : }
317 :
318 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
319 0 : pub fn new_from_io(
320 0 : socket_fd: RawFd,
321 0 : socket: IO,
322 0 : peer_addr: SocketAddr,
323 0 : auth_type: AuthType,
324 0 : tls_config: Option<Arc<rustls::ServerConfig>>,
325 0 : ) -> io::Result<Self> {
326 0 : let stream = MaybeTlsStream::Unencrypted(socket);
327 :
328 0 : Ok(Self {
329 0 : socket_fd,
330 0 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
331 0 : state: ProtoState::Initialization,
332 0 : auth_type,
333 0 : tls_config,
334 0 : peer_addr,
335 0 : })
336 0 : }
337 :
338 0 : pub fn get_peer_addr(&self) -> &SocketAddr {
339 0 : &self.peer_addr
340 0 : }
341 :
342 : /// Read full message or return None if connection is cleanly closed with no
343 : /// unprocessed data.
344 4 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
345 4 : if let ProtoState::Closed = self.state {
346 0 : Ok(None)
347 : } else {
348 4 : match self.framed.read_message().await {
349 2 : Ok(m) => {
350 2 : trace!("read msg {:?}", m);
351 2 : Ok(m)
352 : }
353 0 : Err(e) => {
354 : // remember not to try to read anymore
355 0 : self.state = ProtoState::Closed;
356 0 : Err(e)
357 : }
358 : }
359 : }
360 0 : }
361 :
362 : /// Write message into internal output buffer, doesn't flush it. Technically
363 : /// error type can be only ProtocolError here (if, unlikely, serialization
364 : /// fails), but callers typically wrap it anyway.
365 19 : pub fn write_message_noflush(
366 19 : &mut self,
367 19 : message: &BeMessage<'_>,
368 19 : ) -> Result<&mut Self, ConnectionError> {
369 19 : self.framed.write_message_noflush(message)?;
370 19 : trace!("wrote msg {:?}", message);
371 19 : Ok(self)
372 0 : }
373 :
374 : /// Flush output buffer into the socket.
375 5 : pub async fn flush(&mut self) -> io::Result<()> {
376 5 : self.framed.flush().await
377 0 : }
378 :
379 : /// Polling version of `flush()`, saves the caller need to pin.
380 0 : pub fn poll_flush(
381 0 : &mut self,
382 0 : cx: &mut std::task::Context<'_>,
383 0 : ) -> Poll<Result<(), std::io::Error>> {
384 0 : let flush_fut = std::pin::pin!(self.flush());
385 0 : flush_fut.poll(cx)
386 0 : }
387 :
388 : /// Write message into internal output buffer and flush it to the stream.
389 3 : pub async fn write_message(
390 3 : &mut self,
391 3 : message: &BeMessage<'_>,
392 3 : ) -> Result<&mut Self, ConnectionError> {
393 3 : self.write_message_noflush(message)?;
394 3 : self.flush().await?;
395 3 : Ok(self)
396 0 : }
397 :
398 : /// Returns an AsyncWrite implementation that wraps all the data written
399 : /// to it in CopyData messages, and writes them to the connection
400 : ///
401 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
402 0 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
403 0 : CopyDataWriter { pgb: self }
404 0 : }
405 :
406 : /// Wrapper for run_message_loop() that shuts down socket when we are done
407 2 : pub async fn run(
408 2 : mut self,
409 2 : handler: &mut impl Handler<IO>,
410 2 : cancel: &CancellationToken,
411 2 : ) -> Result<(), QueryError> {
412 2 : let ret = self.run_message_loop(handler, cancel).await;
413 :
414 0 : tokio::select! {
415 0 : _ = cancel.cancelled() => {
416 0 : // do nothing; we most likely got already stopped by shutdown and will log it next.
417 0 : }
418 0 : _ = self.framed.shutdown() => {
419 0 : // socket might be already closed, e.g. if previously received error,
420 0 : // so ignore result.
421 0 : },
422 : }
423 :
424 0 : match ret {
425 0 : Ok(()) => Ok(()),
426 : Err(QueryError::Shutdown) => {
427 0 : info!("Stopped due to shutdown");
428 0 : Ok(())
429 : }
430 : Err(QueryError::Reconnect) => {
431 : // Dropping out of this loop implicitly disconnects
432 0 : info!("Stopped due to handler reconnect request");
433 0 : Ok(())
434 : }
435 0 : Err(QueryError::Disconnected(e)) => {
436 0 : info!("Disconnected ({e:#})");
437 : // Disconnection is not an error: we just use it that way internally to drop
438 : // out of loops.
439 0 : Ok(())
440 : }
441 0 : e => e,
442 : }
443 0 : }
444 :
445 2 : async fn run_message_loop(
446 2 : &mut self,
447 2 : handler: &mut impl Handler<IO>,
448 2 : cancel: &CancellationToken,
449 2 : ) -> Result<(), QueryError> {
450 2 : trace!("postgres backend to {:?} started", self.peer_addr);
451 :
452 2 : tokio::select!(
453 : biased;
454 :
455 2 : _ = cancel.cancelled() => {
456 : // We were requested to shut down.
457 0 : tracing::info!("shutdown request received during handshake");
458 0 : return Err(QueryError::Shutdown)
459 : },
460 :
461 2 : handshake_r = self.handshake(handler) => {
462 2 : handshake_r?;
463 : }
464 : );
465 :
466 : // Authentication completed
467 2 : let mut query_string = Bytes::new();
468 4 : while let Some(msg) = tokio::select!(
469 : biased;
470 4 : _ = cancel.cancelled() => {
471 : // We were requested to shut down.
472 0 : tracing::info!("shutdown request received in run_message_loop");
473 0 : return Err(QueryError::Shutdown)
474 : },
475 4 : msg = self.read_message() => { msg },
476 0 : )? {
477 2 : trace!("got message {:?}", msg);
478 :
479 2 : let result = self.process_message(handler, msg, &mut query_string).await;
480 2 : tokio::select!(
481 : biased;
482 2 : _ = cancel.cancelled() => {
483 : // We were requested to shut down.
484 0 : tracing::info!("shutdown request received during response flush");
485 :
486 : // If we exited process_message with a shutdown error, there may be
487 : // some valid response content on in our transmit buffer: permit sending
488 : // this within a short timeout. This is a best effort thing so we don't
489 : // care about the result.
490 0 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
491 :
492 0 : return Err(QueryError::Shutdown)
493 : },
494 2 : flush_r = self.flush() => {
495 2 : flush_r?;
496 : }
497 : );
498 :
499 2 : match result? {
500 : ProcessMsgResult::Continue => {
501 2 : continue;
502 : }
503 0 : ProcessMsgResult::Break => break,
504 : }
505 : }
506 :
507 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
508 0 : Ok(())
509 0 : }
510 :
511 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
512 1 : async fn tls_upgrade(
513 1 : src: MaybeTlsStream<IO>,
514 1 : tls_config: Arc<rustls::ServerConfig>,
515 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
516 1 : match src {
517 1 : MaybeTlsStream::Unencrypted(s) => {
518 1 : let acceptor = TlsAcceptor::from(tls_config);
519 1 : let tls_stream = acceptor.accept(s).await?;
520 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
521 : }
522 : MaybeTlsStream::Tls(_) => {
523 0 : anyhow::bail!("TLS already started");
524 : }
525 : }
526 0 : }
527 :
528 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
529 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
530 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
531 1 : MaybeWriteOnly::Full(framed) => {
532 1 : let tls_config = self
533 1 : .tls_config
534 1 : .as_ref()
535 1 : .context("start_tls called without conf")?
536 1 : .clone();
537 1 : let tls_framed = framed
538 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
539 1 : .await?;
540 : // push back ready TLS stream
541 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
542 1 : Ok(())
543 : }
544 : MaybeWriteOnly::WriteOnly(_) => {
545 0 : anyhow::bail!("TLS upgrade attempt in split state")
546 : }
547 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
548 : }
549 0 : }
550 :
551 : /// Split off owned read part from which messages can be read in different
552 : /// task/thread.
553 0 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
554 : // temporary replace stream with fake to cook split one, Indiana Jones style
555 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
556 0 : MaybeWriteOnly::Full(framed) => {
557 0 : let (reader, writer) = framed.split();
558 0 : self.framed = MaybeWriteOnly::WriteOnly(writer);
559 0 : Ok(PostgresBackendReader {
560 0 : reader,
561 0 : closed: false,
562 0 : })
563 : }
564 : MaybeWriteOnly::WriteOnly(_) => {
565 0 : anyhow::bail!("PostgresBackend is already split")
566 : }
567 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
568 : }
569 0 : }
570 :
571 : /// Join read part back.
572 0 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
573 : // temporary replace stream with fake to cook joined one, Indiana Jones style
574 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
575 : MaybeWriteOnly::Full(_) => {
576 0 : anyhow::bail!("PostgresBackend is not split")
577 : }
578 0 : MaybeWriteOnly::WriteOnly(writer) => {
579 0 : let joined = Framed::unsplit(reader.reader, writer);
580 0 : self.framed = MaybeWriteOnly::Full(joined);
581 : // if reader encountered connection error, do not attempt reading anymore
582 0 : if reader.closed {
583 0 : self.state = ProtoState::Closed;
584 0 : }
585 0 : Ok(())
586 : }
587 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
588 : }
589 0 : }
590 :
591 : /// Perform handshake with the client, transitioning to Established.
592 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
593 2 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
594 5 : while self.state < ProtoState::Authentication {
595 3 : match self.framed.read_startup_message().await? {
596 3 : Some(msg) => {
597 3 : self.process_startup_message(handler, msg).await?;
598 : }
599 : None => {
600 0 : trace!(
601 0 : "postgres backend to {:?} received EOF during handshake",
602 : self.peer_addr
603 : );
604 0 : self.state = ProtoState::Closed;
605 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
606 0 : ProtocolError::Protocol("EOF during handshake".to_string()),
607 0 : )));
608 : }
609 : }
610 : }
611 :
612 : // Perform auth, if needed.
613 2 : if self.state == ProtoState::Authentication {
614 0 : match self.framed.read_message().await? {
615 0 : Some(FeMessage::PasswordMessage(m)) => {
616 0 : assert!(self.auth_type == AuthType::NeonJWT);
617 :
618 0 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
619 :
620 0 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
621 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
622 0 : &short_error(&e),
623 0 : Some(e.pg_error_code()),
624 0 : ))?;
625 0 : return Err(e);
626 0 : }
627 :
628 0 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
629 0 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
630 0 : .write_message(&BeMessage::ReadyForQuery)
631 0 : .await?;
632 0 : self.state = ProtoState::Established;
633 : }
634 0 : Some(m) => {
635 0 : return Err(QueryError::Other(anyhow::anyhow!(
636 0 : "Unexpected message {:?} while waiting for handshake",
637 0 : m
638 0 : )));
639 : }
640 : None => {
641 0 : trace!(
642 0 : "postgres backend to {:?} received EOF during auth",
643 : self.peer_addr
644 : );
645 0 : self.state = ProtoState::Closed;
646 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
647 0 : ProtocolError::Protocol("EOF during auth".to_string()),
648 0 : )));
649 : }
650 : }
651 0 : }
652 :
653 2 : Ok(())
654 0 : }
655 :
656 : /// Process startup packet:
657 : /// - transition to Established if auth type is trust
658 : /// - transition to Authentication if auth type is NeonJWT.
659 : /// - or perform TLS handshake -- then need to call this again to receive
660 : /// actual startup packet.
661 3 : async fn process_startup_message(
662 3 : &mut self,
663 3 : handler: &mut impl Handler<IO>,
664 3 : msg: FeStartupPacket,
665 3 : ) -> Result<(), QueryError> {
666 3 : assert!(self.state < ProtoState::Authentication);
667 3 : let have_tls = self.tls_config.is_some();
668 3 : match msg {
669 1 : FeStartupPacket::SslRequest { direct } => {
670 1 : debug!("SSL requested");
671 :
672 1 : if !direct {
673 1 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
674 1 : .await?;
675 0 : } else if !have_tls {
676 0 : return Err(QueryError::Other(anyhow::anyhow!(
677 0 : "direct SSL negotiation but no TLS support"
678 0 : )));
679 0 : }
680 :
681 1 : if have_tls {
682 1 : self.start_tls().await?;
683 1 : self.state = ProtoState::Encrypted;
684 0 : }
685 : }
686 : FeStartupPacket::GssEncRequest => {
687 0 : debug!("GSS requested");
688 0 : self.write_message(&BeMessage::EncryptionResponse(false))
689 0 : .await?;
690 : }
691 : FeStartupPacket::StartupMessage { .. } => {
692 2 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
693 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
694 0 : .await?;
695 0 : return Err(QueryError::Other(anyhow::anyhow!(
696 0 : "client did not connect with TLS"
697 0 : )));
698 0 : }
699 :
700 : // NB: startup() may change self.auth_type -- we are using that in proxy code
701 : // to bypass auth for new users.
702 2 : handler.startup(self, &msg)?;
703 :
704 2 : match self.auth_type {
705 : AuthType::Trust => {
706 2 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
707 2 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
708 2 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
709 : // The async python driver requires a valid server_version
710 2 : .write_message_noflush(&BeMessage::server_version("14.1"))?
711 2 : .write_message(&BeMessage::ReadyForQuery)
712 2 : .await?;
713 2 : self.state = ProtoState::Established;
714 : }
715 : AuthType::NeonJWT => {
716 0 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
717 0 : .await?;
718 0 : self.state = ProtoState::Authentication;
719 : }
720 : }
721 : }
722 : FeStartupPacket::CancelRequest { .. } => {
723 0 : return Err(QueryError::Other(anyhow::anyhow!(
724 0 : "Unexpected CancelRequest message during handshake"
725 0 : )));
726 : }
727 : }
728 3 : Ok(())
729 0 : }
730 :
731 : // Proto looks like this:
732 : // FeMessage::Query("pagestream_v2{FeMessage::CopyData(PagesetreamFeMessage::GetPage(..))}")
733 :
734 2 : async fn process_message(
735 2 : &mut self,
736 2 : handler: &mut impl Handler<IO>,
737 2 : msg: FeMessage,
738 2 : unnamed_query_string: &mut Bytes,
739 2 : ) -> Result<ProcessMsgResult, QueryError> {
740 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
741 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
742 2 : assert!(self.state == ProtoState::Established);
743 :
744 2 : match msg {
745 2 : FeMessage::Query(body) => {
746 : // remove null terminator
747 2 : let query_string = cstr_to_str(&body)?;
748 :
749 2 : trace!("got query {query_string:?}");
750 2 : if let Err(e) = handler.process_query(self, query_string).await {
751 0 : match e {
752 0 : QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
753 : QueryError::SimulatedConnectionError => {
754 0 : return Err(QueryError::SimulatedConnectionError);
755 : }
756 0 : err @ QueryError::Reconnect => {
757 : // Instruct the client to reconnect, stop processing messages
758 : // from this libpq connection and, finally, disconnect from the
759 : // server side (returning an Err achieves the later).
760 : //
761 : // Note the flushing is done by the caller.
762 0 : let reconnect_error = short_error(&err);
763 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
764 0 : &reconnect_error,
765 0 : Some(err.pg_error_code()),
766 0 : ))?;
767 :
768 0 : return Err(err);
769 : }
770 0 : e => {
771 0 : log_query_error(query_string, &e);
772 0 : let short_error = short_error(&e);
773 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
774 0 : &short_error,
775 0 : Some(e.pg_error_code()),
776 0 : ))?;
777 : }
778 : }
779 0 : }
780 2 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
781 : }
782 :
783 0 : FeMessage::Parse(m) => {
784 0 : *unnamed_query_string = m.query_string;
785 0 : self.write_message_noflush(&BeMessage::ParseComplete)?;
786 : }
787 :
788 : FeMessage::Describe(_) => {
789 0 : self.write_message_noflush(&BeMessage::ParameterDescription)?
790 0 : .write_message_noflush(&BeMessage::NoData)?;
791 : }
792 :
793 : FeMessage::Bind(_) => {
794 0 : self.write_message_noflush(&BeMessage::BindComplete)?;
795 : }
796 :
797 : FeMessage::Close(_) => {
798 0 : self.write_message_noflush(&BeMessage::CloseComplete)?;
799 : }
800 :
801 : FeMessage::Execute(_) => {
802 0 : let query_string = cstr_to_str(unnamed_query_string)?;
803 0 : trace!("got execute {query_string:?}");
804 0 : if let Err(e) = handler.process_query(self, query_string).await {
805 0 : log_query_error(query_string, &e);
806 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
807 0 : &e.to_string(),
808 0 : Some(e.pg_error_code()),
809 0 : ))?;
810 0 : }
811 : // NOTE there is no ReadyForQuery message. This handler is used
812 : // for basebackup and it uses CopyOut which doesn't require
813 : // ReadyForQuery message and backend just switches back to
814 : // processing mode after sending CopyDone or ErrorResponse.
815 : }
816 :
817 : FeMessage::Sync => {
818 0 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
819 : }
820 :
821 : FeMessage::Terminate => {
822 0 : return Ok(ProcessMsgResult::Break);
823 : }
824 :
825 : // We prefer explicit pattern matching to wildcards, because
826 : // this helps us spot the places where new variants are missing
827 : FeMessage::CopyData(_)
828 : | FeMessage::CopyDone
829 : | FeMessage::CopyFail
830 : | FeMessage::PasswordMessage(_) => {
831 0 : return Err(QueryError::Other(anyhow::anyhow!(
832 0 : "unexpected message type: {msg:?}",
833 0 : )));
834 : }
835 : }
836 :
837 2 : Ok(ProcessMsgResult::Continue)
838 0 : }
839 :
840 : /// - Log as info/error result of handling COPY stream and send back
841 : /// ErrorResponse if that makes sense.
842 : /// - Shutdown the stream if we got Terminate.
843 : /// - Then close the connection because we don't handle exiting from COPY
844 : /// stream normally.
845 0 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
846 : use CopyStreamHandlerEnd::*;
847 :
848 0 : let expected_end = match &end {
849 0 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF | Cancelled => true,
850 : // The timeline doesn't exist and we have been requested to not auto-create it.
851 : // Compute requests for timelines that haven't been created yet
852 : // might reach us before the storcon request to create those timelines.
853 0 : TimelineNoCreate => true,
854 0 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
855 0 : if is_expected_io_error(io_error) =>
856 : {
857 0 : true
858 : }
859 0 : _ => false,
860 : };
861 0 : if expected_end {
862 0 : info!("terminated: {:#}", end);
863 : } else {
864 0 : error!("terminated: {:?}", end);
865 : }
866 :
867 : // Note: no current usages ever send this
868 0 : if let CopyDone = &end {
869 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
870 0 : error!("failed to send CopyDone: {}", e);
871 0 : }
872 0 : }
873 :
874 0 : let err_to_send_and_errcode = match &end {
875 0 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
876 0 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
877 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
878 : // PG walsender; evidently and per my docs reading client should
879 : // finish it with CopyDone). It is not a problem to recover from it
880 : // finishing the stream in both directions like we do, but note that
881 : // sync rust-postgres client (which we don't use anymore) hangs if
882 : // socket is not closed here.
883 : // https://github.com/sfackler/rust-postgres/issues/755
884 : // https://github.com/neondatabase/neon/issues/935
885 : //
886 : // Currently, the version of tokio_postgres replication patch we use
887 : // sends this when it closes the stream (e.g. pageserver decided to
888 : // switch conn to another safekeeper and client gets dropped).
889 : // Moreover, seems like 'connection' task errors with 'unexpected
890 : // message from server' when it receives ErrorResponse (anything but
891 : // CopyData/CopyDone) back.
892 0 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
893 :
894 : // When cancelled, send no response: we must not risk blocking on sending that response
895 0 : Cancelled => None,
896 0 : _ => None,
897 : };
898 0 : if let Some((err, errcode)) = err_to_send_and_errcode {
899 0 : if let Err(ee) = self
900 0 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
901 0 : .await
902 : {
903 0 : error!("failed to send ErrorResponse: {}", ee);
904 0 : }
905 0 : }
906 :
907 : // Proper COPY stream finishing to continue using the connection is not
908 : // implemented at the server side (we don't need it so far). To prevent
909 : // further usages of the connection, close it.
910 0 : self.framed.shutdown().await.ok();
911 0 : self.state = ProtoState::Closed;
912 0 : }
913 : }
914 :
915 : pub struct PostgresBackendReader<IO> {
916 : reader: FramedReader<MaybeTlsStream<IO>>,
917 : closed: bool, // true if received error closing the connection
918 : }
919 :
920 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
921 : /// Read full message or return None if connection is cleanly closed with no
922 : /// unprocessed data.
923 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
924 0 : match self.reader.read_message().await {
925 0 : Ok(m) => {
926 0 : trace!("read msg {:?}", m);
927 0 : Ok(m)
928 : }
929 0 : Err(e) => {
930 0 : self.closed = true;
931 0 : Err(e)
932 : }
933 : }
934 0 : }
935 :
936 : /// Get CopyData contents of the next message in COPY stream or error
937 : /// closing it. The error type is wider than actual errors which can happen
938 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
939 : /// current callers.
940 0 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
941 0 : match self.read_message().await? {
942 0 : Some(msg) => match msg {
943 0 : FeMessage::CopyData(m) => Ok(m),
944 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
945 0 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
946 0 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
947 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
948 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {msg:?}")),
949 0 : ))),
950 : },
951 0 : None => Err(CopyStreamHandlerEnd::EOF),
952 : }
953 0 : }
954 : }
955 :
956 : ///
957 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
958 : /// messages.
959 : ///
960 : pub struct CopyDataWriter<'a, IO> {
961 : pgb: &'a mut PostgresBackend<IO>,
962 : }
963 :
964 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'_, IO> {
965 0 : fn poll_write(
966 0 : self: Pin<&mut Self>,
967 0 : cx: &mut std::task::Context<'_>,
968 0 : buf: &[u8],
969 0 : ) -> Poll<Result<usize, std::io::Error>> {
970 0 : let this = self.get_mut();
971 :
972 : // It's not strictly required to flush between each message, but makes it easier
973 : // to view in wireshark, and usually the messages that the callers write are
974 : // decently-sized anyway.
975 0 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
976 0 : return Poll::Ready(Err(err));
977 0 : }
978 :
979 : // CopyData
980 : // XXX: if the input is large, we should split it into multiple messages.
981 : // Not sure what the threshold should be, but the ultimate hard limit is that
982 : // the length cannot exceed u32.
983 0 : this.pgb
984 0 : .write_message_noflush(&BeMessage::CopyData(buf))
985 : // write_message only writes to the buffer, so it can fail iff the
986 : // message is invaid, but CopyData can't be invalid.
987 0 : .map_err(|_| io::Error::other("failed to serialize CopyData"))?;
988 :
989 0 : Poll::Ready(Ok(buf.len()))
990 0 : }
991 :
992 0 : fn poll_flush(
993 0 : self: Pin<&mut Self>,
994 0 : cx: &mut std::task::Context<'_>,
995 0 : ) -> Poll<Result<(), std::io::Error>> {
996 0 : let this = self.get_mut();
997 0 : this.pgb.poll_flush(cx)
998 0 : }
999 :
1000 0 : fn poll_shutdown(
1001 0 : self: Pin<&mut Self>,
1002 0 : cx: &mut std::task::Context<'_>,
1003 0 : ) -> Poll<Result<(), std::io::Error>> {
1004 0 : let this = self.get_mut();
1005 0 : this.pgb.poll_flush(cx)
1006 0 : }
1007 : }
1008 :
1009 0 : pub fn short_error(e: &QueryError) -> String {
1010 0 : match e {
1011 0 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
1012 0 : QueryError::Reconnect => "reconnect".to_string(),
1013 0 : QueryError::Shutdown => "shutdown".to_string(),
1014 0 : QueryError::NotFound(_) => "not found".to_string(),
1015 0 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
1016 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
1017 0 : QueryError::Other(e) => format!("{e:#}"),
1018 : }
1019 0 : }
1020 :
1021 0 : fn log_query_error(query: &str, e: &QueryError) {
1022 : // If you want to change the log level of a specific error, also re-categorize it in `BasebackupQueryTimeOngoingRecording`.
1023 0 : match e {
1024 0 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1025 0 : if is_expected_io_error(io_error) {
1026 0 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1027 : } else {
1028 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1029 : }
1030 : }
1031 0 : QueryError::Disconnected(other_connection_error) => {
1032 0 : error!(
1033 0 : "query handler for '{query}' failed with connection error: {other_connection_error:?}"
1034 : )
1035 : }
1036 : QueryError::SimulatedConnectionError => {
1037 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1038 : }
1039 : QueryError::Reconnect => {
1040 0 : info!("query handler for '{query}' requested client to reconnect")
1041 : }
1042 : QueryError::Shutdown => {
1043 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1044 : }
1045 0 : QueryError::NotFound(reason) => {
1046 0 : info!("query handler for '{query}' entity not found: {reason}")
1047 : }
1048 0 : QueryError::Unauthorized(e) => {
1049 0 : warn!("query handler for '{query}' failed with authentication error: {e}");
1050 : }
1051 0 : QueryError::Other(e) => {
1052 0 : error!("query handler for '{query}' failed: {e:?}");
1053 : }
1054 : }
1055 0 : }
1056 :
1057 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1058 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1059 : #[derive(thiserror::Error, Debug)]
1060 : pub enum CopyStreamHandlerEnd {
1061 : /// Handler initiates the end of streaming.
1062 : #[error("{0}")]
1063 : ServerInitiated(String),
1064 : #[error("received CopyDone")]
1065 : CopyDone,
1066 : #[error("received CopyFail")]
1067 : CopyFail,
1068 : #[error("received Terminate")]
1069 : Terminate,
1070 : #[error("EOF on COPY stream")]
1071 : EOF,
1072 : #[error("timeline not found, and allow_timeline_creation is false")]
1073 : TimelineNoCreate,
1074 : /// The connection was lost
1075 : #[error("connection error: {0}")]
1076 : Disconnected(#[from] ConnectionError),
1077 : #[error("Shutdown")]
1078 : Cancelled,
1079 : /// Some other error
1080 : #[error(transparent)]
1081 : Other(#[from] anyhow::Error),
1082 : }
|