TLA Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : #![deny(unsafe_code)]
6 : #![deny(clippy::undocumented_unsafe_blocks)]
7 : use anyhow::Context;
8 : use bytes::Bytes;
9 : use futures::pin_mut;
10 : use serde::{Deserialize, Serialize};
11 : use std::io::ErrorKind;
12 : use std::net::SocketAddr;
13 : use std::pin::Pin;
14 : use std::sync::Arc;
15 : use std::task::{ready, Poll};
16 : use std::{fmt, io};
17 : use std::{future::Future, str::FromStr};
18 : use tokio::io::{AsyncRead, AsyncWrite};
19 : use tokio_rustls::TlsAcceptor;
20 : use tracing::{debug, error, info, trace, warn};
21 :
22 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
23 : use pq_proto::{
24 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
25 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
26 : };
27 :
28 : /// An error, occurred during query processing:
29 : /// either during the connection ([`ConnectionError`]) or before/after it.
30 CBC 84 : #[derive(thiserror::Error, Debug)]
31 : pub enum QueryError {
32 : /// The connection was lost while processing the query.
33 : #[error(transparent)]
34 : Disconnected(#[from] ConnectionError),
35 : /// We were instructed to shutdown while processing the query
36 : #[error("Shutting down")]
37 : Shutdown,
38 : /// Query handler indicated that client should reconnect
39 : #[error("Server requested reconnect")]
40 : Reconnect,
41 : /// Query named an entity that was not found
42 : #[error("Not found: {0}")]
43 : NotFound(std::borrow::Cow<'static, str>),
44 : /// Authentication failure
45 : #[error("Unauthorized: {0}")]
46 : Unauthorized(std::borrow::Cow<'static, str>),
47 : #[error("Simulated Connection Error")]
48 : SimulatedConnectionError,
49 : /// Some other error
50 : #[error(transparent)]
51 : Other(#[from] anyhow::Error),
52 : }
53 :
54 : impl From<io::Error> for QueryError {
55 322 : fn from(e: io::Error) -> Self {
56 322 : Self::Disconnected(ConnectionError::Io(e))
57 322 : }
58 : }
59 :
60 : impl QueryError {
61 567 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
62 567 : match self {
63 38 : Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
64 UBC 0 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
65 CBC 5 : Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
66 524 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
67 : }
68 567 : }
69 : }
70 :
71 : /// Returns true if the given error is a normal consequence of a network issue,
72 : /// or the client closing the connection. These errors can happen during normal
73 : /// operations, and don't indicate a bug in our code.
74 1091 : pub fn is_expected_io_error(e: &io::Error) -> bool {
75 : use io::ErrorKind::*;
76 UBC 0 : matches!(
77 CBC 1091 : e.kind(),
78 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
79 : )
80 1091 : }
81 :
82 : #[async_trait::async_trait]
83 : pub trait Handler<IO> {
84 : /// Handle single query.
85 : /// postgres_backend will issue ReadyForQuery after calling this (this
86 : /// might be not what we want after CopyData streaming, but currently we don't
87 : /// care). It will also flush out the output buffer.
88 : async fn process_query(
89 : &mut self,
90 : pgb: &mut PostgresBackend<IO>,
91 : query_string: &str,
92 : ) -> Result<(), QueryError>;
93 :
94 : /// Called on startup packet receival, allows to process params.
95 : ///
96 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
97 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
98 : /// to override whole init logic in implementations.
99 5 : fn startup(
100 5 : &mut self,
101 5 : _pgb: &mut PostgresBackend<IO>,
102 5 : _sm: &FeStartupPacket,
103 5 : ) -> Result<(), QueryError> {
104 5 : Ok(())
105 5 : }
106 :
107 : /// Check auth jwt
108 UBC 0 : fn check_auth_jwt(
109 0 : &mut self,
110 0 : _pgb: &mut PostgresBackend<IO>,
111 0 : _jwt_response: &[u8],
112 0 : ) -> Result<(), QueryError> {
113 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
114 0 : }
115 : }
116 :
117 : /// PostgresBackend protocol state.
118 : /// XXX: The order of the constructors matters.
119 CBC 52265 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
120 : pub enum ProtoState {
121 : /// Nothing happened yet.
122 : Initialization,
123 : /// Encryption handshake is done; waiting for encrypted Startup message.
124 : Encrypted,
125 : /// Waiting for password (auth token).
126 : Authentication,
127 : /// Performed handshake and auth, ReadyForQuery is issued.
128 : Established,
129 : Closed,
130 : }
131 :
132 UBC 0 : #[derive(Clone, Copy)]
133 : pub enum ProcessMsgResult {
134 : Continue,
135 : Break,
136 : }
137 :
138 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
139 : pub enum MaybeTlsStream<IO> {
140 : Unencrypted(IO),
141 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
142 : }
143 :
144 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
145 CBC 6493844 : fn poll_write(
146 6493844 : self: Pin<&mut Self>,
147 6493844 : cx: &mut std::task::Context<'_>,
148 6493844 : buf: &[u8],
149 6493844 : ) -> Poll<io::Result<usize>> {
150 6493844 : match self.get_mut() {
151 6493842 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
152 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
153 : }
154 6493844 : }
155 6508846 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
156 6508846 : match self.get_mut() {
157 6508844 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
158 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
159 : }
160 6508846 : }
161 10471 : fn poll_shutdown(
162 10471 : self: Pin<&mut Self>,
163 10471 : cx: &mut std::task::Context<'_>,
164 10471 : ) -> Poll<io::Result<()>> {
165 10471 : match self.get_mut() {
166 10471 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
167 UBC 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
168 : }
169 CBC 10471 : }
170 : }
171 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
172 13787171 : fn poll_read(
173 13787171 : self: Pin<&mut Self>,
174 13787171 : cx: &mut std::task::Context<'_>,
175 13787171 : buf: &mut tokio::io::ReadBuf<'_>,
176 13787171 : ) -> Poll<io::Result<()>> {
177 13787171 : match self.get_mut() {
178 13787166 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
179 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
180 : }
181 13787171 : }
182 : }
183 :
184 59558 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
185 : pub enum AuthType {
186 : Trust,
187 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
188 : NeonJWT,
189 : }
190 :
191 : impl FromStr for AuthType {
192 : type Err = anyhow::Error;
193 :
194 1812 : fn from_str(s: &str) -> Result<Self, Self::Err> {
195 1812 : match s {
196 1812 : "Trust" => Ok(Self::Trust),
197 44 : "NeonJWT" => Ok(Self::NeonJWT),
198 UBC 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
199 : }
200 CBC 1812 : }
201 : }
202 :
203 : impl fmt::Display for AuthType {
204 1812 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 1812 : f.write_str(match self {
206 1768 : AuthType::Trust => "Trust",
207 44 : AuthType::NeonJWT => "NeonJWT",
208 : })
209 1812 : }
210 : }
211 :
212 : /// Either full duplex Framed or write only half; the latter is left in
213 : /// PostgresBackend after call to `split`. In principle we could always store a
214 : /// pair of splitted handles, but that would force to to pay splitting price
215 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
216 : enum MaybeWriteOnly<IO> {
217 : Full(Framed<MaybeTlsStream<IO>>),
218 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
219 : Broken, // temporary value palmed off during the split
220 : }
221 :
222 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
223 20528 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
224 20528 : match self {
225 20528 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
226 : MaybeWriteOnly::WriteOnly(_) => {
227 UBC 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
228 : }
229 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
230 : }
231 CBC 20528 : }
232 :
233 3758127 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
234 3758127 : match self {
235 3758127 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
236 : MaybeWriteOnly::WriteOnly(_) => {
237 UBC 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
238 : }
239 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
240 : }
241 CBC 3757922 : }
242 :
243 6507626 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
244 6507626 : match self {
245 4638125 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
246 1869501 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
247 UBC 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
248 : }
249 CBC 6507626 : }
250 :
251 6506561 : async fn flush(&mut self) -> io::Result<()> {
252 6506561 : match self {
253 4637060 : MaybeWriteOnly::Full(framed) => framed.flush().await,
254 1869501 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
255 UBC 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
256 : }
257 CBC 6498717 : }
258 :
259 : /// Cancellation safe as long as the underlying IO is cancellation safe.
260 10764 : async fn shutdown(&mut self) -> io::Result<()> {
261 10764 : match self {
262 10764 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
263 UBC 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
264 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
265 : }
266 CBC 10764 : }
267 : }
268 :
269 : pub struct PostgresBackend<IO> {
270 : framed: MaybeWriteOnly<IO>,
271 :
272 : pub state: ProtoState,
273 :
274 : auth_type: AuthType,
275 :
276 : peer_addr: SocketAddr,
277 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
278 : }
279 :
280 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
281 :
282 UBC 0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
283 0 : let mut query_string = query_string.to_vec();
284 0 : if let Some(ch) = query_string.last() {
285 0 : if *ch == 0 {
286 0 : query_string.pop();
287 0 : }
288 0 : }
289 0 : query_string
290 0 : }
291 :
292 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
293 CBC 11945 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
294 11945 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
295 11945 : std::str::from_utf8(without_null).map_err(|e| e.into())
296 11945 : }
297 :
298 : impl PostgresBackend<tokio::net::TcpStream> {
299 5 : pub fn new(
300 5 : socket: tokio::net::TcpStream,
301 5 : auth_type: AuthType,
302 5 : tls_config: Option<Arc<rustls::ServerConfig>>,
303 5 : ) -> io::Result<Self> {
304 5 : let peer_addr = socket.peer_addr()?;
305 5 : let stream = MaybeTlsStream::Unencrypted(socket);
306 5 :
307 5 : Ok(Self {
308 5 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
309 5 : state: ProtoState::Initialization,
310 5 : auth_type,
311 5 : tls_config,
312 5 : peer_addr,
313 5 : })
314 5 : }
315 : }
316 :
317 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
318 11204 : pub fn new_from_io(
319 11204 : socket: IO,
320 11204 : peer_addr: SocketAddr,
321 11204 : auth_type: AuthType,
322 11204 : tls_config: Option<Arc<rustls::ServerConfig>>,
323 11204 : ) -> io::Result<Self> {
324 11204 : let stream = MaybeTlsStream::Unencrypted(socket);
325 11204 :
326 11204 : Ok(Self {
327 11204 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
328 11204 : state: ProtoState::Initialization,
329 11204 : auth_type,
330 11204 : tls_config,
331 11204 : peer_addr,
332 11204 : })
333 11204 : }
334 :
335 2479 : pub fn get_peer_addr(&self) -> &SocketAddr {
336 2479 : &self.peer_addr
337 2479 : }
338 :
339 : /// Read full message or return None if connection is cleanly closed with no
340 : /// unprocessed data.
341 3757934 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
342 3757933 : if let ProtoState::Closed = self.state {
343 21 : Ok(None)
344 : } else {
345 3757912 : match self.framed.read_message().await {
346 3757676 : Ok(m) => {
347 UBC 0 : trace!("read msg {:?}", m);
348 CBC 3757676 : Ok(m)
349 : }
350 31 : Err(e) => {
351 31 : // remember not to try to read anymore
352 31 : self.state = ProtoState::Closed;
353 31 : Err(e)
354 : }
355 : }
356 : }
357 3757728 : }
358 :
359 : /// Write message into internal output buffer, doesn't flush it. Technically
360 : /// error type can be only ProtocolError here (if, unlikely, serialization
361 : /// fails), but callers typically wrap it anyway.
362 6507626 : pub fn write_message_noflush(
363 6507626 : &mut self,
364 6507626 : message: &BeMessage<'_>,
365 6507626 : ) -> Result<&mut Self, ConnectionError> {
366 6507626 : self.framed.write_message_noflush(message)?;
367 6507626 : trace!("wrote msg {:?}", message);
368 6507626 : Ok(self)
369 6507626 : }
370 :
371 : /// Flush output buffer into the socket.
372 6506561 : pub async fn flush(&mut self) -> io::Result<()> {
373 6506561 : self.framed.flush().await
374 6498717 : }
375 :
376 : /// Polling version of `flush()`, saves the caller need to pin.
377 947031 : pub fn poll_flush(
378 947031 : &mut self,
379 947031 : cx: &mut std::task::Context<'_>,
380 947031 : ) -> Poll<Result<(), std::io::Error>> {
381 947031 : let flush_fut = self.flush();
382 947031 : pin_mut!(flush_fut);
383 947031 : flush_fut.poll(cx)
384 947031 : }
385 :
386 : /// Write message into internal output buffer and flush it to the stream.
387 1892785 : pub async fn write_message(
388 1892785 : &mut self,
389 1892785 : message: &BeMessage<'_>,
390 1892785 : ) -> Result<&mut Self, ConnectionError> {
391 1892785 : self.write_message_noflush(message)?;
392 1892785 : self.flush().await?;
393 1892743 : Ok(self)
394 1892763 : }
395 :
396 : /// Returns an AsyncWrite implementation that wraps all the data written
397 : /// to it in CopyData messages, and writes them to the connection
398 : ///
399 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
400 559 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
401 559 : CopyDataWriter { pgb: self }
402 559 : }
403 :
404 : /// Wrapper for run_message_loop() that shuts down socket when we are done
405 11209 : pub async fn run<F, S>(
406 11209 : mut self,
407 11209 : handler: &mut impl Handler<IO>,
408 11209 : shutdown_watcher: F,
409 11209 : ) -> Result<(), QueryError>
410 11209 : where
411 11209 : F: Fn() -> S + Clone,
412 11209 : S: Future,
413 11209 : {
414 11209 : let ret = self
415 11209 : .run_message_loop(handler, shutdown_watcher.clone())
416 9905962 : .await;
417 :
418 10764 : tokio::select! {
419 10765 : _ = shutdown_watcher() => {
420 10765 : // do nothing; we most likely got already stopped by shutdown and will log it next.
421 10765 : }
422 10765 : _ = self.framed.shutdown() => {
423 10765 : // socket might be already closed, e.g. if previously received error,
424 10765 : // so ignore result.
425 10765 : },
426 10765 : }
427 :
428 378 : match ret {
429 10386 : Ok(()) => Ok(()),
430 : Err(QueryError::Shutdown) => {
431 UBC 0 : info!("Stopped due to shutdown");
432 0 : Ok(())
433 : }
434 : Err(QueryError::Reconnect) => {
435 : // Dropping out of this loop implicitly disconnects
436 0 : info!("Stopped due to handler reconnect request");
437 0 : Ok(())
438 : }
439 CBC 320 : Err(QueryError::Disconnected(e)) => {
440 320 : info!("Disconnected ({e:#})");
441 : // Disconnection is not an error: we just use it that way internally to drop
442 : // out of loops.
443 320 : Ok(())
444 : }
445 58 : e => e,
446 : }
447 10764 : }
448 :
449 11209 : async fn run_message_loop<F, S>(
450 11209 : &mut self,
451 11209 : handler: &mut impl Handler<IO>,
452 11209 : shutdown_watcher: F,
453 11209 : ) -> Result<(), QueryError>
454 11209 : where
455 11209 : F: Fn() -> S,
456 11209 : S: Future,
457 11209 : {
458 UBC 0 : trace!("postgres backend to {:?} started", self.peer_addr);
459 :
460 CBC 11209 : tokio::select!(
461 : biased;
462 :
463 : _ = shutdown_watcher() => {
464 : // We were requested to shut down.
465 UBC 0 : tracing::info!("shutdown request received during handshake");
466 : return Err(QueryError::Shutdown)
467 : },
468 :
469 CBC 11209 : handshake_r = self.handshake(handler) => {
470 : handshake_r?;
471 : }
472 : );
473 :
474 : // Authentication completed
475 11203 : let mut query_string = Bytes::new();
476 26093 : while let Some(msg) = tokio::select!(
477 : biased;
478 : _ = shutdown_watcher() => {
479 : // We were requested to shut down.
480 UBC 0 : tracing::info!("shutdown request received in run_message_loop");
481 : return Err(QueryError::Shutdown)
482 : },
483 CBC 26091 : msg = self.read_message() => { msg },
484 22 : )? {
485 UBC 0 : trace!("got message {:?}", msg);
486 :
487 CBC 9872183 : let result = self.process_message(handler, msg, &mut query_string).await;
488 17148 : tokio::select!(
489 : biased;
490 : _ = shutdown_watcher() => {
491 : // We were requested to shut down.
492 UBC 0 : tracing::info!("shutdown request received during response flush");
493 :
494 : // If we exited process_message with a shutdown error, there may be
495 : // some valid response content on in our transmit buffer: permit sending
496 : // this within a short timeout. This is a best effort thing so we don't
497 : // care about the result.
498 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
499 :
500 : return Err(QueryError::Shutdown)
501 : },
502 CBC 17148 : flush_r = self.flush() => {
503 : flush_r?;
504 : }
505 : );
506 :
507 16855 : match result? {
508 : ProcessMsgResult::Continue => {
509 14890 : continue;
510 : }
511 1908 : ProcessMsgResult::Break => break,
512 : }
513 : }
514 :
515 UBC 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
516 CBC 10386 : Ok(())
517 10764 : }
518 :
519 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
520 1 : async fn tls_upgrade(
521 1 : src: MaybeTlsStream<IO>,
522 1 : tls_config: Arc<rustls::ServerConfig>,
523 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
524 1 : match src {
525 1 : MaybeTlsStream::Unencrypted(s) => {
526 1 : let acceptor = TlsAcceptor::from(tls_config);
527 2 : let tls_stream = acceptor.accept(s).await?;
528 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
529 : }
530 : MaybeTlsStream::Tls(_) => {
531 UBC 0 : anyhow::bail!("TLS already started");
532 : }
533 : }
534 CBC 1 : }
535 :
536 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
537 1 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
538 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
539 1 : MaybeWriteOnly::Full(framed) => {
540 1 : let tls_config = self
541 1 : .tls_config
542 1 : .as_ref()
543 1 : .context("start_tls called without conf")?
544 1 : .clone();
545 1 : let tls_framed = framed
546 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
547 2 : .await?;
548 : // push back ready TLS stream
549 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
550 1 : Ok(())
551 : }
552 : MaybeWriteOnly::WriteOnly(_) => {
553 UBC 0 : anyhow::bail!("TLS upgrade attempt in split state")
554 : }
555 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
556 : }
557 CBC 1 : }
558 :
559 : /// Split off owned read part from which messages can be read in different
560 : /// task/thread.
561 2479 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
562 2479 : // temporary replace stream with fake to cook split one, Indiana Jones style
563 2479 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
564 2479 : MaybeWriteOnly::Full(framed) => {
565 2479 : let (reader, writer) = framed.split();
566 2479 : self.framed = MaybeWriteOnly::WriteOnly(writer);
567 2479 : Ok(PostgresBackendReader {
568 2479 : reader,
569 2479 : closed: false,
570 2479 : })
571 : }
572 : MaybeWriteOnly::WriteOnly(_) => {
573 UBC 0 : anyhow::bail!("PostgresBackend is already split")
574 : }
575 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
576 : }
577 CBC 2479 : }
578 :
579 : /// Join read part back.
580 2089 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
581 2089 : // temporary replace stream with fake to cook joined one, Indiana Jones style
582 2089 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
583 : MaybeWriteOnly::Full(_) => {
584 UBC 0 : anyhow::bail!("PostgresBackend is not split")
585 : }
586 CBC 2089 : MaybeWriteOnly::WriteOnly(writer) => {
587 2089 : let joined = Framed::unsplit(reader.reader, writer);
588 2089 : self.framed = MaybeWriteOnly::Full(joined);
589 2089 : // if reader encountered connection error, do not attempt reading anymore
590 2089 : if reader.closed {
591 215 : self.state = ProtoState::Closed;
592 1874 : }
593 2089 : Ok(())
594 : }
595 UBC 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
596 : }
597 CBC 2089 : }
598 :
599 : /// Perform handshake with the client, transitioning to Established.
600 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
601 11209 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
602 31737 : while self.state < ProtoState::Authentication {
603 20528 : match self.framed.read_startup_message().await? {
604 20528 : Some(msg) => {
605 20528 : self.process_startup_message(handler, msg).await?;
606 : }
607 : None => {
608 UBC 0 : trace!(
609 0 : "postgres backend to {:?} received EOF during handshake",
610 0 : self.peer_addr
611 0 : );
612 LBC (1) : self.state = ProtoState::Closed;
613 (1) : return Err(QueryError::Disconnected(ConnectionError::Protocol(
614 (1) : ProtocolError::Protocol("EOF during handshake".to_string()),
615 (1) : )));
616 : }
617 : }
618 : }
619 :
620 : // Perform auth, if needed.
621 CBC 11209 : if self.state == ProtoState::Authentication {
622 215 : match self.framed.read_message().await? {
623 210 : Some(FeMessage::PasswordMessage(m)) => {
624 210 : assert!(self.auth_type == AuthType::NeonJWT);
625 :
626 210 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
627 :
628 210 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
629 1 : self.write_message_noflush(&BeMessage::ErrorResponse(
630 1 : &short_error(&e),
631 1 : Some(e.pg_error_code()),
632 1 : ))?;
633 1 : return Err(e);
634 209 : }
635 209 :
636 209 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
637 209 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
638 209 : .write_message(&BeMessage::ReadyForQuery)
639 UBC 0 : .await?;
640 CBC 209 : self.state = ProtoState::Established;
641 : }
642 UBC 0 : Some(m) => {
643 0 : return Err(QueryError::Other(anyhow::anyhow!(
644 0 : "Unexpected message {:?} while waiting for handshake",
645 0 : m
646 0 : )));
647 : }
648 : None => {
649 0 : trace!(
650 0 : "postgres backend to {:?} received EOF during auth",
651 0 : self.peer_addr
652 0 : );
653 CBC 5 : self.state = ProtoState::Closed;
654 5 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
655 5 : ProtocolError::Protocol("EOF during auth".to_string()),
656 5 : )));
657 : }
658 : }
659 10994 : }
660 :
661 11203 : Ok(())
662 11209 : }
663 :
664 : /// Process startup packet:
665 : /// - transition to Established if auth type is trust
666 : /// - transition to Authentication if auth type is NeonJWT.
667 : /// - or perform TLS handshake -- then need to call this again to receive
668 : /// actual startup packet.
669 20528 : async fn process_startup_message(
670 20528 : &mut self,
671 20528 : handler: &mut impl Handler<IO>,
672 20528 : msg: FeStartupPacket,
673 20528 : ) -> Result<(), QueryError> {
674 20528 : assert!(self.state < ProtoState::Authentication);
675 20528 : let have_tls = self.tls_config.is_some();
676 20528 : match msg {
677 : FeStartupPacket::SslRequest => {
678 UBC 0 : debug!("SSL requested");
679 :
680 CBC 9319 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
681 UBC 0 : .await?;
682 :
683 CBC 9319 : if have_tls {
684 2 : self.start_tls().await?;
685 1 : self.state = ProtoState::Encrypted;
686 9318 : }
687 : }
688 : FeStartupPacket::GssEncRequest => {
689 UBC 0 : debug!("GSS requested");
690 0 : self.write_message(&BeMessage::EncryptionResponse(false))
691 0 : .await?;
692 : }
693 : FeStartupPacket::StartupMessage { .. } => {
694 CBC 11209 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
695 UBC 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
696 0 : .await?;
697 0 : return Err(QueryError::Other(anyhow::anyhow!(
698 0 : "client did not connect with TLS"
699 0 : )));
700 CBC 11209 : }
701 11209 :
702 11209 : // NB: startup() may change self.auth_type -- we are using that in proxy code
703 11209 : // to bypass auth for new users.
704 11209 : handler.startup(self, &msg)?;
705 :
706 11209 : match self.auth_type {
707 : AuthType::Trust => {
708 10994 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
709 10994 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
710 10994 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
711 : // The async python driver requires a valid server_version
712 10994 : .write_message_noflush(&BeMessage::server_version("14.1"))?
713 10994 : .write_message(&BeMessage::ReadyForQuery)
714 UBC 0 : .await?;
715 CBC 10994 : self.state = ProtoState::Established;
716 : }
717 : AuthType::NeonJWT => {
718 215 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
719 UBC 0 : .await?;
720 CBC 215 : self.state = ProtoState::Authentication;
721 : }
722 : }
723 : }
724 : FeStartupPacket::CancelRequest { .. } => {
725 UBC 0 : return Err(QueryError::Other(anyhow::anyhow!(
726 0 : "Unexpected CancelRequest message during handshake"
727 0 : )));
728 : }
729 : }
730 CBC 20528 : Ok(())
731 20528 : }
732 :
733 17591 : async fn process_message(
734 17591 : &mut self,
735 17591 : handler: &mut impl Handler<IO>,
736 17591 : msg: FeMessage,
737 17591 : unnamed_query_string: &mut Bytes,
738 17591 : ) -> Result<ProcessMsgResult, QueryError> {
739 17591 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
740 17591 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
741 17591 : assert!(self.state == ProtoState::Established);
742 :
743 17591 : match msg {
744 11393 : FeMessage::Query(body) => {
745 : // remove null terminator
746 11393 : let query_string = cstr_to_str(&body)?;
747 :
748 UBC 0 : trace!("got query {query_string:?}");
749 CBC 9847474 : if let Err(e) = handler.process_query(self, query_string).await {
750 745 : match e {
751 181 : QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
752 : QueryError::SimulatedConnectionError => {
753 GBC 7 : return Err(QueryError::SimulatedConnectionError)
754 : }
755 CBC 557 : e => {
756 557 : log_query_error(query_string, &e);
757 557 : let short_error = short_error(&e);
758 557 : self.write_message_noflush(&BeMessage::ErrorResponse(
759 557 : &short_error,
760 557 : Some(e.pg_error_code()),
761 557 : ))?;
762 : }
763 : }
764 10206 : }
765 10762 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
766 : }
767 :
768 552 : FeMessage::Parse(m) => {
769 552 : *unnamed_query_string = m.query_string;
770 552 : self.write_message_noflush(&BeMessage::ParseComplete)?;
771 : }
772 :
773 : FeMessage::Describe(_) => {
774 552 : self.write_message_noflush(&BeMessage::ParameterDescription)?
775 552 : .write_message_noflush(&BeMessage::NoData)?;
776 : }
777 :
778 : FeMessage::Bind(_) => {
779 552 : self.write_message_noflush(&BeMessage::BindComplete)?;
780 : }
781 :
782 : FeMessage::Close(_) => {
783 551 : self.write_message_noflush(&BeMessage::CloseComplete)?;
784 : }
785 :
786 : FeMessage::Execute(_) => {
787 552 : let query_string = cstr_to_str(unnamed_query_string)?;
788 UBC 0 : trace!("got execute {query_string:?}");
789 CBC 24709 : if let Err(e) = handler.process_query(self, query_string).await {
790 8 : log_query_error(query_string, &e);
791 8 : self.write_message_noflush(&BeMessage::ErrorResponse(
792 8 : &e.to_string(),
793 8 : Some(e.pg_error_code()),
794 8 : ))?;
795 544 : }
796 : // NOTE there is no ReadyForQuery message. This handler is used
797 : // for basebackup and it uses CopyOut which doesn't require
798 : // ReadyForQuery message and backend just switches back to
799 : // processing mode after sending CopyDone or ErrorResponse.
800 : }
801 :
802 : FeMessage::Sync => {
803 1662 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
804 : }
805 :
806 : FeMessage::Terminate => {
807 1727 : return Ok(ProcessMsgResult::Break);
808 : }
809 :
810 : // We prefer explicit pattern matching to wildcards, because
811 : // this helps us spot the places where new variants are missing
812 : FeMessage::CopyData(_)
813 : | FeMessage::CopyDone
814 : | FeMessage::CopyFail
815 : | FeMessage::PasswordMessage(_) => {
816 50 : return Err(QueryError::Other(anyhow::anyhow!(
817 50 : "unexpected message type: {msg:?}",
818 50 : )));
819 : }
820 : }
821 :
822 15183 : Ok(ProcessMsgResult::Continue)
823 17148 : }
824 :
825 : /// Log as info/error result of handling COPY stream and send back
826 : /// ErrorResponse if that makes sense. Shutdown the stream if we got
827 : /// Terminate. TODO: transition into waiting for Sync msg if we initiate the
828 : /// close.
829 2089 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
830 : use CopyStreamHandlerEnd::*;
831 :
832 2089 : let expected_end = match &end {
833 1854 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
834 234 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
835 234 : if is_expected_io_error(io_error) =>
836 234 : {
837 234 : true
838 : }
839 1 : _ => false,
840 : };
841 2089 : if expected_end {
842 2088 : info!("terminated: {:#}", end);
843 : } else {
844 1 : error!("terminated: {:?}", end);
845 : }
846 :
847 : // Note: no current usages ever send this
848 2089 : if let CopyDone = &end {
849 UBC 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
850 0 : error!("failed to send CopyDone: {}", e);
851 0 : }
852 CBC 2089 : }
853 :
854 2089 : if let Terminate = &end {
855 19 : self.state = ProtoState::Closed;
856 2070 : }
857 :
858 2089 : let err_to_send_and_errcode = match &end {
859 48 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
860 1 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
861 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
862 : // PG walsender; evidently and per my docs reading client should
863 : // finish it with CopyDone). It is not a problem to recover from it
864 : // finishing the stream in both directions like we do, but note that
865 : // sync rust-postgres client (which we don't use anymore) hangs if
866 : // socket is not closed here.
867 : // https://github.com/sfackler/rust-postgres/issues/755
868 : // https://github.com/neondatabase/neon/issues/935
869 : //
870 : // Currently, the version of tokio_postgres replication patch we use
871 : // sends this when it closes the stream (e.g. pageserver decided to
872 : // switch conn to another safekeeper and client gets dropped).
873 : // Moreover, seems like 'connection' task errors with 'unexpected
874 : // message from server' when it receives ErrorResponse (anything but
875 : // CopyData/CopyDone) back.
876 19 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
877 2021 : _ => None,
878 : };
879 2089 : if let Some((err, errcode)) = err_to_send_and_errcode {
880 68 : if let Err(ee) = self
881 68 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
882 GBC 1 : .await
883 : {
884 CBC 1 : error!("failed to send ErrorResponse: {}", ee);
885 67 : }
886 2021 : }
887 2089 : }
888 : }
889 :
890 : pub struct PostgresBackendReader<IO> {
891 : reader: FramedReader<MaybeTlsStream<IO>>,
892 : closed: bool, // true if received error closing the connection
893 : }
894 :
895 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
896 : /// Read full message or return None if connection is cleanly closed with no
897 : /// unprocessed data.
898 2549745 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
899 4817654 : match self.reader.read_message().await {
900 2549072 : Ok(m) => {
901 UBC 0 : trace!("read msg {:?}", m);
902 CBC 2549072 : Ok(m)
903 : }
904 215 : Err(e) => {
905 215 : self.closed = true;
906 215 : Err(e)
907 : }
908 : }
909 2549287 : }
910 :
911 : /// Get CopyData contents of the next message in COPY stream or error
912 : /// closing it. The error type is wider than actual errors which can happen
913 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
914 : /// current callers.
915 2549745 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
916 4817654 : match self.read_message().await? {
917 2547304 : Some(msg) => match msg {
918 2547266 : FeMessage::CopyData(m) => Ok(m),
919 UBC 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
920 CBC 19 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
921 19 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
922 UBC 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
923 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
924 0 : ))),
925 : },
926 CBC 1768 : None => Err(CopyStreamHandlerEnd::EOF),
927 : }
928 2549287 : }
929 : }
930 :
931 : ///
932 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
933 : /// messages.
934 : ///
935 :
936 : pub struct CopyDataWriter<'a, IO> {
937 : pgb: &'a mut PostgresBackend<IO>,
938 : }
939 :
940 : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
941 908485 : fn poll_write(
942 908485 : self: Pin<&mut Self>,
943 908485 : cx: &mut std::task::Context<'_>,
944 908485 : buf: &[u8],
945 908485 : ) -> Poll<Result<usize, std::io::Error>> {
946 908485 : let this = self.get_mut();
947 :
948 : // It's not strictly required to flush between each message, but makes it easier
949 : // to view in wireshark, and usually the messages that the callers write are
950 : // decently-sized anyway.
951 908485 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
952 UBC 0 : return Poll::Ready(Err(err));
953 CBC 900834 : }
954 900834 :
955 900834 : // CopyData
956 900834 : // XXX: if the input is large, we should split it into multiple messages.
957 900834 : // Not sure what the threshold should be, but the ultimate hard limit is that
958 900834 : // the length cannot exceed u32.
959 900834 : this.pgb
960 900834 : .write_message_noflush(&BeMessage::CopyData(buf))
961 900834 : // write_message only writes to the buffer, so it can fail iff the
962 900834 : // message is invaid, but CopyData can't be invalid.
963 900834 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
964 :
965 900834 : Poll::Ready(Ok(buf.len()))
966 908485 : }
967 :
968 38382 : fn poll_flush(
969 38382 : self: Pin<&mut Self>,
970 38382 : cx: &mut std::task::Context<'_>,
971 38382 : ) -> Poll<Result<(), std::io::Error>> {
972 38382 : let this = self.get_mut();
973 38382 : this.pgb.poll_flush(cx)
974 38382 : }
975 :
976 164 : fn poll_shutdown(
977 164 : self: Pin<&mut Self>,
978 164 : cx: &mut std::task::Context<'_>,
979 164 : ) -> Poll<Result<(), std::io::Error>> {
980 164 : let this = self.get_mut();
981 164 : this.pgb.poll_flush(cx)
982 164 : }
983 : }
984 :
985 557 : pub fn short_error(e: &QueryError) -> String {
986 557 : match e {
987 38 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
988 UBC 0 : QueryError::Reconnect => "reconnect".to_string(),
989 0 : QueryError::Shutdown => "shutdown".to_string(),
990 0 : QueryError::NotFound(_) => "not found".to_string(),
991 CBC 5 : QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
992 UBC 0 : QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
993 CBC 514 : QueryError::Other(e) => format!("{e:#}"),
994 : }
995 557 : }
996 :
997 565 : fn log_query_error(query: &str, e: &QueryError) {
998 38 : match e {
999 38 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
1000 38 : if is_expected_io_error(io_error) {
1001 38 : info!("query handler for '{query}' failed with expected io error: {io_error}");
1002 : } else {
1003 UBC 0 : error!("query handler for '{query}' failed with io error: {io_error}");
1004 : }
1005 : }
1006 0 : QueryError::Disconnected(other_connection_error) => {
1007 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
1008 : }
1009 : QueryError::SimulatedConnectionError => {
1010 0 : error!("query handler for query '{query}' failed due to a simulated connection error")
1011 : }
1012 : QueryError::Reconnect => {
1013 0 : info!("query handler for '{query}' requested client to reconnect")
1014 : }
1015 : QueryError::Shutdown => {
1016 0 : info!("query handler for '{query}' cancelled during tenant shutdown")
1017 : }
1018 0 : QueryError::NotFound(reason) => {
1019 0 : info!("query handler for '{query}' entity not found: {reason}")
1020 : }
1021 CBC 4 : QueryError::Unauthorized(e) => {
1022 4 : warn!("query handler for '{query}' failed with authentication error: {e}");
1023 : }
1024 523 : QueryError::Other(e) => {
1025 523 : error!("query handler for '{query}' failed: {e:?}");
1026 : }
1027 : }
1028 565 : }
1029 :
1030 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
1031 : /// This is not always a real error, but it allows to use ? and thiserror impls.
1032 2156 : #[derive(thiserror::Error, Debug)]
1033 : pub enum CopyStreamHandlerEnd {
1034 : /// Handler initiates the end of streaming.
1035 : #[error("{0}")]
1036 : ServerInitiated(String),
1037 : #[error("received CopyDone")]
1038 : CopyDone,
1039 : #[error("received CopyFail")]
1040 : CopyFail,
1041 : #[error("received Terminate")]
1042 : Terminate,
1043 : #[error("EOF on COPY stream")]
1044 : EOF,
1045 : /// The connection was lost
1046 : #[error("connection error: {0}")]
1047 : Disconnected(#[from] ConnectionError),
1048 : /// Some other error
1049 : #[error(transparent)]
1050 : Other(#[from] anyhow::Error),
1051 : }
|