Line data Source code
1 : //! MaybeTlsStream.
2 : //!
3 : //! Represents a stream that may or may not be encrypted with TLS.
4 : use std::io;
5 : use std::pin::Pin;
6 : use std::task::{Context, Poll};
7 :
8 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9 :
10 : use crate::tls::{ChannelBinding, TlsStream};
11 :
12 : /// A stream that may or may not be encrypted with TLS.
13 : pub enum MaybeTlsStream<S, T> {
14 : /// An unencrypted stream.
15 : Raw(S),
16 : /// An encrypted stream.
17 : Tls(T),
18 : }
19 :
20 : impl<S, T> AsyncRead for MaybeTlsStream<S, T>
21 : where
22 : S: AsyncRead + Unpin,
23 : T: AsyncRead + Unpin,
24 : {
25 72 : fn poll_read(
26 72 : mut self: Pin<&mut Self>,
27 72 : cx: &mut Context<'_>,
28 72 : buf: &mut ReadBuf<'_>,
29 72 : ) -> Poll<io::Result<()>> {
30 72 : match &mut *self {
31 4 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
32 68 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
33 : }
34 0 : }
35 : }
36 :
37 : impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
38 : where
39 : S: AsyncWrite + Unpin,
40 : T: AsyncWrite + Unpin,
41 : {
42 36 : fn poll_write(
43 36 : mut self: Pin<&mut Self>,
44 36 : cx: &mut Context<'_>,
45 36 : buf: &[u8],
46 36 : ) -> Poll<io::Result<usize>> {
47 36 : match &mut *self {
48 2 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
49 34 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
50 : }
51 0 : }
52 :
53 36 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
54 36 : match &mut *self {
55 2 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
56 34 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
57 : }
58 0 : }
59 :
60 0 : fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
61 0 : match &mut *self {
62 0 : MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
63 0 : MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
64 : }
65 0 : }
66 : }
67 :
68 : impl<S, T> TlsStream for MaybeTlsStream<S, T>
69 : where
70 : S: AsyncRead + AsyncWrite + Unpin,
71 : T: TlsStream + Unpin,
72 : {
73 12 : fn channel_binding(&self) -> ChannelBinding {
74 12 : match self {
75 0 : MaybeTlsStream::Raw(_) => ChannelBinding::none(),
76 12 : MaybeTlsStream::Tls(s) => s.channel_binding(),
77 : }
78 0 : }
79 : }
|