Line data Source code
1 : //! TLS support.
2 :
3 : use std::error::Error;
4 : use std::future::Future;
5 : use std::pin::Pin;
6 : use std::task::{Context, Poll};
7 : use std::{fmt, io};
8 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9 :
10 : pub(crate) mod private {
11 : pub struct ForcePrivateApi;
12 : }
13 :
14 : /// Channel binding information returned from a TLS handshake.
15 : pub struct ChannelBinding {
16 : pub(crate) tls_server_end_point: Option<Vec<u8>>,
17 : }
18 :
19 : impl ChannelBinding {
20 : /// Creates a `ChannelBinding` containing no information.
21 0 : pub fn none() -> ChannelBinding {
22 0 : ChannelBinding {
23 0 : tls_server_end_point: None,
24 0 : }
25 0 : }
26 :
27 : /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
28 12 : pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
29 12 : ChannelBinding {
30 12 : tls_server_end_point: Some(tls_server_end_point),
31 12 : }
32 12 : }
33 : }
34 :
35 : /// A constructor of `TlsConnect`ors.
36 : ///
37 : /// Requires the `runtime` Cargo feature (enabled by default).
38 : pub trait MakeTlsConnect<S> {
39 : /// The stream type created by the `TlsConnect` implementation.
40 : type Stream: TlsStream + Unpin;
41 : /// The `TlsConnect` implementation created by this type.
42 : type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
43 : /// The error type returned by the `TlsConnect` implementation.
44 : type Error: Into<Box<dyn Error + Sync + Send>>;
45 :
46 : /// Creates a new `TlsConnect`or.
47 : ///
48 : /// The domain name is provided for certificate verification and SNI.
49 : fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
50 : }
51 :
52 : /// An asynchronous function wrapping a stream in a TLS session.
53 : pub trait TlsConnect<S> {
54 : /// The stream returned by the future.
55 : type Stream: TlsStream + Unpin;
56 : /// The error returned by the future.
57 : type Error: Into<Box<dyn Error + Sync + Send>>;
58 : /// The future returned by the connector.
59 : type Future: Future<Output = Result<Self::Stream, Self::Error>>;
60 :
61 : /// Returns a future performing a TLS handshake over the stream.
62 : fn connect(self, stream: S) -> Self::Future;
63 :
64 : #[doc(hidden)]
65 0 : fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
66 0 : true
67 0 : }
68 : }
69 :
70 : /// A TLS-wrapped connection to a PostgreSQL database.
71 : pub trait TlsStream: AsyncRead + AsyncWrite {
72 : /// Returns channel binding information for the session.
73 : fn channel_binding(&self) -> ChannelBinding;
74 : }
75 :
76 : /// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
77 : ///
78 : /// This can be used when `sslmode` is `none` or `prefer`.
79 : #[derive(Debug, Copy, Clone)]
80 : pub struct NoTls;
81 :
82 : impl<S> MakeTlsConnect<S> for NoTls {
83 : type Stream = NoTlsStream;
84 : type TlsConnect = NoTls;
85 : type Error = NoTlsError;
86 :
87 0 : fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
88 0 : Ok(NoTls)
89 0 : }
90 : }
91 :
92 : impl<S> TlsConnect<S> for NoTls {
93 : type Stream = NoTlsStream;
94 : type Error = NoTlsError;
95 : type Future = NoTlsFuture;
96 :
97 0 : fn connect(self, _: S) -> NoTlsFuture {
98 0 : NoTlsFuture(())
99 0 : }
100 :
101 1 : fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
102 1 : false
103 1 : }
104 : }
105 :
106 : /// The future returned by `NoTls`.
107 : pub struct NoTlsFuture(());
108 :
109 : impl Future for NoTlsFuture {
110 : type Output = Result<NoTlsStream, NoTlsError>;
111 :
112 0 : fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
113 0 : Poll::Ready(Err(NoTlsError(())))
114 0 : }
115 : }
116 :
117 : /// The TLS "stream" type produced by the `NoTls` connector.
118 : ///
119 : /// Since `NoTls` doesn't support TLS, this type is uninhabited.
120 : pub enum NoTlsStream {}
121 :
122 : impl AsyncRead for NoTlsStream {
123 0 : fn poll_read(
124 0 : self: Pin<&mut Self>,
125 0 : _: &mut Context<'_>,
126 0 : _: &mut ReadBuf<'_>,
127 0 : ) -> Poll<io::Result<()>> {
128 0 : match *self {}
129 : }
130 : }
131 :
132 : impl AsyncWrite for NoTlsStream {
133 0 : fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
134 0 : match *self {}
135 : }
136 :
137 0 : fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
138 0 : match *self {}
139 : }
140 :
141 0 : fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
142 0 : match *self {}
143 : }
144 : }
145 :
146 : impl TlsStream for NoTlsStream {
147 : fn channel_binding(&self) -> ChannelBinding {
148 : match *self {}
149 : }
150 : }
151 :
152 : /// The error returned by `NoTls`.
153 : #[derive(Debug)]
154 : pub struct NoTlsError(());
155 :
156 : impl fmt::Display for NoTlsError {
157 0 : fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
158 0 : fmt.write_str("no TLS implementation configured")
159 0 : }
160 : }
161 :
162 : impl Error for NoTlsError {}
|