Line data Source code
1 : use std::convert::TryFrom;
2 : use std::sync::Arc;
3 :
4 : use postgres_client::tls::MakeTlsConnect;
5 : use rustls::pki_types::ServerName;
6 : use rustls::ClientConfig;
7 : use tokio::io::{AsyncRead, AsyncWrite};
8 :
9 : mod private {
10 : use std::future::Future;
11 : use std::io;
12 : use std::pin::Pin;
13 : use std::task::{Context, Poll};
14 :
15 : use postgres_client::tls::{ChannelBinding, TlsConnect};
16 : use rustls::pki_types::ServerName;
17 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18 : use tokio_rustls::client::TlsStream;
19 : use tokio_rustls::TlsConnector;
20 :
21 : use crate::tls::TlsServerEndPoint;
22 :
23 : pub struct TlsConnectFuture<S> {
24 : inner: tokio_rustls::Connect<S>,
25 : }
26 :
27 : impl<S> Future for TlsConnectFuture<S>
28 : where
29 : S: AsyncRead + AsyncWrite + Unpin,
30 : {
31 : type Output = io::Result<RustlsStream<S>>;
32 :
33 40 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
34 40 : Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
35 40 : }
36 : }
37 :
38 : pub struct RustlsConnect(pub RustlsConnectData);
39 :
40 : pub struct RustlsConnectData {
41 : pub hostname: ServerName<'static>,
42 : pub connector: TlsConnector,
43 : }
44 :
45 : impl<S> TlsConnect<S> for RustlsConnect
46 : where
47 : S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
48 : {
49 : type Stream = RustlsStream<S>;
50 : type Error = io::Error;
51 : type Future = TlsConnectFuture<S>;
52 :
53 20 : fn connect(self, stream: S) -> Self::Future {
54 20 : TlsConnectFuture {
55 20 : inner: self.0.connector.connect(self.0.hostname, stream),
56 20 : }
57 20 : }
58 : }
59 :
60 : pub struct RustlsStream<S>(TlsStream<S>);
61 :
62 : impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
63 : where
64 : S: AsyncRead + AsyncWrite + Unpin,
65 : {
66 12 : fn channel_binding(&self) -> ChannelBinding {
67 12 : let (_, session) = self.0.get_ref();
68 12 : match session.peer_certificates() {
69 12 : Some([cert, ..]) => TlsServerEndPoint::new(cert)
70 12 : .ok()
71 12 : .and_then(|cb| match cb {
72 12 : TlsServerEndPoint::Sha256(hash) => Some(hash),
73 0 : TlsServerEndPoint::Undefined => None,
74 12 : })
75 12 : .map_or_else(ChannelBinding::none, |hash| {
76 12 : ChannelBinding::tls_server_end_point(hash.to_vec())
77 12 : }),
78 0 : _ => ChannelBinding::none(),
79 : }
80 12 : }
81 : }
82 :
83 : impl<S> AsyncRead for RustlsStream<S>
84 : where
85 : S: AsyncRead + AsyncWrite + Unpin,
86 : {
87 122 : fn poll_read(
88 122 : mut self: Pin<&mut Self>,
89 122 : cx: &mut Context<'_>,
90 122 : buf: &mut ReadBuf<'_>,
91 122 : ) -> Poll<tokio::io::Result<()>> {
92 122 : Pin::new(&mut self.0).poll_read(cx, buf)
93 122 : }
94 : }
95 :
96 : impl<S> AsyncWrite for RustlsStream<S>
97 : where
98 : S: AsyncRead + AsyncWrite + Unpin,
99 : {
100 52 : fn poll_write(
101 52 : mut self: Pin<&mut Self>,
102 52 : cx: &mut Context<'_>,
103 52 : buf: &[u8],
104 52 : ) -> Poll<tokio::io::Result<usize>> {
105 52 : Pin::new(&mut self.0).poll_write(cx, buf)
106 52 : }
107 :
108 52 : fn poll_flush(
109 52 : mut self: Pin<&mut Self>,
110 52 : cx: &mut Context<'_>,
111 52 : ) -> Poll<tokio::io::Result<()>> {
112 52 : Pin::new(&mut self.0).poll_flush(cx)
113 52 : }
114 :
115 0 : fn poll_shutdown(
116 0 : mut self: Pin<&mut Self>,
117 0 : cx: &mut Context<'_>,
118 0 : ) -> Poll<tokio::io::Result<()>> {
119 0 : Pin::new(&mut self.0).poll_shutdown(cx)
120 0 : }
121 : }
122 : }
123 :
124 : /// A `MakeTlsConnect` implementation using `rustls`.
125 : ///
126 : /// That way you can connect to PostgreSQL using `rustls` as the TLS stack.
127 : #[derive(Clone)]
128 : pub struct MakeRustlsConnect {
129 : pub config: Arc<ClientConfig>,
130 : }
131 :
132 : impl MakeRustlsConnect {
133 : /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`.
134 : #[must_use]
135 20 : pub fn new(config: Arc<ClientConfig>) -> Self {
136 20 : Self { config }
137 20 : }
138 : }
139 :
140 : impl<S> MakeTlsConnect<S> for MakeRustlsConnect
141 : where
142 : S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
143 : {
144 : type Stream = private::RustlsStream<S>;
145 : type TlsConnect = private::RustlsConnect;
146 : type Error = rustls::pki_types::InvalidDnsNameError;
147 :
148 20 : fn make_tls_connect(&mut self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
149 20 : ServerName::try_from(hostname).map(|dns_name| {
150 20 : private::RustlsConnect(private::RustlsConnectData {
151 20 : hostname: dns_name.to_owned(),
152 20 : connector: Arc::clone(&self.config).into(),
153 20 : })
154 20 : })
155 20 : }
156 : }
|