LCOV - code coverage report
Current view: top level - proxy/src/console - mgmt.rs (source / functions) Coverage Total Hit
Test: 8ac049b474321fdc72ddcb56d7165153a1a900e8.info Lines: 86.6 % 67 58
Test Date: 2023-09-06 10:18:01 Functions: 63.0 % 27 17

            Line data    Source code
       1              : use crate::{
       2              :     console::messages::{DatabaseInfo, KickSession},
       3              :     waiters::{self, Waiter, Waiters},
       4              : };
       5              : use anyhow::Context;
       6              : use once_cell::sync::Lazy;
       7              : use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
       8              : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
       9              : use std::{convert::Infallible, future};
      10              : use tokio::net::{TcpListener, TcpStream};
      11              : use tracing::{error, info, info_span, Instrument};
      12              : 
      13              : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
      14              : 
      15              : /// Give caller an opportunity to wait for the cloud's reply.
      16            3 : pub async fn with_waiter<R, T, E>(
      17            3 :     psql_session_id: impl Into<String>,
      18            3 :     action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
      19            3 : ) -> Result<T, E>
      20            3 : where
      21            3 :     R: std::future::Future<Output = Result<T, E>>,
      22            3 :     E: From<waiters::RegisterError>,
      23            3 : {
      24            3 :     let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
      25            3 :     action(waiter).await
      26            3 : }
      27              : 
      28            3 : pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
      29            3 :     CPLANE_WAITERS.notify(psql_session_id, msg)
      30            3 : }
      31              : 
      32              : /// Console management API listener task.
      33              : /// It spawns console response handlers needed for the link auth.
      34           14 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
      35           14 :     scopeguard::defer! {
      36           14 :         info!("mgmt has shut down");
      37           14 :     }
      38              : 
      39              :     loop {
      40           17 :         let (socket, peer_addr) = listener.accept().await?;
      41            3 :         info!("accepted connection from {peer_addr}");
      42              : 
      43            3 :         socket
      44            3 :             .set_nodelay(true)
      45            3 :             .context("failed to set client socket option")?;
      46              : 
      47            3 :         let span = info_span!("mgmt", peer = %peer_addr);
      48              : 
      49            3 :         tokio::task::spawn(
      50            3 :             async move {
      51            3 :                 info!("serving a new console management API connection");
      52              : 
      53              :                 // these might be long running connections, have a separate logging for cancelling
      54              :                 // on shutdown and other ways of stopping.
      55            3 :                 let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
      56            0 :                     let _e = span.entered();
      57            0 :                     info!("console management API task cancelled");
      58            3 :                 });
      59              : 
      60           12 :                 if let Err(e) = handle_connection(socket).await {
      61            0 :                     error!("serving failed with an error: {e}");
      62              :                 } else {
      63            3 :                     info!("serving completed");
      64              :                 }
      65              : 
      66              :                 // we can no longer get dropped
      67            3 :                 scopeguard::ScopeGuard::into_inner(cancelled);
      68            3 :             }
      69            3 :             .instrument(span),
      70            3 :         );
      71              :     }
      72            0 : }
      73              : 
      74            3 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
      75            3 :     let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
      76           12 :     pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
      77            3 : }
      78              : 
      79              : /// A message received by `mgmt` when a compute node is ready.
      80              : pub type ComputeReady = Result<DatabaseInfo, String>;
      81              : 
      82              : // TODO: replace with an http-based protocol.
      83              : struct MgmtHandler;
      84              : #[async_trait::async_trait]
      85              : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
      86            3 :     async fn process_query(
      87            3 :         &mut self,
      88            3 :         pgb: &mut PostgresBackendTCP,
      89            3 :         query: &str,
      90            3 :     ) -> Result<(), QueryError> {
      91            3 :         try_process_query(pgb, query).map_err(|e| {
      92            0 :             error!("failed to process response: {e:?}");
      93            0 :             e
      94            3 :         })
      95            3 :     }
      96              : }
      97              : 
      98            3 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
      99            3 :     let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
     100              : 
     101            3 :     let span = info_span!("event", session_id = resp.session_id);
     102            3 :     let _enter = span.enter();
     103            3 :     info!("got response: {:?}", resp.result);
     104              : 
     105            3 :     match notify(resp.session_id, Ok(resp.result)) {
     106              :         Ok(()) => {
     107            3 :             pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
     108            3 :                 .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
     109            3 :                 .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     110              :         }
     111            0 :         Err(e) => {
     112            0 :             error!("failed to deliver response to per-client task");
     113            0 :             pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
     114              :         }
     115              :     }
     116              : 
     117            3 :     Ok(())
     118            3 : }
        

Generated by: LCOV version 2.1-beta