Line data Source code
1 : //! Managing the websocket connection and other signals in the monitor.
2 : //!
3 : //! Contains types that manage the interaction (not data interchange, see `protocol`)
4 : //! between agent and monitor, allowing us to to process and send messages in a
5 : //! straightforward way. The dispatcher also manages that signals that come from
6 : //! the cgroup (requesting upscale), and the signals that go to the cgroup
7 : //! (notifying it of upscale).
8 :
9 : use anyhow::{Context, bail};
10 : use axum::extract::ws::{Message, Utf8Bytes, WebSocket};
11 : use futures::stream::{SplitSink, SplitStream};
12 : use futures::{SinkExt, StreamExt};
13 : use tracing::{debug, info};
14 :
15 : use crate::protocol::{
16 : OutboundMsg, OutboundMsgKind, PROTOCOL_MAX_VERSION, PROTOCOL_MIN_VERSION, ProtocolRange,
17 : ProtocolResponse, ProtocolVersion,
18 : };
19 :
20 : /// The central handler for all communications in the monitor.
21 : ///
22 : /// The dispatcher has two purposes:
23 : /// 1. Manage the connection to the agent, sending and receiving messages.
24 : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
25 : /// and sending a message to the agent when the cgroup manager requests
26 : /// upscale.
27 : #[derive(Debug)]
28 : pub struct Dispatcher {
29 : /// We read agent messages of of `source`
30 : pub(crate) source: SplitStream<WebSocket>,
31 :
32 : /// We send messages to the agent through `sink`
33 : sink: SplitSink<WebSocket, Message>,
34 :
35 : /// The protocol version we have agreed to use with the agent. This is negotiated
36 : /// during the creation of the dispatcher, and should be the highest shared protocol
37 : /// version.
38 : ///
39 : // NOTE: currently unused, but will almost certainly be used in the futures
40 : // as the protocol changes
41 : #[allow(unused)]
42 : pub(crate) proto_version: ProtocolVersion,
43 : }
44 :
45 : impl Dispatcher {
46 : /// Creates a new dispatcher using the passed-in connection.
47 : ///
48 : /// Performs a negotiation with the agent to determine the highest protocol
49 : /// version that both support. This consists of two steps:
50 : /// 1. Wait for the agent to sent the range of protocols it supports.
51 : /// 2. Send a protocol version that works for us as well, or an error if there
52 : /// is no compatible version.
53 0 : pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
54 0 : let (mut sink, mut source) = stream.split();
55 0 :
56 0 : // Figure out the highest protocol version we both support
57 0 : info!("waiting for agent to send protocol version range");
58 0 : let Some(message) = source.next().await else {
59 0 : bail!("websocket connection closed while performing protocol handshake")
60 : };
61 :
62 0 : let message = message.context("failed to read protocol version range off connection")?;
63 :
64 0 : let Message::Text(message_text) = message else {
65 : // All messages should be in text form, since we don't do any
66 : // pinging/ponging. See nhooyr/websocket's implementation and the
67 : // agent for more info
68 0 : bail!("received non-text message during proocol handshake: {message:?}")
69 : };
70 :
71 0 : let monitor_range = ProtocolRange {
72 0 : min: PROTOCOL_MIN_VERSION,
73 0 : max: PROTOCOL_MAX_VERSION,
74 0 : };
75 :
76 0 : let agent_range: ProtocolRange = serde_json::from_str(&message_text)
77 0 : .context("failed to deserialize protocol version range")?;
78 :
79 0 : info!(range = ?agent_range, "received protocol version range");
80 :
81 0 : let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
82 0 : Ok(version) => {
83 0 : sink.send(Message::Text(Utf8Bytes::from(
84 0 : serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
85 0 : )))
86 0 : .await
87 0 : .context("failed to notify agent of negotiated protocol version")?;
88 0 : version
89 : }
90 0 : Err(e) => {
91 0 : sink.send(Message::Text(Utf8Bytes::from(
92 0 : serde_json::to_string(&ProtocolResponse::Error(format!(
93 0 : "Received protocol version range {} which does not overlap with {}",
94 0 : agent_range, monitor_range
95 0 : )))
96 0 : .unwrap(),
97 0 : )))
98 0 : .await
99 0 : .context("failed to notify agent of no overlap between protocol version ranges")?;
100 0 : Err(e).context("error determining suitable protocol version range")?
101 : }
102 : };
103 :
104 0 : Ok(Self {
105 0 : sink,
106 0 : source,
107 0 : proto_version: highest_shared_version,
108 0 : })
109 0 : }
110 :
111 : /// Send a message to the agent.
112 : ///
113 : /// Although this function is small, it has one major benefit: it is the only
114 : /// way to send data accross the connection, and you can only pass in a proper
115 : /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
116 : /// serialize the wrong thing and send it, since `self.sink.send` will take
117 : /// any string.
118 0 : pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
119 0 : if matches!(&message.inner, OutboundMsgKind::HealthCheck { .. }) {
120 0 : debug!(?message, "sending message");
121 : } else {
122 0 : info!(?message, "sending message");
123 : }
124 :
125 0 : let json = serde_json::to_string(&message).context("failed to serialize message")?;
126 0 : self.sink
127 0 : .send(Message::Text(Utf8Bytes::from(json)))
128 0 : .await
129 0 : .context("stream error sending message")
130 0 : }
131 : }
|