LCOV - differential code coverage report
Current view: top level - proxy/src/console - mgmt.rs (source / functions) Coverage Total Hit UBC CBC
Current: f6946e90941b557c917ac98cd5a7e9506d180f3e.info Lines: 86.6 % 67 58 9 58
Current Date: 2023-10-19 02:04:12 Functions: 63.0 % 27 17 10 17
Baseline: c8637f37369098875162f194f92736355783b050.info
Baseline Date: 2023-10-18 20:25:20

           TLA  Line data    Source code
       1                 : use crate::{
       2                 :     console::messages::{DatabaseInfo, KickSession},
       3                 :     waiters::{self, Waiter, Waiters},
       4                 : };
       5                 : use anyhow::Context;
       6                 : use once_cell::sync::Lazy;
       7                 : use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
       8                 : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
       9                 : use std::{convert::Infallible, future};
      10                 : use tokio::net::{TcpListener, TcpStream};
      11                 : use tracing::{error, info, info_span, Instrument};
      12                 : 
      13                 : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
      14                 : 
      15                 : /// Give caller an opportunity to wait for the cloud's reply.
      16 CBC           3 : pub async fn with_waiter<R, T, E>(
      17               3 :     psql_session_id: impl Into<String>,
      18               3 :     action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
      19               3 : ) -> Result<T, E>
      20               3 : where
      21               3 :     R: std::future::Future<Output = Result<T, E>>,
      22               3 :     E: From<waiters::RegisterError>,
      23               3 : {
      24               3 :     let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
      25               3 :     action(waiter).await
      26               3 : }
      27                 : 
      28               3 : pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
      29               3 :     CPLANE_WAITERS.notify(psql_session_id, msg)
      30               3 : }
      31                 : 
      32                 : /// Console management API listener task.
      33                 : /// It spawns console response handlers needed for the link auth.
      34              16 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
      35              16 :     scopeguard::defer! {
      36              16 :         info!("mgmt has shut down");
      37              16 :     }
      38                 : 
      39                 :     loop {
      40              19 :         let (socket, peer_addr) = listener.accept().await?;
      41               3 :         info!("accepted connection from {peer_addr}");
      42                 : 
      43               3 :         socket
      44               3 :             .set_nodelay(true)
      45               3 :             .context("failed to set client socket option")?;
      46                 : 
      47               3 :         let span = info_span!("mgmt", peer = %peer_addr);
      48                 : 
      49               3 :         tokio::task::spawn(
      50               3 :             async move {
      51               3 :                 info!("serving a new console management API connection");
      52                 : 
      53                 :                 // these might be long running connections, have a separate logging for cancelling
      54                 :                 // on shutdown and other ways of stopping.
      55               3 :                 let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
      56 UBC           0 :                     let _e = span.entered();
      57               0 :                     info!("console management API task cancelled");
      58 CBC           3 :                 });
      59                 : 
      60              12 :                 if let Err(e) = handle_connection(socket).await {
      61 UBC           0 :                     error!("serving failed with an error: {e}");
      62                 :                 } else {
      63 CBC           3 :                     info!("serving completed");
      64                 :                 }
      65                 : 
      66                 :                 // we can no longer get dropped
      67               3 :                 scopeguard::ScopeGuard::into_inner(cancelled);
      68               3 :             }
      69               3 :             .instrument(span),
      70               3 :         );
      71                 :     }
      72 UBC           0 : }
      73                 : 
      74 CBC           3 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
      75               3 :     let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
      76              12 :     pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
      77               3 : }
      78                 : 
      79                 : /// A message received by `mgmt` when a compute node is ready.
      80                 : pub type ComputeReady = Result<DatabaseInfo, String>;
      81                 : 
      82                 : // TODO: replace with an http-based protocol.
      83                 : struct MgmtHandler;
      84                 : #[async_trait::async_trait]
      85                 : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
      86               3 :     async fn process_query(
      87               3 :         &mut self,
      88               3 :         pgb: &mut PostgresBackendTCP,
      89               3 :         query: &str,
      90               3 :     ) -> Result<(), QueryError> {
      91               3 :         try_process_query(pgb, query).map_err(|e| {
      92 UBC           0 :             error!("failed to process response: {e:?}");
      93               0 :             e
      94 CBC           3 :         })
      95               3 :     }
      96                 : }
      97                 : 
      98               3 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
      99               3 :     let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
     100                 : 
     101               3 :     let span = info_span!("event", session_id = resp.session_id);
     102               3 :     let _enter = span.enter();
     103               3 :     info!("got response: {:?}", resp.result);
     104                 : 
     105               3 :     match notify(resp.session_id, Ok(resp.result)) {
     106                 :         Ok(()) => {
     107               3 :             pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
     108               3 :                 .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
     109               3 :                 .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     110                 :         }
     111 UBC           0 :         Err(e) => {
     112               0 :             error!("failed to deliver response to per-client task");
     113               0 :             pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
     114                 :         }
     115                 :     }
     116                 : 
     117 CBC           3 :     Ok(())
     118               3 : }
        

Generated by: LCOV version 2.1-beta