Line data Source code
1 : use bytes::BytesMut;
2 : use postgres_protocol2::message::frontend;
3 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4 :
5 : use crate::Error;
6 : use crate::config::SslMode;
7 : use crate::maybe_tls_stream::MaybeTlsStream;
8 : use crate::tls::TlsConnect;
9 : use crate::tls::private::ForcePrivateApi;
10 :
11 15 : pub async fn connect_tls<S, T>(
12 15 : mut stream: S,
13 15 : mode: SslMode,
14 15 : tls: T,
15 15 : ) -> Result<MaybeTlsStream<S, T::Stream>, Error>
16 15 : where
17 15 : S: AsyncRead + AsyncWrite + Unpin,
18 15 : T: TlsConnect<S>,
19 15 : {
20 1 : match mode {
21 1 : SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
22 1 : SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
23 1 : return Ok(MaybeTlsStream::Raw(stream));
24 : }
25 13 : SslMode::Prefer | SslMode::Require => {}
26 13 : }
27 13 :
28 13 : let mut buf = BytesMut::new();
29 13 : frontend::ssl_request(&mut buf);
30 13 : stream.write_all(&buf).await.map_err(Error::io)?;
31 :
32 13 : let mut buf = [0];
33 13 : stream.read_exact(&mut buf).await.map_err(Error::io)?;
34 :
35 13 : if buf[0] != b'S' {
36 0 : if SslMode::Require == mode {
37 0 : return Err(Error::tls("server does not support TLS".into()));
38 : } else {
39 0 : return Ok(MaybeTlsStream::Raw(stream));
40 : }
41 0 : }
42 :
43 13 : let stream = tls
44 13 : .connect(stream)
45 13 : .await
46 13 : .map_err(|e| Error::tls(e.into()))?;
47 :
48 13 : Ok(MaybeTlsStream::Tls(stream))
49 0 : }
|