Line data Source code
1 : /// A stand-alone program that routes connections, e.g. from
2 : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
3 : ///
4 : /// This allows connecting to pods/services running in the same Kubernetes cluster from
5 : /// the outside. Similar to an ingress controller for HTTPS.
6 : use std::{net::SocketAddr, sync::Arc};
7 :
8 : use anyhow::{Context, anyhow, bail, ensure};
9 : use clap::Arg;
10 : use futures::future::Either;
11 : use futures::{FutureExt, TryFutureExt};
12 : use itertools::Itertools;
13 : use rustls::crypto::ring;
14 : use rustls::pki_types::{DnsName, PrivateKeyDer};
15 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16 : use tokio::net::TcpListener;
17 : use tokio_rustls::TlsConnector;
18 : use tokio_util::sync::CancellationToken;
19 : use tracing::{Instrument, error, info};
20 : use utils::project_git_version;
21 : use utils::sentry_init::init_sentry;
22 :
23 : use crate::context::RequestContext;
24 : use crate::metrics::{Metrics, ThreadPoolMetrics};
25 : use crate::protocol2::ConnectionInfo;
26 : use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
27 : use crate::stream::{PqStream, Stream};
28 : use crate::tls::TlsServerEndPoint;
29 :
30 : project_git_version!(GIT_VERSION);
31 :
32 0 : fn cli() -> clap::Command {
33 0 : clap::Command::new("Neon proxy/router")
34 0 : .version(GIT_VERSION)
35 0 : .arg(
36 0 : Arg::new("listen")
37 0 : .short('l')
38 0 : .long("listen")
39 0 : .help("listen for incoming client connections on ip:port")
40 0 : .default_value("127.0.0.1:4432"),
41 0 : )
42 0 : .arg(
43 0 : Arg::new("listen-tls")
44 0 : .long("listen-tls")
45 0 : .help("listen for incoming client connections on ip:port, requiring TLS to compute")
46 0 : .default_value("127.0.0.1:4433"),
47 0 : )
48 0 : .arg(
49 0 : Arg::new("tls-key")
50 0 : .short('k')
51 0 : .long("tls-key")
52 0 : .help("path to TLS key for client postgres connections")
53 0 : .required(true),
54 0 : )
55 0 : .arg(
56 0 : Arg::new("tls-cert")
57 0 : .short('c')
58 0 : .long("tls-cert")
59 0 : .help("path to TLS cert for client postgres connections")
60 0 : .required(true),
61 0 : )
62 0 : .arg(
63 0 : Arg::new("dest")
64 0 : .short('d')
65 0 : .long("destination")
66 0 : .help("append this domain zone to the SNI hostname to get the destination address")
67 0 : .required(true),
68 0 : )
69 0 : }
70 :
71 0 : pub async fn run() -> anyhow::Result<()> {
72 0 : let _logging_guard = crate::logging::init().await?;
73 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
74 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
75 0 :
76 0 : Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
77 0 :
78 0 : let args = cli().get_matches();
79 0 : let destination: String = args
80 0 : .get_one::<String>("dest")
81 0 : .expect("string argument defined")
82 0 : .parse()?;
83 :
84 : // Configure TLS
85 0 : let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
86 0 : args.get_one::<String>("tls-key"),
87 0 : args.get_one::<String>("tls-cert"),
88 : ) {
89 0 : (Some(key_path), Some(cert_path)) => {
90 0 : let key = {
91 0 : let key_bytes = std::fs::read(key_path).context("TLS key file")?;
92 :
93 0 : let mut keys =
94 0 : rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
95 0 :
96 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
97 : PrivateKeyDer::Pkcs8(
98 0 : keys.pop()
99 0 : .expect("keys should not be empty")
100 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?,
101 : )
102 : };
103 :
104 0 : let cert_chain_bytes = std::fs::read(cert_path)
105 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
106 :
107 0 : let cert_chain: Vec<_> = {
108 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
109 0 : .try_collect()
110 0 : .with_context(|| {
111 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
112 0 : })?
113 : };
114 :
115 : // needed for channel bindings
116 0 : let first_cert = cert_chain.first().context("missing certificate")?;
117 0 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
118 :
119 0 : let tls_config =
120 0 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
121 0 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
122 0 : .context("ring should support TLS1.2 and TLS1.3")?
123 0 : .with_no_client_auth()
124 0 : .with_single_cert(cert_chain, key)?
125 0 : .into();
126 0 :
127 0 : (tls_config, tls_server_end_point)
128 : }
129 0 : _ => bail!("tls-key and tls-cert must be specified"),
130 : };
131 :
132 0 : let compute_tls_config =
133 0 : Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
134 :
135 : // Start listening for incoming client connections
136 0 : let proxy_address: SocketAddr = args
137 0 : .get_one::<String>("listen")
138 0 : .expect("listen argument defined")
139 0 : .parse()?;
140 0 : let proxy_address_compute_tls: SocketAddr = args
141 0 : .get_one::<String>("listen-tls")
142 0 : .expect("listen-tls argument defined")
143 0 : .parse()?;
144 :
145 0 : info!("Starting sni router on {proxy_address}");
146 0 : info!("Starting sni router on {proxy_address_compute_tls}");
147 0 : let proxy_listener = TcpListener::bind(proxy_address).await?;
148 0 : let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
149 :
150 0 : let cancellation_token = CancellationToken::new();
151 0 : let dest = Arc::new(destination);
152 0 :
153 0 : let main = tokio::spawn(task_main(
154 0 : dest.clone(),
155 0 : tls_config.clone(),
156 0 : None,
157 0 : tls_server_end_point,
158 0 : proxy_listener,
159 0 : cancellation_token.clone(),
160 0 : ))
161 0 : .map(crate::error::flatten_err);
162 0 :
163 0 : let main_tls = tokio::spawn(task_main(
164 0 : dest,
165 0 : tls_config,
166 0 : Some(compute_tls_config),
167 0 : tls_server_end_point,
168 0 : proxy_listener_compute_tls,
169 0 : cancellation_token.clone(),
170 0 : ))
171 0 : .map(crate::error::flatten_err);
172 0 : let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
173 0 :
174 0 : // the signal task cant ever succeed.
175 0 : // the main task can error, or can succeed on cancellation.
176 0 : // we want to immediately exit on either of these cases
177 0 : let main = futures::future::try_join(main, main_tls);
178 0 : let signal = match futures::future::select(signals_task, main).await {
179 0 : Either::Left((res, _)) => crate::error::flatten_err(res)?,
180 0 : Either::Right((res, _)) => {
181 0 : res?;
182 0 : return Ok(());
183 : }
184 : };
185 :
186 : // maintenance tasks return `Infallible` success values, this is an impossible value
187 : // so this match statically ensures that there are no possibilities for that value
188 : match signal {}
189 0 : }
190 :
191 0 : async fn task_main(
192 0 : dest_suffix: Arc<String>,
193 0 : tls_config: Arc<rustls::ServerConfig>,
194 0 : compute_tls_config: Option<Arc<rustls::ClientConfig>>,
195 0 : tls_server_end_point: TlsServerEndPoint,
196 0 : listener: tokio::net::TcpListener,
197 0 : cancellation_token: CancellationToken,
198 0 : ) -> anyhow::Result<()> {
199 0 : // When set for the server socket, the keepalive setting
200 0 : // will be inherited by all accepted client sockets.
201 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
202 :
203 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
204 :
205 0 : while let Some(accept_result) =
206 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
207 : {
208 0 : let (socket, peer_addr) = accept_result?;
209 :
210 0 : let session_id = uuid::Uuid::new_v4();
211 0 : let tls_config = Arc::clone(&tls_config);
212 0 : let dest_suffix = Arc::clone(&dest_suffix);
213 0 : let compute_tls_config = compute_tls_config.clone();
214 0 :
215 0 : connections.spawn(
216 0 : async move {
217 0 : socket
218 0 : .set_nodelay(true)
219 0 : .context("failed to set socket option")?;
220 :
221 0 : info!(%peer_addr, "serving");
222 0 : let ctx = RequestContext::new(
223 0 : session_id,
224 0 : ConnectionInfo {
225 0 : addr: peer_addr,
226 0 : extra: None,
227 0 : },
228 0 : crate::metrics::Protocol::SniRouter,
229 0 : "sni",
230 0 : );
231 0 : handle_client(
232 0 : ctx,
233 0 : dest_suffix,
234 0 : tls_config,
235 0 : compute_tls_config,
236 0 : tls_server_end_point,
237 0 : socket,
238 0 : )
239 0 : .await
240 0 : }
241 0 : .unwrap_or_else(|e| {
242 0 : // Acknowledge that the task has finished with an error.
243 0 : error!("per-client task finished with an error: {e:#}");
244 0 : })
245 0 : .instrument(tracing::info_span!("handle_client", ?session_id)),
246 : );
247 : }
248 :
249 0 : connections.close();
250 0 : drop(listener);
251 0 :
252 0 : connections.wait().await;
253 :
254 0 : info!("all client connections have finished");
255 0 : Ok(())
256 0 : }
257 :
258 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
259 :
260 0 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
261 0 : ctx: &RequestContext,
262 0 : raw_stream: S,
263 0 : tls_config: Arc<rustls::ServerConfig>,
264 0 : tls_server_end_point: TlsServerEndPoint,
265 0 : ) -> anyhow::Result<Stream<S>> {
266 0 : let mut stream = PqStream::new(Stream::from_raw(raw_stream));
267 :
268 0 : let msg = stream.read_startup_packet().await?;
269 : use pq_proto::FeStartupPacket::SslRequest;
270 :
271 0 : match msg {
272 : SslRequest { direct: false } => {
273 0 : stream
274 0 : .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
275 0 : .await?;
276 :
277 : // Upgrade raw stream into a secure TLS-backed stream.
278 : // NOTE: We've consumed `tls`; this fact will be used later.
279 :
280 0 : let (raw, read_buf) = stream.into_inner();
281 0 : // TODO: Normally, client doesn't send any data before
282 0 : // server says TLS handshake is ok and read_buf is empty.
283 0 : // However, you could imagine pipelining of postgres
284 0 : // SSLRequest + TLS ClientHello in one hunk similar to
285 0 : // pipelining in our node js driver. We should probably
286 0 : // support that by chaining read_buf with the stream.
287 0 : if !read_buf.is_empty() {
288 0 : bail!("data is sent before server replied with EncryptionResponse");
289 0 : }
290 0 :
291 0 : Ok(Stream::Tls {
292 0 : tls: Box::new(
293 0 : raw.upgrade(tls_config, !ctx.has_private_peer_addr())
294 0 : .await?,
295 : ),
296 0 : tls_server_end_point,
297 : })
298 : }
299 0 : unexpected => {
300 0 : info!(
301 : ?unexpected,
302 0 : "unexpected startup packet, rejecting connection"
303 : );
304 0 : stream
305 0 : .throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
306 0 : .await?
307 : }
308 : }
309 0 : }
310 :
311 0 : async fn handle_client(
312 0 : ctx: RequestContext,
313 0 : dest_suffix: Arc<String>,
314 0 : tls_config: Arc<rustls::ServerConfig>,
315 0 : compute_tls_config: Option<Arc<rustls::ClientConfig>>,
316 0 : tls_server_end_point: TlsServerEndPoint,
317 0 : stream: impl AsyncRead + AsyncWrite + Unpin,
318 0 : ) -> anyhow::Result<()> {
319 0 : let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
320 :
321 : // Cut off first part of the SNI domain
322 : // We receive required destination details in the format of
323 : // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
324 0 : let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
325 0 : let dest: Vec<&str> = sni
326 0 : .split_once('.')
327 0 : .context("invalid SNI")?
328 : .0
329 0 : .splitn(3, "--")
330 0 : .collect();
331 0 : let port = dest[2].parse::<u16>().context("invalid port")?;
332 0 : let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
333 0 :
334 0 : info!("destination: {}", destination);
335 :
336 0 : let mut client = tokio::net::TcpStream::connect(&destination).await?;
337 :
338 0 : let client = if let Some(compute_tls_config) = compute_tls_config {
339 0 : info!("upgrading TLS");
340 :
341 : // send SslRequest
342 0 : client
343 0 : .write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
344 0 : .await?;
345 :
346 : // wait for S/N respons
347 0 : let mut resp = b'N';
348 0 : client.read_exact(std::slice::from_mut(&mut resp)).await?;
349 :
350 : // error if not S
351 0 : ensure!(resp == b'S', "compute refused TLS");
352 :
353 : // upgrade to TLS.
354 0 : let domain = DnsName::try_from(destination)?;
355 0 : let domain = rustls::pki_types::ServerName::DnsName(domain);
356 0 : let client = TlsConnector::from(compute_tls_config)
357 0 : .connect(domain, client)
358 0 : .await?;
359 0 : Connection::Tls(client)
360 : } else {
361 0 : Connection::Raw(client)
362 : };
363 :
364 : // doesn't yet matter as pg-sni-router doesn't report analytics logs
365 0 : ctx.set_success();
366 0 : ctx.log_connect();
367 0 :
368 0 : // Starting from here we only proxy the client's traffic.
369 0 : info!("performing the proxy pass...");
370 :
371 0 : let res = match client {
372 0 : Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
373 0 : Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
374 : };
375 :
376 0 : match res {
377 0 : Ok(_) => Ok(()),
378 0 : Err(ErrorSource::Client(err)) => Err(err).context("client"),
379 0 : Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
380 : }
381 0 : }
382 :
383 : enum Connection {
384 : Raw(tokio::net::TcpStream),
385 : Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
386 : }
|