Line data Source code
1 : //! TLS support.
2 :
3 : use std::error::Error;
4 : use std::future::Future;
5 : use std::pin::Pin;
6 : use std::task::{Context, Poll};
7 : use std::{fmt, io};
8 :
9 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10 :
11 : pub(crate) mod private {
12 : pub struct ForcePrivateApi;
13 : }
14 :
15 : /// Channel binding information returned from a TLS handshake.
16 : pub struct ChannelBinding {
17 : pub(crate) tls_server_end_point: Option<Vec<u8>>,
18 : }
19 :
20 : impl ChannelBinding {
21 : /// Creates a `ChannelBinding` containing no information.
22 0 : pub fn none() -> ChannelBinding {
23 0 : ChannelBinding {
24 0 : tls_server_end_point: None,
25 0 : }
26 0 : }
27 :
28 : /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
29 12 : pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
30 12 : ChannelBinding {
31 12 : tls_server_end_point: Some(tls_server_end_point),
32 12 : }
33 12 : }
34 : }
35 :
36 : /// A constructor of `TlsConnect`ors.
37 : ///
38 : /// Requires the `runtime` Cargo feature (enabled by default).
39 : pub trait MakeTlsConnect<S> {
40 : /// The stream type created by the `TlsConnect` implementation.
41 : type Stream: TlsStream + Unpin;
42 : /// The `TlsConnect` implementation created by this type.
43 : type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
44 : /// The error type returned by the `TlsConnect` implementation.
45 : type Error: Into<Box<dyn Error + Sync + Send>>;
46 :
47 : /// Creates a new `TlsConnect`or.
48 : ///
49 : /// The domain name is provided for certificate verification and SNI.
50 : fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
51 : }
52 :
53 : /// An asynchronous function wrapping a stream in a TLS session.
54 : pub trait TlsConnect<S> {
55 : /// The stream returned by the future.
56 : type Stream: TlsStream + Unpin;
57 : /// The error returned by the future.
58 : type Error: Into<Box<dyn Error + Sync + Send>>;
59 : /// The future returned by the connector.
60 : type Future: Future<Output = Result<Self::Stream, Self::Error>>;
61 :
62 : /// Returns a future performing a TLS handshake over the stream.
63 : fn connect(self, stream: S) -> Self::Future;
64 :
65 : #[doc(hidden)]
66 0 : fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
67 0 : true
68 0 : }
69 : }
70 :
71 : /// A TLS-wrapped connection to a PostgreSQL database.
72 : pub trait TlsStream: AsyncRead + AsyncWrite {
73 : /// Returns channel binding information for the session.
74 : fn channel_binding(&self) -> ChannelBinding;
75 : }
76 :
77 : /// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
78 : ///
79 : /// This can be used when `sslmode` is `none` or `prefer`.
80 : #[derive(Debug, Copy, Clone)]
81 : pub struct NoTls;
82 :
83 : impl<S> MakeTlsConnect<S> for NoTls {
84 : type Stream = NoTlsStream;
85 : type TlsConnect = NoTls;
86 : type Error = NoTlsError;
87 :
88 0 : fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
89 0 : Ok(NoTls)
90 0 : }
91 : }
92 :
93 : impl<S> TlsConnect<S> for NoTls {
94 : type Stream = NoTlsStream;
95 : type Error = NoTlsError;
96 : type Future = NoTlsFuture;
97 :
98 0 : fn connect(self, _: S) -> NoTlsFuture {
99 0 : NoTlsFuture(())
100 0 : }
101 :
102 1 : fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
103 1 : false
104 1 : }
105 : }
106 :
107 : /// The future returned by `NoTls`.
108 : pub struct NoTlsFuture(());
109 :
110 : impl Future for NoTlsFuture {
111 : type Output = Result<NoTlsStream, NoTlsError>;
112 :
113 0 : fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
114 0 : Poll::Ready(Err(NoTlsError(())))
115 0 : }
116 : }
117 :
118 : /// The TLS "stream" type produced by the `NoTls` connector.
119 : ///
120 : /// Since `NoTls` doesn't support TLS, this type is uninhabited.
121 : pub enum NoTlsStream {}
122 :
123 : impl AsyncRead for NoTlsStream {
124 0 : fn poll_read(
125 0 : self: Pin<&mut Self>,
126 0 : _: &mut Context<'_>,
127 0 : _: &mut ReadBuf<'_>,
128 0 : ) -> Poll<io::Result<()>> {
129 0 : match *self {}
130 : }
131 : }
132 :
133 : impl AsyncWrite for NoTlsStream {
134 0 : fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
135 0 : match *self {}
136 : }
137 :
138 0 : fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
139 0 : match *self {}
140 : }
141 :
142 0 : fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
143 0 : match *self {}
144 : }
145 : }
146 :
147 : impl TlsStream for NoTlsStream {
148 : fn channel_binding(&self) -> ChannelBinding {
149 : match *self {}
150 : }
151 : }
152 :
153 : /// The error returned by `NoTls`.
154 : #[derive(Debug)]
155 : pub struct NoTlsError(());
156 :
157 : impl fmt::Display for NoTlsError {
158 0 : fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
159 0 : fmt.write_str("no TLS implementation configured")
160 0 : }
161 : }
162 :
163 : impl Error for NoTlsError {}
|