Line data Source code
1 : use std::convert::TryFrom;
2 : use std::sync::Arc;
3 :
4 : use postgres_client::tls::MakeTlsConnect;
5 : use rustls::pki_types::{InvalidDnsNameError, ServerName};
6 : use tokio::io::{AsyncRead, AsyncWrite};
7 :
8 : use crate::config::ComputeConfig;
9 :
10 : mod private {
11 : use std::future::Future;
12 : use std::io;
13 : use std::pin::Pin;
14 : use std::task::{Context, Poll};
15 :
16 : use postgres_client::tls::{ChannelBinding, TlsConnect};
17 : use rustls::pki_types::ServerName;
18 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
19 : use tokio_rustls::TlsConnector;
20 : use tokio_rustls::client::TlsStream;
21 :
22 : use crate::tls::TlsServerEndPoint;
23 :
24 : pub struct TlsConnectFuture<S> {
25 : inner: tokio_rustls::Connect<S>,
26 : }
27 :
28 : impl<S> Future for TlsConnectFuture<S>
29 : where
30 : S: AsyncRead + AsyncWrite + Unpin,
31 : {
32 : type Output = io::Result<RustlsStream<S>>;
33 :
34 40 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
35 40 : Pin::new(&mut self.inner)
36 40 : .poll(cx)
37 40 : .map_ok(|s| RustlsStream(Box::new(s)))
38 40 : }
39 : }
40 :
41 : pub struct RustlsConnect(pub RustlsConnectData);
42 :
43 : pub struct RustlsConnectData {
44 : pub hostname: ServerName<'static>,
45 : pub connector: TlsConnector,
46 : }
47 :
48 : impl<S> TlsConnect<S> for RustlsConnect
49 : where
50 : S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
51 : {
52 : type Stream = RustlsStream<S>;
53 : type Error = io::Error;
54 : type Future = TlsConnectFuture<S>;
55 :
56 20 : fn connect(self, stream: S) -> Self::Future {
57 20 : TlsConnectFuture {
58 20 : inner: self.0.connector.connect(self.0.hostname, stream),
59 20 : }
60 20 : }
61 : }
62 :
63 : pub struct RustlsStream<S>(Box<TlsStream<S>>);
64 :
65 : impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
66 : where
67 : S: AsyncRead + AsyncWrite + Unpin,
68 : {
69 12 : fn channel_binding(&self) -> ChannelBinding {
70 12 : let (_, session) = self.0.get_ref();
71 12 : match session.peer_certificates() {
72 12 : Some([cert, ..]) => TlsServerEndPoint::new(cert)
73 12 : .ok()
74 12 : .and_then(|cb| match cb {
75 12 : TlsServerEndPoint::Sha256(hash) => Some(hash),
76 0 : TlsServerEndPoint::Undefined => None,
77 12 : })
78 12 : .map_or_else(ChannelBinding::none, |hash| {
79 12 : ChannelBinding::tls_server_end_point(hash.to_vec())
80 12 : }),
81 0 : _ => ChannelBinding::none(),
82 : }
83 12 : }
84 : }
85 :
86 : impl<S> AsyncRead for RustlsStream<S>
87 : where
88 : S: AsyncRead + AsyncWrite + Unpin,
89 : {
90 127 : fn poll_read(
91 127 : mut self: Pin<&mut Self>,
92 127 : cx: &mut Context<'_>,
93 127 : buf: &mut ReadBuf<'_>,
94 127 : ) -> Poll<tokio::io::Result<()>> {
95 127 : Pin::new(&mut self.0).poll_read(cx, buf)
96 127 : }
97 : }
98 :
99 : impl<S> AsyncWrite for RustlsStream<S>
100 : where
101 : S: AsyncRead + AsyncWrite + Unpin,
102 : {
103 52 : fn poll_write(
104 52 : mut self: Pin<&mut Self>,
105 52 : cx: &mut Context<'_>,
106 52 : buf: &[u8],
107 52 : ) -> Poll<tokio::io::Result<usize>> {
108 52 : Pin::new(&mut self.0).poll_write(cx, buf)
109 52 : }
110 :
111 52 : fn poll_flush(
112 52 : mut self: Pin<&mut Self>,
113 52 : cx: &mut Context<'_>,
114 52 : ) -> Poll<tokio::io::Result<()>> {
115 52 : Pin::new(&mut self.0).poll_flush(cx)
116 52 : }
117 :
118 0 : fn poll_shutdown(
119 0 : mut self: Pin<&mut Self>,
120 0 : cx: &mut Context<'_>,
121 0 : ) -> Poll<tokio::io::Result<()>> {
122 0 : Pin::new(&mut self.0).poll_shutdown(cx)
123 0 : }
124 : }
125 : }
126 :
127 : impl<S> MakeTlsConnect<S> for ComputeConfig
128 : where
129 : S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
130 : {
131 : type Stream = private::RustlsStream<S>;
132 : type TlsConnect = private::RustlsConnect;
133 : type Error = InvalidDnsNameError;
134 :
135 0 : fn make_tls_connect(&self, hostname: &str) -> Result<Self::TlsConnect, Self::Error> {
136 0 : make_tls_connect(&self.tls, hostname)
137 0 : }
138 : }
139 :
140 20 : pub fn make_tls_connect(
141 20 : tls: &Arc<rustls::ClientConfig>,
142 20 : hostname: &str,
143 20 : ) -> Result<private::RustlsConnect, InvalidDnsNameError> {
144 20 : ServerName::try_from(hostname).map(|dns_name| {
145 20 : private::RustlsConnect(private::RustlsConnectData {
146 20 : hostname: dns_name.to_owned(),
147 20 : connector: tls.clone().into(),
148 20 : })
149 20 : })
150 20 : }
|