Line data Source code
1 : //! Managing the websocket connection and other signals in the monitor.
2 : //!
3 : //! Contains types that manage the interaction (not data interchange, see `protocol`)
4 : //! between agent and monitor, allowing us to to process and send messages in a
5 : //! straightforward way. The dispatcher also manages that signals that come from
6 : //! the cgroup (requesting upscale), and the signals that go to the cgroup
7 : //! (notifying it of upscale).
8 :
9 : use anyhow::{bail, Context};
10 : use axum::extract::ws::{Message, WebSocket};
11 : use futures::{
12 : stream::{SplitSink, SplitStream},
13 : SinkExt, StreamExt,
14 : };
15 : use tracing::{debug, info};
16 :
17 : use crate::protocol::{
18 : OutboundMsg, OutboundMsgKind, ProtocolRange, ProtocolResponse, ProtocolVersion,
19 : PROTOCOL_MAX_VERSION, PROTOCOL_MIN_VERSION,
20 : };
21 :
22 : /// The central handler for all communications in the monitor.
23 : ///
24 : /// The dispatcher has two purposes:
25 : /// 1. Manage the connection to the agent, sending and receiving messages.
26 : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
27 : /// and sending a message to the agent when the cgroup manager requests
28 : /// upscale.
29 : #[derive(Debug)]
30 : pub struct Dispatcher {
31 : /// We read agent messages of of `source`
32 : pub(crate) source: SplitStream<WebSocket>,
33 :
34 : /// We send messages to the agent through `sink`
35 : sink: SplitSink<WebSocket, Message>,
36 :
37 : /// The protocol version we have agreed to use with the agent. This is negotiated
38 : /// during the creation of the dispatcher, and should be the highest shared protocol
39 : /// version.
40 : ///
41 : // NOTE: currently unused, but will almost certainly be used in the futures
42 : // as the protocol changes
43 : #[allow(unused)]
44 : pub(crate) proto_version: ProtocolVersion,
45 : }
46 :
47 : impl Dispatcher {
48 : /// Creates a new dispatcher using the passed-in connection.
49 : ///
50 : /// Performs a negotiation with the agent to determine the highest protocol
51 : /// version that both support. This consists of two steps:
52 : /// 1. Wait for the agent to sent the range of protocols it supports.
53 : /// 2. Send a protocol version that works for us as well, or an error if there
54 : /// is no compatible version.
55 0 : pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
56 0 : let (mut sink, mut source) = stream.split();
57 0 :
58 0 : // Figure out the highest protocol version we both support
59 0 : info!("waiting for agent to send protocol version range");
60 0 : let Some(message) = source.next().await else {
61 0 : bail!("websocket connection closed while performing protocol handshake")
62 : };
63 :
64 0 : let message = message.context("failed to read protocol version range off connection")?;
65 :
66 0 : let Message::Text(message_text) = message else {
67 : // All messages should be in text form, since we don't do any
68 : // pinging/ponging. See nhooyr/websocket's implementation and the
69 : // agent for more info
70 0 : bail!("received non-text message during proocol handshake: {message:?}")
71 : };
72 :
73 0 : let monitor_range = ProtocolRange {
74 0 : min: PROTOCOL_MIN_VERSION,
75 0 : max: PROTOCOL_MAX_VERSION,
76 0 : };
77 :
78 0 : let agent_range: ProtocolRange = serde_json::from_str(&message_text)
79 0 : .context("failed to deserialize protocol version range")?;
80 :
81 0 : info!(range = ?agent_range, "received protocol version range");
82 :
83 0 : let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
84 0 : Ok(version) => {
85 0 : sink.send(Message::Text(
86 0 : serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
87 0 : ))
88 0 : .await
89 0 : .context("failed to notify agent of negotiated protocol version")?;
90 0 : version
91 : }
92 0 : Err(e) => {
93 0 : sink.send(Message::Text(
94 0 : serde_json::to_string(&ProtocolResponse::Error(format!(
95 0 : "Received protocol version range {} which does not overlap with {}",
96 0 : agent_range, monitor_range
97 0 : )))
98 0 : .unwrap(),
99 0 : ))
100 0 : .await
101 0 : .context("failed to notify agent of no overlap between protocol version ranges")?;
102 0 : Err(e).context("error determining suitable protocol version range")?
103 : }
104 : };
105 :
106 0 : Ok(Self {
107 0 : sink,
108 0 : source,
109 0 : proto_version: highest_shared_version,
110 0 : })
111 0 : }
112 :
113 : /// Send a message to the agent.
114 : ///
115 : /// Although this function is small, it has one major benefit: it is the only
116 : /// way to send data accross the connection, and you can only pass in a proper
117 : /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
118 : /// serialize the wrong thing and send it, since `self.sink.send` will take
119 : /// any string.
120 0 : pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
121 0 : if matches!(&message.inner, OutboundMsgKind::HealthCheck { .. }) {
122 0 : debug!(?message, "sending message");
123 : } else {
124 0 : info!(?message, "sending message");
125 : }
126 :
127 0 : let json = serde_json::to_string(&message).context("failed to serialize message")?;
128 0 : self.sink
129 0 : .send(Message::Text(json))
130 0 : .await
131 0 : .context("stream error sending message")
132 0 : }
133 : }
|