Line data Source code
1 : //! A stand-alone program that routes connections, e.g. from
2 : //! `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
3 : //!
4 : //! This allows connecting to pods/services running in the same Kubernetes cluster from
5 : //! the outside. Similar to an ingress controller for HTTPS.
6 :
7 : use std::io;
8 : use std::net::SocketAddr;
9 : use std::path::Path;
10 : use std::sync::Arc;
11 :
12 : use anyhow::{Context, anyhow, bail, ensure};
13 : use clap::Arg;
14 : use futures::future::Either;
15 : use futures::{FutureExt, TryFutureExt};
16 : use itertools::Itertools;
17 : use rustls::crypto::ring;
18 : use rustls::pki_types::{DnsName, PrivateKeyDer};
19 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20 : use tokio::net::TcpListener;
21 : use tokio_rustls::TlsConnector;
22 : use tokio_rustls::server::TlsStream;
23 : use tokio_util::sync::CancellationToken;
24 : use tracing::{Instrument, error, info};
25 : use utils::project_git_version;
26 : use utils::sentry_init::init_sentry;
27 :
28 : use crate::context::RequestContext;
29 : use crate::metrics::{Metrics, ServiceInfo};
30 : use crate::pglb::TlsRequired;
31 : use crate::pqproto::FeStartupPacket;
32 : use crate::protocol2::ConnectionInfo;
33 : use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
34 : use crate::stream::{PqStream, Stream};
35 : use crate::util::run_until_cancelled;
36 :
37 : project_git_version!(GIT_VERSION);
38 :
39 0 : fn cli() -> clap::Command {
40 0 : clap::Command::new("Neon proxy/router")
41 0 : .version(GIT_VERSION)
42 0 : .arg(
43 0 : Arg::new("listen")
44 0 : .short('l')
45 0 : .long("listen")
46 0 : .help("listen for incoming client connections on ip:port")
47 0 : .default_value("127.0.0.1:4432"),
48 : )
49 0 : .arg(
50 0 : Arg::new("listen-tls")
51 0 : .long("listen-tls")
52 0 : .help("listen for incoming client connections on ip:port, requiring TLS to compute")
53 0 : .default_value("127.0.0.1:4433"),
54 : )
55 0 : .arg(
56 0 : Arg::new("tls-key")
57 0 : .short('k')
58 0 : .long("tls-key")
59 0 : .help("path to TLS key for client postgres connections")
60 0 : .required(true),
61 : )
62 0 : .arg(
63 0 : Arg::new("tls-cert")
64 0 : .short('c')
65 0 : .long("tls-cert")
66 0 : .help("path to TLS cert for client postgres connections")
67 0 : .required(true),
68 : )
69 0 : .arg(
70 0 : Arg::new("dest")
71 0 : .short('d')
72 0 : .long("destination")
73 0 : .help("append this domain zone to the SNI hostname to get the destination address")
74 0 : .required(true),
75 : )
76 0 : }
77 :
78 0 : pub async fn run() -> anyhow::Result<()> {
79 0 : let _logging_guard = crate::logging::init()?;
80 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
81 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
82 :
83 0 : let args = cli().get_matches();
84 0 : let destination: String = args
85 0 : .get_one::<String>("dest")
86 0 : .expect("string argument defined")
87 0 : .parse()?;
88 :
89 : // Configure TLS
90 0 : let tls_config = match (
91 0 : args.get_one::<String>("tls-key"),
92 0 : args.get_one::<String>("tls-cert"),
93 : ) {
94 0 : (Some(key_path), Some(cert_path)) => parse_tls(key_path.as_ref(), cert_path.as_ref())?,
95 0 : _ => bail!("tls-key and tls-cert must be specified"),
96 : };
97 :
98 0 : let compute_tls_config =
99 0 : Arc::new(crate::tls::client_config::compute_client_config_with_root_certs()?);
100 :
101 : // Start listening for incoming client connections
102 0 : let proxy_address: SocketAddr = args
103 0 : .get_one::<String>("listen")
104 0 : .expect("listen argument defined")
105 0 : .parse()?;
106 0 : let proxy_address_compute_tls: SocketAddr = args
107 0 : .get_one::<String>("listen-tls")
108 0 : .expect("listen-tls argument defined")
109 0 : .parse()?;
110 :
111 0 : info!("Starting sni router on {proxy_address}");
112 0 : info!("Starting sni router on {proxy_address_compute_tls}");
113 0 : let proxy_listener = TcpListener::bind(proxy_address).await?;
114 0 : let proxy_listener_compute_tls = TcpListener::bind(proxy_address_compute_tls).await?;
115 :
116 0 : let cancellation_token = CancellationToken::new();
117 0 : let dest = Arc::new(destination);
118 :
119 0 : let main = tokio::spawn(task_main(
120 0 : dest.clone(),
121 0 : tls_config.clone(),
122 0 : None,
123 0 : proxy_listener,
124 0 : cancellation_token.clone(),
125 : ))
126 0 : .map(crate::error::flatten_err);
127 :
128 0 : let main_tls = tokio::spawn(task_main(
129 0 : dest,
130 0 : tls_config,
131 0 : Some(compute_tls_config),
132 0 : proxy_listener_compute_tls,
133 0 : cancellation_token.clone(),
134 : ))
135 0 : .map(crate::error::flatten_err);
136 :
137 0 : Metrics::get()
138 0 : .service
139 0 : .info
140 0 : .set_label(ServiceInfo::running());
141 :
142 0 : let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
143 :
144 : // the signal task cant ever succeed.
145 : // the main task can error, or can succeed on cancellation.
146 : // we want to immediately exit on either of these cases
147 0 : let main = futures::future::try_join(main, main_tls);
148 0 : let signal = match futures::future::select(signals_task, main).await {
149 0 : Either::Left((res, _)) => crate::error::flatten_err(res)?,
150 0 : Either::Right((res, _)) => {
151 0 : res?;
152 0 : return Ok(());
153 : }
154 : };
155 :
156 : // maintenance tasks return `Infallible` success values, this is an impossible value
157 : // so this match statically ensures that there are no possibilities for that value
158 : match signal {}
159 0 : }
160 :
161 0 : pub(super) fn parse_tls(
162 0 : key_path: &Path,
163 0 : cert_path: &Path,
164 0 : ) -> anyhow::Result<Arc<rustls::ServerConfig>> {
165 0 : let key = {
166 0 : let key_bytes = std::fs::read(key_path).context("TLS key file")?;
167 :
168 0 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
169 :
170 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
171 : PrivateKeyDer::Pkcs8(
172 0 : keys.pop()
173 0 : .expect("keys should not be empty")
174 0 : .context(format!(
175 0 : "Failed to read TLS keys at '{}'",
176 0 : key_path.display()
177 0 : ))?,
178 : )
179 : };
180 :
181 0 : let cert_chain_bytes = std::fs::read(cert_path).context(format!(
182 0 : "Failed to read TLS cert file at '{}.'",
183 0 : cert_path.display()
184 0 : ))?;
185 :
186 0 : let cert_chain: Vec<_> = {
187 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
188 0 : .try_collect()
189 0 : .with_context(|| {
190 0 : format!(
191 0 : "Failed to read TLS certificate chain from bytes from file at '{}'.",
192 0 : cert_path.display()
193 : )
194 0 : })?
195 : };
196 :
197 0 : let tls_config =
198 0 : rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
199 0 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
200 0 : .context("ring should support TLS1.2 and TLS1.3")?
201 0 : .with_no_client_auth()
202 0 : .with_single_cert(cert_chain, key)?
203 0 : .into();
204 :
205 0 : Ok(tls_config)
206 0 : }
207 :
208 0 : pub(super) async fn task_main(
209 0 : dest_suffix: Arc<String>,
210 0 : tls_config: Arc<rustls::ServerConfig>,
211 0 : compute_tls_config: Option<Arc<rustls::ClientConfig>>,
212 0 : listener: tokio::net::TcpListener,
213 0 : cancellation_token: CancellationToken,
214 0 : ) -> anyhow::Result<()> {
215 : // When set for the server socket, the keepalive setting
216 : // will be inherited by all accepted client sockets.
217 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
218 :
219 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
220 :
221 0 : while let Some(accept_result) =
222 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
223 : {
224 0 : let (socket, peer_addr) = accept_result?;
225 :
226 0 : let session_id = uuid::Uuid::new_v4();
227 0 : let tls_config = Arc::clone(&tls_config);
228 0 : let dest_suffix = Arc::clone(&dest_suffix);
229 0 : let compute_tls_config = compute_tls_config.clone();
230 :
231 0 : connections.spawn(
232 0 : async move {
233 0 : socket
234 0 : .set_nodelay(true)
235 0 : .context("failed to set socket option")?;
236 :
237 0 : let ctx = RequestContext::new(
238 0 : session_id,
239 0 : ConnectionInfo {
240 0 : addr: peer_addr,
241 0 : extra: None,
242 0 : },
243 0 : crate::metrics::Protocol::SniRouter,
244 : );
245 0 : handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
246 0 : }
247 0 : .unwrap_or_else(|e| {
248 0 : if let Some(FirstMessage(io_error)) = e.downcast_ref() {
249 : // this is noisy. if we get EOF on the very first message that's likely
250 : // just NLB doing a healthcheck.
251 0 : if io_error.kind() == io::ErrorKind::UnexpectedEof {
252 0 : return;
253 0 : }
254 0 : }
255 :
256 : // Acknowledge that the task has finished with an error.
257 0 : error!("per-client task finished with an error: {e:#}");
258 0 : })
259 0 : .instrument(tracing::info_span!("handle_client", ?session_id)),
260 : );
261 : }
262 :
263 0 : connections.close();
264 0 : drop(listener);
265 :
266 0 : connections.wait().await;
267 :
268 0 : info!("all client connections have finished");
269 0 : Ok(())
270 0 : }
271 :
272 : #[derive(Debug, thiserror::Error)]
273 : #[error(transparent)]
274 : struct FirstMessage(io::Error);
275 :
276 0 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
277 0 : ctx: &RequestContext,
278 0 : raw_stream: S,
279 0 : tls_config: Arc<rustls::ServerConfig>,
280 0 : ) -> anyhow::Result<TlsStream<S>> {
281 0 : let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
282 0 : .await
283 0 : .map_err(FirstMessage)?;
284 :
285 0 : match msg {
286 : FeStartupPacket::SslRequest { direct: None } => {
287 0 : let raw = stream.accept_tls().await?;
288 :
289 0 : Ok(raw
290 0 : .upgrade(tls_config, !ctx.has_private_peer_addr())
291 0 : .await?)
292 : }
293 0 : unexpected => {
294 0 : info!(
295 : ?unexpected,
296 0 : "unexpected startup packet, rejecting connection"
297 : );
298 0 : Err(stream.throw_error(TlsRequired, None).await)?
299 : }
300 : }
301 0 : }
302 :
303 0 : async fn handle_client(
304 0 : ctx: RequestContext,
305 0 : dest_suffix: Arc<String>,
306 0 : tls_config: Arc<rustls::ServerConfig>,
307 0 : compute_tls_config: Option<Arc<rustls::ClientConfig>>,
308 0 : stream: impl AsyncRead + AsyncWrite + Unpin,
309 0 : ) -> anyhow::Result<()> {
310 0 : let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
311 :
312 : // Cut off first part of the SNI domain
313 : // We receive required destination details in the format of
314 : // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
315 0 : let sni = tls_stream
316 0 : .get_ref()
317 0 : .1
318 0 : .server_name()
319 0 : .ok_or(anyhow!("SNI missing"))?;
320 0 : let dest: Vec<&str> = sni
321 0 : .split_once('.')
322 0 : .context("invalid SNI")?
323 : .0
324 0 : .splitn(3, "--")
325 0 : .collect();
326 0 : let port = dest[2].parse::<u16>().context("invalid port")?;
327 0 : let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
328 :
329 0 : info!("destination: {}", destination);
330 :
331 0 : let mut client = tokio::net::TcpStream::connect(&destination).await?;
332 :
333 0 : let client = if let Some(compute_tls_config) = compute_tls_config {
334 0 : info!("upgrading TLS");
335 :
336 : // send SslRequest
337 0 : client
338 0 : .write_all(b"\x00\x00\x00\x08\x04\xd2\x16\x2f")
339 0 : .await?;
340 :
341 : // wait for S/N respons
342 0 : let mut resp = b'N';
343 0 : client.read_exact(std::slice::from_mut(&mut resp)).await?;
344 :
345 : // error if not S
346 0 : ensure!(resp == b'S', "compute refused TLS");
347 :
348 : // upgrade to TLS.
349 0 : let domain = DnsName::try_from(destination)?;
350 0 : let domain = rustls::pki_types::ServerName::DnsName(domain);
351 0 : let client = TlsConnector::from(compute_tls_config)
352 0 : .connect(domain, client)
353 0 : .await?;
354 0 : Connection::Tls(client)
355 : } else {
356 0 : Connection::Raw(client)
357 : };
358 :
359 : // doesn't yet matter as pg-sni-router doesn't report analytics logs
360 0 : ctx.set_success();
361 0 : ctx.log_connect();
362 :
363 : // Starting from here we only proxy the client's traffic.
364 0 : info!("performing the proxy pass...");
365 :
366 0 : let res = match client {
367 0 : Connection::Raw(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
368 0 : Connection::Tls(mut c) => copy_bidirectional_client_compute(&mut tls_stream, &mut c).await,
369 : };
370 :
371 0 : match res {
372 0 : Ok(_) => Ok(()),
373 0 : Err(ErrorSource::Client(err)) => Err(err).context("client"),
374 0 : Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
375 : }
376 0 : }
377 :
378 : #[allow(clippy::large_enum_variant)]
379 : enum Connection {
380 : Raw(tokio::net::TcpStream),
381 : Tls(tokio_rustls::client::TlsStream<tokio::net::TcpStream>),
382 : }
|