LCOV - code coverage report
Current view: top level - libs/vm_monitor/src - dispatcher.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 0.0 % 56 0
Test Date: 2023-09-06 10:18:01 Functions: 0.0 % 11 0

            Line data    Source code
       1              : //! Managing the websocket connection and other signals in the monitor.
       2              : //!
       3              : //! Contains types that manage the interaction (not data interchange, see `protocol`)
       4              : //! between agent and monitor, allowing us to to process and send messages in a
       5              : //! straightforward way. The dispatcher also manages that signals that come from
       6              : //! the cgroup (requesting upscale), and the signals that go to the cgroup
       7              : //! (notifying it of upscale).
       8              : 
       9              : use anyhow::{bail, Context};
      10              : use axum::extract::ws::{Message, WebSocket};
      11              : use futures::{
      12              :     stream::{SplitSink, SplitStream},
      13              :     SinkExt, StreamExt,
      14              : };
      15              : use tokio::sync::mpsc;
      16              : use tracing::info;
      17              : 
      18              : use crate::cgroup::Sequenced;
      19              : use crate::protocol::{
      20              :     OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, Resources, PROTOCOL_MAX_VERSION,
      21              :     PROTOCOL_MIN_VERSION,
      22              : };
      23              : 
      24              : /// The central handler for all communications in the monitor.
      25              : ///
      26              : /// The dispatcher has two purposes:
      27              : /// 1. Manage the connection to the agent, sending and receiving messages.
      28              : /// 2. Communicate with the cgroup manager, notifying it when upscale is received,
      29              : ///    and sending a message to the agent when the cgroup manager requests
      30              : ///    upscale.
      31            0 : #[derive(Debug)]
      32              : pub struct Dispatcher {
      33              :     /// We read agent messages of of `source`
      34              :     pub(crate) source: SplitStream<WebSocket>,
      35              : 
      36              :     /// We send messages to the agent through `sink`
      37              :     sink: SplitSink<WebSocket, Message>,
      38              : 
      39              :     /// Used to notify the cgroup when we are upscaled.
      40              :     pub(crate) notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
      41              : 
      42              :     /// When the cgroup requests upscale it will send on this channel. In response
      43              :     /// we send an `UpscaleRequst` to the agent.
      44              :     pub(crate) request_upscale_events: mpsc::Receiver<()>,
      45              : 
      46              :     /// The protocol version we have agreed to use with the agent. This is negotiated
      47              :     /// during the creation of the dispatcher, and should be the highest shared protocol
      48              :     /// version.
      49              :     ///
      50              :     // NOTE: currently unused, but will almost certainly be used in the futures
      51              :     // as the protocol changes
      52              :     #[allow(unused)]
      53              :     pub(crate) proto_version: ProtocolVersion,
      54              : }
      55              : 
      56              : impl Dispatcher {
      57              :     /// Creates a new dispatcher using the passed-in connection.
      58              :     ///
      59              :     /// Performs a negotiation with the agent to determine the highest protocol
      60              :     /// version that both support. This consists of two steps:
      61              :     /// 1. Wait for the agent to sent the range of protocols it supports.
      62              :     /// 2. Send a protocol version that works for us as well, or an error if there
      63              :     ///    is no compatible version.
      64            0 :     pub async fn new(
      65            0 :         stream: WebSocket,
      66            0 :         notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
      67            0 :         request_upscale_events: mpsc::Receiver<()>,
      68            0 :     ) -> anyhow::Result<Self> {
      69            0 :         let (mut sink, mut source) = stream.split();
      70              : 
      71              :         // Figure out the highest protocol version we both support
      72            0 :         info!("waiting for agent to send protocol version range");
      73            0 :         let Some(message) = source.next().await else {
      74            0 :             bail!("websocket connection closed while performing protocol handshake")
      75              :         };
      76              : 
      77            0 :         let message = message.context("failed to read protocol version range off connection")?;
      78              : 
      79            0 :         let Message::Text(message_text) = message else {
      80              :             // All messages should be in text form, since we don't do any
      81              :             // pinging/ponging. See nhooyr/websocket's implementation and the
      82              :             // agent for more info
      83            0 :             bail!("received non-text message during proocol handshake: {message:?}")
      84              :         };
      85              : 
      86            0 :         let monitor_range = ProtocolRange {
      87            0 :             min: PROTOCOL_MIN_VERSION,
      88            0 :             max: PROTOCOL_MAX_VERSION,
      89            0 :         };
      90              : 
      91            0 :         let agent_range: ProtocolRange = serde_json::from_str(&message_text)
      92            0 :             .context("failed to deserialize protocol version range")?;
      93              : 
      94            0 :         info!(range = ?agent_range, "received protocol version range");
      95              : 
      96            0 :         let highest_shared_version = match monitor_range.highest_shared_version(&agent_range) {
      97            0 :             Ok(version) => {
      98            0 :                 sink.send(Message::Text(
      99            0 :                     serde_json::to_string(&ProtocolResponse::Version(version)).unwrap(),
     100            0 :                 ))
     101            0 :                 .await
     102            0 :                 .context("failed to notify agent of negotiated protocol version")?;
     103            0 :                 version
     104              :             }
     105            0 :             Err(e) => {
     106            0 :                 sink.send(Message::Text(
     107            0 :                     serde_json::to_string(&ProtocolResponse::Error(format!(
     108            0 :                         "Received protocol version range {} which does not overlap with {}",
     109            0 :                         agent_range, monitor_range
     110            0 :                     )))
     111            0 :                     .unwrap(),
     112            0 :                 ))
     113            0 :                 .await
     114            0 :                 .context("failed to notify agent of no overlap between protocol version ranges")?;
     115            0 :                 Err(e).context("error determining suitable protocol version range")?
     116              :             }
     117              :         };
     118              : 
     119            0 :         Ok(Self {
     120            0 :             sink,
     121            0 :             source,
     122            0 :             notify_upscale_events,
     123            0 :             request_upscale_events,
     124            0 :             proto_version: highest_shared_version,
     125            0 :         })
     126            0 :     }
     127              : 
     128              :     /// Notify the cgroup manager that we have received upscale and wait for
     129              :     /// the acknowledgement.
     130            0 :     #[tracing::instrument(skip_all, fields(?resources))]
     131              :     pub async fn notify_upscale(&self, resources: Sequenced<Resources>) -> anyhow::Result<()> {
     132              :         self.notify_upscale_events
     133              :             .send(resources)
     134              :             .await
     135              :             .context("failed to send resources and oneshot sender across channel")
     136              :     }
     137              : 
     138              :     /// Send a message to the agent.
     139              :     ///
     140              :     /// Although this function is small, it has one major benefit: it is the only
     141              :     /// way to send data accross the connection, and you can only pass in a proper
     142              :     /// `MonitorMessage`. Without safeguards like this, it's easy to accidentally
     143              :     /// serialize the wrong thing and send it, since `self.sink.send` will take
     144              :     /// any string.
     145            0 :     pub async fn send(&mut self, message: OutboundMsg) -> anyhow::Result<()> {
     146            0 :         info!(?message, "sending message");
     147            0 :         let json = serde_json::to_string(&message).context("failed to serialize message")?;
     148            0 :         self.sink
     149            0 :             .send(Message::Text(json))
     150            0 :             .await
     151            0 :             .context("stream error sending message")
     152            0 :     }
     153              : }
        

Generated by: LCOV version 2.1-beta