TLA Line data Source code
1 : use crate::{
2 : console::messages::{DatabaseInfo, KickSession},
3 : waiters::{self, Waiter, Waiters},
4 : };
5 : use anyhow::Context;
6 : use once_cell::sync::Lazy;
7 : use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
8 : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
9 : use std::{convert::Infallible, future};
10 : use tokio::net::{TcpListener, TcpStream};
11 : use tracing::{error, info, info_span, Instrument};
12 :
13 : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
14 :
15 : /// Give caller an opportunity to wait for the cloud's reply.
16 CBC 3 : pub async fn with_waiter<R, T, E>(
17 3 : psql_session_id: impl Into<String>,
18 3 : action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
19 3 : ) -> Result<T, E>
20 3 : where
21 3 : R: std::future::Future<Output = Result<T, E>>,
22 3 : E: From<waiters::RegisterError>,
23 3 : {
24 3 : let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
25 3 : action(waiter).await
26 3 : }
27 :
28 3 : pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
29 3 : CPLANE_WAITERS.notify(psql_session_id, msg)
30 3 : }
31 :
32 : /// Console management API listener task.
33 : /// It spawns console response handlers needed for the link auth.
34 16 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
35 16 : scopeguard::defer! {
36 16 : info!("mgmt has shut down");
37 16 : }
38 :
39 : loop {
40 19 : let (socket, peer_addr) = listener.accept().await?;
41 3 : info!("accepted connection from {peer_addr}");
42 :
43 3 : socket
44 3 : .set_nodelay(true)
45 3 : .context("failed to set client socket option")?;
46 :
47 3 : let span = info_span!("mgmt", peer = %peer_addr);
48 :
49 3 : tokio::task::spawn(
50 3 : async move {
51 3 : info!("serving a new console management API connection");
52 :
53 : // these might be long running connections, have a separate logging for cancelling
54 : // on shutdown and other ways of stopping.
55 3 : let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
56 UBC 0 : let _e = span.entered();
57 0 : info!("console management API task cancelled");
58 CBC 3 : });
59 :
60 12 : if let Err(e) = handle_connection(socket).await {
61 UBC 0 : error!("serving failed with an error: {e}");
62 : } else {
63 CBC 3 : info!("serving completed");
64 : }
65 :
66 : // we can no longer get dropped
67 3 : scopeguard::ScopeGuard::into_inner(cancelled);
68 3 : }
69 3 : .instrument(span),
70 3 : );
71 : }
72 UBC 0 : }
73 :
74 CBC 3 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
75 3 : let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
76 12 : pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
77 3 : }
78 :
79 : /// A message received by `mgmt` when a compute node is ready.
80 : pub type ComputeReady = Result<DatabaseInfo, String>;
81 :
82 : // TODO: replace with an http-based protocol.
83 : struct MgmtHandler;
84 : #[async_trait::async_trait]
85 : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
86 3 : async fn process_query(
87 3 : &mut self,
88 3 : pgb: &mut PostgresBackendTCP,
89 3 : query: &str,
90 3 : ) -> Result<(), QueryError> {
91 3 : try_process_query(pgb, query).map_err(|e| {
92 UBC 0 : error!("failed to process response: {e:?}");
93 0 : e
94 CBC 3 : })
95 3 : }
96 : }
97 :
98 3 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
99 3 : let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
100 :
101 3 : let span = info_span!("event", session_id = resp.session_id);
102 3 : let _enter = span.enter();
103 3 : info!("got response: {:?}", resp.result);
104 :
105 3 : match notify(resp.session_id, Ok(resp.result)) {
106 : Ok(()) => {
107 3 : pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
108 3 : .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
109 3 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
110 : }
111 UBC 0 : Err(e) => {
112 0 : error!("failed to deliver response to per-client task");
113 0 : pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
114 : }
115 : }
116 :
117 CBC 3 : Ok(())
118 3 : }
|