Line data Source code
1 : use futures::FutureExt;
2 : use postgres_client::config::SslMode;
3 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
4 : use postgres_client::tls::{MakeTlsConnect, TlsConnect};
5 : use rustls::pki_types::InvalidDnsNameError;
6 : use thiserror::Error;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 :
9 : use crate::pqproto::request_tls;
10 : use crate::proxy::connect_compute::TlsNegotiation;
11 : use crate::proxy::retry::CouldRetry;
12 :
13 : #[derive(Debug, Error)]
14 : pub enum TlsError {
15 : #[error(transparent)]
16 : Dns(#[from] InvalidDnsNameError),
17 : #[error(transparent)]
18 : Connection(#[from] std::io::Error),
19 : #[error("TLS required but not provided")]
20 : Required,
21 : }
22 :
23 : impl CouldRetry for TlsError {
24 0 : fn could_retry(&self) -> bool {
25 0 : match self {
26 0 : TlsError::Dns(_) => false,
27 0 : TlsError::Connection(err) => err.could_retry(),
28 : // perhaps compute didn't realise it supports TLS?
29 0 : TlsError::Required => true,
30 : }
31 0 : }
32 : }
33 :
34 0 : pub async fn connect_tls<S, T>(
35 0 : mut stream: S,
36 0 : mode: SslMode,
37 0 : tls: &T,
38 0 : host: &str,
39 0 : negotiation: TlsNegotiation,
40 0 : ) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
41 0 : where
42 0 : S: AsyncRead + AsyncWrite + Unpin + Send,
43 0 : T: MakeTlsConnect<
44 0 : S,
45 0 : Error = InvalidDnsNameError,
46 0 : TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
47 0 : >,
48 0 : {
49 0 : match mode {
50 0 : SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
51 0 : SslMode::Prefer | SslMode::Require => {}
52 : }
53 :
54 0 : match negotiation {
55 : // No TLS request needed
56 0 : TlsNegotiation::Direct => {}
57 : // TLS request successful
58 0 : TlsNegotiation::Postgres if request_tls(&mut stream).await? => {}
59 : // TLS request failed but is required
60 0 : TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required),
61 : // TLS request failed but is not required
62 0 : TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)),
63 : }
64 :
65 : Ok(MaybeTlsStream::Tls(
66 0 : tls.make_tls_connect(host)?.connect(stream).boxed().await?,
67 : ))
68 0 : }
|