Line data Source code
1 : //!
2 : //! WAL service listens for client connections and
3 : //! receive WAL from wal_proposer and send it to WAL receivers
4 : //!
5 : use std::os::fd::AsRawFd;
6 : use std::sync::Arc;
7 : use std::time::Duration;
8 :
9 : use anyhow::{Context, Result};
10 : use postgres_backend::{AuthType, PostgresBackend, QueryError};
11 : use safekeeper_api::models::ConnectionId;
12 : use tokio::net::TcpStream;
13 : use tokio_io_timeout::TimeoutReader;
14 : use tokio_util::sync::CancellationToken;
15 : use tracing::*;
16 : use utils::auth::Scope;
17 : use utils::measured_stream::MeasuredStream;
18 :
19 : use crate::handler::SafekeeperPostgresHandler;
20 : use crate::metrics::TrafficMetrics;
21 : use crate::{GlobalTimelines, SafeKeeperConf};
22 :
23 : /// Accept incoming TCP connections and spawn them into a background thread.
24 : ///
25 : /// allowed_auth_scope is either SafekeeperData (wide JWT tokens giving access
26 : /// to any tenant are allowed) or Tenant (only tokens giving access to specific
27 : /// tenant are allowed). Doesn't matter if auth is disabled in conf.
28 0 : pub async fn task_main(
29 0 : conf: Arc<SafeKeeperConf>,
30 0 : pg_listener: std::net::TcpListener,
31 0 : allowed_auth_scope: Scope,
32 0 : global_timelines: Arc<GlobalTimelines>,
33 0 : ) -> anyhow::Result<()> {
34 0 : // Tokio's from_std won't do this for us, per its comment.
35 0 : pg_listener.set_nonblocking(true)?;
36 :
37 0 : let listener = tokio::net::TcpListener::from_std(pg_listener)?;
38 0 : let mut connection_count: ConnectionCount = 0;
39 :
40 : loop {
41 0 : let (socket, peer_addr) = listener.accept().await.context("accept")?;
42 0 : debug!("accepted connection from {}", peer_addr);
43 0 : let conf = conf.clone();
44 0 : let conn_id = issue_connection_id(&mut connection_count);
45 0 : let global_timelines = global_timelines.clone();
46 0 : tokio::spawn(
47 0 : async move {
48 0 : if let Err(err) = handle_socket(socket, conf, conn_id, allowed_auth_scope, global_timelines).await {
49 0 : error!("connection handler exited: {}", err);
50 0 : }
51 0 : }
52 0 : .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty, shard = field::Empty)),
53 : );
54 : }
55 0 : }
56 :
57 : /// This is run by `task_main` above, inside a background thread.
58 : ///
59 0 : async fn handle_socket(
60 0 : socket: TcpStream,
61 0 : conf: Arc<SafeKeeperConf>,
62 0 : conn_id: ConnectionId,
63 0 : allowed_auth_scope: Scope,
64 0 : global_timelines: Arc<GlobalTimelines>,
65 0 : ) -> Result<(), QueryError> {
66 0 : socket.set_nodelay(true)?;
67 0 : let socket_fd = socket.as_raw_fd();
68 0 : let peer_addr = socket.peer_addr()?;
69 :
70 : // Set timeout on reading from the socket. It prevents hanged up connection
71 : // if client suddenly disappears. Note that TCP_KEEPALIVE is not enabled by
72 : // default, and tokio doesn't provide ability to set it out of the box.
73 0 : let mut socket = TimeoutReader::new(socket);
74 0 : let wal_service_timeout = Duration::from_secs(60 * 10);
75 0 : socket.set_timeout(Some(wal_service_timeout));
76 0 : // pin! is here because TimeoutReader (due to storing sleep future inside)
77 0 : // is not Unpin, and all pgbackend/framed/tokio dependencies require stream
78 0 : // to be Unpin. Which is reasonable, as indeed something like TimeoutReader
79 0 : // shouldn't be moved.
80 0 : let socket = std::pin::pin!(socket);
81 0 :
82 0 : let traffic_metrics = TrafficMetrics::new();
83 0 : if let Some(current_az) = conf.availability_zone.as_deref() {
84 0 : traffic_metrics.set_sk_az(current_az);
85 0 : }
86 :
87 0 : let socket = MeasuredStream::new(
88 0 : socket,
89 0 : |cnt| {
90 0 : traffic_metrics.observe_read(cnt);
91 0 : },
92 0 : |cnt| {
93 0 : traffic_metrics.observe_write(cnt);
94 0 : },
95 0 : );
96 :
97 0 : let auth_key = match allowed_auth_scope {
98 0 : Scope::Tenant => conf.pg_tenant_only_auth.clone(),
99 0 : _ => conf.pg_auth.clone(),
100 : };
101 0 : let auth_type = match auth_key {
102 0 : None => AuthType::Trust,
103 0 : Some(_) => AuthType::NeonJWT,
104 : };
105 0 : let auth_pair = auth_key.map(|key| (allowed_auth_scope, key));
106 0 : let mut conn_handler = SafekeeperPostgresHandler::new(
107 0 : conf,
108 0 : conn_id,
109 0 : Some(traffic_metrics.clone()),
110 0 : auth_pair,
111 0 : global_timelines,
112 0 : );
113 0 : let pgbackend = PostgresBackend::new_from_io(socket_fd, socket, peer_addr, auth_type, None)?;
114 : // libpq protocol between safekeeper and walproposer / pageserver
115 : // We don't use shutdown.
116 0 : pgbackend
117 0 : .run(&mut conn_handler, &CancellationToken::new())
118 0 : .await
119 0 : }
120 :
121 : pub type ConnectionCount = u32;
122 :
123 0 : pub fn issue_connection_id(count: &mut ConnectionCount) -> ConnectionId {
124 0 : *count = count.wrapping_add(1);
125 0 : *count
126 0 : }
|