LCOV - code coverage report
Current view: top level - libs/vm_monitor/src - dispatcher.rs (source / functions) Coverage Total Hit
Test: 8ff8efadb0253cf618c612650348666c0c564111.info Lines: 0.0 % 52 0
Test Date: 2024-11-20 17:53:50 Functions: 0.0 % 4 0

            Line data    Source code
       1              : //! Managing the websocket connection and other signals in the monitor.
       2              : //!
       3              : //! Contains types that manage the interaction (not data interchange, see `protocol`)
       4              : //! between agent and monitor, allowing us to to process and send messages in a
       5              : //! straightforward way. The dispatcher also manages that signals that come from
       6              : //! the cgroup (requesting upscale), and the signals that go to the cgroup
       7              : //! (notifying it of upscale).
       8              : 
       9              : use anyhow::{bail, Context};
      10              : use axum::extract::ws::{Message, WebSocket};
      11              : use futures::{
      12              :     stream::{SplitSink, SplitStream},
      13              :     SinkExt, StreamExt,
      14              : };
      15              : use tracing::{debug, info};
      16              : 
      17              : use crate::protocol::{
      18              :     OutboundMsg, OutboundMsgKind, ProtocolRange, ProtocolResponse, ProtocolVersion,
      19              :     PROTOCOL_MAX_VERSION, PROTOCOL_MIN_VERSION,
      20              : };
      21              : 
      22              : /// The central handler for all communications in the monitor.
      23              : ///
      24              : /// The dispatcher has two purposes:
      25              : /// 1. Manage the connection to the agent, sending and receiving messages.
      26              : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
      27              : ///    and sending a message to the agent when the cgroup manager requests
      28              : ///    upscale.
      29              : #[derive(Debug)]
      30              : pub struct Dispatcher {
      31              :     /// We read agent messages of of `source`
      32              :     pub(crate) source: SplitStream<WebSocket>,
      33              : 
      34              :     /// We send messages to the agent through `sink`
      35              :     sink: SplitSink<WebSocket, Message>,
      36              : 
      37              :     /// The protocol version we have agreed to use with the agent. This is negotiated
      38              :     /// during the creation of the dispatcher, and should be the highest shared protocol
      39              :     /// version.
      40              :     ///
      41              :     // NOTE: currently unused, but will almost certainly be used in the futures
      42              :     // as the protocol changes
      43              :     #[allow(unused)]
      44              :     pub(crate) proto_version: ProtocolVersion,
      45              : }
      46              : 
      47              : impl Dispatcher {
      48              :     /// Creates a new dispatcher using the passed-in connection.
      49              :     ///
      50              :     /// Performs a negotiation with the agent to determine the highest protocol
      51              :     /// version that both support. This consists of two steps:
      52              :     /// 1. Wait for the agent to sent the range of protocols it supports.
      53              :     /// 2. Send a protocol version that works for us as well, or an error if there
      54              :     ///    is no compatible version.
      55            0 :     pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
      56            0 :         let (mut sink, mut source) = stream.split();
      57            0 : 
      58            0 :         // Figure out the highest protocol version we both support
      59            0 :         info!("waiting for agent to send protocol version range");
      60            0 :         let Some(message) = source.next().await else {
      61            0 :             bail!("websocket connection closed while performing protocol handshake")
      62              :         };
      63              : 
      64            0 :         let message = message.context("failed to read protocol version range off connection")?;
      65              : 
      66            0 :         let Message::Text(message_text) = message else {
      67              :             // All messages should be in text form, since we don't do any
      68              :             // pinging/ponging. See nhooyr/websocket's implementation and the
      69              :             // agent for more info
      70            0 :             bail!("received non-text message during proocol handshake: {message:?}")
      71              :         };
      72              : 
      73            0 :         let monitor_range = ProtocolRange {
      74            0 :             min: PROTOCOL_MIN_VERSION,
      75            0 :             max: PROTOCOL_MAX_VERSION,
      76            0 :         };
      77              : 
      78            0 :         let agent_range: ProtocolRange = serde_json::from_str(&message_text)
      79            0 :             .context("failed to deserialize protocol version range")?;
      80              : 
      81            0 :         info!(range = ?agent_range, "received protocol version range");
      82              : 
      83            0 :         let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
      84            0 :             Ok(version) => {
      85            0 :                 sink.send(Message::Text(
      86            0 :                     serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
      87            0 :                 ))
      88            0 :                 .await
      89            0 :                 .context("failed to notify agent of negotiated protocol version")?;
      90            0 :                 version
      91              :             }
      92            0 :             Err(e) => {
      93            0 :                 sink.send(Message::Text(
      94            0 :                     serde_json::to_string(&ProtocolResponse::Error(format!(
      95            0 :                         "Received protocol version range {} which does not overlap with {}",
      96            0 :                         agent_range, monitor_range
      97            0 :                     )))
      98            0 :                     .unwrap(),
      99            0 :                 ))
     100            0 :                 .await
     101            0 :                 .context("failed to notify agent of no overlap between protocol version ranges")?;
     102            0 :                 Err(e).context("error determining suitable protocol version range")?
     103              :             }
     104              :         };
     105              : 
     106            0 :         Ok(Self {
     107            0 :             sink,
     108            0 :             source,
     109            0 :             proto_version: highest_shared_version,
     110            0 :         })
     111            0 :     }
     112              : 
     113              :     /// Send a message to the agent.
     114              :     ///
     115              :     /// Although this function is small, it has one major benefit: it is the only
     116              :     /// way to send data accross the connection, and you can only pass in a proper
     117              :     /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
     118              :     /// serialize the wrong thing and send it, since `self.sink.send` will take
     119              :     /// any string.
     120            0 :     pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
     121            0 :         if matches!(&message.inner, OutboundMsgKind::HealthCheck { .. }) {
     122            0 :             debug!(?message, "sending message");
     123              :         } else {
     124            0 :             info!(?message, "sending message");
     125              :         }
     126              : 
     127            0 :         let json = serde_json::to_string(&message).context("failed to serialize message")?;
     128            0 :         self.sink
     129            0 :             .send(Message::Text(json))
     130            0 :             .await
     131            0 :             .context("stream error sending message")
     132            0 :     }
     133              : }
        

Generated by: LCOV version 2.1-beta