Line data Source code
1 : use std::future::poll_fn;
2 : use std::io;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll, ready};
5 :
6 : use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
7 : use tracing::info;
8 :
9 : #[derive(Debug)]
10 : enum TransferState {
11 : Running(CopyBuffer),
12 : ShuttingDown(u64),
13 : Done(u64),
14 : }
15 :
16 : #[derive(Debug)]
17 : pub(crate) enum ErrorDirection {
18 : Read(io::Error),
19 : Write(io::Error),
20 : }
21 :
22 : impl ErrorSource {
23 0 : fn from_client(err: ErrorDirection) -> ErrorSource {
24 0 : match err {
25 0 : ErrorDirection::Read(client) => Self::Client(client),
26 0 : ErrorDirection::Write(compute) => Self::Compute(compute),
27 : }
28 0 : }
29 0 : fn from_compute(err: ErrorDirection) -> ErrorSource {
30 0 : match err {
31 0 : ErrorDirection::Write(client) => Self::Client(client),
32 0 : ErrorDirection::Read(compute) => Self::Compute(compute),
33 : }
34 0 : }
35 : }
36 :
37 : #[derive(Debug)]
38 : pub enum ErrorSource {
39 : Client(io::Error),
40 : Compute(io::Error),
41 : }
42 :
43 5 : fn transfer_one_direction<A, B>(
44 5 : cx: &mut Context<'_>,
45 5 : state: &mut TransferState,
46 5 : r: &mut A,
47 5 : w: &mut B,
48 5 : ) -> Poll<Result<u64, ErrorDirection>>
49 5 : where
50 5 : A: AsyncRead + AsyncWrite + Unpin + ?Sized,
51 5 : B: AsyncRead + AsyncWrite + Unpin + ?Sized,
52 : {
53 5 : let mut r = Pin::new(r);
54 5 : let mut w = Pin::new(w);
55 : loop {
56 12 : match state {
57 4 : TransferState::Running(buf) => {
58 4 : let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
59 3 : *state = TransferState::ShuttingDown(count);
60 : }
61 4 : TransferState::ShuttingDown(count) => {
62 4 : ready!(w.as_mut().poll_shutdown(cx)).map_err(ErrorDirection::Write)?;
63 4 : *state = TransferState::Done(*count);
64 : }
65 4 : TransferState::Done(count) => return Poll::Ready(Ok(*count)),
66 : }
67 : }
68 5 : }
69 :
70 : #[tracing::instrument(skip_all)]
71 : pub async fn copy_bidirectional_client_compute<Client, Compute>(
72 : client: &mut Client,
73 : compute: &mut Compute,
74 : ) -> Result<(u64, u64), ErrorSource>
75 : where
76 : Client: AsyncRead + AsyncWrite + Unpin + ?Sized,
77 : Compute: AsyncRead + AsyncWrite + Unpin + ?Sized,
78 : {
79 : let mut client_to_compute = TransferState::Running(CopyBuffer::new());
80 : let mut compute_to_client = TransferState::Running(CopyBuffer::new());
81 :
82 2 : poll_fn(|cx| {
83 2 : let mut client_to_compute_result =
84 2 : transfer_one_direction(cx, &mut client_to_compute, client, compute)
85 2 : .map_err(ErrorSource::from_client)?;
86 2 : let mut compute_to_client_result =
87 2 : transfer_one_direction(cx, &mut compute_to_client, compute, client)
88 2 : .map_err(ErrorSource::from_compute)?;
89 :
90 : // TODO: 1 info log, with a enum label for close direction.
91 :
92 : // Early termination checks from compute to client.
93 2 : if let TransferState::Done(_) = compute_to_client
94 2 : && let TransferState::Running(buf) = &client_to_compute
95 : {
96 1 : info!("Compute is done, terminate client");
97 : // Initiate shutdown
98 1 : client_to_compute = TransferState::ShuttingDown(buf.amt);
99 : client_to_compute_result =
100 1 : transfer_one_direction(cx, &mut client_to_compute, client, compute)
101 1 : .map_err(ErrorSource::from_client)?;
102 1 : }
103 :
104 : // Early termination checks from client to compute.
105 2 : if let TransferState::Done(_) = client_to_compute
106 2 : && let TransferState::Running(buf) = &compute_to_client
107 : {
108 0 : info!("Client is done, terminate compute");
109 : // Initiate shutdown
110 0 : compute_to_client = TransferState::ShuttingDown(buf.amt);
111 : compute_to_client_result =
112 0 : transfer_one_direction(cx, &mut compute_to_client, compute, client)
113 0 : .map_err(ErrorSource::from_compute)?;
114 2 : }
115 :
116 : // It is not a problem if ready! returns early ... (comment remains the same)
117 2 : let client_to_compute = ready!(client_to_compute_result);
118 2 : let compute_to_client = ready!(compute_to_client_result);
119 :
120 2 : Poll::Ready(Ok((client_to_compute, compute_to_client)))
121 2 : })
122 : .await
123 : }
124 :
125 : #[derive(Debug)]
126 : pub(super) struct CopyBuffer {
127 : read_done: bool,
128 : need_flush: bool,
129 : pos: usize,
130 : cap: usize,
131 : amt: u64,
132 : buf: Box<[u8]>,
133 : }
134 : const DEFAULT_BUF_SIZE: usize = 1024;
135 :
136 : impl CopyBuffer {
137 4 : pub(super) fn new() -> Self {
138 4 : Self {
139 4 : read_done: false,
140 4 : need_flush: false,
141 4 : pos: 0,
142 4 : cap: 0,
143 4 : amt: 0,
144 4 : buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
145 4 : }
146 4 : }
147 :
148 8 : fn poll_fill_buf<R>(
149 8 : &mut self,
150 8 : cx: &mut Context<'_>,
151 8 : reader: Pin<&mut R>,
152 8 : ) -> Poll<io::Result<()>>
153 8 : where
154 8 : R: AsyncRead + ?Sized,
155 : {
156 8 : let me = &mut *self;
157 8 : let mut buf = ReadBuf::new(&mut me.buf);
158 8 : buf.set_filled(me.cap);
159 :
160 8 : let res = reader.poll_read(cx, &mut buf);
161 7 : if let Poll::Ready(Ok(())) = res {
162 7 : let filled_len = buf.filled().len();
163 7 : me.read_done = me.cap == filled_len;
164 7 : me.cap = filled_len;
165 7 : }
166 8 : res
167 8 : }
168 :
169 5 : fn poll_write_buf<R, W>(
170 5 : &mut self,
171 5 : cx: &mut Context<'_>,
172 5 : mut reader: Pin<&mut R>,
173 5 : mut writer: Pin<&mut W>,
174 5 : ) -> Poll<Result<usize, ErrorDirection>>
175 5 : where
176 5 : R: AsyncRead + ?Sized,
177 5 : W: AsyncWrite + ?Sized,
178 : {
179 5 : let me = &mut *self;
180 5 : match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
181 : Poll::Pending => {
182 : // Top up the buffer towards full if we can read a bit more
183 : // data - this should improve the chances of a large write
184 1 : if !me.read_done && me.cap < me.buf.len() {
185 1 : ready!(me.poll_fill_buf(cx, reader.as_mut())).map_err(ErrorDirection::Read)?;
186 0 : }
187 0 : Poll::Pending
188 : }
189 4 : res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
190 : }
191 5 : }
192 :
193 4 : pub(super) fn poll_copy<R, W>(
194 4 : &mut self,
195 4 : cx: &mut Context<'_>,
196 4 : mut reader: Pin<&mut R>,
197 4 : mut writer: Pin<&mut W>,
198 4 : ) -> Poll<Result<u64, ErrorDirection>>
199 4 : where
200 4 : R: AsyncRead + ?Sized,
201 4 : W: AsyncWrite + ?Sized,
202 : {
203 : loop {
204 : // If there is some space left in our buffer, then we try to read some
205 : // data to continue, thus maximizing the chances of a large write.
206 7 : if self.cap < self.buf.len() && !self.read_done {
207 7 : match self.poll_fill_buf(cx, reader.as_mut()) {
208 7 : Poll::Ready(Ok(())) => (),
209 0 : Poll::Ready(Err(err)) => return Poll::Ready(Err(ErrorDirection::Read(err))),
210 : Poll::Pending => {
211 : // Ignore pending reads when our buffer is not empty, because
212 : // we can try to write data immediately.
213 0 : if self.pos == self.cap {
214 : // Try flushing when the reader has no progress to avoid deadlock
215 : // when the reader depends on buffered writer.
216 0 : if self.need_flush {
217 0 : ready!(writer.as_mut().poll_flush(cx))
218 0 : .map_err(ErrorDirection::Write)?;
219 0 : self.need_flush = false;
220 0 : }
221 :
222 0 : return Poll::Pending;
223 0 : }
224 : }
225 : }
226 0 : }
227 :
228 : // If our buffer has some data, let's write it out!
229 11 : while self.pos < self.cap {
230 5 : let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
231 4 : if i == 0 {
232 0 : return Poll::Ready(Err(ErrorDirection::Write(io::Error::new(
233 0 : io::ErrorKind::WriteZero,
234 0 : "write zero byte into writer",
235 0 : ))));
236 4 : }
237 4 : self.pos += i;
238 4 : self.amt += i as u64;
239 4 : self.need_flush = true;
240 : }
241 :
242 : // If pos larger than cap, this loop will never stop.
243 : // In particular, user's wrong poll_write implementation returning
244 : // incorrect written length may lead to thread blocking.
245 6 : debug_assert!(
246 6 : self.pos <= self.cap,
247 0 : "writer returned length larger than input slice"
248 : );
249 :
250 : // All data has been written, the buffer can be considered empty again
251 6 : self.pos = 0;
252 6 : self.cap = 0;
253 :
254 : // If we've written all the data and we've seen EOF, flush out the
255 : // data and finish the transfer.
256 6 : if self.read_done {
257 3 : ready!(writer.as_mut().poll_flush(cx)).map_err(ErrorDirection::Write)?;
258 3 : return Poll::Ready(Ok(self.amt));
259 3 : }
260 : }
261 4 : }
262 : }
263 :
264 : #[cfg(test)]
265 : mod tests {
266 : use tokio::io::AsyncWriteExt;
267 :
268 : use super::*;
269 :
270 : #[tokio::test]
271 1 : async fn test_client_to_compute() {
272 1 : let (mut client_client, mut client_proxy) = tokio::io::duplex(8); // Create a mock duplex stream
273 1 : let (mut compute_proxy, mut compute_client) = tokio::io::duplex(32); // Create a mock duplex stream
274 :
275 : // Simulate 'a' finishing while there's still data for 'b'
276 1 : client_client.write_all(b"hello").await.unwrap();
277 1 : client_client.shutdown().await.unwrap();
278 1 : compute_client.write_all(b"Neon").await.unwrap();
279 1 : compute_client.shutdown().await.unwrap();
280 :
281 1 : let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
282 1 : .await
283 1 : .unwrap();
284 :
285 : // Assert correct transferred amounts
286 1 : let (client_to_compute_count, compute_to_client_count) = result;
287 1 : assert_eq!(client_to_compute_count, 5); // 'hello' was transferred
288 1 : assert_eq!(compute_to_client_count, 4); // response only partially transferred or not at all
289 1 : }
290 :
291 : #[tokio::test]
292 1 : async fn test_compute_to_client() {
293 1 : let (mut client_client, mut client_proxy) = tokio::io::duplex(32); // Create a mock duplex stream
294 1 : let (mut compute_proxy, mut compute_client) = tokio::io::duplex(8); // Create a mock duplex stream
295 :
296 : // Simulate 'a' finishing while there's still data for 'b'
297 1 : compute_client.write_all(b"hello").await.unwrap();
298 1 : compute_client.shutdown().await.unwrap();
299 1 : client_client
300 1 : .write_all(b"Neon Serverless Postgres")
301 1 : .await
302 1 : .unwrap();
303 :
304 1 : let result = copy_bidirectional_client_compute(&mut client_proxy, &mut compute_proxy)
305 1 : .await
306 1 : .unwrap();
307 :
308 : // Assert correct transferred amounts
309 1 : let (client_to_compute_count, compute_to_client_count) = result;
310 1 : assert_eq!(compute_to_client_count, 5); // 'hello' was transferred
311 1 : assert!(client_to_compute_count <= 8); // response only partially transferred or not at all
312 1 : }
313 : }
|