Line data Source code
1 : //! Managing the websocket connection and other signals in the monitor.
2 : //!
3 : //! Contains types that manage the interaction (not data interchange, see `protocol`)
4 : //! between agent and monitor, allowing us to to process and send messages in a
5 : //! straightforward way. The dispatcher also manages that signals that come from
6 : //! the cgroup (requesting upscale), and the signals that go to the cgroup
7 : //! (notifying it of upscale).
8 :
9 : use anyhow::{bail, Context};
10 : use axum::extract::ws::{Message, WebSocket};
11 : use futures::{
12 : stream::{SplitSink, SplitStream},
13 : SinkExt, StreamExt,
14 : };
15 : use tokio::sync::mpsc;
16 : use tracing::info;
17 :
18 : use crate::cgroup::Sequenced;
19 : use crate::protocol::{
20 : OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, Resources, PROTOCOL_MAX_VERSION,
21 : PROTOCOL_MIN_VERSION,
22 : };
23 :
24 : /// The central handler for all communications in the monitor.
25 : ///
26 : /// The dispatcher has two purposes:
27 : /// 1. Manage the connection to the agent, sending and receiving messages.
28 : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
29 : /// and sending a message to the agent when the cgroup manager requests
30 : /// upscale.
31 0 : #[derive(Debug)]
32 : pub struct Dispatcher {
33 : /// We read agent messages of of `source`
34 : pub(crate) source: SplitStream<WebSocket>,
35 :
36 : /// We send messages to the agent through `sink`
37 : sink: SplitSink<WebSocket, Message>,
38 :
39 : /// Used to notify the cgroup when we are upscaled.
40 : pub(crate) notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
41 :
42 : /// When the cgroup requests upscale it will send on this channel. In response
43 : /// we send an `UpscaleRequst` to the agent.
44 : pub(crate) request_upscale_events: mpsc::Receiver<()>,
45 :
46 : /// The protocol version we have agreed to use with the agent. This is negotiated
47 : /// during the creation of the dispatcher, and should be the highest shared protocol
48 : /// version.
49 : ///
50 : // NOTE: currently unused, but will almost certainly be used in the futures
51 : // as the protocol changes
52 : #[allow(unused)]
53 : pub(crate) proto_version: ProtocolVersion,
54 : }
55 :
56 : impl Dispatcher {
57 : /// Creates a new dispatcher using the passed-in connection.
58 : ///
59 : /// Performs a negotiation with the agent to determine the highest protocol
60 : /// version that both support. This consists of two steps:
61 : /// 1. Wait for the agent to sent the range of protocols it supports.
62 : /// 2. Send a protocol version that works for us as well, or an error if there
63 : /// is no compatible version.
64 0 : pub async fn new(
65 0 : stream: WebSocket,
66 0 : notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
67 0 : request_upscale_events: mpsc::Receiver<()>,
68 0 : ) -> anyhow::Result<Self> {
69 0 : let (mut sink, mut source) = stream.split();
70 :
71 : // Figure out the highest protocol version we both support
72 0 : info!("waiting for agent to send protocol version range");
73 0 : let Some(message) = source.next().await else {
74 0 : bail!("websocket connection closed while performing protocol handshake")
75 : };
76 :
77 0 : let message = message.context("failed to read protocol version range off connection")?;
78 :
79 0 : let Message::Text(message_text) = message else {
80 : // All messages should be in text form, since we don't do any
81 : // pinging/ponging. See nhooyr/websocket's implementation and the
82 : // agent for more info
83 0 : bail!("received non-text message during proocol handshake: {message:?}")
84 : };
85 :
86 0 : let monitor_range = ProtocolRange {
87 0 : min: PROTOCOL_MIN_VERSION,
88 0 : max: PROTOCOL_MAX_VERSION,
89 0 : };
90 :
91 0 : let agent_range: ProtocolRange = serde_json::from_str(&message_text)
92 0 : .context("failed to deserialize protocol version range")?;
93 :
94 0 : info!(range = ?agent_range, "received protocol version range");
95 :
96 0 : let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
97 0 : Ok(version) => {
98 0 : sink.send(Message::Text(
99 0 : serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
100 0 : ))
101 0 : .await
102 0 : .context("failed to notify agent of negotiated protocol version")?;
103 0 : version
104 : }
105 0 : Err(e) => {
106 0 : sink.send(Message::Text(
107 0 : serde_json::to_string(&ProtocolResponse::Error(format!(
108 0 : "Received protocol version range {} which does not overlap with {}",
109 0 : agent_range, monitor_range
110 0 : )))
111 0 : .unwrap(),
112 0 : ))
113 0 : .await
114 0 : .context("failed to notify agent of no overlap between protocol version ranges")?;
115 0 : Err(e).context("error determining suitable protocol version range")?
116 : }
117 : };
118 :
119 0 : Ok(Self {
120 0 : sink,
121 0 : source,
122 0 : notify_upscale_events,
123 0 : request_upscale_events,
124 0 : proto_version: highest_shared_version,
125 0 : })
126 0 : }
127 :
128 : /// Notify the cgroup manager that we have received upscale and wait for
129 : /// the acknowledgement.
130 0 : #[tracing::instrument(skip_all, fields(?resources))]
131 : pub async fn notify_upscale(&self, resources: Sequenced<Resources>) -> anyhow::Result<()> {
132 : self.notify_upscale_events
133 : .send(resources)
134 : .await
135 : .context("failed to send resources and oneshot sender across channel")
136 : }
137 :
138 : /// Send a message to the agent.
139 : ///
140 : /// Although this function is small, it has one major benefit: it is the only
141 : /// way to send data accross the connection, and you can only pass in a proper
142 : /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
143 : /// serialize the wrong thing and send it, since `self.sink.send` will take
144 : /// any string.
145 0 : pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
146 0 : info!(?message, "sending message");
147 0 : let json = serde_json::to_string(&message).context("failed to serialize message")?;
148 0 : self.sink
149 0 : .send(Message::Text(json))
150 0 : .await
151 0 : .context("stream error sending message")
152 0 : }
153 : }
|