Line data Source code
1 : use futures::FutureExt;
2 : use postgres_client::config::SslMode;
3 : use postgres_client::maybe_tls_stream::MaybeTlsStream;
4 : use postgres_client::tls::{MakeTlsConnect, TlsConnect};
5 : use rustls::pki_types::InvalidDnsNameError;
6 : use thiserror::Error;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 :
9 : use crate::pqproto::request_tls;
10 : use crate::proxy::retry::CouldRetry;
11 :
12 : #[derive(Debug, Error)]
13 : pub enum TlsError {
14 : #[error(transparent)]
15 : Dns(#[from] InvalidDnsNameError),
16 : #[error(transparent)]
17 : Connection(#[from] std::io::Error),
18 : #[error("TLS required but not provided")]
19 : Required,
20 : }
21 :
22 : impl CouldRetry for TlsError {
23 0 : fn could_retry(&self) -> bool {
24 0 : match self {
25 0 : TlsError::Dns(_) => false,
26 0 : TlsError::Connection(err) => err.could_retry(),
27 : // perhaps compute didn't realise it supports TLS?
28 0 : TlsError::Required => true,
29 : }
30 0 : }
31 : }
32 :
33 0 : pub async fn connect_tls<S, T>(
34 0 : mut stream: S,
35 0 : mode: SslMode,
36 0 : tls: &T,
37 0 : host: &str,
38 0 : ) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
39 0 : where
40 0 : S: AsyncRead + AsyncWrite + Unpin + Send,
41 0 : T: MakeTlsConnect<
42 0 : S,
43 0 : Error = InvalidDnsNameError,
44 0 : TlsConnect: TlsConnect<S, Error = std::io::Error, Future: Send>,
45 0 : >,
46 0 : {
47 0 : match mode {
48 0 : SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
49 0 : SslMode::Prefer | SslMode::Require => {}
50 : }
51 :
52 0 : if !request_tls(&mut stream).await? {
53 0 : if SslMode::Require == mode {
54 0 : return Err(TlsError::Required);
55 0 : }
56 :
57 0 : return Ok(MaybeTlsStream::Raw(stream));
58 0 : }
59 :
60 : Ok(MaybeTlsStream::Tls(
61 0 : tls.make_tls_connect(host)?.connect(stream).boxed().await?,
62 : ))
63 0 : }
|