Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use anyhow::Context;
8 : use bytes::Bytes;
9 : use serde::{Deserialize, Serialize};
10 : use std::io::ErrorKind;
11 : use std::net::SocketAddr;
12 : use std::pin::Pin;
13 : use std::sync::Arc;
14 : use std::task::{ready, Poll};
15 : use std::{fmt, io};
16 : use std::{future::Future, str::FromStr};
17 : use tokio::io::{AsyncRead, AsyncWrite};
18 : use tokio_rustls::TlsAcceptor;
19 : use tracing::{debug, error, info, trace, warn};
20 :
21 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
22 : use pq_proto::{
23 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
24 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
25 : };
26 :
27 : /// An error, occurred during query processing:
28 : /// either during the connection ([`ConnectionError`]) or before/after it.
29 0 : #[derive(thiserror::Error, Debug)]
30 : pub enum QueryError {
31 : /// The connection was lost while processing the query.
32 : #[error(transparent)]
33 : Disconnected(#[from] ConnectionError),
34 : /// We were instructed to shutdown while processing the query
35 : #[error("Shutting down")]
36 : Shutdown,
37 : /// Query handler indicated that client should reconnect
38 : #[error("Server requested reconnect")]
39 : Reconnect,
40 : /// Query named an entity that was not found
41 : #[error("Not found: {0}")]
42 : NotFound(std::borrow::Cow<'static, str>),
43 : /// Authentication failure
44 : #[error("Unauthorized: {0}")]
45 : Unauthorized(std::borrow::Cow<'static, str>),
46 : #[error("Simulated Connection Error")]
47 : SimulatedConnectionError,
48 : /// Some other error
49 : #[error(transparent)]
50 : Other(#[from] anyhow::Error),
51 : }
52 :
53 : impl From<io::Error> for QueryError {
54 0 : fn from(e: io::Error) -> Self {
55 0 : Self::Disconnected(ConnectionError::Io(e))
56 0 : }
57 : }
58 :
59 : impl QueryError {
60 0 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
61 0 : match self {
62 0 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
63 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
64 0 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
65 0 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
66 : }
67 0 : }
68 : }
69 :
70 : /// Returns true if the given error is a normal consequence of a network issue,
71 : /// or the client closing the connection. These errors can happen during normal
72 : /// operations, and don't indicate a bug in our code.
73 0 : pub fn is_expected_io_error(e: &io::Error) -> bool {
74 : use io::ErrorKind::*;
75 0 : matches!(
76 0 : e.kind(),
77 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
78 : )
79 0 : }
80 :
81 : #[async_trait::async_trait]
82 : pub trait Handler<IO> {
83 : /// Handle single query.
84 : /// postgres_backend will issue ReadyForQuery after calling this (this
85 : /// might be not what we want after CopyData streaming, but currently we don't
86 : /// care). It will also flush out the output buffer.
87 : async fn process_query(
88 : &mut self,
89 : pgb: &mut PostgresBackend<IO>,
90 : query_string: &str,
91 : ) -> Result<(), QueryError>;
92 :
93 : /// Called on startup packet receival, allows to process params.
94 : ///
95 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
96 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
97 : /// to override whole init logic in implementations.
98 4 : fn startup(
99 4 : &mut self,
100 4 : _pgb: &mut PostgresBackend<IO>,
101 4 : _sm: &FeStartupPacket,
102 4 : ) -> Result<(), QueryError> {
103 4 : Ok(())
104 4 : }
105 :
106 : /// Check auth jwt
107 0 : fn check_auth_jwt(
108 0 : &mut self,
109 0 : _pgb: &mut PostgresBackend<IO>,
110 0 : _jwt_response: &[u8],
111 0 : ) -> Result<(), QueryError> {
112 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
113 0 : }
114 : }
115 :
116 : /// PostgresBackend protocol state.
117 : /// XXX: The order of the constructors matters.
118 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
119 : pub enum ProtoState {
120 : /// Nothing happened yet.
121 : Initialization,
122 : /// Encryption handshake is done; waiting for encrypted Startup message.
123 : Encrypted,
124 : /// Waiting for password (auth token).
125 : Authentication,
126 : /// Performed handshake and auth, ReadyForQuery is issued.
127 : Established,
128 : Closed,
129 : }
130 :
131 : #[derive(Clone, Copy)]
132 : pub enum ProcessMsgResult {
133 : Continue,
134 : Break,
135 : }
136 :
137 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
138 : pub enum MaybeTlsStream<IO> {
139 : Unencrypted(IO),
140 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
141 : }
142 :
143 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
144 10 : fn poll_write(
145 10 : self: Pin<&mut Self>,
146 10 : cx: &mut std::task::Context<'_>,
147 10 : buf: &[u8],
148 10 : ) -> Poll<io::Result<usize>> {
149 10 : match self.get_mut() {
150 6 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
151 4 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
152 : }
153 10 : }
154 10 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
155 10 : match self.get_mut() {
156 6 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
157 4 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
158 : }
159 10 : }
160 0 : fn poll_shutdown(
161 0 : self: Pin<&mut Self>,
162 0 : cx: &mut std::task::Context<'_>,
163 0 : ) -> Poll<io::Result<()>> {
164 0 : match self.get_mut() {
165 0 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
166 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
167 : }
168 0 : }
169 : }
170 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
171 24 : fn poll_read(
172 24 : self: Pin<&mut Self>,
173 24 : cx: &mut std::task::Context<'_>,
174 24 : buf: &mut tokio::io::ReadBuf<'_>,
175 24 : ) -> Poll<io::Result<()>> {
176 24 : match self.get_mut() {
177 14 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
178 10 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
179 : }
180 24 : }
181 : }
182 :
183 0 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
184 : pub enum AuthType {
185 : Trust,
186 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
187 : NeonJWT,
188 : }
189 :
190 : impl FromStr for AuthType {
191 : type Err = anyhow::Error;
192 :
193 0 : fn from_str(s: &str) -> Result<Self, Self::Err> {
194 0 : match s {
195 0 : "Trust" => Ok(Self::Trust),
196 0 : "NeonJWT" => Ok(Self::NeonJWT),
197 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
198 : }
199 0 : }
200 : }
201 :
202 : impl fmt::Display for AuthType {
203 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
204 0 : f.write_str(match self {
205 0 : AuthType::Trust => "Trust",
206 0 : AuthType::NeonJWT => "NeonJWT",
207 : })
208 0 : }
209 : }
210 :
211 : /// Either full duplex Framed or write only half; the latter is left in
212 : /// PostgresBackend after call to `split`. In principle we could always store a
213 : /// pair of splitted handles, but that would force to to pay splitting price
214 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
215 : enum MaybeWriteOnly<IO> {
216 : Full(Framed<MaybeTlsStream<IO>>),
217 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
218 : Broken, // temporary value palmed off during the split
219 : }
220 :
221 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
222 6 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
223 6 : match self {
224 6 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
225 : MaybeWriteOnly::WriteOnly(_) => {
226 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
227 : }
228 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
229 : }
230 6 : }
231 :
232 8 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
233 8 : match self {
234 8 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
235 : MaybeWriteOnly::WriteOnly(_) => {
236 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
237 : }
238 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
239 : }
240 4 : }
241 :
242 38 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
243 38 : match self {
244 38 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
245 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
246 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
247 : }
248 38 : }
249 :
250 10 : async fn flush(&mut self) -> io::Result<()> {
251 10 : match self {
252 10 : MaybeWriteOnly::Full(framed) => framed.flush().await,
253 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
254 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
255 : }
256 10 : }
257 :
258 : /// Cancellation safe as long as the underlying IO is cancellation safe.
259 0 : async fn shutdown(&mut self) -> io::Result<()> {
260 0 : match self {
261 0 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
262 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
263 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
264 : }
265 0 : }
266 : }
267 :
268 : pub struct PostgresBackend<IO> {
269 : framed: MaybeWriteOnly<IO>,
270 :
271 : pub state: ProtoState,
272 :
273 : auth_type: AuthType,
274 :
275 : peer_addr: SocketAddr,
276 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
277 : }
278 :
279 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
280 :
281 0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
282 0 : let mut query_string = query_string.to_vec();
283 0 : if let Some(ch) = query_string.last() {
284 0 : if *ch == 0 {
285 0 : query_string.pop();
286 0 : }
287 0 : }
288 0 : query_string
289 0 : }
290 :
291 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
292 4 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
293 4 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
294 4 : std::str::from_utf8(without_null).map_err(|e| e.into())
295 4 : }
296 :
297 : impl PostgresBackend<tokio::net::TcpStream> {
298 4 : pub fn new(
299 4 : socket: tokio::net::TcpStream,
300 4 : auth_type: AuthType,
301 4 : tls_config: Option<Arc<rustls::ServerConfig>>,
302 4 : ) -> io::Result<Self> {
303 4 : let peer_addr = socket.peer_addr()?;
304 4 : let stream = MaybeTlsStream::Unencrypted(socket);
305 4 :
306 4 : Ok(Self {
307 4 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
308 4 : state: ProtoState::Initialization,
309 4 : auth_type,
310 4 : tls_config,
311 4 : peer_addr,
312 4 : })
313 4 : }
314 : }
315 :
316 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
317 0 : pub fn new_from_io(
318 0 : socket: IO,
319 0 : peer_addr: SocketAddr,
320 0 : auth_type: AuthType,
321 0 : tls_config: Option<Arc<rustls::ServerConfig>>,
322 0 : ) -> io::Result<Self> {
323 0 : let stream = MaybeTlsStream::Unencrypted(socket);
324 0 :
325 0 : Ok(Self {
326 0 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
327 0 : state: ProtoState::Initialization,
328 0 : auth_type,
329 0 : tls_config,
330 0 : peer_addr,
331 0 : })
332 0 : }
333 :
334 0 : pub fn get_peer_addr(&self) -> &SocketAddr {
335 0 : &self.peer_addr
336 0 : }
337 :
338 : /// Read full message or return None if connection is cleanly closed with no
339 : /// unprocessed data.
340 8 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
341 8 : if let ProtoState::Closed = self.state {
342 0 : Ok(None)
343 : } else {
344 8 : match self.framed.read_message().await {
345 4 : Ok(m) => {
346 4 : trace!("read msg {:?}", m);
347 4 : Ok(m)
348 : }
349 0 : Err(e) => {
350 0 : // remember not to try to read anymore
351 0 : self.state = ProtoState::Closed;
352 0 : Err(e)
353 : }
354 : }
355 : }
356 4 : }
357 :
358 : /// Write message into internal output buffer, doesn't flush it. Technically
359 : /// error type can be only ProtocolError here (if, unlikely, serialization
360 : /// fails), but callers typically wrap it anyway.
361 38 : pub fn write_message_noflush(
362 38 : &mut self,
363 38 : message: &BeMessage<'_>,
364 38 : ) -> Result<&mut Self, ConnectionError> {
365 38 : self.framed.write_message_noflush(message)?;
366 38 : trace!("wrote msg {:?}", message);
367 38 : Ok(self)
368 38 : }
369 :
370 : /// Flush output buffer into the socket.
371 10 : pub async fn flush(&mut self) -> io::Result<()> {
372 10 : self.framed.flush().await
373 10 : }
374 :
375 : /// Polling version of `flush()`, saves the caller need to pin.
376 0 : pub fn poll_flush(
377 0 : &mut self,
378 0 : cx: &mut std::task::Context<'_>,
379 0 : ) -> Poll<Result<(), std::io::Error>> {
380 0 : let flush_fut = std::pin::pin!(self.flush());
381 0 : flush_fut.poll(cx)
382 0 : }
383 :
384 : /// Write message into internal output buffer and flush it to the stream.
385 6 : pub async fn write_message(
386 6 : &mut self,
387 6 : message: &BeMessage<'_>,
388 6 : ) -> Result<&mut Self, ConnectionError> {
389 6 : self.write_message_noflush(message)?;
390 6 : self.flush().await?;
391 6 : Ok(self)
392 6 : }
393 :
394 : /// Returns an AsyncWrite implementation that wraps all the data written
395 : /// to it in CopyData messages, and writes them to the connection
396 : ///
397 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
398 0 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
399 0 : CopyDataWriter { pgb: self }
400 0 : }
401 :
402 : /// Wrapper for run_message_loop() that shuts down socket when we are done
403 4 : pub async fn run<F, S>(
404 4 : mut self,
405 4 : handler: &mut impl Handler<IO>,
406 4 : shutdown_watcher: F,
407 4 : ) -> Result<(), QueryError>
408 4 : where
409 4 : F: Fn() -> S + Clone,
410 4 : S: Future,
411 4 : {
412 4 : let ret = self
413 4 : .run_message_loop(handler, shutdown_watcher.clone())
414 14 : .await;
415 :
416 : tokio::select! {
417 : _ = shutdown_watcher() => {
418 : // do nothing; we most likely got already stopped by shutdown and will log it next.
419 : }
420 : _ = self.framed.shutdown() => {
421 : // socket might be already closed, e.g. if previously received error,
422 : // so ignore result.
423 : },
424 : }
425 :
426 0 : match ret {
427 0 : Ok(()) => Ok(()),
428 : Err(QueryError::Shutdown) => {
429 0 : info!("Stopped due to shutdown");
430 0 : Ok(())
431 : }
432 : Err(QueryError::Reconnect) => {
433 : // Dropping out of this loop implicitly disconnects
434 0 : info!("Stopped due to handler reconnect request");
435 0 : Ok(())
436 : }
437 0 : Err(QueryError::Disconnected(e)) => {
438 0 : info!("Disconnected ({e:#})");
439 : // Disconnection is not an error: we just use it that way internally to drop
440 : // out of loops.
441 0 : Ok(())
442 : }
443 0 : e => e,
444 : }
445 0 : }
446 :
447 4 : async fn run_message_loop<F, S>(
448 4 : &mut self,
449 4 : handler: &mut impl Handler<IO>,
450 4 : shutdown_watcher: F,
451 4 : ) -> Result<(), QueryError>
452 4 : where
453 4 : F: Fn() -> S,
454 4 : S: Future,
455 4 : {
456 4 : trace!("postgres backend to {:?} started", self.peer_addr);
457 :
458 : tokio::select!(
459 : biased;
460 :
461 : _ = shutdown_watcher() => {
462 : // We were requested to shut down.
463 : tracing::info!("shutdown request received during handshake");
464 : return Err(QueryError::Shutdown)
465 : },
466 :
467 : handshake_r = self.handshake(handler) => {
468 : handshake_r?;
469 : }
470 : );
471 :
472 : // Authentication completed
473 4 : let mut query_string = Bytes::new();
474 4 : while let Some(msg) = tokio::select!(
475 : biased;
476 : _ = shutdown_watcher() => {
477 : // We were requested to shut down.
478 : tracing::info!("shutdown request received in run_message_loop");
479 : return Err(QueryError::Shutdown)
480 : },
481 : msg = self.read_message() => { msg },
482 0 : )? {
483 4 : trace!("got message {:?}", msg);
484 :
485 4 : let result = self.process_message(handler, msg, &mut query_string).await;
486 : tokio::select!(
487 : biased;
488 : _ = shutdown_watcher() => {
489 : // We were requested to shut down.
490 : tracing::info!("shutdown request received during response flush");
491 :
492 : // If we exited process_message with a shutdown error, there may be
493 : // some valid response content on in our transmit buffer: permit sending
494 : // this within a short timeout. This is a best effort thing so we don't
495 : // care about the result.
496 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
497 :
498 : return Err(QueryError::Shutdown)
499 : },
500 : flush_r = self.flush() => {
501 : flush_r?;
502 : }
503 : );
504 :
505 4 : match result? {
506 : ProcessMsgResult::Continue => {
507 4 : continue;
508 : }
509 0 : ProcessMsgResult::Break => break,
510 : }
511 : }
512 :
513 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
514 0 : Ok(())
515 0 : }
516 :
517 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
518 2 : async fn tls_upgrade(
519 2 : src: MaybeTlsStream<IO>,
520 2 : tls_config: Arc<rustls::ServerConfig>,
521 2 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
522 2 : match src {
523 2 : MaybeTlsStream::Unencrypted(s) => {
524 2 : let acceptor = TlsAcceptor::from(tls_config);
525 4 : let tls_stream = acceptor.accept(s).await?;
526 2 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
527 : }
528 : MaybeTlsStream::Tls(_) => {
529 0 : anyhow::bail!("TLS already started");
530 : }
531 : }
532 2 : }
533 :
534 2 : async fn start_tls(&mut self) -> anyhow::Result<()> {
535 2 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
536 2 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
537 2 : MaybeWriteOnly::Full(framed) => {
538 2 : let tls_config = self
539 2 : .tls_config
540 2 : .as_ref()
541 2 : .context("start_tls called without conf")?
542 2 : .clone();
543 2 : let tls_framed = framed
544 2 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
545 4 : .await?;
546 : // push back ready TLS stream
547 2 : self.framed = MaybeWriteOnly::Full(tls_framed);
548 2 : Ok(())
549 : }
550 : MaybeWriteOnly::WriteOnly(_) => {
551 0 : anyhow::bail!("TLS upgrade attempt in split state")
552 : }
553 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
554 : }
555 2 : }
556 :
557 : /// Split off owned read part from which messages can be read in different
558 : /// task/thread.
559 0 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
560 0 : // temporary replace stream with fake to cook split one, Indiana Jones style
561 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
562 0 : MaybeWriteOnly::Full(framed) => {
563 0 : let (reader, writer) = framed.split();
564 0 : self.framed = MaybeWriteOnly::WriteOnly(writer);
565 0 : Ok(PostgresBackendReader {
566 0 : reader,
567 0 : closed: false,
568 0 : })
569 : }
570 : MaybeWriteOnly::WriteOnly(_) => {
571 0 : anyhow::bail!("PostgresBackend is already split")
572 : }
573 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
574 : }
575 0 : }
576 :
577 : /// Join read part back.
578 0 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
579 0 : // temporary replace stream with fake to cook joined one, Indiana Jones style
580 0 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
581 : MaybeWriteOnly::Full(_) => {
582 0 : anyhow::bail!("PostgresBackend is not split")
583 : }
584 0 : MaybeWriteOnly::WriteOnly(writer) => {
585 0 : let joined = Framed::unsplit(reader.reader, writer);
586 0 : self.framed = MaybeWriteOnly::Full(joined);
587 0 : // if reader encountered connection error, do not attempt reading anymore
588 0 : if reader.closed {
589 0 : self.state = ProtoState::Closed;
590 0 : }
591 0 : Ok(())
592 : }
593 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
594 : }
595 0 : }
596 :
597 : /// Perform handshake with the client, transitioning to Established.
598 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
599 4 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
600 10 : while self.state < ProtoState::Authentication {
601 6 : match self.framed.read_startup_message().await? {
602 6 : Some(msg) => {
603 6 : self.process_startup_message(handler, msg).await?;
604 : }
605 : None => {
606 0 : trace!(
607 0 : "postgres backend to {:?} received EOF during handshake",
608 : self.peer_addr
609 : );
610 0 : self.state = ProtoState::Closed;
611 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
612 0 : ProtocolError::Protocol("EOF during handshake".to_string()),
613 0 : )));
614 : }
615 : }
616 : }
617 :
618 : // Perform auth, if needed.
619 4 : if self.state == ProtoState::Authentication {
620 0 : match self.framed.read_message().await? {
621 0 : Some(FeMessage::PasswordMessage(m)) => {
622 0 : assert!(self.auth_type == AuthType::NeonJWT);
623 :
624 0 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
625 :
626 0 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
627 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
628 0 : &short_error(&e),
629 0 : Some(e.pg_error_code()),
630 0 : ))?;
631 0 : return Err(e);
632 0 : }
633 0 :
634 0 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
635 0 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
636 0 : .write_message(&BeMessage::ReadyForQuery)
637 0 : .await?;
638 0 : self.state = ProtoState::Established;
639 : }
640 0 : Some(m) => {
641 0 : return Err(QueryError::Other(anyhow::anyhow!(
642 0 : "Unexpected message {:?} while waiting for handshake",
643 0 : m
644 0 : )));
645 : }
646 : None => {
647 0 : trace!(
648 0 : "postgres backend to {:?} received EOF during auth",
649 : self.peer_addr
650 : );
651 0 : self.state = ProtoState::Closed;
652 0 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
653 0 : ProtocolError::Protocol("EOF during auth".to_string()),
654 0 : )));
655 : }
656 : }
657 4 : }
658 :
659 4 : Ok(())
660 4 : }
661 :
662 : /// Process startup packet:
663 : /// - transition to Established if auth type is trust
664 : /// - transition to Authentication if auth type is NeonJWT.
665 : /// - or perform TLS handshake -- then need to call this again to receive
666 : /// actual startup packet.
667 6 : async fn process_startup_message(
668 6 : &mut self,
669 6 : handler: &mut impl Handler<IO>,
670 6 : msg: FeStartupPacket,
671 6 : ) -> Result<(), QueryError> {
672 6 : assert!(self.state < ProtoState::Authentication);
673 6 : let have_tls = self.tls_config.is_some();
674 6 : match msg {
675 : FeStartupPacket::SslRequest => {
676 2 : debug!("SSL requested");
677 :
678 2 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
679 0 : .await?;
680 :
681 2 : if have_tls {
682 4 : self.start_tls().await?;
683 2 : self.state = ProtoState::Encrypted;
684 0 : }
685 : }
686 : FeStartupPacket::GssEncRequest => {
687 0 : debug!("GSS requested");
688 0 : self.write_message(&BeMessage::EncryptionResponse(false))
689 0 : .await?;
690 : }
691 : FeStartupPacket::StartupMessage { .. } => {
692 4 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
693 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
694 0 : .await?;
695 0 : return Err(QueryError::Other(anyhow::anyhow!(
696 0 : "client did not connect with TLS"
697 0 : )));
698 4 : }
699 4 :
700 4 : // NB: startup() may change self.auth_type -- we are using that in proxy code
701 4 : // to bypass auth for new users.
702 4 : handler.startup(self, &msg)?;
703 :
704 4 : match self.auth_type {
705 : AuthType::Trust => {
706 4 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
707 4 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
708 4 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
709 : // The async python driver requires a valid server_version
710 4 : .write_message_noflush(&BeMessage::server_version("14.1"))?
711 4 : .write_message(&BeMessage::ReadyForQuery)
712 0 : .await?;
713 4 : self.state = ProtoState::Established;
714 : }
715 : AuthType::NeonJWT => {
716 0 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
717 0 : .await?;
718 0 : self.state = ProtoState::Authentication;
719 : }
720 : }
721 : }
722 : FeStartupPacket::CancelRequest { .. } => {
723 0 : return Err(QueryError::Other(anyhow::anyhow!(
724 0 : "Unexpected CancelRequest message during handshake"
725 0 : )));
726 : }
727 : }
728 6 : Ok(())
729 6 : }
730 :
731 4 : async fn process_message(
732 4 : &mut self,
733 4 : handler: &mut impl Handler<IO>,
734 4 : msg: FeMessage,
735 4 : unnamed_query_string: &mut Bytes,
736 4 : ) -> Result<ProcessMsgResult, QueryError> {
737 4 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
738 4 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
739 4 : assert!(self.state == ProtoState::Established);
740 :
741 4 : match msg {
742 4 : FeMessage::Query(body) => {
743 : // remove null terminator
744 4 : let query_string = cstr_to_str(&body)?;
745 :
746 4 : trace!("got query {query_string:?}");
747 4 : if let Err(e) = handler.process_query(self, query_string).await {
748 0 : match e {
749 0 : QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
750 : QueryError::SimulatedConnectionError => {
751 0 : return Err(QueryError::SimulatedConnectionError)
752 : }
753 0 : e => {
754 0 : log_query_error(query_string, &e);
755 0 : let short_error = short_error(&e);
756 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
757 0 : &short_error,
758 0 : Some(e.pg_error_code()),
759 0 : ))?;
760 : }
761 : }
762 4 : }
763 4 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
764 : }
765 :
766 0 : FeMessage::Parse(m) => {
767 0 : *unnamed_query_string = m.query_string;
768 0 : self.write_message_noflush(&BeMessage::ParseComplete)?;
769 : }
770 :
771 : FeMessage::Describe(_) => {
772 0 : self.write_message_noflush(&BeMessage::ParameterDescription)?
773 0 : .write_message_noflush(&BeMessage::NoData)?;
774 : }
775 :
776 : FeMessage::Bind(_) => {
777 0 : self.write_message_noflush(&BeMessage::BindComplete)?;
778 : }
779 :
780 : FeMessage::Close(_) => {
781 0 : self.write_message_noflush(&BeMessage::CloseComplete)?;
782 : }
783 :
784 : FeMessage::Execute(_) => {
785 0 : let query_string = cstr_to_str(unnamed_query_string)?;
786 0 : trace!("got execute {query_string:?}");
787 0 : if let Err(e) = handler.process_query(self, query_string).await {
788 0 : log_query_error(query_string, &e);
789 0 : self.write_message_noflush(&BeMessage::ErrorResponse(
790 0 : &e.to_string(),
791 0 : Some(e.pg_error_code()),
792 0 : ))?;
793 0 : }
794 : // NOTE there is no ReadyForQuery message. This handler is used
795 : // for basebackup and it uses CopyOut which doesn't require
796 : // ReadyForQuery message and backend just switches back to
797 : // processing mode after sending CopyDone or ErrorResponse.
798 : }
799 :
800 : FeMessage::Sync => {
801 0 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
802 : }
803 :
804 : FeMessage::Terminate => {
805 0 : return Ok(ProcessMsgResult::Break);
806 : }
807 :
808 : // We prefer explicit pattern matching to wildcards, because
809 : // this helps us spot the places where new variants are missing
810 : FeMessage::CopyData(_)
811 : | FeMessage::CopyDone
812 : | FeMessage::CopyFail
813 : | FeMessage::PasswordMessage(_) => {
814 0 : return Err(QueryError::Other(anyhow::anyhow!(
815 0 : "unexpected message type: {msg:?}",
816 0 : )));
817 : }
818 : }
819 :
820 4 : Ok(ProcessMsgResult::Continue)
821 4 : }
822 :
823 : /// - Log as info/error result of handling COPY stream and send back
824 : /// ErrorResponse if that makes sense.
825 : /// - Shutdown the stream if we got Terminate.
826 : /// - Then close the connection because we don't handle exiting from COPY
827 : /// stream normally.
828 0 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
829 : use CopyStreamHandlerEnd::*;
830 :
831 0 : let expected_end = match &end {
832 0 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
833 0 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
834 0 : if is_expected_io_error(io_error) =>
835 0 : {
836 0 : true
837 : }
838 0 : _ => false,
839 : };
840 0 : if expected_end {
841 0 : info!("terminated: {:#}", end);
842 : } else {
843 0 : error!("terminated: {:?}", end);
844 : }
845 :
846 : // Note: no current usages ever send this
847 0 : if let CopyDone = &end {
848 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
849 0 : error!("failed to send CopyDone: {}", e);
850 0 : }
851 0 : }
852 :
853 0 : let err_to_send_and_errcode = match &end {
854 0 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
855 0 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
856 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
857 : // PG walsender; evidently and per my docs reading client should
858 : // finish it with CopyDone). It is not a problem to recover from it
859 : // finishing the stream in both directions like we do, but note that
860 : // sync rust-postgres client (which we don't use anymore) hangs if
861 : // socket is not closed here.
862 : // https://github.com/sfackler/rust-postgres/issues/755
863 : // https://github.com/neondatabase/neon/issues/935
864 : //
865 : // Currently, the version of tokio_postgres replication patch we use
866 : // sends this when it closes the stream (e.g. pageserver decided to
867 : // switch conn to another safekeeper and client gets dropped).
868 : // Moreover, seems like 'connection' task errors with 'unexpected
869 : // message from server' when it receives ErrorResponse (anything but
870 : // CopyData/CopyDone) back.
871 0 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
872 0 : _ => None,
873 : };
874 0 : if let Some((err, errcode)) = err_to_send_and_errcode {
875 0 : if let Err(ee) = self
876 0 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
877 0 : .await
878 : {
879 0 : error!("failed to send ErrorResponse: {}", ee);
880 0 : }
881 0 : }
882 :
883 : // Proper COPY stream finishing to continue using the connection is not
884 : // implemented at the server side (we don't need it so far). To prevent
885 : // further usages of the connection, close it.
886 0 : self.framed.shutdown().await.ok();
887 0 : self.state = ProtoState::Closed;
888 0 : }
889 : }
890 :
891 : pub struct PostgresBackendReader<IO> {
892 : reader: FramedReader<MaybeTlsStream<IO>>,
893 : closed: bool, // true if received error closing the connection
894 : }
895 :
896 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
897 : /// Read full message or return None if connection is cleanly closed with no
898 : /// unprocessed data.
899 0 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
900 0 : match self.reader.read_message().await {
901 0 : Ok(m) => {
902 0 : trace!("read msg {:?}", m);
903 0 : Ok(m)
904 : }
905 0 : Err(e) => {
906 0 : self.closed = true;
907 0 : Err(e)
908 : }
909 : }
910 0 : }
911 :
912 : /// Get CopyData contents of the next message in COPY stream or error
913 : /// closing it. The error type is wider than actual errors which can happen
914 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
915 : /// current callers.
916 0 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
917 0 : match self.read_message().await? {
918 0 : Some(msg) => match msg {
919 0 : FeMessage::CopyData(m) => Ok(m),
920 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
921 0 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
922 0 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
923 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
924 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
925 0 : ))),
926 : },
927 0 : None => Err(CopyStreamHandlerEnd::EOF),
928 : }
929 0 : }
930 : }
931 :
932 : ///
933 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
934 : /// messages.
935 : ///
936 :
937 : pub struct CopyDataWriter<'a, IO> {
938 : pgb: &'a mut PostgresBackend<IO>,
939 : }
940 :
941 : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
942 0 : fn poll_write(
943 0 : self: Pin<&mut Self>,
944 0 : cx: &mut std::task::Context<'_>,
945 0 : buf: &[u8],
946 0 : ) -> Poll<Result<usize, std::io::Error>> {
947 0 : let this = self.get_mut();
948 :
949 : // It's not strictly required to flush between each message, but makes it easier
950 : // to view in wireshark, and usually the messages that the callers write are
951 : // decently-sized anyway.
952 0 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
953 0 : return Poll::Ready(Err(err));
954 0 : }
955 0 :
956 0 : // CopyData
957 0 : // XXX: if the input is large, we should split it into multiple messages.
958 0 : // Not sure what the threshold should be, but the ultimate hard limit is that
959 0 : // the length cannot exceed u32.
960 0 : this.pgb
961 0 : .write_message_noflush(&BeMessage::CopyData(buf))
962 0 : // write_message only writes to the buffer, so it can fail iff the
963 0 : // message is invaid, but CopyData can't be invalid.
964 0 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
965 :
966 0 : Poll::Ready(Ok(buf.len()))
967 0 : }
968 :
969 0 : fn poll_flush(
970 0 : self: Pin<&mut Self>,
971 0 : cx: &mut std::task::Context<'_>,
972 0 : ) -> Poll<Result<(), std::io::Error>> {
973 0 : let this = self.get_mut();
974 0 : this.pgb.poll_flush(cx)
975 0 : }
976 :
977 0 : fn poll_shutdown(
978 0 : self: Pin<&mut Self>,
979 0 : cx: &mut std::task::Context<'_>,
980 0 : ) -> Poll<Result<(), std::io::Error>> {
981 0 : let this = self.get_mut();
982 0 : this.pgb.poll_flush(cx)
983 0 : }
984 : }
985 :
986 0 : pub fn short_error(e: &QueryError) -> String {
987 0 : match e {
988 0 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
989 0 : QueryError::Reconnect => "reconnect".to_string(),
990 0 : QueryError::Shutdown => "shutdown".to_string(),
991 0 : QueryError::NotFound(_) => "not found".to_string(),
992 0 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
993 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
994 0 : QueryError::Other(e) => format!("{e:#}"),
995 : }
996 0 : }
997 :
998 0 : fn log_query_error(query: &str, e: &QueryError) {
999 0 : match e {
1000 0 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1001 0 : if is_expected_io_error(io_error) {
1002 0 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1003 : } else {
1004 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1005 : }
1006 : }
1007 0 : QueryError::Disconnected(other_connection_error) => {
1008 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
1009 : }
1010 : QueryError::SimulatedConnectionError => {
1011 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1012 : }
1013 : QueryError::Reconnect => {
1014 0 : info!("query handler for '{query}' requested client to reconnect")
1015 : }
1016 : QueryError::Shutdown => {
1017 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1018 : }
1019 0 : QueryError::NotFound(reason) => {
1020 0 : info!("query handler for '{query}' entity not found: {reason}")
1021 : }
1022 0 : QueryError::Unauthorized(e) => {
1023 0 : warn!("query handler for '{query}' failed with authentication error: {e}");
1024 : }
1025 0 : QueryError::Other(e) => {
1026 0 : error!("query handler for '{query}' failed: {e:?}");
1027 : }
1028 : }
1029 0 : }
1030 :
1031 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1032 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1033 0 : #[derive(thiserror::Error, Debug)]
1034 : pub enum CopyStreamHandlerEnd {
1035 : /// Handler initiates the end of streaming.
1036 : #[error("{0}")]
1037 : ServerInitiated(String),
1038 : #[error("received CopyDone")]
1039 : CopyDone,
1040 : #[error("received CopyFail")]
1041 : CopyFail,
1042 : #[error("received Terminate")]
1043 : Terminate,
1044 : #[error("EOF on COPY stream")]
1045 : EOF,
1046 : /// The connection was lost
1047 : #[error("connection error: {0}")]
1048 : Disconnected(#[from] ConnectionError),
1049 : /// Some other error
1050 : #[error(transparent)]
1051 : Other(#[from] anyhow::Error),
1052 : }
|