Line data Source code
1 : use crate::{
2 : console::messages::{DatabaseInfo, KickSession},
3 : waiters::{self, Waiter, Waiters},
4 : };
5 : use anyhow::Context;
6 : use once_cell::sync::Lazy;
7 : use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
8 : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
9 : use std::{convert::Infallible, future};
10 : use tokio::net::{TcpListener, TcpStream};
11 : use tracing::{error, info, info_span, Instrument};
12 :
13 : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
14 :
15 : /// Give caller an opportunity to wait for the cloud's reply.
16 0 : pub fn get_waiter(
17 0 : psql_session_id: impl Into<String>,
18 0 : ) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
19 0 : CPLANE_WAITERS.register(psql_session_id.into())
20 0 : }
21 :
22 0 : pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
23 0 : CPLANE_WAITERS.notify(psql_session_id, msg)
24 0 : }
25 :
26 : /// Console management API listener task.
27 : /// It spawns console response handlers needed for the link auth.
28 0 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
29 : scopeguard::defer! {
30 : info!("mgmt has shut down");
31 : }
32 :
33 : loop {
34 0 : let (socket, peer_addr) = listener.accept().await?;
35 0 : info!("accepted connection from {peer_addr}");
36 :
37 0 : socket
38 0 : .set_nodelay(true)
39 0 : .context("failed to set client socket option")?;
40 :
41 0 : let span = info_span!("mgmt", peer = %peer_addr);
42 :
43 0 : tokio::task::spawn(
44 0 : async move {
45 0 : info!("serving a new console management API connection");
46 :
47 : // these might be long running connections, have a separate logging for cancelling
48 : // on shutdown and other ways of stopping.
49 0 : let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
50 0 : let _e = span.entered();
51 0 : info!("console management API task cancelled");
52 0 : });
53 :
54 0 : if let Err(e) = handle_connection(socket).await {
55 0 : error!("serving failed with an error: {e}");
56 : } else {
57 0 : info!("serving completed");
58 : }
59 :
60 : // we can no longer get dropped
61 0 : scopeguard::ScopeGuard::into_inner(cancelled);
62 0 : }
63 0 : .instrument(span),
64 0 : );
65 : }
66 0 : }
67 :
68 0 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
69 0 : let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
70 0 : pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
71 0 : }
72 :
73 : /// A message received by `mgmt` when a compute node is ready.
74 : pub type ComputeReady = DatabaseInfo;
75 :
76 : // TODO: replace with an http-based protocol.
77 : struct MgmtHandler;
78 : #[async_trait::async_trait]
79 : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
80 0 : async fn process_query(
81 0 : &mut self,
82 0 : pgb: &mut PostgresBackendTCP,
83 0 : query: &str,
84 0 : ) -> Result<(), QueryError> {
85 0 : try_process_query(pgb, query).map_err(|e| {
86 0 : error!("failed to process response: {e:?}");
87 0 : e
88 0 : })
89 0 : }
90 : }
91 :
92 0 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
93 0 : let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
94 :
95 0 : let span = info_span!("event", session_id = resp.session_id);
96 0 : let _enter = span.enter();
97 0 : info!("got response: {:?}", resp.result);
98 :
99 0 : match notify(resp.session_id, resp.result) {
100 : Ok(()) => {
101 0 : pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
102 0 : .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
103 0 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
104 : }
105 0 : Err(e) => {
106 0 : error!("failed to deliver response to per-client task");
107 0 : pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
108 : }
109 : }
110 :
111 0 : Ok(())
112 0 : }
|