Line data Source code
1 : #![allow(dead_code, reason = "TODO: work in progress")]
2 :
3 : use std::pin::{Pin, pin};
4 : use std::sync::Arc;
5 : use std::sync::atomic::{AtomicUsize, Ordering};
6 : use std::task::{Context, Poll};
7 : use std::{fmt, io};
8 :
9 : use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
10 : use tokio::sync::mpsc;
11 :
12 : const STREAM_CHANNEL_SIZE: usize = 16;
13 : const MAX_STREAM_BUFFER_SIZE: usize = 4096;
14 :
15 : #[derive(Debug)]
16 : pub struct Connection {
17 : stream_sender: mpsc::Sender<Stream>,
18 : stream_receiver: mpsc::Receiver<Stream>,
19 : stream_id_counter: Arc<AtomicUsize>,
20 : }
21 :
22 : impl Connection {
23 1 : pub fn new() -> (Connection, Connection) {
24 1 : let (sender_a, receiver_a) = mpsc::channel(STREAM_CHANNEL_SIZE);
25 1 : let (sender_b, receiver_b) = mpsc::channel(STREAM_CHANNEL_SIZE);
26 1 :
27 1 : let stream_id_counter = Arc::new(AtomicUsize::new(1));
28 1 :
29 1 : let conn_a = Connection {
30 1 : stream_sender: sender_a,
31 1 : stream_receiver: receiver_b,
32 1 : stream_id_counter: Arc::clone(&stream_id_counter),
33 1 : };
34 1 : let conn_b = Connection {
35 1 : stream_sender: sender_b,
36 1 : stream_receiver: receiver_a,
37 1 : stream_id_counter,
38 1 : };
39 1 :
40 1 : (conn_a, conn_b)
41 1 : }
42 :
43 : #[inline]
44 1 : fn next_stream_id(&self) -> StreamId {
45 1 : StreamId(self.stream_id_counter.fetch_add(1, Ordering::Relaxed))
46 1 : }
47 :
48 : #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
49 : pub async fn open_stream(&self) -> io::Result<Stream> {
50 : let (local, remote) = tokio::io::duplex(MAX_STREAM_BUFFER_SIZE);
51 : let stream_id = self.next_stream_id();
52 : tracing::Span::current().record("stream_id", stream_id.0);
53 :
54 : let local = Stream {
55 : inner: local,
56 : id: stream_id,
57 : };
58 : let remote = Stream {
59 : inner: remote,
60 : id: stream_id,
61 : };
62 :
63 : self.stream_sender
64 : .send(remote)
65 : .await
66 : .map_err(io::Error::other)?;
67 :
68 : Ok(local)
69 : }
70 :
71 : #[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
72 : pub async fn accept_stream(&mut self) -> io::Result<Option<Stream>> {
73 1 : Ok(self.stream_receiver.recv().await.inspect(|stream| {
74 1 : tracing::Span::current().record("stream_id", stream.id.0);
75 1 : }))
76 : }
77 : }
78 :
79 : #[derive(Copy, Clone, Debug)]
80 : pub struct StreamId(usize);
81 :
82 : impl fmt::Display for StreamId {
83 : #[inline]
84 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 0 : write!(f, "{}", self.0)
86 0 : }
87 : }
88 :
89 : // TODO: Proper closing. Currently Streams can outlive their Connections.
90 : // Carry WeakSender and check strong_count?
91 : #[derive(Debug)]
92 : pub struct Stream {
93 : inner: DuplexStream,
94 : id: StreamId,
95 : }
96 :
97 : impl Stream {
98 : #[inline]
99 0 : pub fn id(&self) -> StreamId {
100 0 : self.id
101 0 : }
102 : }
103 :
104 : impl AsyncRead for Stream {
105 : #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
106 : #[inline]
107 : fn poll_read(
108 : mut self: Pin<&mut Self>,
109 : cx: &mut Context<'_>,
110 : buf: &mut ReadBuf<'_>,
111 : ) -> Poll<io::Result<()>> {
112 : pin!(&mut self.inner).poll_read(cx, buf)
113 : }
114 : }
115 :
116 : impl AsyncWrite for Stream {
117 : #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
118 : #[inline]
119 : fn poll_write(
120 : mut self: Pin<&mut Self>,
121 : cx: &mut Context<'_>,
122 : buf: &[u8],
123 : ) -> Poll<Result<usize, io::Error>> {
124 : pin!(&mut self.inner).poll_write(cx, buf)
125 : }
126 :
127 : #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
128 : #[inline]
129 : fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
130 : pin!(&mut self.inner).poll_flush(cx)
131 : }
132 :
133 : #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
134 : #[inline]
135 : fn poll_shutdown(
136 : mut self: Pin<&mut Self>,
137 : cx: &mut Context<'_>,
138 : ) -> Poll<Result<(), io::Error>> {
139 : pin!(&mut self.inner).poll_shutdown(cx)
140 : }
141 :
142 : #[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
143 : #[inline]
144 : fn poll_write_vectored(
145 : mut self: Pin<&mut Self>,
146 : cx: &mut Context<'_>,
147 : bufs: &[io::IoSlice<'_>],
148 : ) -> Poll<Result<usize, io::Error>> {
149 : pin!(&mut self.inner).poll_write_vectored(cx, bufs)
150 : }
151 :
152 : #[inline]
153 0 : fn is_write_vectored(&self) -> bool {
154 0 : self.inner.is_write_vectored()
155 0 : }
156 : }
157 :
158 : #[cfg(test)]
159 : mod tests {
160 : use tokio::io::{AsyncReadExt, AsyncWriteExt};
161 :
162 : use super::*;
163 :
164 : #[tokio::test]
165 1 : async fn test_simple_roundtrip() {
166 1 : let (client, mut server) = Connection::new();
167 1 :
168 1 : let server_task = tokio::spawn(async move {
169 2 : while let Some(mut stream) = server.accept_stream().await.unwrap() {
170 1 : tokio::spawn(async move {
171 1 : let mut buf = [0; 64];
172 1 : loop {
173 2 : match stream.read(&mut buf).await.unwrap() {
174 1 : 0 => break,
175 1 : n => stream.write(&buf[..n]).await.unwrap(),
176 1 : };
177 1 : }
178 1 : });
179 1 : }
180 1 : });
181 1 :
182 1 : let mut stream = client.open_stream().await.unwrap();
183 1 : stream.write_all(b"hello!").await.unwrap();
184 1 : let mut buf = [0; 64];
185 1 : let n = stream.read(&mut buf).await.unwrap();
186 1 : assert_eq!(n, 6);
187 1 : assert_eq!(&buf[..n], b"hello!");
188 1 :
189 1 : drop(stream);
190 1 : drop(client);
191 1 : server_task.await.unwrap();
192 1 : }
193 : }
|