TLA Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : use anyhow::Context;
6 : use bytes::Bytes;
7 : use futures::pin_mut;
8 : use serde::{Deserialize, Serialize};
9 : use std::io::ErrorKind;
10 : use std::net::SocketAddr;
11 : use std::pin::Pin;
12 : use std::sync::Arc;
13 : use std::task::{ready, Poll};
14 : use std::{fmt, io};
15 : use std::{future::Future, str::FromStr};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tokio_rustls::TlsAcceptor;
18 : use tracing::{debug, error, info, trace};
19 :
20 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
21 : use pq_proto::{
22 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
23 : SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
24 : };
25 :
26 : /// An error, occurred during query processing:
27 : /// either during the connection ([`ConnectionError`]) or before/after it.
28 CBC 74 : #[derive(thiserror::Error, Debug)]
29 : pub enum QueryError {
30 : /// The connection was lost while processing the query.
31 : #[error(transparent)]
32 : Disconnected(#[from] ConnectionError),
33 : /// We were instructed to shutdown while processing the query
34 : #[error("Shutting down")]
35 : Shutdown,
36 : /// Some other error
37 : #[error(transparent)]
38 : Other(#[from] anyhow::Error),
39 : }
40 :
41 : impl From<io::Error> for QueryError {
42 360 : fn from(e: io::Error) -> Self {
43 360 : Self::Disconnected(ConnectionError::Io(e))
44 360 : }
45 : }
46 :
47 : impl QueryError {
48 167 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
49 167 : match self {
50 7 : Self::Disconnected(_) => b"08006", // connection failure
51 109 : Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
52 51 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
53 : }
54 167 : }
55 : }
56 :
57 : /// Returns true if the given error is a normal consequence of a network issue,
58 : /// or the client closing the connection. These errors can happen during normal
59 : /// operations, and don't indicate a bug in our code.
60 873 : pub fn is_expected_io_error(e: &io::Error) -> bool {
61 : use io::ErrorKind::*;
62 UBC 0 : matches!(
63 CBC 873 : e.kind(),
64 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
65 : )
66 873 : }
67 :
68 : #[async_trait::async_trait]
69 : pub trait Handler<IO> {
70 : /// Handle single query.
71 : /// postgres_backend will issue ReadyForQuery after calling this (this
72 : /// might be not what we want after CopyData streaming, but currently we don't
73 : /// care). It will also flush out the output buffer.
74 : async fn process_query(
75 : &mut self,
76 : pgb: &mut PostgresBackend<IO>,
77 : query_string: &str,
78 : ) -> Result<(), QueryError>;
79 :
80 : /// Called on startup packet receival, allows to process params.
81 : ///
82 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
83 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
84 : /// to override whole init logic in implementations.
85 5 : fn startup(
86 5 : &mut self,
87 5 : _pgb: &mut PostgresBackend<IO>,
88 5 : _sm: &FeStartupPacket,
89 5 : ) -> Result<(), QueryError> {
90 5 : Ok(())
91 5 : }
92 :
93 : /// Check auth jwt
94 UBC 0 : fn check_auth_jwt(
95 0 : &mut self,
96 0 : _pgb: &mut PostgresBackend<IO>,
97 0 : _jwt_response: &[u8],
98 0 : ) -> Result<(), QueryError> {
99 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
100 0 : }
101 : }
102 :
103 : /// PostgresBackend protocol state.
104 : /// XXX: The order of the constructors matters.
105 CBC 39281 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
106 : pub enum ProtoState {
107 : /// Nothing happened yet.
108 : Initialization,
109 : /// Encryption handshake is done; waiting for encrypted Startup message.
110 : Encrypted,
111 : /// Waiting for password (auth token).
112 : Authentication,
113 : /// Performed handshake and auth, ReadyForQuery is issued.
114 : Established,
115 : Closed,
116 : }
117 :
118 UBC 0 : #[derive(Clone, Copy)]
119 : pub enum ProcessMsgResult {
120 : Continue,
121 : Break,
122 : }
123 :
124 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
125 : pub enum MaybeTlsStream<IO> {
126 : Unencrypted(IO),
127 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
128 : }
129 :
130 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
131 CBC 8034540 : fn poll_write(
132 8034540 : self: Pin<&mut Self>,
133 8034540 : cx: &mut std::task::Context<'_>,
134 8034540 : buf: &[u8],
135 8034540 : ) -> Poll<io::Result<usize>> {
136 8034540 : match self.get_mut() {
137 8034538 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
138 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
139 : }
140 8034540 : }
141 8039271 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
142 8039271 : match self.get_mut() {
143 8039269 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
144 2 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
145 : }
146 8039271 : }
147 7850 : fn poll_shutdown(
148 7850 : self: Pin<&mut Self>,
149 7850 : cx: &mut std::task::Context<'_>,
150 7850 : ) -> Poll<io::Result<()>> {
151 7850 : match self.get_mut() {
152 7850 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
153 UBC 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
154 : }
155 CBC 7850 : }
156 : }
157 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
158 18389325 : fn poll_read(
159 18389325 : self: Pin<&mut Self>,
160 18389325 : cx: &mut std::task::Context<'_>,
161 18389325 : buf: &mut tokio::io::ReadBuf<'_>,
162 18389325 : ) -> Poll<io::Result<()>> {
163 18389325 : match self.get_mut() {
164 18389320 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
165 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
166 : }
167 18389325 : }
168 : }
169 :
170 62830 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
171 : pub enum AuthType {
172 : Trust,
173 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
174 : NeonJWT,
175 : }
176 :
177 : impl FromStr for AuthType {
178 : type Err = anyhow::Error;
179 :
180 1818 : fn from_str(s: &str) -> Result<Self, Self::Err> {
181 1818 : match s {
182 1818 : "Trust" => Ok(Self::Trust),
183 36 : "NeonJWT" => Ok(Self::NeonJWT),
184 UBC 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
185 : }
186 CBC 1818 : }
187 : }
188 :
189 : impl fmt::Display for AuthType {
190 2938 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 2938 : f.write_str(match self {
192 2884 : AuthType::Trust => "Trust",
193 54 : AuthType::NeonJWT => "NeonJWT",
194 : })
195 2938 : }
196 : }
197 :
198 : /// Either full duplex Framed or write only half; the latter is left in
199 : /// PostgresBackend after call to `split`. In principle we could always store a
200 : /// pair of splitted handles, but that would force to to pay splitting price
201 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
202 : enum MaybeWriteOnly<IO> {
203 : Full(Framed<MaybeTlsStream<IO>>),
204 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
205 : Broken, // temporary value palmed off during the split
206 : }
207 :
208 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
209 15299 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
210 15299 : match self {
211 15299 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
212 : MaybeWriteOnly::WriteOnly(_) => {
213 UBC 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
214 : }
215 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
216 : }
217 CBC 15299 : }
218 :
219 3831530 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
220 3831530 : match self {
221 3831530 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
222 : MaybeWriteOnly::WriteOnly(_) => {
223 UBC 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
224 : }
225 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
226 : }
227 CBC 3831381 : }
228 :
229 8033135 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
230 8033135 : match self {
231 4783778 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
232 3249357 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
233 UBC 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
234 : }
235 CBC 8033135 : }
236 :
237 8039458 : async fn flush(&mut self) -> io::Result<()> {
238 8039458 : match self {
239 4790101 : MaybeWriteOnly::Full(framed) => framed.flush().await,
240 3249357 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
241 UBC 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
242 : }
243 CBC 8031794 : }
244 :
245 8210 : async fn shutdown(&mut self) -> io::Result<()> {
246 8210 : match self {
247 8210 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
248 UBC 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
249 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
250 : }
251 CBC 8210 : }
252 : }
253 :
254 : pub struct PostgresBackend<IO> {
255 : framed: MaybeWriteOnly<IO>,
256 :
257 : pub state: ProtoState,
258 :
259 : auth_type: AuthType,
260 :
261 : peer_addr: SocketAddr,
262 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
263 : }
264 :
265 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
266 :
267 UBC 0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
268 0 : let mut query_string = query_string.to_vec();
269 0 : if let Some(ch) = query_string.last() {
270 0 : if *ch == 0 {
271 0 : query_string.pop();
272 0 : }
273 0 : }
274 0 : query_string
275 0 : }
276 :
277 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
278 CBC 9384 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
279 9384 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
280 9384 : std::str::from_utf8(without_null).map_err(|e| e.into())
281 9384 : }
282 :
283 : impl PostgresBackend<tokio::net::TcpStream> {
284 5 : pub fn new(
285 5 : socket: tokio::net::TcpStream,
286 5 : auth_type: AuthType,
287 5 : tls_config: Option<Arc<rustls::ServerConfig>>,
288 5 : ) -> io::Result<Self> {
289 5 : let peer_addr = socket.peer_addr()?;
290 5 : let stream = MaybeTlsStream::Unencrypted(socket);
291 5 :
292 5 : Ok(Self {
293 5 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
294 5 : state: ProtoState::Initialization,
295 5 : auth_type,
296 5 : tls_config,
297 5 : peer_addr,
298 5 : })
299 5 : }
300 : }
301 :
302 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
303 8680 : pub fn new_from_io(
304 8680 : socket: IO,
305 8680 : peer_addr: SocketAddr,
306 8680 : auth_type: AuthType,
307 8680 : tls_config: Option<Arc<rustls::ServerConfig>>,
308 8680 : ) -> io::Result<Self> {
309 8680 : let stream = MaybeTlsStream::Unencrypted(socket);
310 8680 :
311 8680 : Ok(Self {
312 8680 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
313 8680 : state: ProtoState::Initialization,
314 8680 : auth_type,
315 8680 : tls_config,
316 8680 : peer_addr,
317 8680 : })
318 8680 : }
319 :
320 2796 : pub fn get_peer_addr(&self) -> &SocketAddr {
321 2796 : &self.peer_addr
322 2796 : }
323 :
324 : /// Read full message or return None if connection is cleanly closed with no
325 : /// unprocessed data.
326 3831402 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
327 3831400 : if let ProtoState::Closed = self.state {
328 95 : Ok(None)
329 : } else {
330 3831305 : match self.framed.read_message().await {
331 3831110 : Ok(m) => {
332 UBC 0 : trace!("read msg {:?}", m);
333 CBC 3831110 : Ok(m)
334 : }
335 46 : Err(e) => {
336 46 : // remember not to try to read anymore
337 46 : self.state = ProtoState::Closed;
338 46 : Err(e)
339 : }
340 : }
341 : }
342 3831251 : }
343 :
344 : /// Write message into internal output buffer, doesn't flush it. Technically
345 : /// error type can be only ProtocolError here (if, unlikely, serialization
346 : /// fails), but callers typically wrap it anyway.
347 : pub fn write_message_noflush(
348 : &mut self,
349 : message: &BeMessage<'_>,
350 : ) -> Result<&mut Self, ConnectionError> {
351 8033135 : self.framed.write_message_noflush(message)?;
352 8033135 : trace!("wrote msg {:?}", message);
353 8033135 : Ok(self)
354 8033135 : }
355 :
356 : /// Flush output buffer into the socket.
357 8039571 : pub async fn flush(&mut self) -> io::Result<()> {
358 8039458 : self.framed.flush().await
359 8031794 : }
360 :
361 : /// Polling version of `flush()`, saves the caller need to pin.
362 972105 : pub fn poll_flush(
363 972105 : &mut self,
364 972105 : cx: &mut std::task::Context<'_>,
365 972105 : ) -> Poll<Result<(), std::io::Error>> {
366 972105 : let flush_fut = self.flush();
367 972105 : pin_mut!(flush_fut);
368 972105 : flush_fut.poll(cx)
369 972105 : }
370 :
371 : /// Write message into internal output buffer and flush it to the stream.
372 3267740 : pub async fn write_message(
373 3267740 : &mut self,
374 3267740 : message: &BeMessage<'_>,
375 3267740 : ) -> Result<&mut Self, ConnectionError> {
376 3267740 : self.write_message_noflush(message)?;
377 3267740 : self.flush().await?;
378 3267716 : Ok(self)
379 3267729 : }
380 :
381 : /// Returns an AsyncWrite implementation that wraps all the data written
382 : /// to it in CopyData messages, and writes them to the connection
383 : ///
384 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
385 640 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
386 640 : CopyDataWriter { pgb: self }
387 640 : }
388 :
389 : /// Wrapper for run_message_loop() that shuts down socket when we are done
390 8685 : pub async fn run<F, S>(
391 8685 : mut self,
392 8685 : handler: &mut impl Handler<IO>,
393 8685 : shutdown_watcher: F,
394 8685 : ) -> Result<(), QueryError>
395 8685 : where
396 8685 : F: Fn() -> S,
397 8685 : S: Future,
398 8685 : {
399 17074572 : let ret = self.run_message_loop(handler, shutdown_watcher).await;
400 : // socket might be already closed, e.g. if previously received error,
401 : // so ignore result.
402 8210 : self.framed.shutdown().await.ok();
403 575 : match ret {
404 7635 : Ok(()) => Ok(()),
405 : Err(QueryError::Shutdown) => {
406 109 : info!("Stopped due to shutdown");
407 109 : Ok(())
408 : }
409 406 : Err(QueryError::Disconnected(e)) => {
410 406 : info!("Disconnected ({e:#})");
411 : // Disconnection is not an error: we just use it that way internally to drop
412 : // out of loops.
413 406 : Ok(())
414 : }
415 60 : e => e,
416 : }
417 8210 : }
418 :
419 8685 : async fn run_message_loop<F, S>(
420 8685 : &mut self,
421 8685 : handler: &mut impl Handler<IO>,
422 8685 : shutdown_watcher: F,
423 8685 : ) -> Result<(), QueryError>
424 8685 : where
425 8685 : F: Fn() -> S,
426 8685 : S: Future,
427 8685 : {
428 UBC 0 : trace!("postgres backend to {:?} started", self.peer_addr);
429 :
430 CBC 8685 : tokio::select!(
431 : biased;
432 :
433 : _ = shutdown_watcher() => {
434 : // We were requested to shut down.
435 UBC 0 : tracing::info!("shutdown request received during handshake");
436 : return Err(QueryError::Shutdown)
437 : },
438 :
439 CBC 8685 : handshake_r = self.handshake(handler) => {
440 : handshake_r?;
441 : }
442 : );
443 :
444 : // Authentication completed
445 8677 : let mut query_string = Bytes::new();
446 21605 : while let Some(msg) = tokio::select!(
447 : biased;
448 : _ = shutdown_watcher() => {
449 : // We were requested to shut down.
450 UBC 0 : tracing::info!("shutdown request received in run_message_loop");
451 : return Err(QueryError::Shutdown)
452 : },
453 CBC 21603 : msg = self.read_message() => { msg },
454 39 : )? {
455 UBC 0 : trace!("got message {:?}", msg);
456 :
457 CBC 17047547 : let result = self.process_message(handler, msg, &mut query_string).await;
458 14886 : tokio::select!(
459 : biased;
460 : _ = shutdown_watcher() => {
461 : // We were requested to shut down.
462 109 : tracing::info!("shutdown request received during response flush");
463 :
464 : // If we exited process_message with a shutdown error, there may be
465 : // some valid response content on in our transmit buffer: permit sending
466 : // this within a short timeout. This is a best effort thing so we don't
467 : // care about the result.
468 : tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
469 :
470 : return Err(QueryError::Shutdown)
471 : },
472 14777 : flush_r = self.flush() => {
473 : flush_r?;
474 : }
475 : );
476 :
477 14417 : match result? {
478 : ProcessMsgResult::Continue => {
479 12928 : continue;
480 : }
481 1430 : ProcessMsgResult::Break => break,
482 : }
483 : }
484 :
485 UBC 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
486 CBC 7635 : Ok(())
487 8210 : }
488 :
489 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
490 1 : async fn tls_upgrade(
491 1 : src: MaybeTlsStream<IO>,
492 1 : tls_config: Arc<rustls::ServerConfig>,
493 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
494 1 : match src {
495 1 : MaybeTlsStream::Unencrypted(s) => {
496 1 : let acceptor = TlsAcceptor::from(tls_config);
497 2 : let tls_stream = acceptor.accept(s).await?;
498 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
499 : }
500 : MaybeTlsStream::Tls(_) => {
501 UBC 0 : anyhow::bail!("TLS already started");
502 : }
503 : }
504 CBC 1 : }
505 :
506 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
507 1 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
508 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
509 1 : MaybeWriteOnly::Full(framed) => {
510 1 : let tls_config = self
511 1 : .tls_config
512 1 : .as_ref()
513 1 : .context("start_tls called without conf")?
514 1 : .clone();
515 1 : let tls_framed = framed
516 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
517 2 : .await?;
518 : // push back ready TLS stream
519 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
520 1 : Ok(())
521 : }
522 : MaybeWriteOnly::WriteOnly(_) => {
523 UBC 0 : anyhow::bail!("TLS upgrade attempt in split state")
524 : }
525 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
526 : }
527 CBC 1 : }
528 :
529 : /// Split off owned read part from which messages can be read in different
530 : /// task/thread.
531 2796 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
532 2796 : // temporary replace stream with fake to cook split one, Indiana Jones style
533 2796 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
534 2796 : MaybeWriteOnly::Full(framed) => {
535 2796 : let (reader, writer) = framed.split();
536 2796 : self.framed = MaybeWriteOnly::WriteOnly(writer);
537 2796 : Ok(PostgresBackendReader {
538 2796 : reader,
539 2796 : closed: false,
540 2796 : })
541 : }
542 : MaybeWriteOnly::WriteOnly(_) => {
543 UBC 0 : anyhow::bail!("PostgresBackend is already split")
544 : }
545 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
546 : }
547 CBC 2796 : }
548 :
549 : /// Join read part back.
550 2399 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
551 2399 : // temporary replace stream with fake to cook joined one, Indiana Jones style
552 2399 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
553 : MaybeWriteOnly::Full(_) => {
554 UBC 0 : anyhow::bail!("PostgresBackend is not split")
555 : }
556 CBC 2399 : MaybeWriteOnly::WriteOnly(writer) => {
557 2399 : let joined = Framed::unsplit(reader.reader, writer);
558 2399 : self.framed = MaybeWriteOnly::Full(joined);
559 2399 : // if reader encountered connection error, do not attempt reading anymore
560 2399 : if reader.closed {
561 334 : self.state = ProtoState::Closed;
562 2065 : }
563 2399 : Ok(())
564 : }
565 UBC 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
566 : }
567 CBC 2399 : }
568 :
569 : /// Perform handshake with the client, transitioning to Established.
570 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
571 8685 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
572 23983 : while self.state < ProtoState::Authentication {
573 15299 : match self.framed.read_startup_message().await? {
574 15298 : Some(msg) => {
575 15298 : self.process_startup_message(handler, msg).await?;
576 : }
577 : None => {
578 UBC 0 : trace!(
579 0 : "postgres backend to {:?} received EOF during handshake",
580 0 : self.peer_addr
581 0 : );
582 GBC 1 : self.state = ProtoState::Closed;
583 1 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
584 1 : ProtocolError::Protocol("EOF during handshake".to_string()),
585 1 : )));
586 : }
587 : }
588 : }
589 :
590 : // Perform auth, if needed.
591 CBC 8684 : if self.state == ProtoState::Authentication {
592 225 : match self.framed.read_message().await? {
593 219 : Some(FeMessage::PasswordMessage(m)) => {
594 219 : assert!(self.auth_type == AuthType::NeonJWT);
595 :
596 219 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
597 :
598 219 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
599 1 : self.write_message_noflush(&BeMessage::ErrorResponse(
600 1 : &e.to_string(),
601 1 : Some(e.pg_error_code()),
602 1 : ))?;
603 1 : return Err(e);
604 218 : }
605 218 :
606 218 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
607 218 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
608 218 : .write_message(&BeMessage::ReadyForQuery)
609 UBC 0 : .await?;
610 CBC 218 : self.state = ProtoState::Established;
611 : }
612 UBC 0 : Some(m) => {
613 0 : return Err(QueryError::Other(anyhow::anyhow!(
614 0 : "Unexpected message {:?} while waiting for handshake",
615 0 : m
616 0 : )));
617 : }
618 : None => {
619 0 : trace!(
620 0 : "postgres backend to {:?} received EOF during auth",
621 0 : self.peer_addr
622 0 : );
623 CBC 6 : self.state = ProtoState::Closed;
624 6 : return Err(QueryError::Disconnected(ConnectionError::Protocol(
625 6 : ProtocolError::Protocol("EOF during auth".to_string()),
626 6 : )));
627 : }
628 : }
629 8459 : }
630 :
631 8677 : Ok(())
632 8685 : }
633 :
634 : /// Process startup packet:
635 : /// - transition to Established if auth type is trust
636 : /// - transition to Authentication if auth type is NeonJWT.
637 : /// - or perform TLS handshake -- then need to call this again to receive
638 : /// actual startup packet.
639 15298 : async fn process_startup_message(
640 15298 : &mut self,
641 15298 : handler: &mut impl Handler<IO>,
642 15298 : msg: FeStartupPacket,
643 15298 : ) -> Result<(), QueryError> {
644 15298 : assert!(self.state < ProtoState::Authentication);
645 15298 : let have_tls = self.tls_config.is_some();
646 15298 : match msg {
647 : FeStartupPacket::SslRequest => {
648 UBC 0 : debug!("SSL requested");
649 :
650 CBC 6614 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
651 UBC 0 : .await?;
652 :
653 CBC 6614 : if have_tls {
654 2 : self.start_tls().await?;
655 1 : self.state = ProtoState::Encrypted;
656 6613 : }
657 : }
658 : FeStartupPacket::GssEncRequest => {
659 UBC 0 : debug!("GSS requested");
660 0 : self.write_message(&BeMessage::EncryptionResponse(false))
661 0 : .await?;
662 : }
663 : FeStartupPacket::StartupMessage { .. } => {
664 CBC 8684 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
665 UBC 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
666 0 : .await?;
667 0 : return Err(QueryError::Other(anyhow::anyhow!(
668 0 : "client did not connect with TLS"
669 0 : )));
670 CBC 8684 : }
671 8684 :
672 8684 : // NB: startup() may change self.auth_type -- we are using that in proxy code
673 8684 : // to bypass auth for new users.
674 8684 : handler.startup(self, &msg)?;
675 :
676 8684 : match self.auth_type {
677 : AuthType::Trust => {
678 8459 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
679 8459 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
680 8459 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
681 : // The async python driver requires a valid server_version
682 8459 : .write_message_noflush(&BeMessage::server_version("14.1"))?
683 8459 : .write_message(&BeMessage::ReadyForQuery)
684 UBC 0 : .await?;
685 CBC 8459 : self.state = ProtoState::Established;
686 : }
687 : AuthType::NeonJWT => {
688 225 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
689 UBC 0 : .await?;
690 CBC 225 : self.state = ProtoState::Authentication;
691 : }
692 : }
693 : }
694 : FeStartupPacket::CancelRequest { .. } => {
695 UBC 0 : return Err(QueryError::Other(anyhow::anyhow!(
696 0 : "Unexpected CancelRequest message during handshake"
697 0 : )));
698 : }
699 : }
700 CBC 15298 : Ok(())
701 15298 : }
702 :
703 15359 : async fn process_message(
704 15359 : &mut self,
705 15359 : handler: &mut impl Handler<IO>,
706 15359 : msg: FeMessage,
707 15359 : unnamed_query_string: &mut Bytes,
708 15359 : ) -> Result<ProcessMsgResult, QueryError> {
709 15359 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
710 15359 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
711 15359 : assert!(self.state == ProtoState::Established);
712 :
713 15359 : match msg {
714 8744 : FeMessage::Query(body) => {
715 : // remove null terminator
716 8744 : let query_string = cstr_to_str(&body)?;
717 :
718 UBC 0 : trace!("got query {query_string:?}");
719 CBC 17023429 : if let Err(e) = handler.process_query(self, query_string).await {
720 155 : log_query_error(query_string, &e);
721 155 : let short_error = short_error(&e);
722 155 : self.write_message_noflush(&BeMessage::ErrorResponse(
723 155 : &short_error,
724 155 : Some(e.pg_error_code()),
725 155 : ))?;
726 8116 : }
727 8271 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
728 : }
729 :
730 640 : FeMessage::Parse(m) => {
731 640 : *unnamed_query_string = m.query_string;
732 640 : self.write_message_noflush(&BeMessage::ParseComplete)?;
733 : }
734 :
735 : FeMessage::Describe(_) => {
736 640 : self.write_message_noflush(&BeMessage::ParameterDescription)?
737 640 : .write_message_noflush(&BeMessage::NoData)?;
738 : }
739 :
740 : FeMessage::Bind(_) => {
741 640 : self.write_message_noflush(&BeMessage::BindComplete)?;
742 : }
743 :
744 : FeMessage::Close(_) => {
745 639 : self.write_message_noflush(&BeMessage::CloseComplete)?;
746 : }
747 :
748 : FeMessage::Execute(_) => {
749 640 : let query_string = cstr_to_str(unnamed_query_string)?;
750 UBC 0 : trace!("got execute {query_string:?}");
751 CBC 24118 : if let Err(e) = handler.process_query(self, query_string).await {
752 9 : log_query_error(query_string, &e);
753 9 : self.write_message_noflush(&BeMessage::ErrorResponse(
754 9 : &e.to_string(),
755 9 : Some(e.pg_error_code()),
756 9 : ))?;
757 631 : }
758 : // NOTE there is no ReadyForQuery message. This handler is used
759 : // for basebackup and it uses CopyOut which doesn't require
760 : // ReadyForQuery message and backend just switches back to
761 : // processing mode after sending CopyDone or ErrorResponse.
762 : }
763 :
764 : FeMessage::Sync => {
765 1927 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
766 : }
767 :
768 : FeMessage::Terminate => {
769 1430 : return Ok(ProcessMsgResult::Break);
770 : }
771 :
772 : // We prefer explicit pattern matching to wildcards, because
773 : // this helps us spot the places where new variants are missing
774 : FeMessage::CopyData(_)
775 : | FeMessage::CopyDone
776 : | FeMessage::CopyFail
777 : | FeMessage::PasswordMessage(_) => {
778 59 : return Err(QueryError::Other(anyhow::anyhow!(
779 59 : "unexpected message type: {msg:?}",
780 59 : )));
781 : }
782 : }
783 :
784 13397 : Ok(ProcessMsgResult::Continue)
785 14886 : }
786 :
787 : /// Log as info/error result of handling COPY stream and send back
788 : /// ErrorResponse if that makes sense. Shutdown the stream if we got
789 : /// Terminate. TODO: transition into waiting for Sync msg if we initiate the
790 : /// close.
791 2399 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
792 : use CopyStreamHandlerEnd::*;
793 :
794 2399 : let expected_end = match &end {
795 2051 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
796 347 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
797 347 : if is_expected_io_error(io_error) =>
798 347 : {
799 347 : true
800 : }
801 1 : _ => false,
802 : };
803 2399 : if expected_end {
804 2398 : info!("terminated: {:#}", end);
805 : } else {
806 1 : error!("terminated: {:?}", end);
807 : }
808 :
809 : // Note: no current usages ever send this
810 2399 : if let CopyDone = &end {
811 UBC 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
812 0 : error!("failed to send CopyDone: {}", e);
813 0 : }
814 CBC 2399 : }
815 :
816 2399 : if let Terminate = &end {
817 95 : self.state = ProtoState::Closed;
818 2304 : }
819 :
820 2399 : let err_to_send_and_errcode = match &end {
821 57 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
822 1 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
823 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
824 : // PG walsender; evidently and per my docs reading client should
825 : // finish it with CopyDone). It is not a problem to recover from it
826 : // finishing the stream in both directions like we do, but note that
827 : // sync rust-postgres client (which we don't use anymore) hangs if
828 : // socket is not closed here.
829 : // https://github.com/sfackler/rust-postgres/issues/755
830 : // https://github.com/neondatabase/neon/issues/935
831 : //
832 : // Currently, the version of tokio_postgres replication patch we use
833 : // sends this when it closes the stream (e.g. pageserver decided to
834 : // switch conn to another safekeeper and client gets dropped).
835 : // Moreover, seems like 'connection' task errors with 'unexpected
836 : // message from server' when it receives ErrorResponse (anything but
837 : // CopyData/CopyDone) back.
838 13 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
839 2328 : _ => None,
840 : };
841 2399 : if let Some((err, errcode)) = err_to_send_and_errcode {
842 71 : if let Err(ee) = self
843 71 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
844 UBC 0 : .await
845 : {
846 0 : error!("failed to send ErrorResponse: {}", ee);
847 CBC 71 : }
848 2328 : }
849 2399 : }
850 : }
851 :
852 : pub struct PostgresBackendReader<IO> {
853 : reader: FramedReader<MaybeTlsStream<IO>>,
854 : closed: bool, // true if received error closing the connection
855 : }
856 :
857 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
858 : /// Read full message or return None if connection is cleanly closed with no
859 : /// unprocessed data.
860 3700009 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
861 8011620 : match self.reader.read_message().await {
862 3699208 : Ok(m) => {
863 UBC 0 : trace!("read msg {:?}", m);
864 CBC 3699208 : Ok(m)
865 : }
866 334 : Err(e) => {
867 334 : self.closed = true;
868 334 : Err(e)
869 : }
870 : }
871 3699542 : }
872 :
873 : /// Get CopyData contents of the next message in COPY stream or error
874 : /// closing it. The error type is wider than actual errors which can happen
875 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
876 : /// current callers.
877 3700009 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
878 8011620 : match self.read_message().await? {
879 3697322 : Some(msg) => match msg {
880 3697214 : FeMessage::CopyData(m) => Ok(m),
881 UBC 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
882 CBC 13 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
883 95 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
884 UBC 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
885 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
886 0 : ))),
887 : },
888 CBC 1886 : None => Err(CopyStreamHandlerEnd::EOF),
889 : }
890 3699542 : }
891 : }
892 :
893 : ///
894 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
895 : /// messages.
896 : ///
897 :
898 : pub struct CopyDataWriter<'a, IO> {
899 : pgb: &'a mut PostgresBackend<IO>,
900 : }
901 :
902 : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
903 935770 : fn poll_write(
904 935770 : self: Pin<&mut Self>,
905 935770 : cx: &mut std::task::Context<'_>,
906 935770 : buf: &[u8],
907 935770 : ) -> Poll<Result<usize, std::io::Error>> {
908 935770 : let this = self.get_mut();
909 :
910 : // It's not strictly required to flush between each message, but makes it easier
911 : // to view in wireshark, and usually the messages that the callers write are
912 : // decently-sized anyway.
913 935770 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
914 UBC 0 : return Poll::Ready(Err(err));
915 CBC 928234 : }
916 928234 :
917 928234 : // CopyData
918 928234 : // XXX: if the input is large, we should split it into multiple messages.
919 928234 : // Not sure what the threshold should be, but the ultimate hard limit is that
920 928234 : // the length cannot exceed u32.
921 928234 : this.pgb
922 928234 : .write_message_noflush(&BeMessage::CopyData(buf))
923 928234 : // write_message only writes to the buffer, so it can fail iff the
924 928234 : // message is invaid, but CopyData can't be invalid.
925 928234 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
926 :
927 928234 : Poll::Ready(Ok(buf.len()))
928 935770 : }
929 :
930 36112 : fn poll_flush(
931 36112 : self: Pin<&mut Self>,
932 36112 : cx: &mut std::task::Context<'_>,
933 36112 : ) -> Poll<Result<(), std::io::Error>> {
934 36112 : let this = self.get_mut();
935 36112 : this.pgb.poll_flush(cx)
936 36112 : }
937 :
938 223 : fn poll_shutdown(
939 223 : self: Pin<&mut Self>,
940 223 : cx: &mut std::task::Context<'_>,
941 223 : ) -> Poll<Result<(), std::io::Error>> {
942 223 : let this = self.get_mut();
943 223 : this.pgb.poll_flush(cx)
944 223 : }
945 : }
946 :
947 155 : pub fn short_error(e: &QueryError) -> String {
948 155 : match e {
949 7 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
950 109 : QueryError::Shutdown => "shutdown".to_string(),
951 39 : QueryError::Other(e) => format!("{e:#}"),
952 : }
953 155 : }
954 :
955 : fn log_query_error(query: &str, e: &QueryError) {
956 7 : match e {
957 7 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
958 7 : if is_expected_io_error(io_error) {
959 7 : info!("query handler for '{query}' failed with expected io error: {io_error}");
960 : } else {
961 UBC 0 : error!("query handler for '{query}' failed with io error: {io_error}");
962 : }
963 : }
964 0 : QueryError::Disconnected(other_connection_error) => {
965 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
966 : }
967 : QueryError::Shutdown => {
968 CBC 109 : info!("query handler for '{query}' cancelled during tenant shutdown")
969 : }
970 48 : QueryError::Other(e) => {
971 48 : error!("query handler for '{query}' failed: {e:?}");
972 : }
973 : }
974 164 : }
975 :
976 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
977 : /// This is not always a real error, but it allows to use ? and thiserror impls.
978 2469 : #[derive(thiserror::Error, Debug)]
979 : pub enum CopyStreamHandlerEnd {
980 : /// Handler initiates the end of streaming.
981 : #[error("{0}")]
982 : ServerInitiated(String),
983 : #[error("received CopyDone")]
984 : CopyDone,
985 : #[error("received CopyFail")]
986 : CopyFail,
987 : #[error("received Terminate")]
988 : Terminate,
989 : #[error("EOF on COPY stream")]
990 : EOF,
991 : /// The connection was lost
992 : #[error("connection error: {0}")]
993 : Disconnected(#[from] ConnectionError),
994 : /// Some other error
995 : #[error(transparent)]
996 : Other(#[from] anyhow::Error),
997 : }
|