LCOV - code coverage report
Current view: top level - proxy/src/control_plane - mgmt.rs (source / functions) Coverage Total Hit
Test: 8ff8efadb0253cf618c612650348666c0c564111.info Lines: 0.0 % 64 0
Test Date: 2024-11-20 17:53:50 Functions: 0.0 % 13 0

            Line data    Source code
       1              : use std::convert::Infallible;
       2              : 
       3              : use anyhow::Context;
       4              : use once_cell::sync::Lazy;
       5              : use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
       6              : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
       7              : use tokio::net::{TcpListener, TcpStream};
       8              : use tokio_util::sync::CancellationToken;
       9              : use tracing::{error, info, info_span, Instrument};
      10              : 
      11              : use crate::control_plane::messages::{DatabaseInfo, KickSession};
      12              : use crate::waiters::{self, Waiter, Waiters};
      13              : 
      14              : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
      15              : 
      16              : /// Give caller an opportunity to wait for the cloud's reply.
      17            0 : pub(crate) fn get_waiter(
      18            0 :     psql_session_id: impl Into<String>,
      19            0 : ) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
      20            0 :     CPLANE_WAITERS.register(psql_session_id.into())
      21            0 : }
      22              : 
      23            0 : pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
      24            0 :     CPLANE_WAITERS.notify(psql_session_id, msg)
      25            0 : }
      26              : 
      27              : /// Management API listener task.
      28              : /// It spawns management response handlers needed for the console redirect auth flow.
      29            0 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
      30            0 :     scopeguard::defer! {
      31            0 :         info!("mgmt has shut down");
      32            0 :     }
      33              : 
      34              :     loop {
      35            0 :         let (socket, peer_addr) = listener.accept().await?;
      36            0 :         info!("accepted connection from {peer_addr}");
      37              : 
      38            0 :         socket
      39            0 :             .set_nodelay(true)
      40            0 :             .context("failed to set client socket option")?;
      41              : 
      42            0 :         let span = info_span!("mgmt", peer = %peer_addr);
      43              : 
      44            0 :         tokio::task::spawn(
      45            0 :             async move {
      46            0 :                 info!("serving a new management API connection");
      47              : 
      48              :                 // these might be long running connections, have a separate logging for cancelling
      49              :                 // on shutdown and other ways of stopping.
      50            0 :                 let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
      51            0 :                     let _e = span.entered();
      52            0 :                     info!("management API task cancelled");
      53            0 :                 });
      54              : 
      55            0 :                 if let Err(e) = handle_connection(socket).await {
      56            0 :                     error!("serving failed with an error: {e}");
      57              :                 } else {
      58            0 :                     info!("serving completed");
      59              :                 }
      60              : 
      61              :                 // we can no longer get dropped
      62            0 :                 scopeguard::ScopeGuard::into_inner(cancelled);
      63            0 :             }
      64            0 :             .instrument(span),
      65            0 :         );
      66              :     }
      67            0 : }
      68              : 
      69            0 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
      70            0 :     let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
      71            0 :     pgbackend
      72            0 :         .run(&mut MgmtHandler, &CancellationToken::new())
      73            0 :         .await
      74            0 : }
      75              : 
      76              : /// A message received by `mgmt` when a compute node is ready.
      77              : pub(crate) type ComputeReady = DatabaseInfo;
      78              : 
      79              : // TODO: replace with an http-based protocol.
      80              : struct MgmtHandler;
      81              : 
      82              : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
      83            0 :     async fn process_query(
      84            0 :         &mut self,
      85            0 :         pgb: &mut PostgresBackendTCP,
      86            0 :         query: &str,
      87            0 :     ) -> Result<(), QueryError> {
      88            0 :         try_process_query(pgb, query).map_err(|e| {
      89            0 :             error!("failed to process response: {e:?}");
      90            0 :             e
      91            0 :         })
      92            0 :     }
      93              : }
      94              : 
      95            0 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
      96            0 :     let resp: KickSession<'_> =
      97            0 :         serde_json::from_str(query).context("Failed to parse query as json")?;
      98              : 
      99            0 :     let span = info_span!("event", session_id = resp.session_id);
     100            0 :     let _enter = span.enter();
     101            0 :     info!("got response: {:?}", resp.result);
     102              : 
     103            0 :     match notify(resp.session_id, resp.result) {
     104              :         Ok(()) => {
     105            0 :             pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
     106            0 :                 .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
     107            0 :                 .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
     108              :         }
     109            0 :         Err(e) => {
     110            0 :             error!("failed to deliver response to per-client task");
     111            0 :             pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
     112              :         }
     113              :     }
     114              : 
     115            0 :     Ok(())
     116            0 : }
        

Generated by: LCOV version 2.1-beta