Line data Source code
1 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2 : use tracing::info;
3 :
4 : use std::future::poll_fn;
5 : use std::io;
6 : use std::pin::Pin;
7 : use std::task::{ready, Context, Poll};
8 :
9 : #[derive(Debug)]
10 : enum TransferState {
11 : Running(CopyBuffer),
12 : ShuttingDown(u64),
13 : Done(u64),
14 : }
15 :
16 10 : fn transfer_one_direction<A, B>(
17 10 : cx: &mut Context<'_>,
18 10 : state: &mut TransferState,
19 10 : r: &mut A,
20 10 : w: &mut B,
21 10 : ) -> Poll<io::Result<u64>>
22 10 : where
23 10 : A: AsyncRead + AsyncWrite + Unpin + ?Sized,
24 10 : B: AsyncRead + AsyncWrite + Unpin + ?Sized,
25 10 : {
26 10 : let mut r = Pin::new(r);
27 10 : let mut w = Pin::new(w);
28 24 : loop {
29 24 : match state {
30 8 : TransferState::Running(buf) => {
31 8 : let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
32 6 : *state = TransferState::ShuttingDown(count);
33 : }
34 8 : TransferState::ShuttingDown(count) => {
35 8 : ready!(w.as_mut().poll_shutdown(cx))?;
36 8 : *state = TransferState::Done(*count);
37 : }
38 8 : TransferState::Done(count) => return Poll::Ready(Ok(*count)),
39 : }
40 : }
41 10 : }
42 :
43 8 : #[tracing::instrument(skip_all)]
44 : pub async fn copy_bidirectional_client_compute<Client, Compute>(
45 : client: &mut Client,
46 : compute: &mut Compute,
47 : ) -> Result<(u64, u64), std::io::Error>
48 : where
49 : Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
50 : Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
51 : {
52 : let mut client_to_compute = TransferState::Running(CopyBuffer::new());
53 : let mut compute_to_client = TransferState::Running(CopyBuffer::new());
54 :
55 4 : poll_fn(|cx| {
56 4 : let mut client_to_compute_result =
57 4 : transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
58 4 : let mut compute_to_client_result =
59 4 : transfer_one_direction(cx, &mut compute_to_client, compute, client)?;
60 :
61 : // Early termination checks from compute to client.
62 4 : if let TransferState::Done(_) = compute_to_client {
63 4 : if let TransferState::Running(buf) = &client_to_compute {
64 2 : info!("Compute is done, terminate client");
65 : // Initiate shutdown
66 2 : client_to_compute = TransferState::ShuttingDown(buf.amt);
67 2 : client_to_compute_result =
68 2 : transfer_one_direction(cx, &mut client_to_compute, client, compute)?;
69 2 : }
70 0 : }
71 :
72 : // Early termination checks from compute to client.
73 4 : if let TransferState::Done(_) = client_to_compute {
74 4 : if let TransferState::Running(buf) = &compute_to_client {
75 0 : info!("Client is done, terminate compute");
76 : // Initiate shutdown
77 0 : compute_to_client = TransferState::ShuttingDown(buf.amt);
78 0 : compute_to_client_result =
79 0 : transfer_one_direction(cx, &mut compute_to_client, client, compute)?;
80 4 : }
81 0 : }
82 :
83 : // It is not a problem if ready! returns early ... (comment remains the same)
84 4 : let client_to_compute = ready!(client_to_compute_result);
85 4 : let compute_to_client = ready!(compute_to_client_result);
86 :
87 4 : Poll::Ready(Ok((client_to_compute, compute_to_client)))
88 4 : })
89 : .await
90 : }
91 :
92 : #[derive(Debug)]
93 : pub(super) struct CopyBuffer {
94 : read_done: bool,
95 : need_flush: bool,
96 : pos: usize,
97 : cap: usize,
98 : amt: u64,
99 : buf: Box<[u8]>,
100 : }
101 : const DEFAULT_BUF_SIZE: usize = 1024;
102 :
103 : impl CopyBuffer {
104 8 : pub(super) fn new() -> Self {
105 8 : Self {
106 8 : read_done: false,
107 8 : need_flush: false,
108 8 : pos: 0,
109 8 : cap: 0,
110 8 : amt: 0,
111 8 : buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
112 8 : }
113 8 : }
114 :
115 16 : fn poll_fill_buf<R>(
116 16 : &mut self,
117 16 : cx: &mut Context<'_>,
118 16 : reader: Pin<&mut R>,
119 16 : ) -> Poll<io::Result<()>>
120 16 : where
121 16 : R: AsyncRead + ?Sized,
122 16 : {
123 16 : let me = &mut *self;
124 16 : let mut buf = ReadBuf::new(&mut me.buf);
125 16 : buf.set_filled(me.cap);
126 16 :
127 16 : let res = reader.poll_read(cx, &mut buf);
128 16 : if let Poll::Ready(Ok(())) = res {
129 14 : let filled_len = buf.filled().len();
130 14 : me.read_done = me.cap == filled_len;
131 14 : me.cap = filled_len;
132 14 : }
133 16 : res
134 16 : }
135 :
136 10 : fn poll_write_buf<R, W>(
137 10 : &mut self,
138 10 : cx: &mut Context<'_>,
139 10 : mut reader: Pin<&mut R>,
140 10 : mut writer: Pin<&mut W>,
141 10 : ) -> Poll<io::Result<usize>>
142 10 : where
143 10 : R: AsyncRead + ?Sized,
144 10 : W: AsyncWrite + ?Sized,
145 10 : {
146 10 : let me = &mut *self;
147 10 : match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
148 : Poll::Pending => {
149 : // Top up the buffer towards full if we can read a bit more
150 : // data - this should improve the chances of a large write
151 2 : if !me.read_done && me.cap < me.buf.len() {
152 2 : ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
153 0 : }
154 0 : Poll::Pending
155 : }
156 8 : res => res,
157 : }
158 10 : }
159 :
160 8 : pub(super) fn poll_copy<R, W>(
161 8 : &mut self,
162 8 : cx: &mut Context<'_>,
163 8 : mut reader: Pin<&mut R>,
164 8 : mut writer: Pin<&mut W>,
165 8 : ) -> Poll<io::Result<u64>>
166 8 : where
167 8 : R: AsyncRead + ?Sized,
168 8 : W: AsyncWrite + ?Sized,
169 8 : {
170 14 : loop {
171 14 : // If our buffer is empty, then we need to read some data to
172 14 : // continue.
173 14 : if self.pos == self.cap && !self.read_done {
174 14 : self.pos = 0;
175 14 : self.cap = 0;
176 14 :
177 14 : match self.poll_fill_buf(cx, reader.as_mut()) {
178 14 : Poll::Ready(Ok(())) => (),
179 0 : Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
180 : Poll::Pending => {
181 : // Try flushing when the reader has no progress to avoid deadlock
182 : // when the reader depends on buffered writer.
183 0 : if self.need_flush {
184 0 : ready!(writer.as_mut().poll_flush(cx))?;
185 0 : self.need_flush = false;
186 0 : }
187 :
188 0 : return Poll::Pending;
189 : }
190 : }
191 0 : }
192 :
193 : // If our buffer has some data, let's write it out!
194 22 : while self.pos < self.cap {
195 10 : let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
196 8 : if i == 0 {
197 0 : return Poll::Ready(Err(io::Error::new(
198 0 : io::ErrorKind::WriteZero,
199 0 : "write zero byte into writer",
200 0 : )));
201 8 : } else {
202 8 : self.pos += i;
203 8 : self.amt += i as u64;
204 8 : self.need_flush = true;
205 8 : }
206 : }
207 :
208 : // If pos larger than cap, this loop will never stop.
209 : // In particular, user's wrong poll_write implementation returning
210 : // incorrect written length may lead to thread blocking.
211 12 : debug_assert!(
212 12 : self.pos <= self.cap,
213 0 : "writer returned length larger than input slice"
214 : );
215 :
216 : // If we've written all the data and we've seen EOF, flush out the
217 : // data and finish the transfer.
218 12 : if self.pos == self.cap && self.read_done {
219 6 : ready!(writer.as_mut().poll_flush(cx))?;
220 6 : return Poll::Ready(Ok(self.amt));
221 6 : }
222 : }
223 8 : }
224 : }
225 :
226 : #[cfg(test)]
227 : mod tests {
228 : use super::*;
229 : use tokio::io::AsyncWriteExt;
230 :
231 : #[tokio::test]
232 2 : async fn test_client_to_compute() {
233 2 : let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
234 2 : let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
235 2 :
236 2 : // Simulate 'a' finishing while there's still data for 'b'
237 2 : client_client.write_all(b"hello").await.unwrap();
238 2 : client_client.shutdown().await.unwrap();
239 2 : compute_client.write_all(b"Neon").await.unwrap();
240 2 : compute_client.shutdown().await.unwrap();
241 2 :
242 2 : let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
243 2 : .await
244 2 : .unwrap();
245 2 :
246 2 : // Assert correct transferred amounts
247 2 : let (client_to_compute_count, compute_to_client_count) = result;
248 2 : assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
249 2 : assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
250 2 : }
251 :
252 : #[tokio::test]
253 2 : async fn test_compute_to_client() {
254 2 : let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
255 2 : let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
256 2 :
257 2 : // Simulate 'a' finishing while there's still data for 'b'
258 2 : compute_client.write_all(b"hello").await.unwrap();
259 2 : compute_client.shutdown().await.unwrap();
260 2 : client_client
261 2 : .write_all(b"Neon Serverless Postgres")
262 2 : .await
263 2 : .unwrap();
264 2 :
265 2 : let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
266 2 : .await
267 2 : .unwrap();
268 2 :
269 2 : // Assert correct transferred amounts
270 2 : let (client_to_compute_count, compute_to_client_count) = result;
271 2 : assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
272 2 : assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
273 2 : }
274 : }
|