TLA Line data Source code
1 : /// A stand-alone program that routes connections, e.g. from
2 : /// `aaa--bbb--1234.external.domain` to `aaa.bbb.internal.domain:1234`.
3 : ///
4 : /// This allows connecting to pods/services running in the same Kubernetes cluster from
5 : /// the outside. Similar to an ingress controller for HTTPS.
6 : use std::{net::SocketAddr, sync::Arc};
7 :
8 : use futures::future::Either;
9 : use tokio::net::TcpListener;
10 :
11 : use anyhow::{anyhow, bail, ensure, Context};
12 : use clap::{self, Arg};
13 : use futures::TryFutureExt;
14 : use proxy::console::messages::MetricsAuxInfo;
15 : use proxy::stream::{PqStream, Stream};
16 :
17 : use tokio::io::{AsyncRead, AsyncWrite};
18 : use tokio_util::sync::CancellationToken;
19 : use utils::{project_git_version, sentry_init::init_sentry};
20 :
21 : use tracing::{error, info, warn, Instrument};
22 :
23 : project_git_version!(GIT_VERSION);
24 :
25 CBC 1 : fn cli() -> clap::Command {
26 1 : clap::Command::new("Neon proxy/router")
27 1 : .version(GIT_VERSION)
28 1 : .arg(
29 1 : Arg::new("listen")
30 1 : .short('l')
31 1 : .long("listen")
32 1 : .help("listen for incoming client connections on ip:port")
33 1 : .default_value("127.0.0.1:4432"),
34 1 : )
35 1 : .arg(
36 1 : Arg::new("tls-key")
37 1 : .short('k')
38 1 : .long("tls-key")
39 1 : .help("path to TLS key for client postgres connections")
40 1 : .required(true),
41 1 : )
42 1 : .arg(
43 1 : Arg::new("tls-cert")
44 1 : .short('c')
45 1 : .long("tls-cert")
46 1 : .help("path to TLS cert for client postgres connections")
47 1 : .required(true),
48 1 : )
49 1 : .arg(
50 1 : Arg::new("dest")
51 1 : .short('d')
52 1 : .long("destination")
53 1 : .help("append this domain zone to the SNI hostname to get the destination address")
54 1 : .required(true),
55 1 : )
56 1 : }
57 :
58 : #[tokio::main]
59 1 : async fn main() -> anyhow::Result<()> {
60 1 : let _logging_guard = proxy::logging::init().await?;
61 1 : let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
62 1 : let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
63 1 :
64 1 : let args = cli().get_matches();
65 1 : let destination: String = args.get_one::<String>("dest").unwrap().parse()?;
66 :
67 : // Configure TLS
68 1 : let tls_config: Arc<rustls::ServerConfig> = match (
69 1 : args.get_one::<String>("tls-key"),
70 1 : args.get_one::<String>("tls-cert"),
71 : ) {
72 1 : (Some(key_path), Some(cert_path)) => {
73 1 : let key = {
74 1 : let key_bytes = std::fs::read(key_path).context("TLS key file")?;
75 1 : let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
76 1 : .context(format!("Failed to read TLS keys at '{key_path}'"))?;
77 :
78 1 : ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
79 1 : keys.pop().map(rustls::PrivateKey).unwrap()
80 : };
81 :
82 1 : let cert_chain_bytes = std::fs::read(cert_path)
83 1 : .context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
84 :
85 1 : let cert_chain = {
86 1 : rustls_pemfile::certs(&mut &cert_chain_bytes[..])
87 1 : .context(format!(
88 1 : "Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
89 1 : ))?
90 1 : .into_iter()
91 1 : .map(rustls::Certificate)
92 1 : .collect()
93 1 : };
94 1 :
95 1 : rustls::ServerConfig::builder()
96 1 : .with_safe_default_cipher_suites()
97 1 : .with_safe_default_kx_groups()
98 1 : .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
99 1 : .with_no_client_auth()
100 1 : .with_single_cert(cert_chain, key)?
101 1 : .into()
102 : }
103 UBC 0 : _ => bail!("tls-key and tls-cert must be specified"),
104 : };
105 :
106 : // Start listening for incoming client connections
107 CBC 1 : let proxy_address: SocketAddr = args.get_one::<String>("listen").unwrap().parse()?;
108 1 : info!("Starting sni router on {proxy_address}");
109 1 : let proxy_listener = TcpListener::bind(proxy_address).await?;
110 :
111 1 : let cancellation_token = CancellationToken::new();
112 1 :
113 1 : let main = tokio::spawn(task_main(
114 1 : Arc::new(destination),
115 1 : tls_config,
116 1 : proxy_listener,
117 1 : cancellation_token.clone(),
118 1 : ));
119 1 : let signals_task = tokio::spawn(proxy::handle_signals(cancellation_token));
120 :
121 : // the signal task cant ever succeed.
122 : // the main task can error, or can succeed on cancellation.
123 : // we want to immediately exit on either of these cases
124 1 : let signal = match futures::future::select(signals_task, main).await {
125 UBC 0 : Either::Left((res, _)) => proxy::flatten_err(res)?,
126 CBC 1 : Either::Right((res, _)) => return proxy::flatten_err(res),
127 : };
128 :
129 : // maintenance tasks return `Infallible` success values, this is an impossible value
130 : // so this match statically ensures that there are no possibilities for that value
131 : match signal {}
132 : }
133 :
134 1 : async fn task_main(
135 1 : dest_suffix: Arc<String>,
136 1 : tls_config: Arc<rustls::ServerConfig>,
137 1 : listener: tokio::net::TcpListener,
138 1 : cancellation_token: CancellationToken,
139 1 : ) -> anyhow::Result<()> {
140 : // When set for the server socket, the keepalive setting
141 : // will be inherited by all accepted client sockets.
142 1 : socket2::SockRef::from(&listener).set_keepalive(true)?;
143 :
144 1 : let mut connections = tokio::task::JoinSet::new();
145 :
146 : loop {
147 3 : tokio::select! {
148 2 : accept_result = listener.accept() => {
149 : let (socket, peer_addr) = accept_result?;
150 :
151 : let session_id = uuid::Uuid::new_v4();
152 : let tls_config = Arc::clone(&tls_config);
153 : let dest_suffix = Arc::clone(&dest_suffix);
154 :
155 : connections.spawn(
156 : async move {
157 2 : socket
158 2 : .set_nodelay(true)
159 2 : .context("failed to set socket option")?;
160 :
161 2 : info!(%peer_addr, "serving");
162 11 : handle_client(dest_suffix, tls_config, socket).await
163 2 : }
164 2 : .unwrap_or_else(|e| {
165 2 : // Acknowledge that the task has finished with an error.
166 2 : error!("per-client task finished with an error: {e:#}");
167 2 : })
168 : .instrument(tracing::info_span!("handle_client", ?session_id))
169 : );
170 : }
171 UBC 0 : Some(Err(e)) = connections.join_next(), if !connections.is_empty() => {
172 : if !e.is_panic() && !e.is_cancelled() {
173 0 : warn!("unexpected error from joined connection task: {e:?}");
174 : }
175 : }
176 : _ = cancellation_token.cancelled() => {
177 : drop(listener);
178 : break;
179 : }
180 : }
181 : }
182 :
183 : // Drain connections
184 CBC 1 : info!("waiting for all client connections to finish");
185 2 : while let Some(res) = connections.join_next().await {
186 1 : if let Err(e) = res {
187 UBC 0 : if !e.is_panic() && !e.is_cancelled() {
188 0 : warn!("unexpected error from joined connection task: {e:?}");
189 0 : }
190 CBC 1 : }
191 : }
192 1 : info!("all client connections have finished");
193 1 : Ok(())
194 1 : }
195 :
196 : const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
197 :
198 2 : async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
199 2 : raw_stream: S,
200 2 : tls_config: Arc<rustls::ServerConfig>,
201 2 : ) -> anyhow::Result<Stream<S>> {
202 2 : let mut stream = PqStream::new(Stream::from_raw(raw_stream));
203 :
204 2 : let msg = stream.read_startup_packet().await?;
205 : use pq_proto::FeStartupPacket::*;
206 :
207 1 : match msg {
208 : SslRequest => {
209 1 : stream
210 1 : .write_message(&pq_proto::BeMessage::EncryptionResponse(true))
211 UBC 0 : .await?;
212 : // Upgrade raw stream into a secure TLS-backed stream.
213 : // NOTE: We've consumed `tls`; this fact will be used later.
214 :
215 CBC 1 : let (raw, read_buf) = stream.into_inner();
216 1 : // TODO: Normally, client doesn't send any data before
217 1 : // server says TLS handshake is ok and read_buf is empy.
218 1 : // However, you could imagine pipelining of postgres
219 1 : // SSLRequest + TLS ClientHello in one hunk similar to
220 1 : // pipelining in our node js driver. We should probably
221 1 : // support that by chaining read_buf with the stream.
222 1 : if !read_buf.is_empty() {
223 UBC 0 : bail!("data is sent before server replied with EncryptionResponse");
224 CBC 1 : }
225 2 : Ok(raw.upgrade(tls_config).await?)
226 : }
227 UBC 0 : unexpected => {
228 0 : info!(
229 0 : ?unexpected,
230 0 : "unexpected startup packet, rejecting connection"
231 0 : );
232 0 : stream.throw_error_str(ERR_INSECURE_CONNECTION).await?
233 : }
234 : }
235 CBC 2 : }
236 :
237 2 : async fn handle_client(
238 2 : dest_suffix: Arc<String>,
239 2 : tls_config: Arc<rustls::ServerConfig>,
240 2 : stream: impl AsyncRead + AsyncWrite + Unpin,
241 2 : ) -> anyhow::Result<()> {
242 4 : let tls_stream = ssl_handshake(stream, tls_config).await?;
243 :
244 : // Cut off first part of the SNI domain
245 : // We receive required destination details in the format of
246 : // `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
247 1 : let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
248 1 : let dest: Vec<&str> = sni
249 1 : .split_once('.')
250 1 : .context("invalid SNI")?
251 : .0
252 1 : .splitn(3, "--")
253 1 : .collect();
254 1 : let port = dest[2].parse::<u16>().context("invalid port")?;
255 1 : let destination = format!("{}.{}.{}:{}", dest[0], dest[1], dest_suffix, port);
256 :
257 1 : info!("destination: {}", destination);
258 :
259 2 : let client = tokio::net::TcpStream::connect(destination).await?;
260 :
261 1 : let metrics_aux: MetricsAuxInfo = Default::default();
262 5 : proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
263 2 : }
|