LCOV - code coverage report
Current view: top level - libs/vm_monitor/src - dispatcher.rs (source / functions) Coverage Total Hit
Test: 1b0a6a0c05cee5a7de360813c8034804e105ce1c.info Lines: 0.0 % 52 0
Test Date: 2025-03-12 00:01:28 Functions: 0.0 % 4 0

            Line data    Source code
       1              : //! Managing the websocket connection and other signals in the monitor.
       2              : //!
       3              : //! Contains types that manage the interaction (not data interchange, see `protocol`)
       4              : //! between agent and monitor, allowing us to to process and send messages in a
       5              : //! straightforward way. The dispatcher also manages that signals that come from
       6              : //! the cgroup (requesting upscale), and the signals that go to the cgroup
       7              : //! (notifying it of upscale).
       8              : 
       9              : use anyhow::{Context, bail};
      10              : use axum::extract::ws::{Message, Utf8Bytes, WebSocket};
      11              : use futures::stream::{SplitSink, SplitStream};
      12              : use futures::{SinkExt, StreamExt};
      13              : use tracing::{debug, info};
      14              : 
      15              : use crate::protocol::{
      16              :     OutboundMsg, OutboundMsgKind, PROTOCOL_MAX_VERSION, PROTOCOL_MIN_VERSION, ProtocolRange,
      17              :     ProtocolResponse, ProtocolVersion,
      18              : };
      19              : 
      20              : /// The central handler for all communications in the monitor.
      21              : ///
      22              : /// The dispatcher has two purposes:
      23              : /// 1. Manage the connection to the agent, sending and receiving messages.
      24              : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
      25              : ///    and sending a message to the agent when the cgroup manager requests
      26              : ///    upscale.
      27              : #[derive(Debug)]
      28              : pub struct Dispatcher {
      29              :     /// We read agent messages of of `source`
      30              :     pub(crate) source: SplitStream<WebSocket>,
      31              : 
      32              :     /// We send messages to the agent through `sink`
      33              :     sink: SplitSink<WebSocket, Message>,
      34              : 
      35              :     /// The protocol version we have agreed to use with the agent. This is negotiated
      36              :     /// during the creation of the dispatcher, and should be the highest shared protocol
      37              :     /// version.
      38              :     ///
      39              :     // NOTE: currently unused, but will almost certainly be used in the futures
      40              :     // as the protocol changes
      41              :     #[allow(unused)]
      42              :     pub(crate) proto_version: ProtocolVersion,
      43              : }
      44              : 
      45              : impl Dispatcher {
      46              :     /// Creates a new dispatcher using the passed-in connection.
      47              :     ///
      48              :     /// Performs a negotiation with the agent to determine the highest protocol
      49              :     /// version that both support. This consists of two steps:
      50              :     /// 1. Wait for the agent to sent the range of protocols it supports.
      51              :     /// 2. Send a protocol version that works for us as well, or an error if there
      52              :     ///    is no compatible version.
      53            0 :     pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
      54            0 :         let (mut sink, mut source) = stream.split();
      55            0 : 
      56            0 :         // Figure out the highest protocol version we both support
      57            0 :         info!("waiting for agent to send protocol version range");
      58            0 :         let Some(message) = source.next().await else {
      59            0 :             bail!("websocket connection closed while performing protocol handshake")
      60              :         };
      61              : 
      62            0 :         let message = message.context("failed to read protocol version range off connection")?;
      63              : 
      64            0 :         let Message::Text(message_text) = message else {
      65              :             // All messages should be in text form, since we don't do any
      66              :             // pinging/ponging. See nhooyr/websocket's implementation and the
      67              :             // agent for more info
      68            0 :             bail!("received non-text message during proocol handshake: {message:?}")
      69              :         };
      70              : 
      71            0 :         let monitor_range = ProtocolRange {
      72            0 :             min: PROTOCOL_MIN_VERSION,
      73            0 :             max: PROTOCOL_MAX_VERSION,
      74            0 :         };
      75              : 
      76            0 :         let agent_range: ProtocolRange = serde_json::from_str(&message_text)
      77            0 :             .context("failed to deserialize protocol version range")?;
      78              : 
      79            0 :         info!(range = ?agent_range, "received protocol version range");
      80              : 
      81            0 :         let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
      82            0 :             Ok(version) => {
      83            0 :                 sink.send(Message::Text(Utf8Bytes::from(
      84            0 :                     serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
      85            0 :                 )))
      86            0 :                 .await
      87            0 :                 .context("failed to notify agent of negotiated protocol version")?;
      88            0 :                 version
      89              :             }
      90            0 :             Err(e) => {
      91            0 :                 sink.send(Message::Text(Utf8Bytes::from(
      92            0 :                     serde_json::to_string(&ProtocolResponse::Error(format!(
      93            0 :                         "Received protocol version range {} which does not overlap with {}",
      94            0 :                         agent_range, monitor_range
      95            0 :                     )))
      96            0 :                     .unwrap(),
      97            0 :                 )))
      98            0 :                 .await
      99            0 :                 .context("failed to notify agent of no overlap between protocol version ranges")?;
     100            0 :                 Err(e).context("error determining suitable protocol version range")?
     101              :             }
     102              :         };
     103              : 
     104            0 :         Ok(Self {
     105            0 :             sink,
     106            0 :             source,
     107            0 :             proto_version: highest_shared_version,
     108            0 :         })
     109            0 :     }
     110              : 
     111              :     /// Send a message to the agent.
     112              :     ///
     113              :     /// Although this function is small, it has one major benefit: it is the only
     114              :     /// way to send data accross the connection, and you can only pass in a proper
     115              :     /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
     116              :     /// serialize the wrong thing and send it, since `self.sink.send` will take
     117              :     /// any string.
     118            0 :     pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
     119            0 :         if matches!(&message.inner, OutboundMsgKind::HealthCheck { .. }) {
     120            0 :             debug!(?message, "sending message");
     121              :         } else {
     122            0 :             info!(?message, "sending message");
     123              :         }
     124              : 
     125            0 :         let json = serde_json::to_string(&message).context("failed to serialize message")?;
     126            0 :         self.sink
     127            0 :             .send(Message::Text(Utf8Bytes::from(json)))
     128            0 :             .await
     129            0 :             .context("stream error sending message")
     130            0 :     }
     131              : }
        

Generated by: LCOV version 2.1-beta