Line data Source code
1 : /// A stand-alone program that routes connections, e.g. from
2 : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
3 : ///
4 : /// This allows connecting to pods/services running in the same Kubernetes cluster from
5 : /// the outside. Similar to an ingress controller for HTTPS.
6 : use std::{net::SocketAddr, sync::Arc};
7 :
8 : use futures::future::Either;
9 : use itertools::Itertools;
10 : use proxy::config::TlsServerEndPoint;
11 : use proxy::context::RequestMonitoring;
12 : use proxy::metrics::{Metrics, ThreadPoolMetrics};
13 : use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
14 : use rustls::pki_types::PrivateKeyDer;
15 : use tokio::net::TcpListener;
16 :
17 : use anyhow::{anyhow, bail, ensure, Context};
18 : use clap::Arg;
19 : use futures::TryFutureExt;
20 : use proxy::stream::{PqStream, Stream};
21 :
22 : use tokio::io::{AsyncRead, AsyncWrite};
23 : use tokio_util::sync::CancellationToken;
24 : use utils::{project_git_version, sentry_init::init_sentry};
25 :
26 : use tracing::{error, info, Instrument};
27 :
28 : project_git_version!(GIT_VERSION);
29 :
30 0 : fn cli() -> clap::Command {
31 0 : clap::Command::new("Neon proxy/router")
32 0 : .version(GIT_VERSION)
33 0 : .arg(
34 0 : Arg::new("listen")
35 0 : .short('l')
36 0 : .long("listen")
37 0 : .help("listen for incoming client connections on ip:port")
38 0 : .default_value("127.0.0.1:4432"),
39 0 : )
40 0 : .arg(
41 0 : Arg::new("tls-key")
42 0 : .short('k')
43 0 : .long("tls-key")
44 0 : .help("path to TLS key for client postgres connections")
45 0 : .required(true),
46 0 : )
47 0 : .arg(
48 0 : Arg::new("tls-cert")
49 0 : .short('c')
50 0 : .long("tls-cert")
51 0 : .help("path to TLS cert for client postgres connections")
52 0 : .required(true),
53 0 : )
54 0 : .arg(
55 0 : Arg::new("dest")
56 0 : .short('d')
57 0 : .long("destination")
58 0 : .help("append this domain zone to the SNI hostname to get the destination address")
59 0 : .required(true),
60 0 : )
61 0 : }
62 :
63 : #[tokio::main]
64 0 : async fn main() -> anyhow::Result<()> {
65 0 : let _logging_guard = proxy::logging::init().await?;
66 0 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
67 0 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
68 0 :
69 0 : Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
70 0 :
71 0 : let args = cli().get_matches();
72 0 : let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
73 0 :
74 0 : // Configure TLS
75 0 : let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
76 0 : args.get_one::<String>("tls-key"),
77 0 : args.get_one::<String>("tls-cert"),
78 0 : ) {
79 0 : (Some(key_path), Some(cert_path)) => {
80 0 : let key = {
81 0 : let key_bytes = std::fs::read(key_path).context("TLS key file")?;
82 0 :
83 0 : let mut keys =
84 0 : rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
85 0 :
86 0 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
87 0 : PrivateKeyDer::Pkcs8(
88 0 : keys.pop()
89 0 : .unwrap()
90 0 : .context(format!("Failed to read TLS keys at '{key_path}'"))?,
91 0 : )
92 0 : };
93 0 :
94 0 : let cert_chain_bytes = std::fs::read(cert_path)
95 0 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
96 0 :
97 0 : let cert_chain: Vec<_> = {
98 0 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
99 0 : .try_collect()
100 0 : .with_context(|| {
101 0 : format!("Failed to read TLS certificate chain from bytes from file at '{cert_path}'.")
102 0 : })?
103 0 : };
104 0 :
105 0 : // needed for channel bindings
106 0 : let first_cert = cert_chain.first().context("missing certificate")?;
107 0 : let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
108 0 :
109 0 : let tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[
110 0 : &rustls::version::TLS13,
111 0 : &rustls::version::TLS12,
112 0 : ])
113 0 : .with_no_client_auth()
114 0 : .with_single_cert(cert_chain, key)?
115 0 : .into();
116 0 :
117 0 : (tls_config, tls_server_end_point)
118 0 : }
119 0 : _ => bail!("tls-key and tls-cert must be specified"),
120 0 : };
121 0 :
122 0 : // Start listening for incoming client connections
123 0 : let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
124 0 : info!("Starting sni router on {proxy_address}");
125 0 : let proxy_listener = TcpListener::bind(proxy_address).await?;
126 0 :
127 0 : let cancellation_token = CancellationToken::new();
128 0 :
129 0 : let main = tokio::spawn(task_main(
130 0 : Arc::new(destination),
131 0 : tls_config,
132 0 : tls_server_end_point,
133 0 : proxy_listener,
134 0 : cancellation_token.clone(),
135 0 : ));
136 0 : let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token, || async {
137 0 : Ok(())
138 0 : }));
139 0 :
140 0 : // the signal task cant ever succeed.
141 0 : // the main task can error, or can succeed on cancellation.
142 0 : // we want to immediately exit on either of these cases
143 0 : let signal = match futures::future::select(signals_task, main).await {
144 0 : Either::Left((res, _)) => proxy::flatten_err(res)?,
145 0 : Either::Right((res, _)) => return proxy::flatten_err(res),
146 0 : };
147 0 :
148 0 : // maintenance tasks return `Infallible` success values, this is an impossible value
149 0 : // so this match statically ensures that there are no possibilities for that value
150 0 : match signal {}
151 0 : }
152 :
153 0 : async fn task_main(
154 0 : dest_suffix: Arc<String>,
155 0 : tls_config: Arc<rustls::ServerConfig>,
156 0 : tls_server_end_point: TlsServerEndPoint,
157 0 : listener: tokio::net::TcpListener,
158 0 : cancellation_token: CancellationToken,
159 0 : ) -> anyhow::Result<()> {
160 0 : // When set for the server socket, the keepalive setting
161 0 : // will be inherited by all accepted client sockets.
162 0 : socket2::SockRef::from(&listener).set_keepalive(true)?;
163 :
164 0 : let connections = tokio_util::task::task_tracker::TaskTracker::new();
165 :
166 0 : while let Some(accept_result) =
167 0 : run_until_cancelled(listener.accept(), &cancellation_token).await
168 : {
169 0 : let (socket, peer_addr) = accept_result?;
170 :
171 0 : let session_id = uuid::Uuid::new_v4();
172 0 : let tls_config = Arc::clone(&tls_config);
173 0 : let dest_suffix = Arc::clone(&dest_suffix);
174 0 :
175 0 : connections.spawn(
176 0 : async move {
177 0 : socket
178 0 : .set_nodelay(true)
179 0 : .context("failed to set socket option")?;
180 :
181 0 : info!(%peer_addr, "serving");
182 0 : let ctx = RequestMonitoring::new(
183 0 : session_id,
184 0 : peer_addr.ip(),
185 0 : proxy::metrics::Protocol::SniRouter,
186 0 : "sni",
187 0 : );
188 0 : handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
189 0 : }
190 0 : .unwrap_or_else(|e| {
191 0 : // Acknowledge that the task has finished with an error.
192 0 : error!("per-client task finished with an error: {e:#}");
193 0 : })
194 0 : .instrument(tracing::info_span!("handle_client", ?session_id)),
195 : );
196 : }
197 :
198 0 : connections.close();
199 0 : drop(listener);
200 0 :
201 0 : connections.wait().await;
202 :
203 0 : info!("all client connections have finished");
204 0 : Ok(())
205 0 : }
206 :
207 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
208 :
209 0 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
210 0 : ctx: &RequestMonitoring,
211 0 : raw_stream: S,
212 0 : tls_config: Arc<rustls::ServerConfig>,
213 0 : tls_server_end_point: TlsServerEndPoint,
214 0 : ) -> anyhow::Result<Stream<S>> {
215 0 : let mut stream = PqStream::new(Stream::from_raw(raw_stream));
216 :
217 0 : let msg = stream.read_startup_packet().await?;
218 : use pq_proto::FeStartupPacket::*;
219 :
220 0 : match msg {
221 : SslRequest { direct: false } => {
222 0 : stream
223 0 : .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
224 0 : .await?;
225 :
226 : // Upgrade raw stream into a secure TLS-backed stream.
227 : // NOTE: We've consumed `tls`; this fact will be used later.
228 :
229 0 : let (raw, read_buf) = stream.into_inner();
230 0 : // TODO: Normally, client doesn't send any data before
231 0 : // server says TLS handshake is ok and read_buf is empy.
232 0 : // However, you could imagine pipelining of postgres
233 0 : // SSLRequest + TLS ClientHello in one hunk similar to
234 0 : // pipelining in our node js driver. We should probably
235 0 : // support that by chaining read_buf with the stream.
236 0 : if !read_buf.is_empty() {
237 0 : bail!("data is sent before server replied with EncryptionResponse");
238 0 : }
239 0 :
240 0 : Ok(Stream::Tls {
241 0 : tls: Box::new(
242 0 : raw.upgrade(tls_config, !ctx.has_private_peer_addr())
243 0 : .await?,
244 : ),
245 0 : tls_server_end_point,
246 : })
247 : }
248 0 : unexpected => {
249 0 : info!(
250 : ?unexpected,
251 0 : "unexpected startup packet, rejecting connection"
252 : );
253 0 : stream
254 0 : .throw_error_str(ERR_INSECURE_CONNECTION, proxy::error::ErrorKind::User)
255 0 : .await?
256 : }
257 : }
258 0 : }
259 :
260 0 : async fn handle_client(
261 0 : ctx: RequestMonitoring,
262 0 : dest_suffix: Arc<String>,
263 0 : tls_config: Arc<rustls::ServerConfig>,
264 0 : tls_server_end_point: TlsServerEndPoint,
265 0 : stream: impl AsyncRead + AsyncWrite + Unpin,
266 0 : ) -> anyhow::Result<()> {
267 0 : let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
268 :
269 : // Cut off first part of the SNI domain
270 : // We receive required destination details in the format of
271 : // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
272 0 : let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
273 0 : let dest: Vec<&str> = sni
274 0 : .split_once('.')
275 0 : .context("invalid SNI")?
276 : .0
277 0 : .splitn(3, "--")
278 0 : .collect();
279 0 : let port = dest[2].parse::<u16>().context("invalid port")?;
280 0 : let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
281 0 :
282 0 : info!("destination: {}", destination);
283 :
284 0 : let mut client = tokio::net::TcpStream::connect(destination).await?;
285 :
286 : // doesn't yet matter as pg-sni-router doesn't report analytics logs
287 0 : ctx.set_success();
288 0 : ctx.log_connect();
289 0 :
290 0 : // Starting from here we only proxy the client's traffic.
291 0 : info!("performing the proxy pass...");
292 :
293 0 : match copy_bidirectional_client_compute(&mut tls_stream, &mut client).await {
294 0 : Ok(_) => Ok(()),
295 0 : Err(ErrorSource::Client(err)) => Err(err).context("client"),
296 0 : Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
297 : }
298 0 : }
|