Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub(crate) mod connect_auth;
5 : pub(crate) mod connect_compute;
6 : pub(crate) mod retry;
7 : pub(crate) mod wake_compute;
8 :
9 : use std::collections::HashSet;
10 : use std::convert::Infallible;
11 : use std::sync::Arc;
12 :
13 : use futures::TryStreamExt;
14 : use itertools::Itertools;
15 : use once_cell::sync::OnceCell;
16 : use postgres_client::RawCancelToken;
17 : use postgres_client::connect_raw::StartupStream;
18 : use postgres_protocol::message::backend::Message;
19 : use regex::Regex;
20 : use serde::{Deserialize, Serialize};
21 : use smol_str::{SmolStr, format_smolstr};
22 : use tokio::io::{AsyncRead, AsyncWrite};
23 : use tokio::net::TcpStream;
24 : use tokio::sync::oneshot;
25 : use tracing::Instrument;
26 :
27 : use crate::cancellation::{CancelClosure, CancellationHandler};
28 : use crate::compute::{ComputeConnection, PostgresError, RustlsStream};
29 : use crate::config::ProxyConfig;
30 : use crate::context::RequestContext;
31 : pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
32 : use crate::pglb::{ClientMode, ClientRequestError};
33 : use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
34 : use crate::rate_limiter::EndpointRateLimiter;
35 : use crate::stream::{PqStream, Stream};
36 : use crate::types::EndpointCacheKey;
37 : use crate::{auth, compute};
38 :
39 : #[allow(clippy::too_many_arguments)]
40 0 : pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
41 0 : config: &'static ProxyConfig,
42 0 : auth_backend: &'static auth::Backend<'static, ()>,
43 0 : ctx: &RequestContext,
44 0 : cancellation_handler: Arc<CancellationHandler>,
45 0 : client: &mut PqStream<Stream<S>>,
46 0 : mode: &ClientMode,
47 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
48 0 : common_names: Option<&HashSet<String>>,
49 0 : params: &StartupMessageParams,
50 0 : ) -> Result<(ComputeConnection, oneshot::Sender<Infallible>), ClientRequestError> {
51 0 : let hostname = mode.hostname(client.get_ref());
52 : // Extract credentials which we're going to use for auth.
53 0 : let result = auth_backend
54 0 : .as_ref()
55 0 : .map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, params, hostname, common_names))
56 0 : .transpose();
57 :
58 0 : let user_info = match result {
59 0 : Ok(user_info) => user_info,
60 0 : Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
61 : };
62 :
63 0 : let user = user_info.get_user().to_owned();
64 0 : let user_info = match user_info
65 0 : .authenticate(
66 0 : ctx,
67 0 : client,
68 0 : mode.allow_cleartext(),
69 0 : &config.authentication_config,
70 0 : endpoint_rate_limiter,
71 : )
72 0 : .await
73 : {
74 0 : Ok(auth_result) => auth_result,
75 0 : Err(e) => {
76 0 : let db = params.get("database");
77 0 : let app = params.get("application_name");
78 0 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
79 :
80 0 : return Err(client
81 0 : .throw_error(e, Some(ctx))
82 0 : .instrument(params_span)
83 0 : .await)?;
84 : }
85 : };
86 :
87 0 : let (cplane, creds) = match user_info {
88 0 : auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
89 0 : auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
90 : };
91 0 : let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
92 0 : let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
93 0 : auth_info.set_startup_params(params, params_compat);
94 :
95 0 : let backend = auth::Backend::ControlPlane(cplane, creds.info);
96 :
97 : // TODO: callback to pglb
98 0 : let res = connect_auth::connect_to_compute_and_auth(
99 0 : ctx,
100 0 : config,
101 0 : &backend,
102 0 : auth_info,
103 0 : connect_compute::TlsNegotiation::Postgres,
104 0 : )
105 0 : .await;
106 :
107 0 : let mut node = match res {
108 0 : Ok(node) => node,
109 0 : Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
110 : };
111 :
112 0 : send_client_greeting(ctx, &config.greetings, client);
113 :
114 0 : let auth::Backend::ControlPlane(_, user_info) = backend else {
115 0 : unreachable!("ensured above");
116 : };
117 :
118 0 : let session = cancellation_handler.get_key();
119 :
120 0 : let (process_id, secret_key) =
121 0 : forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?;
122 0 : let hostname = node.hostname.to_string();
123 :
124 0 : let session_id = ctx.session_id();
125 0 : let (cancel_on_shutdown, cancel) = oneshot::channel();
126 0 : tokio::spawn(async move {
127 0 : session
128 0 : .maintain_cancel_key(
129 0 : session_id,
130 0 : cancel,
131 0 : &CancelClosure {
132 0 : socket_addr: node.socket_addr,
133 0 : cancel_token: RawCancelToken {
134 0 : ssl_mode: node.ssl_mode,
135 0 : process_id,
136 0 : secret_key,
137 0 : },
138 0 : hostname,
139 0 : user_info,
140 0 : },
141 0 : &config.connect_to_compute,
142 0 : )
143 0 : .await;
144 0 : });
145 :
146 0 : Ok((node, cancel_on_shutdown))
147 0 : }
148 :
149 : /// Greet the client with any useful information.
150 0 : pub(crate) fn send_client_greeting(
151 0 : ctx: &RequestContext,
152 0 : greetings: &String,
153 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
154 0 : ) {
155 : // Expose session_id to clients if we have a greeting message.
156 0 : if !greetings.is_empty() {
157 0 : let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id());
158 0 : client.write_message(BeMessage::NoticeResponse(session_msg.as_str()));
159 0 : }
160 :
161 : // Forward recorded latencies for probing requests
162 0 : if let Some(testodrome_id) = ctx.get_testodrome_id() {
163 0 : client.write_message(BeMessage::ParameterStatus {
164 0 : name: "neon.testodrome_id".as_bytes(),
165 0 : value: testodrome_id.as_bytes(),
166 0 : });
167 0 :
168 0 : let latency_measured = ctx.get_proxy_latency();
169 0 :
170 0 : client.write_message(BeMessage::ParameterStatus {
171 0 : name: "neon.cplane_latency".as_bytes(),
172 0 : value: latency_measured.cplane.as_micros().to_string().as_bytes(),
173 0 : });
174 0 :
175 0 : client.write_message(BeMessage::ParameterStatus {
176 0 : name: "neon.client_latency".as_bytes(),
177 0 : value: latency_measured.client.as_micros().to_string().as_bytes(),
178 0 : });
179 0 :
180 0 : client.write_message(BeMessage::ParameterStatus {
181 0 : name: "neon.compute_latency".as_bytes(),
182 0 : value: latency_measured.compute.as_micros().to_string().as_bytes(),
183 0 : });
184 0 :
185 0 : client.write_message(BeMessage::ParameterStatus {
186 0 : name: "neon.retry_latency".as_bytes(),
187 0 : value: latency_measured.retry.as_micros().to_string().as_bytes(),
188 0 : });
189 0 : }
190 0 : }
191 :
192 0 : pub(crate) async fn forward_compute_params_to_client(
193 0 : ctx: &RequestContext,
194 0 : cancel_key_data: CancelKeyData,
195 0 : client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
196 0 : compute: &mut StartupStream<TcpStream, RustlsStream>,
197 0 : ) -> Result<(i32, i32), ClientRequestError> {
198 0 : let mut process_id = 0;
199 0 : let mut secret_key = 0;
200 :
201 0 : let err = loop {
202 : // if the client buffer is too large, let's write out some bytes now to save some space
203 0 : client.write_if_full().await?;
204 :
205 0 : let msg = match compute.try_next().await {
206 0 : Ok(msg) => msg,
207 0 : Err(e) => break postgres_client::Error::io(e),
208 : };
209 :
210 0 : match msg {
211 : // Send our cancellation key data instead.
212 0 : Some(Message::BackendKeyData(body)) => {
213 0 : client.write_message(BeMessage::BackendKeyData(cancel_key_data));
214 0 : process_id = body.process_id();
215 0 : secret_key = body.secret_key();
216 0 : }
217 : // Forward all postgres connection params to the client.
218 0 : Some(Message::ParameterStatus(body)) => {
219 0 : if let Ok(name) = body.name()
220 0 : && let Ok(value) = body.value()
221 0 : {
222 0 : client.write_message(BeMessage::ParameterStatus {
223 0 : name: name.as_bytes(),
224 0 : value: value.as_bytes(),
225 0 : });
226 0 : }
227 : }
228 : // Forward all notices to the client.
229 0 : Some(Message::NoticeResponse(notice)) => {
230 0 : client.write_raw(notice.as_bytes().len(), b'N', |buf| {
231 0 : buf.extend_from_slice(notice.as_bytes());
232 0 : });
233 : }
234 : Some(Message::ReadyForQuery(_)) => {
235 0 : client.write_message(BeMessage::ReadyForQuery);
236 0 : return Ok((process_id, secret_key));
237 : }
238 0 : Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body),
239 0 : Some(_) => break postgres_client::Error::unexpected_message(),
240 0 : None => break postgres_client::Error::closed(),
241 : }
242 : };
243 :
244 0 : Err(client
245 0 : .throw_error(PostgresError::Postgres(err), Some(ctx))
246 0 : .await)?
247 0 : }
248 :
249 0 : #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
250 : pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
251 :
252 : impl NeonOptions {
253 : // proxy options:
254 :
255 : /// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
256 : pub const PARAMS_COMPAT: &'static str = "proxy_params_compat";
257 :
258 : // cplane options:
259 :
260 : /// `LSN` allows provisioning an ephemeral compute with time-travel to the provided LSN.
261 : const LSN: &'static str = "lsn";
262 :
263 : /// `TIMESTAMP` allows provisioning an ephemeral compute with time-travel to the provided timestamp.
264 : const TIMESTAMP: &'static str = "timestamp";
265 :
266 : /// `ENDPOINT_TYPE` allows configuring an ephemeral compute to be read_only or read_write.
267 : const ENDPOINT_TYPE: &'static str = "endpoint_type";
268 :
269 13 : pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
270 13 : params
271 13 : .options_raw()
272 13 : .map(Self::parse_from_iter)
273 13 : .unwrap_or_default()
274 13 : }
275 :
276 13 : pub(crate) fn parse_options_raw(options: &str) -> Self {
277 13 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
278 13 : }
279 :
280 0 : pub(crate) fn get(&self, key: &str) -> Option<SmolStr> {
281 0 : self.0
282 0 : .iter()
283 0 : .find_map(|(k, v)| (k == key).then_some(v))
284 0 : .cloned()
285 0 : }
286 :
287 2 : pub(crate) fn is_ephemeral(&self) -> bool {
288 2 : self.0.iter().any(|(k, _)| match &**k {
289 : // This is not a cplane option, we know it does not create ephemeral computes.
290 0 : Self::PARAMS_COMPAT => false,
291 0 : Self::LSN => true,
292 0 : Self::TIMESTAMP => true,
293 0 : Self::ENDPOINT_TYPE => true,
294 : // err on the side of caution. any cplane options we don't know about
295 : // might lead to ephemeral computes.
296 0 : _ => true,
297 0 : })
298 2 : }
299 :
300 20 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
301 20 : let mut options = options
302 20 : .filter_map(neon_option)
303 20 : .map(|(k, v)| (k.into(), v.into()))
304 20 : .collect_vec();
305 20 : options.sort();
306 20 : Self(options)
307 20 : }
308 :
309 4 : pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
310 : // prefix + format!(" {k}:{v}")
311 : // kinda jank because SmolStr is immutable
312 4 : std::iter::once(prefix)
313 4 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
314 4 : .collect::<SmolStr>()
315 4 : .into()
316 4 : }
317 :
318 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
319 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
320 0 : pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
321 0 : self.0
322 0 : .iter()
323 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
324 0 : .collect()
325 0 : }
326 : }
327 :
328 34 : pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
329 : static RE: OnceCell<Regex> = OnceCell::new();
330 34 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").expect("regex should be correct"));
331 :
332 34 : let cap = re.captures(bytes)?;
333 5 : let (_, [k, v]) = cap.extract();
334 5 : Some((k, v))
335 34 : }
|