Line data Source code
1 : use std::convert::Infallible;
2 :
3 : use anyhow::Context;
4 : use once_cell::sync::Lazy;
5 : use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
6 : use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
7 : use tokio::net::{TcpListener, TcpStream};
8 : use tokio_util::sync::CancellationToken;
9 : use tracing::{error, info, info_span, Instrument};
10 :
11 : use crate::control_plane::messages::{DatabaseInfo, KickSession};
12 : use crate::waiters::{self, Waiter, Waiters};
13 :
14 : static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
15 :
16 : /// Give caller an opportunity to wait for the cloud's reply.
17 0 : pub(crate) fn get_waiter(
18 0 : psql_session_id: impl Into<String>,
19 0 : ) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
20 0 : CPLANE_WAITERS.register(psql_session_id.into())
21 0 : }
22 :
23 0 : pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
24 0 : CPLANE_WAITERS.notify(psql_session_id, msg)
25 0 : }
26 :
27 : /// Management API listener task.
28 : /// It spawns management response handlers needed for the console redirect auth flow.
29 0 : pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
30 0 : scopeguard::defer! {
31 0 : info!("mgmt has shut down");
32 0 : }
33 :
34 : loop {
35 0 : let (socket, peer_addr) = listener.accept().await?;
36 0 : info!("accepted connection from {peer_addr}");
37 :
38 0 : socket
39 0 : .set_nodelay(true)
40 0 : .context("failed to set client socket option")?;
41 :
42 0 : let span = info_span!("mgmt", peer = %peer_addr);
43 :
44 0 : tokio::task::spawn(
45 0 : async move {
46 0 : info!("serving a new management API connection");
47 :
48 : // these might be long running connections, have a separate logging for cancelling
49 : // on shutdown and other ways of stopping.
50 0 : let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
51 0 : let _e = span.entered();
52 0 : info!("management API task cancelled");
53 0 : });
54 :
55 0 : if let Err(e) = handle_connection(socket).await {
56 0 : error!("serving failed with an error: {e}");
57 : } else {
58 0 : info!("serving completed");
59 : }
60 :
61 : // we can no longer get dropped
62 0 : scopeguard::ScopeGuard::into_inner(cancelled);
63 0 : }
64 0 : .instrument(span),
65 0 : );
66 : }
67 0 : }
68 :
69 0 : async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
70 0 : let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
71 0 : pgbackend
72 0 : .run(&mut MgmtHandler, &CancellationToken::new())
73 0 : .await
74 0 : }
75 :
76 : /// A message received by `mgmt` when a compute node is ready.
77 : pub(crate) type ComputeReady = DatabaseInfo;
78 :
79 : // TODO: replace with an http-based protocol.
80 : struct MgmtHandler;
81 :
82 : impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
83 0 : async fn process_query(
84 0 : &mut self,
85 0 : pgb: &mut PostgresBackendTCP,
86 0 : query: &str,
87 0 : ) -> Result<(), QueryError> {
88 0 : try_process_query(pgb, query).map_err(|e| {
89 0 : error!("failed to process response: {e:?}");
90 0 : e
91 0 : })
92 0 : }
93 : }
94 :
95 0 : fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
96 0 : let resp: KickSession<'_> =
97 0 : serde_json::from_str(query).context("Failed to parse query as json")?;
98 :
99 0 : let span = info_span!("event", session_id = resp.session_id);
100 0 : let _enter = span.enter();
101 0 : info!("got response: {:?}", resp.result);
102 :
103 0 : match notify(resp.session_id, resp.result) {
104 : Ok(()) => {
105 0 : pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
106 0 : .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
107 0 : .write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
108 : }
109 0 : Err(e) => {
110 0 : error!("failed to deliver response to per-client task");
111 0 : pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
112 : }
113 : }
114 :
115 0 : Ok(())
116 0 : }
|