LCOV - code coverage report
Current view: top level - safekeeper/src - wal_service.rs (source / functions) Coverage Total Hit
Test: 9e3a1ccbd88185d44390421f76c05f0bf588f617.info Lines: 0.0 % 74 0
Test Date: 2025-07-29 14:19:29 Functions: 0.0 % 9 0

            Line data    Source code
       1              : //!
       2              : //!   WAL service listens for client connections and
       3              : //!   receive WAL from wal_proposer and send it to WAL receivers
       4              : //!
       5              : use std::os::fd::AsRawFd;
       6              : use std::sync::Arc;
       7              : use std::time::Duration;
       8              : 
       9              : use anyhow::{Context, Result};
      10              : use postgres_backend::{AuthType, PostgresBackend, QueryError};
      11              : use safekeeper_api::models::ConnectionId;
      12              : use tokio::net::TcpStream;
      13              : use tokio_io_timeout::TimeoutReader;
      14              : use tokio_util::sync::CancellationToken;
      15              : use tracing::*;
      16              : use utils::auth::Scope;
      17              : use utils::measured_stream::MeasuredStream;
      18              : 
      19              : use crate::handler::SafekeeperPostgresHandler;
      20              : use crate::metrics::TrafficMetrics;
      21              : use crate::{GlobalTimelines, SafeKeeperConf};
      22              : 
      23              : /// Accept incoming TCP connections and spawn them into a background thread.
      24              : ///
      25              : /// allowed_auth_scope is either SafekeeperData (wide JWT tokens giving access
      26              : /// to any tenant are allowed) or Tenant (only tokens giving access to specific
      27              : /// tenant are allowed). Doesn't matter if auth is disabled in conf.
      28            0 : pub async fn task_main(
      29            0 :     conf: Arc<SafeKeeperConf>,
      30            0 :     pg_listener: std::net::TcpListener,
      31            0 :     allowed_auth_scope: Scope,
      32            0 :     tls_config: Option<Arc<rustls::ServerConfig>>,
      33            0 :     global_timelines: Arc<GlobalTimelines>,
      34            0 : ) -> anyhow::Result<()> {
      35              :     // Tokio's from_std won't do this for us, per its comment.
      36            0 :     pg_listener.set_nonblocking(true)?;
      37              : 
      38            0 :     let listener = tokio::net::TcpListener::from_std(pg_listener)?;
      39            0 :     let mut connection_count: ConnectionCount = 0;
      40              : 
      41              :     loop {
      42            0 :         let (socket, peer_addr) = listener.accept().await.context("accept")?;
      43            0 :         debug!("accepted connection from {}", peer_addr);
      44            0 :         let conf = conf.clone();
      45            0 :         let conn_id = issue_connection_id(&mut connection_count);
      46            0 :         let global_timelines = global_timelines.clone();
      47            0 :         let tls_config = tls_config.clone();
      48            0 :         tokio::spawn(
      49            0 :             async move {
      50            0 :                 if let Err(err) = handle_socket(socket, conf, conn_id, allowed_auth_scope, tls_config, global_timelines).await {
      51            0 :                     error!("connection handler exited: {}", err);
      52            0 :                 }
      53            0 :             }
      54            0 :             .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty, shard = field::Empty)),
      55              :         );
      56              :     }
      57            0 : }
      58              : 
      59              : /// This is run by `task_main` above, inside a background thread.
      60              : ///
      61            0 : async fn handle_socket(
      62            0 :     socket: TcpStream,
      63            0 :     conf: Arc<SafeKeeperConf>,
      64            0 :     conn_id: ConnectionId,
      65            0 :     allowed_auth_scope: Scope,
      66            0 :     tls_config: Option<Arc<rustls::ServerConfig>>,
      67            0 :     global_timelines: Arc<GlobalTimelines>,
      68            0 : ) -> Result<(), QueryError> {
      69            0 :     socket.set_nodelay(true)?;
      70            0 :     let socket_fd = socket.as_raw_fd();
      71            0 :     let peer_addr = socket.peer_addr()?;
      72              : 
      73              :     // Set timeout on reading from the socket. It prevents hanged up connection
      74              :     // if client suddenly disappears. Note that TCP_KEEPALIVE is not enabled by
      75              :     // default, and tokio doesn't provide ability to set it out of the box.
      76            0 :     let mut socket = TimeoutReader::new(socket);
      77            0 :     let wal_service_timeout = Duration::from_secs(60 * 10);
      78            0 :     socket.set_timeout(Some(wal_service_timeout));
      79              :     // pin! is here because TimeoutReader (due to storing sleep future inside)
      80              :     // is not Unpin, and all pgbackend/framed/tokio dependencies require stream
      81              :     // to be Unpin. Which is reasonable, as indeed something like TimeoutReader
      82              :     // shouldn't be moved.
      83            0 :     let socket = std::pin::pin!(socket);
      84              : 
      85            0 :     let traffic_metrics = TrafficMetrics::new();
      86            0 :     if let Some(current_az) = conf.availability_zone.as_deref() {
      87            0 :         traffic_metrics.set_sk_az(current_az);
      88            0 :     }
      89              : 
      90            0 :     let socket = MeasuredStream::new(
      91            0 :         socket,
      92            0 :         |cnt| {
      93            0 :             traffic_metrics.observe_read(cnt);
      94            0 :         },
      95            0 :         |cnt| {
      96            0 :             traffic_metrics.observe_write(cnt);
      97            0 :         },
      98              :     );
      99              : 
     100            0 :     let auth_key = match allowed_auth_scope {
     101            0 :         Scope::Tenant => conf.pg_tenant_only_auth.clone(),
     102            0 :         _ => conf.pg_auth.clone(),
     103              :     };
     104            0 :     let auth_type = match auth_key {
     105            0 :         None => AuthType::Trust,
     106            0 :         Some(_) => AuthType::NeonJWT,
     107              :     };
     108            0 :     let auth_pair = auth_key.map(|key| (allowed_auth_scope, key));
     109            0 :     let mut conn_handler = SafekeeperPostgresHandler::new(
     110            0 :         conf,
     111            0 :         conn_id,
     112            0 :         Some(traffic_metrics.clone()),
     113            0 :         auth_pair,
     114            0 :         global_timelines,
     115              :     );
     116            0 :     let pgbackend =
     117            0 :         PostgresBackend::new_from_io(socket_fd, socket, peer_addr, auth_type, tls_config)?;
     118              :     // libpq protocol between safekeeper and walproposer / pageserver
     119              :     // We don't use shutdown.
     120            0 :     pgbackend
     121            0 :         .run(&mut conn_handler, &CancellationToken::new())
     122            0 :         .await
     123            0 : }
     124              : 
     125              : pub type ConnectionCount = u32;
     126              : 
     127            0 : pub fn issue_connection_id(count: &mut ConnectionCount) -> ConnectionId {
     128            0 :     *count = count.wrapping_add(1);
     129            0 :     *count
     130            0 : }
        

Generated by: LCOV version 2.1-beta