LCOV - code coverage report
Current view: top level - storage_broker/src/bin - storage_broker.rs (source / functions) Coverage Total Hit
Test: 1e20c4f2b28aa592527961bb32170ebbd2c9172f.info Lines: 39.4 % 485 191
Test Date: 2025-07-16 12:29:03 Functions: 36.5 % 52 19

            Line data    Source code
       1              : //! Simple pub-sub based on grpc (tonic) and Tokio broadcast channel for storage
       2              : //! nodes messaging.
       3              : //!
       4              : //! Subscriptions to 1) single timeline 2) all timelines are possible. We could
       5              : //! add subscription to the set of timelines to save grpc streams, but testing
       6              : //! shows many individual streams is also ok.
       7              : //!
       8              : //! Message is dropped if subscriber can't consume it, not affecting other
       9              : //! subscribers.
      10              : //!
      11              : //! Only safekeeper message is supported, but it is not hard to add something
      12              : //! else with generics.
      13              : use std::collections::HashMap;
      14              : use std::convert::Infallible;
      15              : use std::net::SocketAddr;
      16              : use std::pin::Pin;
      17              : use std::sync::Arc;
      18              : use std::time::Duration;
      19              : 
      20              : use bytes::Bytes;
      21              : use camino::Utf8PathBuf;
      22              : use clap::{Parser, command};
      23              : use futures::future::OptionFuture;
      24              : use futures_core::Stream;
      25              : use futures_util::StreamExt;
      26              : use http_body_util::combinators::BoxBody;
      27              : use http_body_util::{Empty, Full};
      28              : use http_utils::tls_certs::ReloadingCertificateResolver;
      29              : use hyper::body::Incoming;
      30              : use hyper::header::CONTENT_TYPE;
      31              : use hyper::service::service_fn;
      32              : use hyper::{Method, StatusCode};
      33              : use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
      34              : use metrics::{Encoder, TextEncoder};
      35              : use parking_lot::RwLock;
      36              : use storage_broker::metrics::{
      37              :     BROADCAST_DROPPED_MESSAGES_TOTAL, BROADCASTED_MESSAGES_TOTAL, NUM_PUBS, NUM_SUBS_ALL,
      38              :     NUM_SUBS_TIMELINE, PROCESSED_MESSAGES_TOTAL, PUBLISHED_ONEOFF_MESSAGES_TOTAL,
      39              : };
      40              : use storage_broker::proto::broker_service_server::{BrokerService, BrokerServiceServer};
      41              : use storage_broker::proto::subscribe_safekeeper_info_request::SubscriptionKey as ProtoSubscriptionKey;
      42              : use storage_broker::proto::{
      43              :     FilterTenantTimelineId, MessageType, SafekeeperDiscoveryRequest, SafekeeperDiscoveryResponse,
      44              :     SafekeeperTimelineInfo, SubscribeByFilterRequest, SubscribeSafekeeperInfoRequest, TypedMessage,
      45              : };
      46              : use storage_broker::{DEFAULT_KEEPALIVE_INTERVAL, parse_proto_ttid};
      47              : use tokio::net::TcpListener;
      48              : use tokio::sync::broadcast;
      49              : use tokio::sync::broadcast::error::RecvError;
      50              : use tokio::time;
      51              : use tonic::codegen::Service;
      52              : use tonic::{Code, Request, Response, Status};
      53              : use tracing::*;
      54              : use utils::id::TenantTimelineId;
      55              : use utils::logging::{self, LogFormat};
      56              : use utils::sentry_init::init_sentry;
      57              : use utils::signals::ShutdownSignals;
      58              : use utils::{project_build_tag, project_git_version};
      59              : 
      60              : project_git_version!(GIT_VERSION);
      61              : project_build_tag!(BUILD_TAG);
      62              : 
      63              : const DEFAULT_CHAN_SIZE: usize = 32;
      64              : const DEFAULT_ALL_KEYS_CHAN_SIZE: usize = 16384;
      65              : 
      66              : const DEFAULT_SSL_KEY_FILE: &str = "server.key";
      67              : const DEFAULT_SSL_CERT_FILE: &str = "server.crt";
      68              : const DEFAULT_SSL_CERT_RELOAD_PERIOD: &str = "60s";
      69              : 
      70              : #[derive(Parser, Debug)]
      71              : #[command(version = GIT_VERSION, about = "Broker for neon storage nodes communication", long_about = None)]
      72              : #[clap(group(
      73              :     clap::ArgGroup::new("listen-addresses")
      74              :         .required(true)
      75              :         .multiple(true)
      76              :         .args(&["listen_addr", "listen_https_addr"]),
      77              : ))]
      78              : struct Args {
      79              :     /// Endpoint to listen HTTP on.
      80              :     #[arg(short, long)]
      81              :     listen_addr: Option<SocketAddr>,
      82              :     /// Endpoint to listen HTTPS on.
      83              :     #[arg(long)]
      84              :     listen_https_addr: Option<SocketAddr>,
      85              :     /// Size of the queue to the per timeline subscriber.
      86              :     #[arg(long, default_value_t = DEFAULT_CHAN_SIZE)]
      87              :     timeline_chan_size: usize,
      88              :     /// Size of the queue to the all keys subscriber.
      89              :     #[arg(long, default_value_t = DEFAULT_ALL_KEYS_CHAN_SIZE)]
      90              :     all_keys_chan_size: usize,
      91              :     /// HTTP/2 keepalive interval.
      92              :     #[arg(long, value_parser = humantime::parse_duration, default_value = DEFAULT_KEEPALIVE_INTERVAL)]
      93              :     http2_keepalive_interval: Duration,
      94              :     /// Format for logging, either 'plain' or 'json'.
      95              :     #[arg(long, default_value = "plain")]
      96              :     log_format: String,
      97              :     /// Path to a file with certificate's private key for https API.
      98              :     #[arg(long, default_value = DEFAULT_SSL_KEY_FILE)]
      99              :     ssl_key_file: Utf8PathBuf,
     100              :     /// Path to a file with a X509 certificate for https API.
     101              :     #[arg(long, default_value = DEFAULT_SSL_CERT_FILE)]
     102              :     ssl_cert_file: Utf8PathBuf,
     103              :     /// Period to reload certificate and private key from files.
     104              :     #[arg(long, value_parser = humantime::parse_duration, default_value = DEFAULT_SSL_CERT_RELOAD_PERIOD)]
     105              :     ssl_cert_reload_period: Duration,
     106              : }
     107              : 
     108              : /// Id of publisher for registering in maps
     109              : type PubId = u64;
     110              : 
     111              : /// Id of subscriber for registering in maps
     112              : type SubId = u64;
     113              : 
     114              : /// Single enum type for all messages.
     115              : #[derive(Clone, Debug, PartialEq)]
     116              : #[allow(clippy::enum_variant_names)]
     117              : enum Message {
     118              :     SafekeeperTimelineInfo(SafekeeperTimelineInfo),
     119              :     SafekeeperDiscoveryRequest(SafekeeperDiscoveryRequest),
     120              :     SafekeeperDiscoveryResponse(SafekeeperDiscoveryResponse),
     121              : }
     122              : 
     123              : impl Message {
     124              :     /// Convert proto message to internal message.
     125              :     #[allow(clippy::result_large_err, reason = "TODO")]
     126            0 :     pub fn from(proto_msg: TypedMessage) -> Result<Self, Status> {
     127            0 :         match proto_msg.r#type() {
     128              :             MessageType::SafekeeperTimelineInfo => Ok(Message::SafekeeperTimelineInfo(
     129            0 :                 proto_msg.safekeeper_timeline_info.ok_or_else(|| {
     130            0 :                     Status::new(Code::InvalidArgument, "missing safekeeper_timeline_info")
     131            0 :                 })?,
     132              :             )),
     133              :             MessageType::SafekeeperDiscoveryRequest => Ok(Message::SafekeeperDiscoveryRequest(
     134            0 :                 proto_msg.safekeeper_discovery_request.ok_or_else(|| {
     135            0 :                     Status::new(
     136            0 :                         Code::InvalidArgument,
     137              :                         "missing safekeeper_discovery_request",
     138              :                     )
     139            0 :                 })?,
     140              :             )),
     141              :             MessageType::SafekeeperDiscoveryResponse => Ok(Message::SafekeeperDiscoveryResponse(
     142            0 :                 proto_msg.safekeeper_discovery_response.ok_or_else(|| {
     143            0 :                     Status::new(
     144            0 :                         Code::InvalidArgument,
     145              :                         "missing safekeeper_discovery_response",
     146              :                     )
     147            0 :                 })?,
     148              :             )),
     149            0 :             MessageType::Unknown => Err(Status::new(
     150            0 :                 Code::InvalidArgument,
     151            0 :                 format!("invalid message type: {:?}", proto_msg.r#type),
     152            0 :             )),
     153              :         }
     154            0 :     }
     155              : 
     156              :     /// Get the tenant_timeline_id from the message.
     157              :     #[allow(clippy::result_large_err, reason = "TODO")]
     158            2 :     pub fn tenant_timeline_id(&self) -> Result<Option<TenantTimelineId>, Status> {
     159            2 :         match self {
     160            2 :             Message::SafekeeperTimelineInfo(msg) => Ok(msg
     161            2 :                 .tenant_timeline_id
     162            2 :                 .as_ref()
     163            2 :                 .map(parse_proto_ttid)
     164            2 :                 .transpose()?),
     165            0 :             Message::SafekeeperDiscoveryRequest(msg) => Ok(msg
     166            0 :                 .tenant_timeline_id
     167            0 :                 .as_ref()
     168            0 :                 .map(parse_proto_ttid)
     169            0 :                 .transpose()?),
     170            0 :             Message::SafekeeperDiscoveryResponse(msg) => Ok(msg
     171            0 :                 .tenant_timeline_id
     172            0 :                 .as_ref()
     173            0 :                 .map(parse_proto_ttid)
     174            0 :                 .transpose()?),
     175              :         }
     176            2 :     }
     177              : 
     178              :     /// Convert internal message to the protobuf struct.
     179            0 :     pub fn as_typed_message(&self) -> TypedMessage {
     180            0 :         let mut res = TypedMessage {
     181            0 :             r#type: self.message_type() as i32,
     182            0 :             ..Default::default()
     183            0 :         };
     184            0 :         match self {
     185            0 :             Message::SafekeeperTimelineInfo(msg) => {
     186            0 :                 res.safekeeper_timeline_info = Some(msg.clone())
     187              :             }
     188            0 :             Message::SafekeeperDiscoveryRequest(msg) => {
     189            0 :                 res.safekeeper_discovery_request = Some(msg.clone())
     190              :             }
     191            0 :             Message::SafekeeperDiscoveryResponse(msg) => {
     192            0 :                 res.safekeeper_discovery_response = Some(msg.clone())
     193              :             }
     194              :         }
     195            0 :         res
     196            0 :     }
     197              : 
     198              :     /// Get the message type.
     199            0 :     pub fn message_type(&self) -> MessageType {
     200            0 :         match self {
     201            0 :             Message::SafekeeperTimelineInfo(_) => MessageType::SafekeeperTimelineInfo,
     202            0 :             Message::SafekeeperDiscoveryRequest(_) => MessageType::SafekeeperDiscoveryRequest,
     203            0 :             Message::SafekeeperDiscoveryResponse(_) => MessageType::SafekeeperDiscoveryResponse,
     204              :         }
     205            0 :     }
     206              : }
     207              : 
     208              : #[derive(Copy, Clone, Debug)]
     209              : enum SubscriptionKey {
     210              :     All,
     211              :     Timeline(TenantTimelineId),
     212              : }
     213              : 
     214              : impl SubscriptionKey {
     215              :     /// Parse protobuf subkey (protobuf doesn't have fixed size bytes, we get vectors).
     216              :     #[allow(clippy::result_large_err, reason = "TODO")]
     217            0 :     pub fn from_proto_subscription_key(key: ProtoSubscriptionKey) -> Result<Self, Status> {
     218            0 :         match key {
     219            0 :             ProtoSubscriptionKey::All(_) => Ok(SubscriptionKey::All),
     220            0 :             ProtoSubscriptionKey::TenantTimelineId(proto_ttid) => {
     221            0 :                 Ok(SubscriptionKey::Timeline(parse_proto_ttid(&proto_ttid)?))
     222              :             }
     223              :         }
     224            0 :     }
     225              : 
     226              :     /// Parse from FilterTenantTimelineId
     227              :     #[allow(clippy::result_large_err, reason = "TODO")]
     228            0 :     pub fn from_proto_filter_tenant_timeline_id(
     229            0 :         opt: Option<&FilterTenantTimelineId>,
     230            0 :     ) -> Result<Self, Status> {
     231            0 :         if opt.is_none() {
     232            0 :             return Ok(SubscriptionKey::All);
     233            0 :         }
     234              : 
     235            0 :         let f = opt.unwrap();
     236            0 :         if !f.enabled {
     237            0 :             return Ok(SubscriptionKey::All);
     238            0 :         }
     239              : 
     240            0 :         let ttid =
     241            0 :             parse_proto_ttid(f.tenant_timeline_id.as_ref().ok_or_else(|| {
     242            0 :                 Status::new(Code::InvalidArgument, "missing tenant_timeline_id")
     243            0 :             })?)?;
     244            0 :         Ok(SubscriptionKey::Timeline(ttid))
     245            0 :     }
     246              : }
     247              : 
     248              : /// Channel to timeline subscribers.
     249              : struct ChanToTimelineSub {
     250              :     chan: broadcast::Sender<Message>,
     251              :     /// Tracked separately to know when delete the shmem entry. receiver_count()
     252              :     /// is unhandy for that as unregistering and dropping the receiver side
     253              :     /// happens at different moments.
     254              :     num_subscribers: u64,
     255              : }
     256              : 
     257              : struct SharedState {
     258              :     next_pub_id: PubId,
     259              :     num_pubs: i64,
     260              :     next_sub_id: SubId,
     261              :     num_subs_to_timelines: i64,
     262              :     chans_to_timeline_subs: HashMap<TenantTimelineId, ChanToTimelineSub>,
     263              :     num_subs_to_all: i64,
     264              :     chan_to_all_subs: broadcast::Sender<Message>,
     265              : }
     266              : 
     267              : impl SharedState {
     268            1 :     pub fn new(all_keys_chan_size: usize) -> Self {
     269            1 :         SharedState {
     270            1 :             next_pub_id: 0,
     271            1 :             num_pubs: 0,
     272            1 :             next_sub_id: 0,
     273            1 :             num_subs_to_timelines: 0,
     274            1 :             chans_to_timeline_subs: HashMap::new(),
     275            1 :             num_subs_to_all: 0,
     276            1 :             chan_to_all_subs: broadcast::channel(all_keys_chan_size).0,
     277            1 :         }
     278            1 :     }
     279              : 
     280              :     // Register new publisher.
     281            1 :     pub fn register_publisher(&mut self) -> PubId {
     282            1 :         let pub_id = self.next_pub_id;
     283            1 :         self.next_pub_id += 1;
     284            1 :         self.num_pubs += 1;
     285            1 :         NUM_PUBS.set(self.num_pubs);
     286            1 :         pub_id
     287            1 :     }
     288              : 
     289              :     // Unregister publisher.
     290            1 :     pub fn unregister_publisher(&mut self) {
     291            1 :         self.num_pubs -= 1;
     292            1 :         NUM_PUBS.set(self.num_pubs);
     293            1 :     }
     294              : 
     295              :     // Register new subscriber.
     296            2 :     pub fn register_subscriber(
     297            2 :         &mut self,
     298            2 :         sub_key: SubscriptionKey,
     299            2 :         timeline_chan_size: usize,
     300            2 :     ) -> (SubId, broadcast::Receiver<Message>) {
     301            2 :         let sub_id = self.next_sub_id;
     302            2 :         self.next_sub_id += 1;
     303            2 :         let sub_rx = match sub_key {
     304              :             SubscriptionKey::All => {
     305            1 :                 self.num_subs_to_all += 1;
     306            1 :                 NUM_SUBS_ALL.set(self.num_subs_to_all);
     307            1 :                 self.chan_to_all_subs.subscribe()
     308              :             }
     309            1 :             SubscriptionKey::Timeline(ttid) => {
     310            1 :                 self.num_subs_to_timelines += 1;
     311            1 :                 NUM_SUBS_TIMELINE.set(self.num_subs_to_timelines);
     312              :                 // Create new broadcast channel for this key, or subscriber to
     313              :                 // the existing one.
     314            1 :                 let chan_to_timeline_sub =
     315            1 :                     self.chans_to_timeline_subs
     316            1 :                         .entry(ttid)
     317            1 :                         .or_insert(ChanToTimelineSub {
     318            1 :                             chan: broadcast::channel(timeline_chan_size).0,
     319            1 :                             num_subscribers: 0,
     320            1 :                         });
     321            1 :                 chan_to_timeline_sub.num_subscribers += 1;
     322            1 :                 chan_to_timeline_sub.chan.subscribe()
     323              :             }
     324              :         };
     325            2 :         (sub_id, sub_rx)
     326            2 :     }
     327              : 
     328              :     // Unregister the subscriber.
     329            2 :     pub fn unregister_subscriber(&mut self, sub_key: SubscriptionKey) {
     330            2 :         match sub_key {
     331            1 :             SubscriptionKey::All => {
     332            1 :                 self.num_subs_to_all -= 1;
     333            1 :                 NUM_SUBS_ALL.set(self.num_subs_to_all);
     334            1 :             }
     335            1 :             SubscriptionKey::Timeline(ttid) => {
     336            1 :                 self.num_subs_to_timelines -= 1;
     337            1 :                 NUM_SUBS_TIMELINE.set(self.num_subs_to_timelines);
     338              : 
     339              :                 // Remove from the map, destroying the channel, if we are the
     340              :                 // last subscriber to this timeline.
     341              : 
     342              :                 // Missing entry is a bug; we must have registered.
     343            1 :                 let chan_to_timeline_sub = self
     344            1 :                     .chans_to_timeline_subs
     345            1 :                     .get_mut(&ttid)
     346            1 :                     .expect("failed to find sub entry in shmem during unregister");
     347            1 :                 chan_to_timeline_sub.num_subscribers -= 1;
     348            1 :                 if chan_to_timeline_sub.num_subscribers == 0 {
     349            1 :                     self.chans_to_timeline_subs.remove(&ttid);
     350            1 :                 }
     351              :             }
     352              :         }
     353            2 :     }
     354              : }
     355              : 
     356              : // SharedState wrapper.
     357              : #[derive(Clone)]
     358              : struct Registry {
     359              :     shared_state: Arc<RwLock<SharedState>>,
     360              :     timeline_chan_size: usize,
     361              : }
     362              : 
     363              : impl Registry {
     364              :     // Register new publisher in shared state.
     365            1 :     pub fn register_publisher(&self, remote_addr: SocketAddr) -> Publisher {
     366            1 :         let pub_id = self.shared_state.write().register_publisher();
     367            1 :         info!("publication started id={} addr={:?}", pub_id, remote_addr);
     368            1 :         Publisher {
     369            1 :             id: pub_id,
     370            1 :             registry: self.clone(),
     371            1 :             remote_addr,
     372            1 :         }
     373            1 :     }
     374              : 
     375            1 :     pub fn unregister_publisher(&self, publisher: &Publisher) {
     376            1 :         self.shared_state.write().unregister_publisher();
     377            1 :         info!(
     378            0 :             "publication ended id={} addr={:?}",
     379              :             publisher.id, publisher.remote_addr
     380              :         );
     381            1 :     }
     382              : 
     383              :     // Register new subscriber in shared state.
     384            2 :     pub fn register_subscriber(
     385            2 :         &self,
     386            2 :         sub_key: SubscriptionKey,
     387            2 :         remote_addr: SocketAddr,
     388            2 :     ) -> Subscriber {
     389            2 :         let (sub_id, sub_rx) = self
     390            2 :             .shared_state
     391            2 :             .write()
     392            2 :             .register_subscriber(sub_key, self.timeline_chan_size);
     393            2 :         info!(
     394            0 :             "subscription started id={}, key={:?}, addr={:?}",
     395              :             sub_id, sub_key, remote_addr
     396              :         );
     397            2 :         Subscriber {
     398            2 :             id: sub_id,
     399            2 :             key: sub_key,
     400            2 :             sub_rx,
     401            2 :             registry: self.clone(),
     402            2 :             remote_addr,
     403            2 :         }
     404            2 :     }
     405              : 
     406              :     // Unregister the subscriber
     407            2 :     pub fn unregister_subscriber(&self, subscriber: &Subscriber) {
     408            2 :         self.shared_state
     409            2 :             .write()
     410            2 :             .unregister_subscriber(subscriber.key);
     411            2 :         info!(
     412            0 :             "subscription ended id={}, key={:?}, addr={:?}",
     413              :             subscriber.id, subscriber.key, subscriber.remote_addr
     414              :         );
     415            2 :     }
     416              : 
     417              :     /// Send msg to relevant subscribers.
     418              :     #[allow(clippy::result_large_err, reason = "TODO")]
     419            2 :     pub fn send_msg(&self, msg: &Message) -> Result<(), Status> {
     420            2 :         PROCESSED_MESSAGES_TOTAL.inc();
     421              : 
     422              :         // send message to subscribers for everything
     423            2 :         let shared_state = self.shared_state.read();
     424              :         // Err means there is no subscribers, it is fine.
     425            2 :         shared_state.chan_to_all_subs.send(msg.clone()).ok();
     426              : 
     427              :         // send message to per timeline subscribers, if there is ttid
     428            2 :         let ttid = msg.tenant_timeline_id()?;
     429            2 :         if let Some(ttid) = ttid {
     430            2 :             if let Some(subs) = shared_state.chans_to_timeline_subs.get(&ttid) {
     431            1 :                 // Err can't happen here, as tx is destroyed only after removing
     432            1 :                 // from the map the last subscriber along with tx.
     433            1 :                 subs.chan
     434            1 :                     .send(msg.clone())
     435            1 :                     .expect("rx is still in the map with zero subscribers");
     436            1 :             }
     437            0 :         }
     438            2 :         Ok(())
     439            2 :     }
     440              : }
     441              : 
     442              : // Private subscriber state.
     443              : struct Subscriber {
     444              :     id: SubId,
     445              :     key: SubscriptionKey,
     446              :     // Subscriber receives messages from publishers here.
     447              :     sub_rx: broadcast::Receiver<Message>,
     448              :     // to unregister itself from shared state in Drop
     449              :     registry: Registry,
     450              :     // for logging
     451              :     remote_addr: SocketAddr,
     452              : }
     453              : 
     454              : impl Drop for Subscriber {
     455            2 :     fn drop(&mut self) {
     456            2 :         self.registry.unregister_subscriber(self);
     457            2 :     }
     458              : }
     459              : 
     460              : // Private publisher state
     461              : struct Publisher {
     462              :     id: PubId,
     463              :     registry: Registry,
     464              :     // for logging
     465              :     remote_addr: SocketAddr,
     466              : }
     467              : 
     468              : impl Publisher {
     469              :     /// Send msg to relevant subscribers.
     470              :     #[allow(clippy::result_large_err, reason = "TODO")]
     471            2 :     pub fn send_msg(&mut self, msg: &Message) -> Result<(), Status> {
     472            2 :         self.registry.send_msg(msg)
     473            2 :     }
     474              : }
     475              : 
     476              : impl Drop for Publisher {
     477            1 :     fn drop(&mut self) {
     478            1 :         self.registry.unregister_publisher(self);
     479            1 :     }
     480              : }
     481              : 
     482              : struct Broker {
     483              :     registry: Registry,
     484              : }
     485              : 
     486              : #[tonic::async_trait]
     487              : impl BrokerService for Broker {
     488            0 :     async fn publish_safekeeper_info(
     489              :         &self,
     490              :         request: Request<tonic::Streaming<SafekeeperTimelineInfo>>,
     491            0 :     ) -> Result<Response<()>, Status> {
     492            0 :         let &RemoteAddr(remote_addr) = request
     493            0 :             .extensions()
     494            0 :             .get()
     495            0 :             .expect("RemoteAddr inserted by handler");
     496            0 :         let mut publisher = self.registry.register_publisher(remote_addr);
     497              : 
     498            0 :         let mut stream = request.into_inner();
     499              : 
     500              :         loop {
     501            0 :             match stream.next().await {
     502            0 :                 Some(Ok(msg)) => publisher.send_msg(&Message::SafekeeperTimelineInfo(msg))?,
     503            0 :                 Some(Err(e)) => return Err(e), // grpc error from the stream
     504            0 :                 None => break,                 // closed stream
     505              :             }
     506              :         }
     507              : 
     508            0 :         Ok(Response::new(()))
     509            0 :     }
     510              : 
     511              :     type SubscribeSafekeeperInfoStream =
     512              :         Pin<Box<dyn Stream<Item = Result<SafekeeperTimelineInfo, Status>> + Send + 'static>>;
     513              : 
     514            0 :     async fn subscribe_safekeeper_info(
     515              :         &self,
     516              :         request: Request<SubscribeSafekeeperInfoRequest>,
     517            0 :     ) -> Result<Response<Self::SubscribeSafekeeperInfoStream>, Status> {
     518            0 :         let &RemoteAddr(remote_addr) = request
     519            0 :             .extensions()
     520            0 :             .get()
     521            0 :             .expect("RemoteAddr inserted by handler");
     522            0 :         let proto_key = request
     523            0 :             .into_inner()
     524            0 :             .subscription_key
     525            0 :             .ok_or_else(|| Status::new(Code::InvalidArgument, "missing subscription key"))?;
     526            0 :         let sub_key = SubscriptionKey::from_proto_subscription_key(proto_key)?;
     527            0 :         let mut subscriber = self.registry.register_subscriber(sub_key, remote_addr);
     528              : 
     529              :         // transform rx into stream with item = Result, as method result demands
     530            0 :         let output = async_stream::try_stream! {
     531              :             let mut warn_interval = time::interval(Duration::from_millis(1000));
     532              :             let mut missed_msgs: u64 = 0;
     533              :             loop {
     534              :                 match subscriber.sub_rx.recv().await {
     535              :                     Ok(info) => {
     536              :                         match info {
     537              :                             Message::SafekeeperTimelineInfo(info) => yield info,
     538              :                             _ => {},
     539              :                         }
     540              :                         BROADCASTED_MESSAGES_TOTAL.inc();
     541              :                     },
     542              :                     Err(RecvError::Lagged(skipped_msg)) => {
     543              :                         BROADCAST_DROPPED_MESSAGES_TOTAL.inc_by(skipped_msg);
     544              :                         missed_msgs += skipped_msg;
     545              :                         if (futures::poll!(Box::pin(warn_interval.tick()))).is_ready() {
     546              :                             warn!("subscription id={}, key={:?} addr={:?} dropped {} messages, channel is full",
     547              :                                 subscriber.id, subscriber.key, subscriber.remote_addr, missed_msgs);
     548              :                             missed_msgs = 0;
     549              :                         }
     550              :                     }
     551              :                     Err(RecvError::Closed) => {
     552              :                         // can't happen, we never drop the channel while there is a subscriber
     553              :                         Err(Status::new(Code::Internal, "channel unexpectantly closed"))?;
     554              :                     }
     555              :                 }
     556              :             }
     557              :         };
     558              : 
     559            0 :         Ok(Response::new(
     560            0 :             Box::pin(output) as Self::SubscribeSafekeeperInfoStream
     561            0 :         ))
     562            0 :     }
     563              : 
     564              :     type SubscribeByFilterStream =
     565              :         Pin<Box<dyn Stream<Item = Result<TypedMessage, Status>> + Send + 'static>>;
     566              : 
     567              :     /// Subscribe to all messages, limited by a filter.
     568            0 :     async fn subscribe_by_filter(
     569              :         &self,
     570              :         request: Request<SubscribeByFilterRequest>,
     571            0 :     ) -> std::result::Result<Response<Self::SubscribeByFilterStream>, Status> {
     572            0 :         let &RemoteAddr(remote_addr) = request
     573            0 :             .extensions()
     574            0 :             .get()
     575            0 :             .expect("RemoteAddr inserted by handler");
     576            0 :         let proto_filter = request.into_inner();
     577            0 :         let ttid_filter = proto_filter.tenant_timeline_id.as_ref();
     578              : 
     579            0 :         let sub_key = SubscriptionKey::from_proto_filter_tenant_timeline_id(ttid_filter)?;
     580            0 :         let types_set = proto_filter
     581            0 :             .types
     582            0 :             .iter()
     583            0 :             .map(|t| t.r#type)
     584            0 :             .collect::<std::collections::HashSet<_>>();
     585              : 
     586            0 :         let mut subscriber = self.registry.register_subscriber(sub_key, remote_addr);
     587              : 
     588              :         // transform rx into stream with item = Result, as method result demands
     589            0 :         let output = async_stream::try_stream! {
     590              :             let mut warn_interval = time::interval(Duration::from_millis(1000));
     591              :             let mut missed_msgs: u64 = 0;
     592              :             loop {
     593              :                 match subscriber.sub_rx.recv().await {
     594              :                     Ok(msg) => {
     595              :                         let msg_type = msg.message_type() as i32;
     596              :                         if types_set.contains(&msg_type) {
     597              :                             yield msg.as_typed_message();
     598              :                             BROADCASTED_MESSAGES_TOTAL.inc();
     599              :                         }
     600              :                     },
     601              :                     Err(RecvError::Lagged(skipped_msg)) => {
     602              :                         BROADCAST_DROPPED_MESSAGES_TOTAL.inc_by(skipped_msg);
     603              :                         missed_msgs += skipped_msg;
     604              :                         if (futures::poll!(Box::pin(warn_interval.tick()))).is_ready() {
     605              :                             warn!("subscription id={}, key={:?} addr={:?} dropped {} messages, channel is full",
     606              :                                 subscriber.id, subscriber.key, subscriber.remote_addr, missed_msgs);
     607              :                             missed_msgs = 0;
     608              :                         }
     609              :                     }
     610              :                     Err(RecvError::Closed) => {
     611              :                         // can't happen, we never drop the channel while there is a subscriber
     612              :                         Err(Status::new(Code::Internal, "channel unexpectantly closed"))?;
     613              :                     }
     614              :                 }
     615              :             }
     616              :         };
     617              : 
     618            0 :         Ok(Response::new(
     619            0 :             Box::pin(output) as Self::SubscribeByFilterStream
     620            0 :         ))
     621            0 :     }
     622              : 
     623              :     /// Publish one message.
     624            0 :     async fn publish_one(
     625              :         &self,
     626              :         request: Request<TypedMessage>,
     627            0 :     ) -> std::result::Result<Response<()>, Status> {
     628            0 :         let msg = Message::from(request.into_inner())?;
     629            0 :         PUBLISHED_ONEOFF_MESSAGES_TOTAL.inc();
     630            0 :         self.registry.send_msg(&msg)?;
     631            0 :         Ok(Response::new(()))
     632            0 :     }
     633              : }
     634              : 
     635              : // We serve only metrics and healthcheck through http1.
     636            0 : async fn http1_handler(
     637            0 :     req: hyper::Request<Incoming>,
     638            0 : ) -> Result<hyper::Response<BoxBody<Bytes, Infallible>>, Infallible> {
     639            0 :     let resp = match (req.method(), req.uri().path()) {
     640            0 :         (&Method::GET, "/metrics") => {
     641            0 :             let mut buffer = vec![];
     642            0 :             let metrics = metrics::gather();
     643            0 :             let encoder = TextEncoder::new();
     644            0 :             encoder.encode(&metrics, &mut buffer).unwrap();
     645              : 
     646            0 :             hyper::Response::builder()
     647            0 :                 .status(StatusCode::OK)
     648            0 :                 .header(CONTENT_TYPE, encoder.format_type())
     649            0 :                 .body(BoxBody::new(Full::new(Bytes::from(buffer))))
     650            0 :                 .unwrap()
     651              :         }
     652            0 :         (&Method::GET, "/status") => hyper::Response::builder()
     653            0 :             .status(StatusCode::OK)
     654            0 :             .body(BoxBody::new(Empty::new()))
     655            0 :             .unwrap(),
     656            0 :         _ => hyper::Response::builder()
     657            0 :             .status(StatusCode::NOT_FOUND)
     658            0 :             .body(BoxBody::new(Empty::new()))
     659            0 :             .unwrap(),
     660              :     };
     661            0 :     Ok(resp)
     662            0 : }
     663              : 
     664              : #[derive(Clone, Copy)]
     665              : struct RemoteAddr(SocketAddr);
     666              : 
     667              : #[tokio::main]
     668            0 : async fn main() -> Result<(), Box<dyn std::error::Error>> {
     669            0 :     let args = Args::parse();
     670              : 
     671              :     // important to keep the order of:
     672              :     // 1. init logging
     673              :     // 2. tracing panic hook
     674              :     // 3. sentry
     675            0 :     logging::init(
     676            0 :         LogFormat::from_config(&args.log_format)?,
     677            0 :         logging::TracingErrorLayerEnablement::Disabled,
     678            0 :         logging::Output::Stdout,
     679            0 :     )?;
     680            0 :     logging::replace_panic_hook_with_tracing_panic_hook().forget();
     681              :     // initialize sentry if SENTRY_DSN is provided
     682            0 :     let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
     683            0 :     info!("version: {GIT_VERSION} build_tag: {BUILD_TAG}");
     684            0 :     metrics::set_build_info_metric(GIT_VERSION, BUILD_TAG);
     685              : 
     686              :     // On any shutdown signal, log receival and exit.
     687            0 :     std::thread::spawn(move || {
     688            0 :         ShutdownSignals::handle(|signal| {
     689            0 :             info!("received {}, terminating", signal.name());
     690            0 :             std::process::exit(0);
     691              :         })
     692            0 :     });
     693              : 
     694            0 :     let registry = Registry {
     695            0 :         shared_state: Arc::new(RwLock::new(SharedState::new(args.all_keys_chan_size))),
     696            0 :         timeline_chan_size: args.timeline_chan_size,
     697            0 :     };
     698            0 :     let storage_broker_impl = Broker {
     699            0 :         registry: registry.clone(),
     700            0 :     };
     701            0 :     let storage_broker_server = BrokerServiceServer::new(storage_broker_impl);
     702              : 
     703            0 :     let http_listener = match &args.listen_addr {
     704            0 :         Some(addr) => {
     705            0 :             info!("listening HTTP on {}", addr);
     706            0 :             Some(TcpListener::bind(addr).await?)
     707              :         }
     708            0 :         None => None,
     709              :     };
     710              : 
     711            0 :     let (https_listener, tls_acceptor) = match &args.listen_https_addr {
     712            0 :         Some(addr) => {
     713            0 :             let listener = TcpListener::bind(addr).await?;
     714              : 
     715            0 :             let cert_resolver = ReloadingCertificateResolver::new(
     716            0 :                 "main",
     717            0 :                 &args.ssl_key_file,
     718            0 :                 &args.ssl_cert_file,
     719            0 :                 args.ssl_cert_reload_period,
     720            0 :             )
     721            0 :             .await?;
     722              : 
     723            0 :             let mut tls_config = rustls::ServerConfig::builder()
     724            0 :                 .with_no_client_auth()
     725            0 :                 .with_cert_resolver(cert_resolver);
     726              : 
     727              :             // Tonic is HTTP/2 only and it negotiates it with ALPN.
     728            0 :             tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
     729              : 
     730            0 :             let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(tls_config));
     731              : 
     732            0 :             info!("listening HTTPS on {}", addr);
     733            0 :             (Some(listener), Some(acceptor))
     734              :         }
     735            0 :         None => (None, None),
     736              :     };
     737              : 
     738              :     // grpc is served along with http1 for metrics on a single port, hence we
     739              :     // don't use tonic's Server.
     740            0 :     loop {
     741            0 :         let (conn, is_https) = tokio::select! {
     742            0 :             Some(conn) = OptionFuture::from(http_listener.as_ref().map(|l| l.accept())) => (conn, false),
     743            0 :             Some(conn) = OptionFuture::from(https_listener.as_ref().map(|l| l.accept())) => (conn, true),
     744            0 :         };
     745            0 : 
     746            0 :         let (tcp_stream, addr) = match conn {
     747            0 :             Ok(v) => v,
     748            0 :             Err(e) => {
     749            0 :                 info!("couldn't accept connection: {e}");
     750            0 :                 continue;
     751            0 :             }
     752            0 :         };
     753            0 : 
     754            0 :         let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
     755            0 :         builder.http1().timer(TokioTimer::new());
     756            0 :         builder
     757            0 :             .http2()
     758            0 :             .timer(TokioTimer::new())
     759            0 :             .keep_alive_interval(Some(args.http2_keepalive_interval))
     760            0 :             // This matches the tonic server default. It allows us to support production-like workloads.
     761            0 :             .max_concurrent_streams(None);
     762            0 : 
     763            0 :         let storage_broker_server_cloned = storage_broker_server.clone();
     764            0 :         let remote_addr = RemoteAddr(addr);
     765            0 :         let service_fn_ = async move {
     766            0 :             service_fn(move |mut req| {
     767            0 :                 // That's what tonic's MakeSvc.call does to pass conninfo to
     768            0 :                 // the request handler (and where its request.remote_addr()
     769            0 :                 // expects it to find).
     770            0 :                 req.extensions_mut().insert(remote_addr);
     771            0 : 
     772            0 :                 // Technically this second clone is not needed, but consume
     773            0 :                 // by async block is apparently unavoidable. BTW, error
     774            0 :                 // message is enigmatic, see
     775            0 :                 // https://github.com/rust-lang/rust/issues/68119
     776            0 :                 //
     777            0 :                 // We could get away without async block at all, but then we
     778            0 :                 // need to resort to futures::Either to merge the result,
     779            0 :                 // which doesn't caress an eye as well.
     780            0 :                 let mut storage_broker_server_svc = storage_broker_server_cloned.clone();
     781            0 :                 async move {
     782            0 :                     if req.headers().get("content-type").map(|x| x.as_bytes())
     783            0 :                         == Some(b"application/grpc")
     784            0 :                     {
     785            0 :                         let res_resp = storage_broker_server_svc.call(req).await;
     786            0 :                         // Grpc and http1 handlers have slightly different
     787            0 :                         // Response types: it is UnsyncBoxBody for the
     788            0 :                         // former one (not sure why) and plain hyper::Body
     789            0 :                         // for the latter. Both implement HttpBody though,
     790            0 :                         // and `Either` is used to merge them.
     791            0 :                         res_resp.map(|resp| resp.map(http_body_util::Either::Left))
     792            0 :                     } else {
     793            0 :                         let res_resp = http1_handler(req).await;
     794            0 :                         res_resp.map(|resp| resp.map(http_body_util::Either::Right))
     795            0 :                     }
     796            0 :                 }
     797            0 :             })
     798            0 :         }
     799            0 :         .await;
     800            0 : 
     801            0 :         let tls_acceptor = tls_acceptor.clone();
     802            0 : 
     803            0 :         tokio::task::spawn(async move {
     804            0 :             let res = if is_https {
     805            0 :                 let tls_acceptor =
     806            0 :                     tls_acceptor.expect("tls_acceptor is set together with https_listener");
     807            0 : 
     808            0 :                 let tls_stream = match tls_acceptor.accept(tcp_stream).await {
     809            0 :                     Ok(tls_stream) => tls_stream,
     810            0 :                     Err(e) => {
     811            0 :                         info!("error accepting TLS connection from {addr}: {e}");
     812            0 :                         return;
     813            0 :                     }
     814            0 :                 };
     815            0 : 
     816            0 :                 builder
     817            0 :                     .serve_connection(TokioIo::new(tls_stream), service_fn_)
     818            0 :                     .await
     819            0 :             } else {
     820            0 :                 builder
     821            0 :                     .serve_connection(TokioIo::new(tcp_stream), service_fn_)
     822            0 :                     .await
     823            0 :             };
     824            0 : 
     825            0 :             if let Err(e) = res {
     826            0 :                 info!(%is_https, "error serving connection from {addr}: {e}");
     827            0 :             }
     828            0 :         });
     829            0 :     }
     830            0 : }
     831              : 
     832              : #[cfg(test)]
     833              : mod tests {
     834              :     use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
     835              :     use tokio::sync::broadcast::error::TryRecvError;
     836              :     use utils::id::{TenantId, TimelineId};
     837              : 
     838              :     use super::*;
     839              : 
     840            2 :     fn msg(timeline_id: Vec<u8>) -> Message {
     841            2 :         Message::SafekeeperTimelineInfo(SafekeeperTimelineInfo {
     842            2 :             safekeeper_id: 1,
     843            2 :             tenant_timeline_id: Some(ProtoTenantTimelineId {
     844            2 :                 tenant_id: vec![0x00; 16],
     845            2 :                 timeline_id,
     846            2 :             }),
     847            2 :             term: 0,
     848            2 :             last_log_term: 0,
     849            2 :             flush_lsn: 1,
     850            2 :             commit_lsn: 2,
     851            2 :             backup_lsn: 3,
     852            2 :             remote_consistent_lsn: 4,
     853            2 :             peer_horizon_lsn: 5,
     854            2 :             safekeeper_connstr: "neon-1-sk-1.local:7676".to_owned(),
     855            2 :             http_connstr: "neon-1-sk-1.local:7677".to_owned(),
     856            2 :             https_connstr: Some("neon-1-sk-1.local:7678".to_owned()),
     857            2 :             local_start_lsn: 0,
     858            2 :             availability_zone: None,
     859            2 :             standby_horizon: 0,
     860            2 :         })
     861            2 :     }
     862              : 
     863            3 :     fn tli_from_u64(i: u64) -> Vec<u8> {
     864            3 :         let mut timeline_id = vec![0xFF; 8];
     865            3 :         timeline_id.extend_from_slice(&i.to_be_bytes());
     866            3 :         timeline_id
     867            3 :     }
     868              : 
     869            3 :     fn mock_addr() -> SocketAddr {
     870            3 :         "127.0.0.1:8080".parse().unwrap()
     871            3 :     }
     872              : 
     873              :     #[tokio::test]
     874            1 :     async fn test_registry() {
     875            1 :         let registry = Registry {
     876            1 :             shared_state: Arc::new(RwLock::new(SharedState::new(16))),
     877            1 :             timeline_chan_size: 16,
     878            1 :         };
     879              : 
     880              :         // subscribe to timeline 2
     881            1 :         let ttid_2 = TenantTimelineId {
     882            1 :             tenant_id: TenantId::from_slice(&[0x00; 16]).unwrap(),
     883            1 :             timeline_id: TimelineId::from_slice(&tli_from_u64(2)).unwrap(),
     884            1 :         };
     885            1 :         let sub_key_2 = SubscriptionKey::Timeline(ttid_2);
     886            1 :         let mut subscriber_2 = registry.register_subscriber(sub_key_2, mock_addr());
     887            1 :         let mut subscriber_all = registry.register_subscriber(SubscriptionKey::All, mock_addr());
     888              : 
     889              :         // send two messages with different keys
     890            1 :         let msg_1 = msg(tli_from_u64(1));
     891            1 :         let msg_2 = msg(tli_from_u64(2));
     892            1 :         let mut publisher = registry.register_publisher(mock_addr());
     893            1 :         publisher.send_msg(&msg_1).expect("failed to send msg");
     894            1 :         publisher.send_msg(&msg_2).expect("failed to send msg");
     895              : 
     896              :         // msg with key 2 should arrive to subscriber_2
     897            1 :         assert_eq!(subscriber_2.sub_rx.try_recv().unwrap(), msg_2);
     898              : 
     899              :         // but nothing more
     900            1 :         assert_eq!(
     901            1 :             subscriber_2.sub_rx.try_recv().unwrap_err(),
     902              :             TryRecvError::Empty
     903              :         );
     904              : 
     905              :         // subscriber_all should receive both messages
     906            1 :         assert_eq!(subscriber_all.sub_rx.try_recv().unwrap(), msg_1);
     907            1 :         assert_eq!(subscriber_all.sub_rx.try_recv().unwrap(), msg_2);
     908            1 :         assert_eq!(
     909            1 :             subscriber_all.sub_rx.try_recv().unwrap_err(),
     910            1 :             TryRecvError::Empty
     911            1 :         );
     912            1 :     }
     913              : }
        

Generated by: LCOV version 2.1-beta