Line data Source code
1 : //! Server-side asynchronous Postgres connection, as limited as we need.
2 : //! To use, create PostgresBackend and run() it, passing the Handler
3 : //! implementation determining how to process the queries. Currently its API
4 : //! is rather narrow, but we can extend it once required.
5 : use anyhow::Context;
6 : use bytes::Bytes;
7 : use futures::pin_mut;
8 : use serde::{Deserialize, Serialize};
9 : use std::io::ErrorKind;
10 : use std::net::SocketAddr;
11 : use std::pin::Pin;
12 : use std::sync::Arc;
13 : use std::task::{ready, Poll};
14 : use std::{fmt, io};
15 : use std::{future::Future, str::FromStr};
16 : use tokio::io::{AsyncRead, AsyncWrite};
17 : use tokio_rustls::TlsAcceptor;
18 : use tracing::{debug, error, info, trace};
19 :
20 : use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
21 : use pq_proto::{
22 : BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
23 : SQLSTATE_SUCCESSFUL_COMPLETION,
24 : };
25 :
26 : /// An error, occurred during query processing:
27 : /// either during the connection ([`ConnectionError`]) or before/after it.
28 502 : #[derive(thiserror::Error, Debug)]
29 : pub enum QueryError {
30 : /// The connection was lost while processing the query.
31 : #[error(transparent)]
32 : Disconnected(#[from] ConnectionError),
33 : /// Some other error
34 : #[error(transparent)]
35 : Other(#[from] anyhow::Error),
36 : }
37 :
38 : impl From<io::Error> for QueryError {
39 327 : fn from(e: io::Error) -> Self {
40 327 : Self::Disconnected(ConnectionError::Io(e))
41 327 : }
42 : }
43 :
44 : impl QueryError {
45 91 : pub fn pg_error_code(&self) -> &'static [u8; 5] {
46 91 : match self {
47 6 : Self::Disconnected(_) => b"08006", // connection failure
48 85 : Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
49 : }
50 91 : }
51 : }
52 :
53 : /// Returns true if the given error is a normal consequence of a network issue,
54 : /// or the client closing the connection. These errors can happen during normal
55 : /// operations, and don't indicate a bug in our code.
56 804 : pub fn is_expected_io_error(e: &io::Error) -> bool {
57 : use io::ErrorKind::*;
58 0 : matches!(
59 804 : e.kind(),
60 : BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
61 : )
62 804 : }
63 :
64 : #[async_trait::async_trait]
65 : pub trait Handler<IO> {
66 : /// Handle single query.
67 : /// postgres_backend will issue ReadyForQuery after calling this (this
68 : /// might be not what we want after CopyData streaming, but currently we don't
69 : /// care). It will also flush out the output buffer.
70 : async fn process_query(
71 : &mut self,
72 : pgb: &mut PostgresBackend<IO>,
73 : query_string: &str,
74 : ) -> Result<(), QueryError>;
75 :
76 : /// Called on startup packet receival, allows to process params.
77 : ///
78 : /// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
79 : /// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
80 : /// to override whole init logic in implementations.
81 5 : fn startup(
82 5 : &mut self,
83 5 : _pgb: &mut PostgresBackend<IO>,
84 5 : _sm: &FeStartupPacket,
85 5 : ) -> Result<(), QueryError> {
86 5 : Ok(())
87 5 : }
88 :
89 : /// Check auth jwt
90 0 : fn check_auth_jwt(
91 0 : &mut self,
92 0 : _pgb: &mut PostgresBackend<IO>,
93 0 : _jwt_response: &[u8],
94 0 : ) -> Result<(), QueryError> {
95 0 : Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
96 0 : }
97 : }
98 :
99 : /// PostgresBackend protocol state.
100 : /// XXX: The order of the constructors matters.
101 40899 : #[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
102 : pub enum ProtoState {
103 : /// Nothing happened yet.
104 : Initialization,
105 : /// Encryption handshake is done; waiting for encrypted Startup message.
106 : Encrypted,
107 : /// Waiting for password (auth token).
108 : Authentication,
109 : /// Performed handshake and auth, ReadyForQuery is issued.
110 : Established,
111 : Closed,
112 : }
113 :
114 0 : #[derive(Clone, Copy)]
115 : pub enum ProcessMsgResult {
116 : Continue,
117 : Break,
118 : }
119 :
120 : /// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
121 : pub enum MaybeTlsStream<IO> {
122 : Unencrypted(IO),
123 : Tls(Box<tokio_rustls::server::TlsStream<IO>>),
124 : }
125 :
126 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
127 8640428 : fn poll_write(
128 8640428 : self: Pin<&mut Self>,
129 8640428 : cx: &mut std::task::Context<'_>,
130 8640428 : buf: &[u8],
131 8640428 : ) -> Poll<io::Result<usize>> {
132 8640428 : match self.get_mut() {
133 8640426 : Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
134 2 : Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
135 : }
136 8640428 : }
137 8654170 : fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
138 8654170 : match self.get_mut() {
139 8654167 : Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
140 3 : Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
141 : }
142 8654170 : }
143 8294 : fn poll_shutdown(
144 8294 : self: Pin<&mut Self>,
145 8294 : cx: &mut std::task::Context<'_>,
146 8294 : ) -> Poll<io::Result<()>> {
147 8294 : match self.get_mut() {
148 8294 : Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
149 0 : Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
150 : }
151 8294 : }
152 : }
153 : impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
154 18994316 : fn poll_read(
155 18994316 : self: Pin<&mut Self>,
156 18994316 : cx: &mut std::task::Context<'_>,
157 18994316 : buf: &mut tokio::io::ReadBuf<'_>,
158 18994316 : ) -> Poll<io::Result<()>> {
159 18994316 : match self.get_mut() {
160 18994311 : Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
161 5 : Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
162 : }
163 18994316 : }
164 : }
165 :
166 39718 : #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
167 : pub enum AuthType {
168 : Trust,
169 : // This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
170 : NeonJWT,
171 : }
172 :
173 : impl FromStr for AuthType {
174 : type Err = anyhow::Error;
175 :
176 1888 : fn from_str(s: &str) -> Result<Self, Self::Err> {
177 1888 : match s {
178 1888 : "Trust" => Ok(Self::Trust),
179 36 : "NeonJWT" => Ok(Self::NeonJWT),
180 0 : _ => anyhow::bail!("invalid value \"{s}\" for auth type"),
181 : }
182 1888 : }
183 : }
184 :
185 : impl fmt::Display for AuthType {
186 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 3038 : f.write_str(match self {
188 2984 : AuthType::Trust => "Trust",
189 54 : AuthType::NeonJWT => "NeonJWT",
190 : })
191 3038 : }
192 : }
193 :
194 : /// Either full duplex Framed or write only half; the latter is left in
195 : /// PostgresBackend after call to `split`. In principle we could always store a
196 : /// pair of splitted handles, but that would force to to pay splitting price
197 : /// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
198 : enum MaybeWriteOnly<IO> {
199 : Full(Framed<MaybeTlsStream<IO>>),
200 : WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
201 : Broken, // temporary value palmed off during the split
202 : }
203 :
204 : impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
205 15920 : async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
206 15920 : match self {
207 15920 : MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
208 : MaybeWriteOnly::WriteOnly(_) => {
209 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
210 : }
211 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
212 : }
213 15920 : }
214 :
215 4653304 : async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
216 4653304 : match self {
217 4653304 : MaybeWriteOnly::Full(framed) => framed.read_message().await,
218 : MaybeWriteOnly::WriteOnly(_) => {
219 0 : Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
220 : }
221 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
222 : }
223 4653177 : }
224 :
225 8635171 : fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
226 8635171 : match self {
227 5688444 : MaybeWriteOnly::Full(framed) => framed.write_message(msg),
228 2946727 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
229 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
230 : }
231 8635171 : }
232 :
233 8654851 : async fn flush(&mut self) -> io::Result<()> {
234 8654851 : match self {
235 5708124 : MaybeWriteOnly::Full(framed) => framed.flush().await,
236 2946727 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
237 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
238 : }
239 8646209 : }
240 :
241 8620 : async fn shutdown(&mut self) -> io::Result<()> {
242 8620 : match self {
243 8620 : MaybeWriteOnly::Full(framed) => framed.shutdown().await,
244 0 : MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
245 0 : MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
246 : }
247 8620 : }
248 : }
249 :
250 : pub struct PostgresBackend<IO> {
251 : framed: MaybeWriteOnly<IO>,
252 :
253 : pub state: ProtoState,
254 :
255 : auth_type: AuthType,
256 :
257 : peer_addr: SocketAddr,
258 : pub tls_config: Option<Arc<rustls::ServerConfig>>,
259 : }
260 :
261 : pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
262 :
263 0 : pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
264 0 : let mut query_string = query_string.to_vec();
265 0 : if let Some(ch) = query_string.last() {
266 0 : if *ch == 0 {
267 0 : query_string.pop();
268 0 : }
269 0 : }
270 0 : query_string
271 0 : }
272 :
273 : /// Cast a byte slice to a string slice, dropping null terminator if there's one.
274 9834 : fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
275 9834 : let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
276 9834 : std::str::from_utf8(without_null).map_err(|e| e.into())
277 9834 : }
278 :
279 : impl PostgresBackend<tokio::net::TcpStream> {
280 5 : pub fn new(
281 5 : socket: tokio::net::TcpStream,
282 5 : auth_type: AuthType,
283 5 : tls_config: Option<Arc<rustls::ServerConfig>>,
284 5 : ) -> io::Result<Self> {
285 5 : let peer_addr = socket.peer_addr()?;
286 5 : let stream = MaybeTlsStream::Unencrypted(socket);
287 5 :
288 5 : Ok(Self {
289 5 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
290 5 : state: ProtoState::Initialization,
291 5 : auth_type,
292 5 : tls_config,
293 5 : peer_addr,
294 5 : })
295 5 : }
296 : }
297 :
298 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
299 9054 : pub fn new_from_io(
300 9054 : socket: IO,
301 9054 : peer_addr: SocketAddr,
302 9054 : auth_type: AuthType,
303 9054 : tls_config: Option<Arc<rustls::ServerConfig>>,
304 9054 : ) -> io::Result<Self> {
305 9054 : let stream = MaybeTlsStream::Unencrypted(socket);
306 9054 :
307 9054 : Ok(Self {
308 9054 : framed: MaybeWriteOnly::Full(Framed::new(stream)),
309 9054 : state: ProtoState::Initialization,
310 9054 : auth_type,
311 9054 : tls_config,
312 9054 : peer_addr,
313 9054 : })
314 9054 : }
315 :
316 2935 : pub fn get_peer_addr(&self) -> &SocketAddr {
317 2935 : &self.peer_addr
318 2935 : }
319 :
320 : /// Read full message or return None if connection is cleanly closed with no
321 : /// unprocessed data.
322 4653256 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
323 4653153 : if let ProtoState::Closed = self.state {
324 65 : Ok(None)
325 : } else {
326 4653088 : match self.framed.read_message().await {
327 4652923 : Ok(m) => {
328 0 : trace!("read msg {:?}", m);
329 4652923 : Ok(m)
330 : }
331 38 : Err(e) => {
332 38 : // remember not to try to read anymore
333 38 : self.state = ProtoState::Closed;
334 38 : Err(e)
335 : }
336 : }
337 : }
338 4653026 : }
339 :
340 : /// Write message into internal output buffer, doesn't flush it. Technically
341 : /// error type can be only ProtocolError here (if, unlikely, serialization
342 : /// fails), but callers typically wrap it anyway.
343 : pub fn write_message_noflush(
344 : &mut self,
345 : message: &BeMessage<'_>,
346 : ) -> Result<&mut Self, ConnectionError> {
347 8635171 : self.framed.write_message_noflush(message)?;
348 8635171 : trace!("wrote msg {:?}", message);
349 8635171 : Ok(self)
350 8635171 : }
351 :
352 : /// Flush output buffer into the socket.
353 8654851 : pub async fn flush(&mut self) -> io::Result<()> {
354 8654851 : self.framed.flush().await
355 8646209 : }
356 :
357 : /// Polling version of `flush()`, saves the caller need to pin.
358 1054034 : pub fn poll_flush(
359 1054034 : &mut self,
360 1054034 : cx: &mut std::task::Context<'_>,
361 1054034 : ) -> Poll<Result<(), std::io::Error>> {
362 1054034 : let flush_fut = self.flush();
363 1054034 : pin_mut!(flush_fut);
364 1054034 : flush_fut.poll(cx)
365 1054034 : }
366 :
367 : /// Write message into internal output buffer and flush it to the stream.
368 2965945 : pub async fn write_message(
369 2965945 : &mut self,
370 2965945 : message: &BeMessage<'_>,
371 2965945 : ) -> Result<&mut Self, ConnectionError> {
372 2965945 : self.write_message_noflush(message)?;
373 2965945 : self.flush().await?;
374 2965925 : Ok(self)
375 2965931 : }
376 :
377 : /// Returns an AsyncWrite implementation that wraps all the data written
378 : /// to it in CopyData messages, and writes them to the connection
379 : ///
380 : /// The caller is responsible for sending CopyOutResponse and CopyDone messages.
381 662 : pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
382 662 : CopyDataWriter { pgb: self }
383 662 : }
384 :
385 : /// Wrapper for run_message_loop() that shuts down socket when we are done
386 9059 : pub async fn run<F, S>(
387 9059 : mut self,
388 9059 : handler: &mut impl Handler<IO>,
389 9059 : shutdown_watcher: F,
390 9059 : ) -> Result<(), QueryError>
391 9059 : where
392 9059 : F: Fn() -> S,
393 9059 : S: Future,
394 9059 : {
395 20363098 : let ret = self.run_message_loop(handler, shutdown_watcher).await;
396 : // socket might be already closed, e.g. if previously received error,
397 : // so ignore result.
398 8620 : self.framed.shutdown().await.ok();
399 8620 : ret
400 8620 : }
401 :
402 9059 : async fn run_message_loop<F, S>(
403 9059 : &mut self,
404 9059 : handler: &mut impl Handler<IO>,
405 9059 : shutdown_watcher: F,
406 9059 : ) -> Result<(), QueryError>
407 9059 : where
408 9059 : F: Fn() -> S,
409 9059 : S: Future,
410 9059 : {
411 0 : trace!("postgres backend to {:?} started", self.peer_addr);
412 :
413 9059 : tokio::select!(
414 : biased;
415 :
416 : _ = shutdown_watcher() => {
417 : // We were requested to shut down.
418 0 : tracing::info!("shutdown request received during handshake");
419 : return Ok(())
420 : },
421 :
422 9059 : result = self.handshake(handler) => {
423 : // Handshake complete.
424 : result?;
425 : if self.state == ProtoState::Closed {
426 : return Ok(()); // EOF during handshake
427 : }
428 : }
429 : );
430 :
431 : // Authentication completed
432 9053 : let mut query_string = Bytes::new();
433 22765 : while let Some(msg) = tokio::select!(
434 : biased;
435 : _ = shutdown_watcher() => {
436 : // We were requested to shut down.
437 96 : tracing::info!("shutdown request received in run_message_loop");
438 : Ok(None)
439 : },
440 22667 : msg = self.read_message() => { msg },
441 33 : )? {
442 0 : trace!("got message {:?}", msg);
443 :
444 20334774 : let result = self.process_message(handler, msg, &mut query_string).await;
445 15694 : self.flush().await?;
446 15368 : match result? {
447 : ProcessMsgResult::Continue => {
448 13712 : self.flush().await?;
449 13712 : continue;
450 : }
451 1521 : ProcessMsgResult::Break => break,
452 : }
453 : }
454 :
455 0 : trace!("postgres backend to {:?} exited", self.peer_addr);
456 8120 : Ok(())
457 8620 : }
458 :
459 : /// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
460 1 : async fn tls_upgrade(
461 1 : src: MaybeTlsStream<IO>,
462 1 : tls_config: Arc<rustls::ServerConfig>,
463 1 : ) -> anyhow::Result<MaybeTlsStream<IO>> {
464 1 : match src {
465 1 : MaybeTlsStream::Unencrypted(s) => {
466 1 : let acceptor = TlsAcceptor::from(tls_config);
467 2 : let tls_stream = acceptor.accept(s).await?;
468 1 : Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
469 : }
470 : MaybeTlsStream::Tls(_) => {
471 0 : anyhow::bail!("TLS already started");
472 : }
473 : }
474 1 : }
475 :
476 1 : async fn start_tls(&mut self) -> anyhow::Result<()> {
477 1 : // temporary replace stream with fake to cook TLS one, Indiana Jones style
478 1 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
479 1 : MaybeWriteOnly::Full(framed) => {
480 1 : let tls_config = self
481 1 : .tls_config
482 1 : .as_ref()
483 1 : .context("start_tls called without conf")?
484 1 : .clone();
485 1 : let tls_framed = framed
486 1 : .map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
487 2 : .await?;
488 : // push back ready TLS stream
489 1 : self.framed = MaybeWriteOnly::Full(tls_framed);
490 1 : Ok(())
491 : }
492 : MaybeWriteOnly::WriteOnly(_) => {
493 0 : anyhow::bail!("TLS upgrade attempt in split state")
494 : }
495 0 : MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
496 : }
497 1 : }
498 :
499 : /// Split off owned read part from which messages can be read in different
500 : /// task/thread.
501 2935 : pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
502 2935 : // temporary replace stream with fake to cook split one, Indiana Jones style
503 2935 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
504 2935 : MaybeWriteOnly::Full(framed) => {
505 2935 : let (reader, writer) = framed.split();
506 2935 : self.framed = MaybeWriteOnly::WriteOnly(writer);
507 2935 : Ok(PostgresBackendReader {
508 2935 : reader,
509 2935 : closed: false,
510 2935 : })
511 : }
512 : MaybeWriteOnly::WriteOnly(_) => {
513 0 : anyhow::bail!("PostgresBackend is already split")
514 : }
515 0 : MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
516 : }
517 2935 : }
518 :
519 : /// Join read part back.
520 2565 : pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
521 2565 : // temporary replace stream with fake to cook joined one, Indiana Jones style
522 2565 : match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
523 : MaybeWriteOnly::Full(_) => {
524 0 : anyhow::bail!("PostgresBackend is not split")
525 : }
526 2565 : MaybeWriteOnly::WriteOnly(writer) => {
527 2565 : let joined = Framed::unsplit(reader.reader, writer);
528 2565 : self.framed = MaybeWriteOnly::Full(joined);
529 2565 : // if reader encountered connection error, do not attempt reading anymore
530 2565 : if reader.closed {
531 303 : self.state = ProtoState::Closed;
532 2262 : }
533 2565 : Ok(())
534 : }
535 0 : MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
536 : }
537 2565 : }
538 :
539 : /// Perform handshake with the client, transitioning to Established.
540 : /// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
541 9059 : async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
542 24979 : while self.state < ProtoState::Authentication {
543 15920 : match self.framed.read_startup_message().await? {
544 15920 : Some(msg) => {
545 15920 : self.process_startup_message(handler, msg).await?;
546 : }
547 : None => {
548 0 : trace!(
549 0 : "postgres backend to {:?} received EOF during handshake",
550 0 : self.peer_addr
551 0 : );
552 0 : self.state = ProtoState::Closed;
553 0 : return Ok(());
554 : }
555 : }
556 : }
557 :
558 : // Perform auth, if needed.
559 9059 : if self.state == ProtoState::Authentication {
560 216 : match self.framed.read_message().await? {
561 211 : Some(FeMessage::PasswordMessage(m)) => {
562 211 : assert!(self.auth_type == AuthType::NeonJWT);
563 :
564 211 : let (_, jwt_response) = m.split_last().context("protocol violation")?;
565 :
566 211 : if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
567 1 : self.write_message_noflush(&BeMessage::ErrorResponse(
568 1 : &e.to_string(),
569 1 : Some(e.pg_error_code()),
570 1 : ))?;
571 1 : return Err(e);
572 210 : }
573 210 :
574 210 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
575 210 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
576 210 : .write_message(&BeMessage::ReadyForQuery)
577 0 : .await?;
578 210 : self.state = ProtoState::Established;
579 : }
580 0 : Some(m) => {
581 0 : return Err(QueryError::Other(anyhow::anyhow!(
582 0 : "Unexpected message {:?} while waiting for handshake",
583 0 : m
584 0 : )));
585 : }
586 : None => {
587 0 : trace!(
588 0 : "postgres backend to {:?} received EOF during auth",
589 0 : self.peer_addr
590 0 : );
591 5 : self.state = ProtoState::Closed;
592 5 : return Ok(());
593 : }
594 : }
595 8843 : }
596 :
597 9053 : Ok(())
598 9059 : }
599 :
600 : /// Process startup packet:
601 : /// - transition to Established if auth type is trust
602 : /// - transition to Authentication if auth type is NeonJWT.
603 : /// - or perform TLS handshake -- then need to call this again to receive
604 : /// actual startup packet.
605 15920 : async fn process_startup_message(
606 15920 : &mut self,
607 15920 : handler: &mut impl Handler<IO>,
608 15920 : msg: FeStartupPacket,
609 15920 : ) -> Result<(), QueryError> {
610 15920 : assert!(self.state < ProtoState::Authentication);
611 15920 : let have_tls = self.tls_config.is_some();
612 15920 : match msg {
613 : FeStartupPacket::SslRequest => {
614 0 : debug!("SSL requested");
615 :
616 6861 : self.write_message(&BeMessage::EncryptionResponse(have_tls))
617 0 : .await?;
618 :
619 6861 : if have_tls {
620 2 : self.start_tls().await?;
621 1 : self.state = ProtoState::Encrypted;
622 6860 : }
623 : }
624 : FeStartupPacket::GssEncRequest => {
625 0 : debug!("GSS requested");
626 0 : self.write_message(&BeMessage::EncryptionResponse(false))
627 0 : .await?;
628 : }
629 : FeStartupPacket::StartupMessage { .. } => {
630 9059 : if have_tls && !matches!(self.state, ProtoState::Encrypted) {
631 0 : self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
632 0 : .await?;
633 0 : return Err(QueryError::Other(anyhow::anyhow!(
634 0 : "client did not connect with TLS"
635 0 : )));
636 9059 : }
637 9059 :
638 9059 : // NB: startup() may change self.auth_type -- we are using that in proxy code
639 9059 : // to bypass auth for new users.
640 9059 : handler.startup(self, &msg)?;
641 :
642 9059 : match self.auth_type {
643 : AuthType::Trust => {
644 8843 : self.write_message_noflush(&BeMessage::AuthenticationOk)?
645 8843 : .write_message_noflush(&BeMessage::CLIENT_ENCODING)?
646 8843 : .write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
647 : // The async python driver requires a valid server_version
648 8843 : .write_message_noflush(&BeMessage::server_version("14.1"))?
649 8843 : .write_message(&BeMessage::ReadyForQuery)
650 0 : .await?;
651 8843 : self.state = ProtoState::Established;
652 : }
653 : AuthType::NeonJWT => {
654 216 : self.write_message(&BeMessage::AuthenticationCleartextPassword)
655 0 : .await?;
656 216 : self.state = ProtoState::Authentication;
657 : }
658 : }
659 : }
660 : FeStartupPacket::CancelRequest { .. } => {
661 0 : return Err(QueryError::Other(anyhow::anyhow!(
662 0 : "Unexpected CancelRequest message during handshake"
663 0 : )));
664 : }
665 : }
666 15920 : Ok(())
667 15920 : }
668 :
669 16131 : async fn process_message(
670 16131 : &mut self,
671 16131 : handler: &mut impl Handler<IO>,
672 16131 : msg: FeMessage,
673 16131 : unnamed_query_string: &mut Bytes,
674 16131 : ) -> Result<ProcessMsgResult, QueryError> {
675 16131 : // Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
676 16131 : // TODO: change that to proper top-level match of protocol state with separate message handling for each state
677 16131 : assert!(self.state == ProtoState::Established);
678 :
679 16131 : match msg {
680 9172 : FeMessage::Query(body) => {
681 : // remove null terminator
682 9172 : let query_string = cstr_to_str(&body)?;
683 :
684 0 : trace!("got query {query_string:?}");
685 20320920 : if let Err(e) = handler.process_query(self, query_string).await {
686 79 : log_query_error(query_string, &e);
687 79 : let short_error = short_error(&e);
688 79 : self.write_message_noflush(&BeMessage::ErrorResponse(
689 79 : &short_error,
690 79 : Some(e.pg_error_code()),
691 79 : ))?;
692 8656 : }
693 8735 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
694 : }
695 :
696 662 : FeMessage::Parse(m) => {
697 662 : *unnamed_query_string = m.query_string;
698 662 : self.write_message_noflush(&BeMessage::ParseComplete)?;
699 : }
700 :
701 : FeMessage::Describe(_) => {
702 662 : self.write_message_noflush(&BeMessage::ParameterDescription)?
703 662 : .write_message_noflush(&BeMessage::NoData)?;
704 : }
705 :
706 : FeMessage::Bind(_) => {
707 662 : self.write_message_noflush(&BeMessage::BindComplete)?;
708 : }
709 :
710 : FeMessage::Close(_) => {
711 661 : self.write_message_noflush(&BeMessage::CloseComplete)?;
712 : }
713 :
714 : FeMessage::Execute(_) => {
715 662 : let query_string = cstr_to_str(unnamed_query_string)?;
716 0 : trace!("got execute {query_string:?}");
717 13854 : if let Err(e) = handler.process_query(self, query_string).await {
718 9 : log_query_error(query_string, &e);
719 9 : self.write_message_noflush(&BeMessage::ErrorResponse(
720 9 : &e.to_string(),
721 9 : Some(e.pg_error_code()),
722 9 : ))?;
723 653 : }
724 : // NOTE there is no ReadyForQuery message. This handler is used
725 : // for basebackup and it uses CopyOut which doesn't require
726 : // ReadyForQuery message and backend just switches back to
727 : // processing mode after sending CopyDone or ErrorResponse.
728 : }
729 :
730 : FeMessage::Sync => {
731 1994 : self.write_message_noflush(&BeMessage::ReadyForQuery)?;
732 : }
733 :
734 : FeMessage::Terminate => {
735 1521 : return Ok(ProcessMsgResult::Break);
736 : }
737 :
738 : // We prefer explicit pattern matching to wildcards, because
739 : // this helps us spot the places where new variants are missing
740 : FeMessage::CopyData(_)
741 : | FeMessage::CopyDone
742 : | FeMessage::CopyFail
743 : | FeMessage::PasswordMessage(_) => {
744 135 : return Err(QueryError::Other(anyhow::anyhow!(
745 135 : "unexpected message type: {msg:?}",
746 135 : )));
747 : }
748 : }
749 :
750 14038 : Ok(ProcessMsgResult::Continue)
751 15694 : }
752 :
753 : /// Log as info/error result of handling COPY stream and send back
754 : /// ErrorResponse if that makes sense. Shutdown the stream if we got
755 : /// Terminate. TODO: transition into waiting for Sync msg if we initiate the
756 : /// close.
757 2565 : pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
758 : use CopyStreamHandlerEnd::*;
759 :
760 2565 : let expected_end = match &end {
761 2255 : ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
762 309 : CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
763 309 : if is_expected_io_error(io_error) =>
764 309 : {
765 309 : true
766 : }
767 1 : _ => false,
768 : };
769 2565 : if expected_end {
770 2564 : info!("terminated: {:#}", end);
771 : } else {
772 1 : error!("terminated: {:?}", end);
773 : }
774 :
775 : // Note: no current usages ever send this
776 2565 : if let CopyDone = &end {
777 0 : if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
778 0 : error!("failed to send CopyDone: {}", e);
779 0 : }
780 2565 : }
781 :
782 2565 : if let Terminate = &end {
783 65 : self.state = ProtoState::Closed;
784 2500 : }
785 :
786 2565 : let err_to_send_and_errcode = match &end {
787 133 : ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
788 1 : Other(_) => Some((format!("{end:#}"), SQLSTATE_INTERNAL_ERROR)),
789 : // Note: CopyFail in duplex copy is somewhat unexpected (at least to
790 : // PG walsender; evidently and per my docs reading client should
791 : // finish it with CopyDone). It is not a problem to recover from it
792 : // finishing the stream in both directions like we do, but note that
793 : // sync rust-postgres client (which we don't use anymore) hangs if
794 : // socket is not closed here.
795 : // https://github.com/sfackler/rust-postgres/issues/755
796 : // https://github.com/neondatabase/neon/issues/935
797 : //
798 : // Currently, the version of tokio_postgres replication patch we use
799 : // sends this when it closes the stream (e.g. pageserver decided to
800 : // switch conn to another safekeeper and client gets dropped).
801 : // Moreover, seems like 'connection' task errors with 'unexpected
802 : // message from server' when it receives ErrorResponse (anything but
803 : // CopyData/CopyDone) back.
804 19 : CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
805 2412 : _ => None,
806 : };
807 2565 : if let Some((err, errcode)) = err_to_send_and_errcode {
808 153 : if let Err(ee) = self
809 153 : .write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
810 0 : .await
811 : {
812 0 : error!("failed to send ErrorResponse: {}", ee);
813 153 : }
814 2412 : }
815 2565 : }
816 : }
817 :
818 : pub struct PostgresBackendReader<IO> {
819 : reader: FramedReader<MaybeTlsStream<IO>>,
820 : closed: bool, // true if received error closing the connection
821 : }
822 :
823 : impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
824 : /// Read full message or return None if connection is cleanly closed with no
825 : /// unprocessed data.
826 3387810 : pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
827 7312479 : match self.reader.read_message().await {
828 3386998 : Ok(m) => {
829 0 : trace!("read msg {:?}", m);
830 3386998 : Ok(m)
831 : }
832 303 : Err(e) => {
833 303 : self.closed = true;
834 303 : Err(e)
835 : }
836 : }
837 3387301 : }
838 :
839 : /// Get CopyData contents of the next message in COPY stream or error
840 : /// closing it. The error type is wider than actual errors which can happen
841 : /// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
842 : /// current callers.
843 3387810 : pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
844 7312479 : match self.read_message().await? {
845 3384960 : Some(msg) => match msg {
846 3384876 : FeMessage::CopyData(m) => Ok(m),
847 0 : FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
848 19 : FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
849 65 : FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
850 0 : _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
851 0 : ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
852 0 : ))),
853 : },
854 2038 : None => Err(CopyStreamHandlerEnd::EOF),
855 : }
856 3387301 : }
857 : }
858 :
859 : ///
860 : /// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
861 : /// messages.
862 : ///
863 :
864 : pub struct CopyDataWriter<'a, IO> {
865 : pgb: &'a mut PostgresBackend<IO>,
866 : }
867 :
868 : impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
869 1017389 : fn poll_write(
870 1017389 : self: Pin<&mut Self>,
871 1017389 : cx: &mut std::task::Context<'_>,
872 1017389 : buf: &[u8],
873 1017389 : ) -> Poll<Result<usize, std::io::Error>> {
874 1017389 : let this = self.get_mut();
875 :
876 : // It's not strictly required to flush between each message, but makes it easier
877 : // to view in wireshark, and usually the messages that the callers write are
878 : // decently-sized anyway.
879 1017389 : if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
880 0 : return Poll::Ready(Err(err));
881 1008955 : }
882 1008955 :
883 1008955 : // CopyData
884 1008955 : // XXX: if the input is large, we should split it into multiple messages.
885 1008955 : // Not sure what the threshold should be, but the ultimate hard limit is that
886 1008955 : // the length cannot exceed u32.
887 1008955 : this.pgb
888 1008955 : .write_message_noflush(&BeMessage::CopyData(buf))
889 1008955 : // write_message only writes to the buffer, so it can fail iff the
890 1008955 : // message is invaid, but CopyData can't be invalid.
891 1008955 : .map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
892 :
893 1008955 : Poll::Ready(Ok(buf.len()))
894 1017389 : }
895 :
896 36440 : fn poll_flush(
897 36440 : self: Pin<&mut Self>,
898 36440 : cx: &mut std::task::Context<'_>,
899 36440 : ) -> Poll<Result<(), std::io::Error>> {
900 36440 : let this = self.get_mut();
901 36440 : this.pgb.poll_flush(cx)
902 36440 : }
903 :
904 205 : fn poll_shutdown(
905 205 : self: Pin<&mut Self>,
906 205 : cx: &mut std::task::Context<'_>,
907 205 : ) -> Poll<Result<(), std::io::Error>> {
908 205 : let this = self.get_mut();
909 205 : this.pgb.poll_flush(cx)
910 205 : }
911 : }
912 :
913 79 : pub fn short_error(e: &QueryError) -> String {
914 79 : match e {
915 6 : QueryError::Disconnected(connection_error) => connection_error.to_string(),
916 73 : QueryError::Other(e) => format!("{e:#}"),
917 : }
918 79 : }
919 :
920 : fn log_query_error(query: &str, e: &QueryError) {
921 6 : match e {
922 6 : QueryError::Disconnected(ConnectionError::Io(io_error)) => {
923 6 : if is_expected_io_error(io_error) {
924 6 : info!("query handler for '{query}' failed with expected io error: {io_error}");
925 : } else {
926 0 : error!("query handler for '{query}' failed with io error: {io_error}");
927 : }
928 : }
929 0 : QueryError::Disconnected(other_connection_error) => {
930 0 : error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
931 : }
932 82 : QueryError::Other(e) => {
933 82 : error!("query handler for '{query}' failed: {e:?}");
934 : }
935 : }
936 88 : }
937 :
938 : /// Something finishing handling of COPY stream, see handle_copy_stream_end.
939 : /// This is not always a real error, but it allows to use ? and thiserror impls.
940 2717 : #[derive(thiserror::Error, Debug)]
941 : pub enum CopyStreamHandlerEnd {
942 : /// Handler initiates the end of streaming.
943 : #[error("{0}")]
944 : ServerInitiated(String),
945 : #[error("received CopyDone")]
946 : CopyDone,
947 : #[error("received CopyFail")]
948 : CopyFail,
949 : #[error("received Terminate")]
950 : Terminate,
951 : #[error("EOF on COPY stream")]
952 : EOF,
953 : /// The connection was lost
954 : #[error("connection error: {0}")]
955 : Disconnected(#[from] ConnectionError),
956 : /// Some other error
957 : #[error(transparent)]
958 : Other(#[from] anyhow::Error),
959 : }
|