Line data Source code
1 : //! MaybeTlsStream.
2 : //!
3 : //! Represents a stream that may or may not be encrypted with TLS.
4 : use crate::tls::{ChannelBinding, TlsStream};
5 : use std::io;
6 : use std::pin::Pin;
7 : use std::task::{Context, Poll};
8 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9 :
10 : /// A stream that may or may not be encrypted with TLS.
11 : pub enum MaybeTlsStream<S, T> {
12 : /// An unencrypted stream.
13 : Raw(S),
14 : /// An encrypted stream.
15 : Tls(T),
16 : }
17 :
18 : impl<S, T> AsyncRead for MaybeTlsStream<S, T>
19 : where
20 : S: AsyncRead + Unpin,
21 : T: AsyncRead + Unpin,
22 : {
23 72 : fn poll_read(
24 72 : mut self: Pin<&mut Self>,
25 72 : cx: &mut Context<'_>,
26 72 : buf: &mut ReadBuf<'_>,
27 72 : ) -> Poll<io::Result<()>> {
28 72 : match &mut *self {
29 4 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
30 68 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
31 : }
32 72 : }
33 : }
34 :
35 : impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
36 : where
37 : S: AsyncWrite + Unpin,
38 : T: AsyncWrite + Unpin,
39 : {
40 36 : fn poll_write(
41 36 : mut self: Pin<&mut Self>,
42 36 : cx: &mut Context<'_>,
43 36 : buf: &[u8],
44 36 : ) -> Poll<io::Result<usize>> {
45 36 : match &mut *self {
46 2 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
47 34 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
48 : }
49 36 : }
50 :
51 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
52 36 : match &mut *self {
53 2 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
54 34 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
55 : }
56 36 : }
57 :
58 0 : fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
59 0 : match &mut *self {
60 0 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
61 0 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
62 : }
63 0 : }
64 : }
65 :
66 : impl<S, T> TlsStream for MaybeTlsStream<S, T>
67 : where
68 : S: AsyncRead + AsyncWrite + Unpin,
69 : T: TlsStream + Unpin,
70 : {
71 12 : fn channel_binding(&self) -> ChannelBinding {
72 12 : match self {
73 0 : MaybeTlsStream::Raw(_) => ChannelBinding::none(),
74 12 : MaybeTlsStream::Tls(s) => s.channel_binding(),
75 : }
76 12 : }
77 : }
|