Line data Source code
1 : #[cfg(test)]
2 : mod tests;
3 :
4 : pub mod connect_compute;
5 : mod copy_bidirectional;
6 : pub mod handshake;
7 : pub mod passthrough;
8 : pub mod retry;
9 : pub mod wake_compute;
10 :
11 : use crate::{
12 : auth,
13 : cancellation::{self, CancellationHandler},
14 : compute,
15 : config::{ProxyConfig, TlsConfig},
16 : context::RequestMonitoring,
17 : error::ReportableError,
18 : metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
19 : protocol2::WithClientIp,
20 : proxy::handshake::{handshake, HandshakeData},
21 : rate_limiter::EndpointRateLimiter,
22 : stream::{PqStream, Stream},
23 : EndpointCacheKey,
24 : };
25 : use futures::TryFutureExt;
26 : use itertools::Itertools;
27 : use once_cell::sync::OnceCell;
28 : use pq_proto::{BeMessage as Be, StartupMessageParams};
29 : use regex::Regex;
30 : use smol_str::{format_smolstr, SmolStr};
31 : use std::sync::Arc;
32 : use thiserror::Error;
33 : use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
34 : use tokio_util::sync::CancellationToken;
35 : use tracing::{error, info, Instrument};
36 :
37 : use self::{
38 : connect_compute::{connect_to_compute, TcpMechanism},
39 : passthrough::ProxyPassthrough,
40 : };
41 :
42 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
43 :
44 0 : pub async fn run_until_cancelled<F: std::future::Future>(
45 0 : f: F,
46 0 : cancellation_token: &CancellationToken,
47 0 : ) -> Option<F::Output> {
48 0 : match futures::future::select(
49 0 : std::pin::pin!(f),
50 0 : std::pin::pin!(cancellation_token.cancelled()),
51 0 : )
52 0 : .await
53 : {
54 0 : futures::future::Either::Left((f, _)) => Some(f),
55 0 : futures::future::Either::Right(((), _)) => None,
56 : }
57 0 : }
58 :
59 0 : pub async fn task_main(
60 0 : config: &'static ProxyConfig,
61 0 : listener: tokio::net::TcpListener,
62 0 : cancellation_token: CancellationToken,
63 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
64 0 : cancellation_handler: Arc<CancellationHandler>,
65 0 : ) -> anyhow::Result<()> {
66 0 : scopeguard::defer! {
67 0 : info!("proxy has shut down");
68 : }
69 :
70 : // When set for the server socket, the keepalive setting
71 : // will be inherited by all accepted client sockets.
72 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
73 :
74 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
75 :
76 0 : while let Some(accept_result) =
77 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
78 0 : {
79 0 : let (socket, peer_addr) = accept_result?;
80 :
81 0 : let session_id = uuid::Uuid::new_v4();
82 0 : let cancellation_handler = Arc::clone(&cancellation_handler);
83 0 : let endpoint_rate_limiter = endpoint_rate_limiter.clone();
84 0 :
85 0 : connections.spawn(async move {
86 0 : let mut socket = WithClientIp::new(socket);
87 0 : let mut peer_addr = peer_addr.ip();
88 0 : match socket.wait_for_addr().await {
89 0 : Ok(Some(addr)) => peer_addr = addr.ip(),
90 0 : Err(e) => {
91 0 : error!("per-client task finished with an error: {e:#}");
92 0 : return;
93 : }
94 0 : Ok(None) if config.require_client_ip => {
95 0 : error!("missing required client IP");
96 0 : return;
97 : }
98 0 : Ok(None) => {}
99 : }
100 :
101 0 : match socket.inner.set_nodelay(true) {
102 0 : Ok(()) => {},
103 0 : Err(e) => {
104 0 : error!("per-client task finished with an error: failed to set socket option: {e:#}");
105 0 : return;
106 : },
107 : };
108 :
109 0 : let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
110 0 : let span = ctx.span.clone();
111 :
112 0 : let res = handle_client(
113 0 : config,
114 0 : &mut ctx,
115 0 : cancellation_handler,
116 0 : socket,
117 0 : ClientMode::Tcp,
118 0 : endpoint_rate_limiter,
119 0 : )
120 0 : .instrument(span.clone())
121 0 : .await;
122 :
123 0 : match res {
124 0 : Err(e) => {
125 0 : // todo: log and push to ctx the error kind
126 0 : ctx.set_error_kind(e.get_error_kind());
127 0 : ctx.log();
128 0 : error!(parent: &span, "per-client task finished with an error: {e:#}");
129 : }
130 0 : Ok(None) => {
131 0 : ctx.set_success();
132 0 : ctx.log();
133 0 : }
134 0 : Ok(Some(p)) => {
135 0 : ctx.set_success();
136 0 : ctx.log();
137 0 : match p.proxy_pass().instrument(span.clone()).await {
138 0 : Ok(()) => {}
139 0 : Err(e) => {
140 0 : error!(parent: &span, "per-client task finished with an error: {e:#}");
141 : }
142 : }
143 : }
144 : }
145 0 : });
146 : }
147 :
148 0 : connections.close();
149 0 : drop(listener);
150 0 :
151 0 : // Drain connections
152 0 : connections.wait().await;
153 :
154 0 : Ok(())
155 0 : }
156 :
157 : pub enum ClientMode {
158 : Tcp,
159 : Websockets { hostname: Option<String> },
160 : }
161 :
162 : /// Abstracts the logic of handling TCP vs WS clients
163 : impl ClientMode {
164 0 : pub fn allow_cleartext(&self) -> bool {
165 0 : match self {
166 0 : ClientMode::Tcp => false,
167 0 : ClientMode::Websockets { .. } => true,
168 : }
169 0 : }
170 :
171 0 : pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
172 0 : match self {
173 0 : ClientMode::Tcp => config.allow_self_signed_compute,
174 0 : ClientMode::Websockets { .. } => false,
175 : }
176 0 : }
177 :
178 0 : fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
179 0 : match self {
180 0 : ClientMode::Tcp => s.sni_hostname(),
181 0 : ClientMode::Websockets { hostname } => hostname.as_deref(),
182 : }
183 0 : }
184 :
185 0 : fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
186 0 : match self {
187 0 : ClientMode::Tcp => tls,
188 : // TLS is None here if using websockets, because the connection is already encrypted.
189 0 : ClientMode::Websockets { .. } => None,
190 : }
191 0 : }
192 : }
193 :
194 0 : #[derive(Debug, Error)]
195 : // almost all errors should be reported to the user, but there's a few cases where we cannot
196 : // 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
197 : // 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
198 : // we cannot be sure the client even understands our error message
199 : // 3. PrepareClient: The client disconnected, so we can't tell them anyway...
200 : pub enum ClientRequestError {
201 : #[error("{0}")]
202 : Cancellation(#[from] cancellation::CancelError),
203 : #[error("{0}")]
204 : Handshake(#[from] handshake::HandshakeError),
205 : #[error("{0}")]
206 : HandshakeTimeout(#[from] tokio::time::error::Elapsed),
207 : #[error("{0}")]
208 : PrepareClient(#[from] std::io::Error),
209 : #[error("{0}")]
210 : ReportedError(#[from] crate::stream::ReportedError),
211 : }
212 :
213 : impl ReportableError for ClientRequestError {
214 0 : fn get_error_kind(&self) -> crate::error::ErrorKind {
215 0 : match self {
216 0 : ClientRequestError::Cancellation(e) => e.get_error_kind(),
217 0 : ClientRequestError::Handshake(e) => e.get_error_kind(),
218 0 : ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
219 0 : ClientRequestError::ReportedError(e) => e.get_error_kind(),
220 0 : ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
221 : }
222 0 : }
223 : }
224 :
225 0 : pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
226 0 : config: &'static ProxyConfig,
227 0 : ctx: &mut RequestMonitoring,
228 0 : cancellation_handler: Arc<CancellationHandler>,
229 0 : stream: S,
230 0 : mode: ClientMode,
231 0 : endpoint_rate_limiter: Arc<EndpointRateLimiter>,
232 0 : ) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
233 0 : info!("handling interactive connection from client");
234 :
235 0 : let proto = ctx.protocol;
236 0 : let _client_gauge = NUM_CLIENT_CONNECTION_GAUGE
237 0 : .with_label_values(&[proto])
238 0 : .guard();
239 0 : let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
240 0 : .with_label_values(&[proto])
241 0 : .guard();
242 0 :
243 0 : let tls = config.tls_config.as_ref();
244 0 :
245 0 : let pause = ctx.latency_timer.pause();
246 0 : let do_handshake = handshake(stream, mode.handshake_tls(tls));
247 0 : let (mut stream, params) =
248 0 : match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
249 0 : HandshakeData::Startup(stream, params) => (stream, params),
250 0 : HandshakeData::Cancel(cancel_key_data) => {
251 0 : return Ok(cancellation_handler
252 0 : .cancel_session(cancel_key_data, ctx.session_id)
253 0 : .await
254 0 : .map(|()| None)?)
255 : }
256 : };
257 0 : drop(pause);
258 0 :
259 0 : let hostname = mode.hostname(stream.get_ref());
260 0 :
261 0 : let common_names = tls.map(|tls| &tls.common_names);
262 0 :
263 0 : // Extract credentials which we're going to use for auth.
264 0 : let result = config
265 0 : .auth_backend
266 0 : .as_ref()
267 0 : .map(|_| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
268 0 : .transpose();
269 :
270 0 : let user_info = match result {
271 0 : Ok(user_info) => user_info,
272 0 : Err(e) => stream.throw_error(e).await?,
273 : };
274 :
275 : // check rate limit
276 0 : if let Some(ep) = user_info.get_endpoint() {
277 0 : if !endpoint_rate_limiter.check(ep) {
278 0 : return stream
279 0 : .throw_error(auth::AuthError::too_many_connections())
280 0 : .await?;
281 0 : }
282 0 : }
283 :
284 0 : let user = user_info.get_user().to_owned();
285 0 : let user_info = match user_info
286 0 : .authenticate(
287 0 : ctx,
288 0 : &mut stream,
289 0 : mode.allow_cleartext(),
290 0 : &config.authentication_config,
291 0 : )
292 0 : .await
293 : {
294 0 : Ok(auth_result) => auth_result,
295 0 : Err(e) => {
296 0 : let db = params.get("database");
297 0 : let app = params.get("application_name");
298 0 : let params_span = tracing::info_span!("", ?user, ?db, ?app);
299 :
300 0 : return stream.throw_error(e).instrument(params_span).await?;
301 : }
302 : };
303 :
304 0 : let mut node = connect_to_compute(
305 0 : ctx,
306 0 : &TcpMechanism { params: ¶ms },
307 0 : &user_info,
308 0 : mode.allow_self_signed_compute(config),
309 0 : )
310 0 : .or_else(|e| stream.throw_error(e))
311 0 : .await?;
312 :
313 0 : let session = cancellation_handler.get_session();
314 0 : prepare_client_connection(&node, &session, &mut stream).await?;
315 :
316 : // Before proxy passing, forward to compute whatever data is left in the
317 : // PqStream input buffer. Normally there is none, but our serverless npm
318 : // driver in pipeline mode sends startup, password and first query
319 : // immediately after opening the connection.
320 0 : let (stream, read_buf) = stream.into_inner();
321 0 : node.stream.write_all(&read_buf).await?;
322 :
323 0 : Ok(Some(ProxyPassthrough {
324 0 : client: stream,
325 0 : aux: node.aux.clone(),
326 0 : compute: node,
327 0 : req: _request_gauge,
328 0 : conn: _client_gauge,
329 0 : cancel: session,
330 0 : }))
331 0 : }
332 :
333 : /// Finish client connection initialization: confirm auth success, send params, etc.
334 0 : #[tracing::instrument(skip_all)]
335 : async fn prepare_client_connection(
336 : node: &compute::PostgresConnection,
337 : session: &cancellation::Session,
338 : stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
339 : ) -> Result<(), std::io::Error> {
340 : // Register compute's query cancellation token and produce a new, unique one.
341 : // The new token (cancel_key_data) will be sent to the client.
342 : let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
343 :
344 : // Forward all postgres connection params to the client.
345 : // Right now the implementation is very hacky and inefficent (ideally,
346 : // we don't need an intermediate hashmap), but at least it should be correct.
347 : for (name, value) in &node.params {
348 : // TODO: Theoretically, this could result in a big pile of params...
349 : stream.write_message_noflush(&Be::ParameterStatus {
350 : name: name.as_bytes(),
351 : value: value.as_bytes(),
352 : })?;
353 : }
354 :
355 : stream
356 : .write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
357 : .write_message(&Be::ReadyForQuery)
358 : .await?;
359 :
360 : Ok(())
361 : }
362 :
363 26 : #[derive(Debug, Clone, PartialEq, Eq, Default)]
364 : pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
365 :
366 : impl NeonOptions {
367 22 : pub fn parse_params(params: &StartupMessageParams) -> Self {
368 22 : params
369 22 : .options_raw()
370 22 : .map(Self::parse_from_iter)
371 22 : .unwrap_or_default()
372 22 : }
373 14 : pub fn parse_options_raw(options: &str) -> Self {
374 14 : Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
375 14 : }
376 :
377 26 : fn parse_from_iter<'a>(options: impl Iterator<Item = &'a str>) -> Self {
378 26 : let mut options = options
379 26 : .filter_map(neon_option)
380 26 : .map(|(k, v)| (k.into(), v.into()))
381 26 : .collect_vec();
382 26 : options.sort();
383 26 : Self(options)
384 26 : }
385 :
386 8 : pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
387 8 : // prefix + format!(" {k}:{v}")
388 8 : // kinda jank because SmolStr is immutable
389 8 : std::iter::once(prefix)
390 8 : .chain(self.0.iter().flat_map(|(k, v)| [" ", &**k, ":", &**v]))
391 8 : .collect::<SmolStr>()
392 8 : .into()
393 8 : }
394 :
395 : /// <https://swagger.io/docs/specification/serialization/> DeepObject format
396 : /// `paramName[prop1]=value1¶mName[prop2]=value2&...`
397 0 : pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
398 0 : self.0
399 0 : .iter()
400 0 : .map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
401 0 : .collect()
402 0 : }
403 : }
404 :
405 64 : pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
406 64 : static RE: OnceCell<Regex> = OnceCell::new();
407 64 : let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
408 :
409 64 : let cap = re.captures(bytes)?;
410 8 : let (_, [k, v]) = cap.extract();
411 8 : Some((k, v))
412 64 : }
|