Line data Source code
1 : use std::cmp::Ordering;
2 : use std::collections::{BinaryHeap, VecDeque};
3 : use std::fmt::{self, Debug};
4 : use std::ops::DerefMut;
5 : use std::sync::{Arc, mpsc};
6 :
7 : use parking_lot::lock_api::{MappedMutexGuard, MutexGuard};
8 : use parking_lot::{Mutex, RawMutex};
9 : use rand::rngs::StdRng;
10 : use tracing::debug;
11 :
12 : use super::chan::Chan;
13 : use super::proto::AnyMessage;
14 : use crate::executor::{self, ThreadContext};
15 : use crate::options::NetworkOptions;
16 : use crate::proto::{NetEvent, NodeEvent};
17 :
18 : pub struct NetworkTask {
19 : options: Arc<NetworkOptions>,
20 : connections: Mutex<Vec<VirtualConnection>>,
21 : /// min-heap of connections having something to deliver.
22 : events: Mutex<BinaryHeap<Event>>,
23 : task_context: Arc<ThreadContext>,
24 : }
25 :
26 : impl NetworkTask {
27 528 : pub fn start_new(options: Arc<NetworkOptions>, tx: mpsc::Sender<Arc<NetworkTask>>) {
28 528 : let ctx = executor::get_thread_ctx();
29 528 : let task = Arc::new(Self {
30 528 : options,
31 528 : connections: Mutex::new(Vec::new()),
32 528 : events: Mutex::new(BinaryHeap::new()),
33 528 : task_context: ctx,
34 528 : });
35 528 :
36 528 : // send the task upstream
37 528 : tx.send(task.clone()).unwrap();
38 528 :
39 528 : // start the task
40 528 : task.start();
41 528 : }
42 :
43 35875 : pub fn start_new_connection(self: &Arc<Self>, rng: StdRng, dst_accept: Chan<NodeEvent>) -> TCP {
44 35875 : let now = executor::now();
45 35875 : let connection_id = self.connections.lock().len();
46 35875 :
47 35875 : let vc = VirtualConnection {
48 35875 : connection_id,
49 35875 : dst_accept,
50 35875 : dst_sockets: [Chan::new(), Chan::new()],
51 35875 : state: Mutex::new(ConnectionState {
52 35875 : buffers: [NetworkBuffer::new(None), NetworkBuffer::new(Some(now))],
53 35875 : rng,
54 35875 : }),
55 35875 : };
56 35875 : vc.schedule_timeout(self);
57 35875 : vc.send_connect(self);
58 35875 :
59 35875 : let recv_chan = vc.dst_sockets[0].clone();
60 35875 : self.connections.lock().push(vc);
61 35875 :
62 35875 : TCP {
63 35875 : net: self.clone(),
64 35875 : conn_id: connection_id,
65 35875 : dir: 0,
66 35875 : recv_chan,
67 35875 : }
68 35875 : }
69 : }
70 :
71 : // private functions
72 : impl NetworkTask {
73 : /// Schedule to wakeup network task (self) `after_ms` later to deliver
74 : /// messages of connection `id`.
75 184561 : fn schedule(&self, id: usize, after_ms: u64) {
76 184561 : self.events.lock().push(Event {
77 184561 : time: executor::now() + after_ms,
78 184561 : conn_id: id,
79 184561 : });
80 184561 : self.task_context.schedule_wakeup(after_ms);
81 184561 : }
82 :
83 : /// Get locked connection `id`.
84 245993 : fn get(&self, id: usize) -> MappedMutexGuard<'_, RawMutex, VirtualConnection> {
85 245993 : MutexGuard::map(self.connections.lock(), |connections| {
86 245993 : connections.get_mut(id).unwrap()
87 245993 : })
88 245993 : }
89 :
90 173255 : fn collect_pending_events(&self, now: u64, vec: &mut Vec<Event>) {
91 173255 : vec.clear();
92 173255 : let mut events = self.events.lock();
93 345983 : while let Some(event) = events.peek() {
94 343646 : if event.time > now {
95 170918 : break;
96 172728 : }
97 172728 : let event = events.pop().unwrap();
98 172728 : vec.push(event);
99 : }
100 173255 : }
101 :
102 528 : fn start(self: &Arc<Self>) {
103 528 : debug!("started network task");
104 :
105 528 : let mut events = Vec::new();
106 : loop {
107 173255 : let now = executor::now();
108 173255 : self.collect_pending_events(now, &mut events);
109 :
110 173255 : for event in events.drain(..) {
111 172728 : let conn = self.get(event.conn_id);
112 172728 : conn.process(self);
113 172728 : }
114 :
115 : // block until wakeup
116 173255 : executor::yield_me(-1);
117 : }
118 : }
119 : }
120 :
121 : // 0 - from node(0) to node(1)
122 : // 1 - from node(1) to node(0)
123 : type MessageDirection = u8;
124 :
125 1311 : fn sender_str(dir: MessageDirection) -> &'static str {
126 1311 : match dir {
127 222 : 0 => "client",
128 1089 : 1 => "server",
129 0 : _ => unreachable!(),
130 : }
131 1311 : }
132 :
133 355 : fn receiver_str(dir: MessageDirection) -> &'static str {
134 355 : match dir {
135 163 : 0 => "server",
136 192 : 1 => "client",
137 0 : _ => unreachable!(),
138 : }
139 355 : }
140 :
141 : /// Virtual connection between two nodes.
142 : /// Node 0 is the creator of the connection (client),
143 : /// and node 1 is the acceptor (server).
144 : struct VirtualConnection {
145 : connection_id: usize,
146 : /// one-off chan, used to deliver Accept message to dst
147 : dst_accept: Chan<NodeEvent>,
148 : /// message sinks
149 : dst_sockets: [Chan<NetEvent>; 2],
150 : state: Mutex<ConnectionState>,
151 : }
152 :
153 : struct ConnectionState {
154 : buffers: [NetworkBuffer; 2],
155 : rng: StdRng,
156 : }
157 :
158 : impl VirtualConnection {
159 : /// Notify the future about the possible timeout.
160 106104 : fn schedule_timeout(&self, net: &NetworkTask) {
161 106104 : if let Some(timeout) = net.options.keepalive_timeout {
162 106104 : net.schedule(self.connection_id, timeout);
163 106104 : }
164 106104 : }
165 :
166 : /// Send the handshake (Accept) to the server.
167 35875 : fn send_connect(&self, net: &NetworkTask) {
168 35875 : let now = executor::now();
169 35875 : let mut state = self.state.lock();
170 35875 : let delay = net.options.connect_delay.delay(&mut state.rng);
171 35875 : let buffer = &mut state.buffers[0];
172 35875 : assert!(buffer.buf.is_empty());
173 35875 : assert!(!buffer.recv_closed);
174 35875 : assert!(!buffer.send_closed);
175 35875 : assert!(buffer.last_recv.is_none());
176 :
177 35875 : let delay = if let Some(ms) = delay {
178 28163 : ms
179 : } else {
180 7712 : debug!("NET: TCP #{} dropped connect", self.connection_id);
181 7712 : buffer.send_closed = true;
182 7712 : return;
183 : };
184 :
185 : // Send a message into the future.
186 28163 : buffer
187 28163 : .buf
188 28163 : .push_back((now + delay, AnyMessage::InternalConnect));
189 28163 : net.schedule(self.connection_id, delay);
190 35875 : }
191 :
192 : /// Transmit some of the messages from the buffer to the nodes.
193 172728 : fn process(&self, net: &Arc<NetworkTask>) {
194 172728 : let now = executor::now();
195 172728 :
196 172728 : let mut state = self.state.lock();
197 :
198 518184 : for direction in 0..2 {
199 345456 : self.process_direction(
200 345456 : net,
201 345456 : state.deref_mut(),
202 345456 : now,
203 345456 : direction as MessageDirection,
204 345456 : &self.dst_sockets[direction ^ 1],
205 345456 : );
206 345456 : }
207 :
208 : // Close the one side of the connection by timeout if the node
209 : // has not received any messages for a long time.
210 172728 : if let Some(timeout) = net.options.keepalive_timeout {
211 172728 : let mut to_close = [false, false];
212 518184 : for direction in 0..2 {
213 345456 : let buffer = &mut state.buffers[direction];
214 345456 : if buffer.recv_closed {
215 73634 : continue;
216 271822 : }
217 271822 : if let Some(last_recv) = buffer.last_recv {
218 245390 : if now - last_recv >= timeout {
219 54312 : debug!(
220 0 : "NET: connection {} timed out at {}",
221 0 : self.connection_id,
222 0 : receiver_str(direction as MessageDirection)
223 : );
224 54312 : let node_idx = direction ^ 1;
225 54312 : to_close[node_idx] = true;
226 191078 : }
227 26432 : }
228 : }
229 172728 : drop(state);
230 :
231 345456 : for (node_idx, should_close) in to_close.iter().enumerate() {
232 345456 : if *should_close {
233 54312 : self.close(node_idx);
234 291144 : }
235 : }
236 0 : }
237 172728 : }
238 :
239 : /// Process messages in the buffer in the given direction.
240 345456 : fn process_direction(
241 345456 : &self,
242 345456 : net: &Arc<NetworkTask>,
243 345456 : state: &mut ConnectionState,
244 345456 : now: u64,
245 345456 : direction: MessageDirection,
246 345456 : to_socket: &Chan<NetEvent>,
247 345456 : ) {
248 345456 : let buffer = &mut state.buffers[direction as usize];
249 345456 : if buffer.recv_closed {
250 73634 : assert!(buffer.buf.is_empty());
251 271822 : }
252 :
253 415685 : while !buffer.buf.is_empty() && buffer.buf.front().unwrap().0 <= now {
254 70229 : let msg = buffer.buf.pop_front().unwrap().1;
255 70229 :
256 70229 : buffer.last_recv = Some(now);
257 70229 : self.schedule_timeout(net);
258 70229 :
259 70229 : if let AnyMessage::InternalConnect = msg {
260 26796 : // TODO: assert to_socket is the server
261 26796 : let server_to_client = TCP {
262 26796 : net: net.clone(),
263 26796 : conn_id: self.connection_id,
264 26796 : dir: direction ^ 1,
265 26796 : recv_chan: to_socket.clone(),
266 26796 : };
267 26796 : // special case, we need to deliver new connection to a separate channel
268 26796 : self.dst_accept.send(NodeEvent::Accept(server_to_client));
269 43433 : } else {
270 43433 : to_socket.send(NetEvent::Message(msg));
271 43433 : }
272 : }
273 345456 : }
274 :
275 : /// Try to send a message to the buffer, optionally dropping it and
276 : /// determining delivery timestamp.
277 69249 : fn send(&self, net: &NetworkTask, direction: MessageDirection, msg: AnyMessage) {
278 69249 : let now = executor::now();
279 69249 : let mut state = self.state.lock();
280 :
281 69249 : let (delay, close) = if let Some(ms) = net.options.send_delay.delay(&mut state.rng) {
282 63442 : (ms, false)
283 : } else {
284 5807 : (0, true)
285 : };
286 :
287 69249 : let buffer = &mut state.buffers[direction as usize];
288 69249 : if buffer.send_closed {
289 7264 : debug!(
290 0 : "NET: TCP #{} dropped message {:?} (broken pipe)",
291 : self.connection_id, msg
292 : );
293 7264 : return;
294 61985 : }
295 61985 :
296 61985 : if close {
297 4525 : debug!(
298 0 : "NET: TCP #{} dropped message {:?} (pipe just broke)",
299 : self.connection_id, msg
300 : );
301 4525 : buffer.send_closed = true;
302 4525 : return;
303 57460 : }
304 57460 :
305 57460 : if buffer.recv_closed {
306 7166 : debug!(
307 0 : "NET: TCP #{} dropped message {:?} (recv closed)",
308 : self.connection_id, msg
309 : );
310 7166 : return;
311 50294 : }
312 50294 :
313 50294 : // Send a message into the future.
314 50294 : buffer.buf.push_back((now + delay, msg));
315 50294 : net.schedule(self.connection_id, delay);
316 69249 : }
317 :
318 : /// Close the connection. Only one side of the connection will be closed,
319 : /// and no further messages will be delivered. The other side will not be notified.
320 58328 : fn close(&self, node_idx: usize) {
321 58328 : let mut state = self.state.lock();
322 58328 : let recv_buffer = &mut state.buffers[1 ^ node_idx];
323 58328 : if recv_buffer.recv_closed {
324 294 : debug!(
325 0 : "NET: TCP #{} closed twice at {}",
326 0 : self.connection_id,
327 0 : sender_str(node_idx as MessageDirection),
328 : );
329 294 : return;
330 58034 : }
331 58034 :
332 58034 : debug!(
333 0 : "NET: TCP #{} closed at {}",
334 0 : self.connection_id,
335 0 : sender_str(node_idx as MessageDirection),
336 : );
337 58034 : recv_buffer.recv_closed = true;
338 58034 : for msg in recv_buffer.buf.drain(..) {
339 5280 : debug!(
340 0 : "NET: TCP #{} dropped message {:?} (closed)",
341 : self.connection_id, msg
342 : );
343 : }
344 :
345 58034 : let send_buffer = &mut state.buffers[node_idx];
346 58034 : send_buffer.send_closed = true;
347 58034 : drop(state);
348 58034 :
349 58034 : // TODO: notify the other side?
350 58034 :
351 58034 : self.dst_sockets[node_idx].send(NetEvent::Closed);
352 58328 : }
353 : }
354 :
355 : struct NetworkBuffer {
356 : /// Messages paired with time of delivery
357 : buf: VecDeque<(u64, AnyMessage)>,
358 : /// True if the connection is closed on the receiving side,
359 : /// i.e. no more messages from the buffer will be delivered.
360 : recv_closed: bool,
361 : /// True if the connection is closed on the sending side,
362 : /// i.e. no more messages will be added to the buffer.
363 : send_closed: bool,
364 : /// Last time a message was delivered from the buffer.
365 : /// If None, it means that the server is the receiver and
366 : /// it has not yet aware of this connection (i.e. has not
367 : /// received the Accept).
368 : last_recv: Option<u64>,
369 : }
370 :
371 : impl NetworkBuffer {
372 71750 : fn new(last_recv: Option<u64>) -> Self {
373 71750 : Self {
374 71750 : buf: VecDeque::new(),
375 71750 : recv_closed: false,
376 71750 : send_closed: false,
377 71750 : last_recv,
378 71750 : }
379 71750 : }
380 : }
381 :
382 : /// Single end of a bidirectional network stream without reordering (TCP-like).
383 : /// Reads are implemented using channels, writes go to the buffer inside VirtualConnection.
384 : pub struct TCP {
385 : net: Arc<NetworkTask>,
386 : conn_id: usize,
387 : dir: MessageDirection,
388 : recv_chan: Chan<NetEvent>,
389 : }
390 :
391 : impl Debug for TCP {
392 911 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 911 : write!(f, "TCP #{} ({})", self.conn_id, sender_str(self.dir),)
394 911 : }
395 : }
396 :
397 : impl TCP {
398 : /// Send a message to the other side. It's guaranteed that it will not arrive
399 : /// before the arrival of all messages sent earlier.
400 69249 : pub fn send(&self, msg: AnyMessage) {
401 69249 : let conn = self.net.get(self.conn_id);
402 69249 : conn.send(&self.net, self.dir, msg);
403 69249 : }
404 :
405 : /// Get a channel to receive incoming messages.
406 453543 : pub fn recv_chan(&self) -> Chan<NetEvent> {
407 453543 : self.recv_chan.clone()
408 453543 : }
409 :
410 310765 : pub fn connection_id(&self) -> usize {
411 310765 : self.conn_id
412 310765 : }
413 :
414 4016 : pub fn close(&self) {
415 4016 : let conn = self.net.get(self.conn_id);
416 4016 : conn.close(self.dir as usize);
417 4016 : }
418 : }
419 : struct Event {
420 : time: u64,
421 : conn_id: usize,
422 : }
423 :
424 : // BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
425 : // to get that.
426 : impl PartialOrd for Event {
427 1076667 : fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
428 1076667 : Some(self.cmp(other))
429 1076667 : }
430 : }
431 :
432 : impl Ord for Event {
433 1076667 : fn cmp(&self, other: &Self) -> Ordering {
434 1076667 : (other.time, other.conn_id).cmp(&(self.time, self.conn_id))
435 1076667 : }
436 : }
437 :
438 : impl PartialEq for Event {
439 0 : fn eq(&self, other: &Self) -> bool {
440 0 : (other.time, other.conn_id) == (self.time, self.conn_id)
441 0 : }
442 : }
443 :
444 : impl Eq for Event {}
|