LCOV - code coverage report
Current view: top level - safekeeper/src - send_interpreted_wal.rs (source / functions) Coverage Total Hit
Test: 5445d246133daeceb0507e6cc0797ab7c1c70cb8.info Lines: 86.9 % 892 775
Test Date: 2025-03-12 18:05:02 Functions: 75.0 % 52 39

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::fmt::Display;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use anyhow::{Context, anyhow};
       7              : use futures::StreamExt;
       8              : use futures::future::Either;
       9              : use pageserver_api::shard::ShardIdentity;
      10              : use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend};
      11              : use postgres_ffi::get_current_timestamp;
      12              : use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder};
      13              : use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
      14              : use tokio::io::{AsyncRead, AsyncWrite};
      15              : use tokio::sync::mpsc::error::SendError;
      16              : use tokio::task::JoinHandle;
      17              : use tokio::time::MissedTickBehavior;
      18              : use tracing::{Instrument, error, info, info_span};
      19              : use utils::critical;
      20              : use utils::lsn::Lsn;
      21              : use utils::postgres_client::{Compression, InterpretedFormat};
      22              : use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
      23              : use wal_decoder::wire_format::ToWireFormat;
      24              : 
      25              : use crate::metrics::WAL_READERS;
      26              : use crate::send_wal::{EndWatchView, WalSenderGuard};
      27              : use crate::timeline::WalResidentTimeline;
      28              : use crate::wal_reader_stream::{StreamingWalReader, WalBytes};
      29              : 
      30              : /// Identifier used to differentiate between senders of the same
      31              : /// shard.
      32              : ///
      33              : /// In the steady state there's only one, but two pageservers may
      34              : /// temporarily have the same shard attached and attempt to ingest
      35              : /// WAL for it. See also [`ShardSenderId`].
      36              : #[derive(Hash, Eq, PartialEq, Copy, Clone)]
      37              : struct SenderId(u8);
      38              : 
      39              : impl SenderId {
      40            6 :     fn first() -> Self {
      41            6 :         SenderId(0)
      42            6 :     }
      43              : 
      44            2 :     fn next(&self) -> Self {
      45            2 :         SenderId(self.0.checked_add(1).expect("few senders"))
      46            2 :     }
      47              : }
      48              : 
      49              : #[derive(Hash, Eq, PartialEq)]
      50              : struct ShardSenderId {
      51              :     shard: ShardIdentity,
      52              :     sender_id: SenderId,
      53              : }
      54              : 
      55              : impl Display for ShardSenderId {
      56            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      57            0 :         write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug())
      58            0 :     }
      59              : }
      60              : 
      61              : impl ShardSenderId {
      62          877 :     fn new(shard: ShardIdentity, sender_id: SenderId) -> Self {
      63          877 :         ShardSenderId { shard, sender_id }
      64          877 :     }
      65              : 
      66            0 :     fn shard(&self) -> ShardIdentity {
      67            0 :         self.shard
      68            0 :     }
      69              : }
      70              : 
      71              : /// Shard-aware fan-out interpreted record reader.
      72              : /// Reads WAL from disk, decodes it, intepretets it, and sends
      73              : /// it to any [`InterpretedWalSender`] connected to it.
      74              : /// Each [`InterpretedWalSender`] corresponds to one shard
      75              : /// and gets interpreted records concerning that shard only.
      76              : pub(crate) struct InterpretedWalReader {
      77              :     wal_stream: StreamingWalReader,
      78              :     shard_senders: HashMap<ShardIdentity, smallvec::SmallVec<[ShardSenderState; 1]>>,
      79              :     shard_notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>>,
      80              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      81              :     pg_version: u32,
      82              : }
      83              : 
      84              : /// A handle for [`InterpretedWalReader`] which allows for interacting with it
      85              : /// when it runs as a separate tokio task.
      86              : #[derive(Debug)]
      87              : pub(crate) struct InterpretedWalReaderHandle {
      88              :     join_handle: JoinHandle<Result<(), InterpretedWalReaderError>>,
      89              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      90              :     shard_notification_tx: tokio::sync::mpsc::UnboundedSender<AttachShardNotification>,
      91              : }
      92              : 
      93              : struct ShardSenderState {
      94              :     sender_id: SenderId,
      95              :     tx: tokio::sync::mpsc::Sender<Batch>,
      96              :     next_record_lsn: Lsn,
      97              : }
      98              : 
      99              : /// State of [`InterpretedWalReader`] visible outside of the task running it.
     100              : #[derive(Debug)]
     101              : pub(crate) enum InterpretedWalReaderState {
     102              :     Running {
     103              :         current_position: Lsn,
     104              :         /// Tracks the start of the PG WAL LSN from which the current batch of
     105              :         /// interpreted records originated.
     106              :         current_batch_wal_start: Option<Lsn>,
     107              :     },
     108              :     Done,
     109              : }
     110              : 
     111              : pub(crate) struct Batch {
     112              :     wal_end_lsn: Lsn,
     113              :     available_wal_end_lsn: Lsn,
     114              :     records: InterpretedWalRecords,
     115              : }
     116              : 
     117              : #[derive(thiserror::Error, Debug)]
     118              : pub enum InterpretedWalReaderError {
     119              :     /// Handler initiates the end of streaming.
     120              :     #[error("decode error: {0}")]
     121              :     Decode(#[from] WalDecodeError),
     122              :     #[error("read or interpret error: {0}")]
     123              :     ReadOrInterpret(#[from] anyhow::Error),
     124              :     #[error("wal stream closed")]
     125              :     WalStreamClosed,
     126              : }
     127              : 
     128              : enum CurrentPositionUpdate {
     129              :     Reset { from: Lsn, to: Lsn },
     130              :     NotReset(Lsn),
     131              : }
     132              : 
     133              : impl CurrentPositionUpdate {
     134            0 :     fn current_position(&self) -> Lsn {
     135            0 :         match self {
     136            0 :             CurrentPositionUpdate::Reset { from: _, to } => *to,
     137            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     138              :         }
     139            0 :     }
     140              : 
     141            0 :     fn previous_position(&self) -> Lsn {
     142            0 :         match self {
     143            0 :             CurrentPositionUpdate::Reset { from, to: _ } => *from,
     144            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     145              :         }
     146            0 :     }
     147              : }
     148              : 
     149              : impl InterpretedWalReaderState {
     150            6 :     fn current_position(&self) -> Option<Lsn> {
     151            6 :         match self {
     152              :             InterpretedWalReaderState::Running {
     153            3 :                 current_position, ..
     154            3 :             } => Some(*current_position),
     155            3 :             InterpretedWalReaderState::Done => None,
     156              :         }
     157            6 :     }
     158              : 
     159              :     #[cfg(test)]
     160           57 :     fn current_batch_wal_start(&self) -> Option<Lsn> {
     161           57 :         match self {
     162              :             InterpretedWalReaderState::Running {
     163           57 :                 current_batch_wal_start,
     164           57 :                 ..
     165           57 :             } => *current_batch_wal_start,
     166            0 :             InterpretedWalReaderState::Done => None,
     167              :         }
     168           57 :     }
     169              : 
     170              :     // Reset the current position of the WAL reader if the requested starting position
     171              :     // of the new shard is smaller than the current value.
     172            4 :     fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate {
     173            4 :         match self {
     174              :             InterpretedWalReaderState::Running {
     175            4 :                 current_position,
     176            4 :                 current_batch_wal_start,
     177            4 :             } => {
     178            4 :                 if new_shard_start_pos < *current_position {
     179            3 :                     let from = *current_position;
     180            3 :                     *current_position = new_shard_start_pos;
     181            3 :                     *current_batch_wal_start = None;
     182            3 :                     CurrentPositionUpdate::Reset {
     183            3 :                         from,
     184            3 :                         to: *current_position,
     185            3 :                     }
     186              :                 } else {
     187              :                     // Edge case: The new shard is at the same current position as
     188              :                     // the reader. Note that the current position is WAL record aligned,
     189              :                     // so the reader might have done some partial reads and updated the
     190              :                     // batch start. If that's the case, adjust the batch start to match
     191              :                     // starting position of the new shard. It can lead to some shards
     192              :                     // seeing overlaps, but in that case the actual record LSNs are checked
     193              :                     // which should be fine based on the filtering logic.
     194            1 :                     if let Some(start) = current_batch_wal_start {
     195            0 :                         *start = std::cmp::min(*start, new_shard_start_pos);
     196            1 :                     }
     197            1 :                     CurrentPositionUpdate::NotReset(*current_position)
     198              :                 }
     199              :             }
     200              :             InterpretedWalReaderState::Done => {
     201            0 :                 panic!("maybe_reset called on finished reader")
     202              :             }
     203              :         }
     204            4 :     }
     205              : 
     206           50 :     fn update_current_batch_wal_start(&mut self, lsn: Lsn) {
     207           50 :         match self {
     208              :             InterpretedWalReaderState::Running {
     209           50 :                 current_batch_wal_start,
     210           50 :                 ..
     211           50 :             } => {
     212           50 :                 if current_batch_wal_start.is_none() {
     213            6 :                     *current_batch_wal_start = Some(lsn);
     214           44 :                 }
     215              :             }
     216              :             InterpretedWalReaderState::Done => {
     217            0 :                 panic!("update_current_batch_wal_start called on finished reader")
     218              :             }
     219              :         }
     220           50 :     }
     221              : 
     222           41 :     fn replace_current_batch_wal_start(&mut self, with: Lsn) -> Lsn {
     223           41 :         match self {
     224              :             InterpretedWalReaderState::Running {
     225           41 :                 current_batch_wal_start,
     226           41 :                 ..
     227           41 :             } => current_batch_wal_start.replace(with).unwrap(),
     228              :             InterpretedWalReaderState::Done => {
     229            0 :                 panic!("take_current_batch_wal_start called on finished reader")
     230              :             }
     231              :         }
     232           41 :     }
     233              : 
     234           41 :     fn update_current_position(&mut self, lsn: Lsn) {
     235           41 :         match self {
     236              :             InterpretedWalReaderState::Running {
     237           41 :                 current_position, ..
     238           41 :             } => {
     239           41 :                 *current_position = lsn;
     240           41 :             }
     241              :             InterpretedWalReaderState::Done => {
     242            0 :                 panic!("update_current_position called on finished reader")
     243              :             }
     244              :         }
     245           41 :     }
     246              : }
     247              : 
     248              : pub(crate) struct AttachShardNotification {
     249              :     shard_id: ShardIdentity,
     250              :     sender: tokio::sync::mpsc::Sender<Batch>,
     251              :     start_pos: Lsn,
     252              : }
     253              : 
     254              : impl InterpretedWalReader {
     255              :     /// Spawn the reader in a separate tokio task and return a handle
     256            3 :     pub(crate) fn spawn(
     257            3 :         wal_stream: StreamingWalReader,
     258            3 :         start_pos: Lsn,
     259            3 :         tx: tokio::sync::mpsc::Sender<Batch>,
     260            3 :         shard: ShardIdentity,
     261            3 :         pg_version: u32,
     262            3 :         appname: &Option<String>,
     263            3 :     ) -> InterpretedWalReaderHandle {
     264            3 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     265            3 :             current_position: start_pos,
     266            3 :             current_batch_wal_start: None,
     267            3 :         }));
     268            3 : 
     269            3 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
     270              : 
     271            3 :         let reader = InterpretedWalReader {
     272            3 :             wal_stream,
     273            3 :             shard_senders: HashMap::from([(
     274            3 :                 shard,
     275            3 :                 smallvec::smallvec![ShardSenderState {
     276            0 :                     sender_id: SenderId::first(),
     277            0 :                     tx,
     278            0 :                     next_record_lsn: start_pos,
     279            0 :                 }],
     280              :             )]),
     281            3 :             shard_notification_rx: Some(shard_notification_rx),
     282            3 :             state: state.clone(),
     283            3 :             pg_version,
     284            3 :         };
     285            3 : 
     286            3 :         let metric = WAL_READERS
     287            3 :             .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")])
     288            3 :             .unwrap();
     289              : 
     290            3 :         let join_handle = tokio::task::spawn(
     291            3 :             async move {
     292            3 :                 metric.inc();
     293            3 :                 scopeguard::defer! {
     294            3 :                     metric.dec();
     295            3 :                 }
     296            3 : 
     297            3 :                 reader
     298            3 :                     .run_impl(start_pos)
     299            3 :                     .await
     300            0 :                     .inspect_err(|err| match err {
     301              :                         // TODO: we may want to differentiate these errors further.
     302              :                         InterpretedWalReaderError::Decode(_) => {
     303            0 :                             critical!("failed to decode WAL record: {err:?}");
     304              :                         }
     305            0 :                         err => error!("failed to read WAL record: {err}"),
     306            0 :                     })
     307            0 :             }
     308            3 :             .instrument(info_span!("interpreted wal reader")),
     309              :         );
     310              : 
     311            3 :         InterpretedWalReaderHandle {
     312            3 :             join_handle,
     313            3 :             state,
     314            3 :             shard_notification_tx,
     315            3 :         }
     316            3 :     }
     317              : 
     318              :     /// Construct the reader without spawning anything
     319              :     /// Callers should drive the future returned by [`Self::run`].
     320            1 :     pub(crate) fn new(
     321            1 :         wal_stream: StreamingWalReader,
     322            1 :         start_pos: Lsn,
     323            1 :         tx: tokio::sync::mpsc::Sender<Batch>,
     324            1 :         shard: ShardIdentity,
     325            1 :         pg_version: u32,
     326            1 :         shard_notification_rx: Option<
     327            1 :             tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>,
     328            1 :         >,
     329            1 :     ) -> InterpretedWalReader {
     330            1 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     331            1 :             current_position: start_pos,
     332            1 :             current_batch_wal_start: None,
     333            1 :         }));
     334            1 : 
     335            1 :         InterpretedWalReader {
     336            1 :             wal_stream,
     337            1 :             shard_senders: HashMap::from([(
     338            1 :                 shard,
     339            1 :                 smallvec::smallvec![ShardSenderState {
     340            0 :                     sender_id: SenderId::first(),
     341            0 :                     tx,
     342            0 :                     next_record_lsn: start_pos,
     343            0 :                 }],
     344              :             )]),
     345            1 :             shard_notification_rx,
     346            1 :             state: state.clone(),
     347            1 :             pg_version,
     348            1 :         }
     349            1 :     }
     350              : 
     351              :     /// Entry point for future (polling) based wal reader.
     352            1 :     pub(crate) async fn run(
     353            1 :         self,
     354            1 :         start_pos: Lsn,
     355            1 :         appname: &Option<String>,
     356            1 :     ) -> Result<(), CopyStreamHandlerEnd> {
     357            1 :         let metric = WAL_READERS
     358            1 :             .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")])
     359            1 :             .unwrap();
     360            1 : 
     361            1 :         metric.inc();
     362            1 :         scopeguard::defer! {
     363            1 :             metric.dec();
     364            1 :         }
     365            1 : 
     366            1 :         match self.run_impl(start_pos).await {
     367            0 :             Err(err @ InterpretedWalReaderError::Decode(_)) => {
     368            0 :                 critical!("failed to decode WAL record: {err:?}");
     369              :             }
     370            0 :             Err(err) => error!("failed to read WAL record: {err}"),
     371            0 :             Ok(()) => info!("interpreted wal reader exiting"),
     372              :         }
     373              : 
     374            0 :         Err(CopyStreamHandlerEnd::Other(anyhow!(
     375            0 :             "interpreted wal reader finished"
     376            0 :         )))
     377            0 :     }
     378              : 
     379              :     /// Send interpreted WAL to one or more [`InterpretedWalSender`]s
     380              :     /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`]
     381              :     /// goes out of scope.
     382            4 :     async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> {
     383            4 :         let defer_state = self.state.clone();
     384            4 :         scopeguard::defer! {
     385            4 :             *defer_state.write().unwrap() = InterpretedWalReaderState::Done;
     386            4 :         }
     387            4 : 
     388            4 :         let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version);
     389              : 
     390              :         loop {
     391           58 :             tokio::select! {
     392              :                 // Main branch for reading WAL and forwarding it
     393           58 :                 wal_or_reset = self.wal_stream.next() => {
     394           50 :                     let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below"));
     395              :                     let WalBytes {
     396           50 :                         wal,
     397           50 :                         wal_start_lsn,
     398           50 :                         wal_end_lsn,
     399           50 :                         available_wal_end_lsn,
     400           50 :                     } = match wal {
     401           50 :                         Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?,
     402              :                         None => {
     403              :                             // [`StreamingWalReader::next`] is an endless stream of WAL.
     404              :                             // It shouldn't ever finish unless it panicked or became internally
     405              :                             // inconsistent.
     406            0 :                             return Result::Err(InterpretedWalReaderError::WalStreamClosed);
     407              :                         }
     408              :                     };
     409              : 
     410           50 :                     self.state.write().unwrap().update_current_batch_wal_start(wal_start_lsn);
     411           50 : 
     412           50 :                     wal_decoder.feed_bytes(&wal);
     413           50 : 
     414           50 :                     // Deserialize and interpret WAL records from this batch of WAL.
     415           50 :                     // Interpreted records for each shard are collected separately.
     416           50 :                     let shard_ids = self.shard_senders.keys().copied().collect::<Vec<_>>();
     417           50 :                     let mut records_by_sender: HashMap<ShardSenderId, Vec<InterpretedWalRecord>> = HashMap::new();
     418           50 :                     let mut max_next_record_lsn = None;
     419           50 :                     let mut max_end_record_lsn = None;
     420          656 :                     while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()?
     421              :                     {
     422          606 :                         assert!(next_record_lsn.is_aligned());
     423          606 :                         max_next_record_lsn = Some(next_record_lsn);
     424          606 :                         max_end_record_lsn = Some(wal_decoder.lsn());
     425              : 
     426          606 :                         let interpreted = InterpretedWalRecord::from_bytes_filtered(
     427          606 :                             recdata,
     428          606 :                             &shard_ids,
     429          606 :                             next_record_lsn,
     430          606 :                             self.pg_version,
     431          606 :                         )
     432          606 :                         .with_context(|| "Failed to interpret WAL")?;
     433              : 
     434         1412 :                         for (shard, record) in interpreted {
     435              :                             // Shard zero needs to track the start LSN of the latest record
     436              :                             // in adition to the LSN of the next record to ingest. The former
     437              :                             // is included in basebackup persisted by the compute in WAL.
     438          806 :                             if !shard.is_shard_zero() && record.is_empty() {
     439          200 :                                 continue;
     440          606 :                             }
     441          606 : 
     442          606 :                             let mut states_iter = self.shard_senders
     443          606 :                                 .get(&shard)
     444          606 :                                 .expect("keys collected above")
     445          606 :                                 .iter()
     446         1002 :                                 .filter(|state| record.next_record_lsn > state.next_record_lsn)
     447          606 :                                 .peekable();
     448          996 :                             while let Some(state) = states_iter.next() {
     449          796 :                                 let shard_sender_id = ShardSenderId::new(shard, state.sender_id);
     450          796 : 
     451          796 :                                 // The most commont case is one sender per shard. Peek and break to avoid the
     452          796 :                                 // clone in that situation.
     453          796 :                                 if states_iter.peek().is_none() {
     454          406 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record);
     455          406 :                                     break;
     456          390 :                                 } else {
     457          390 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record.clone());
     458          390 :                                 }
     459              :                             }
     460              :                         }
     461              :                     }
     462              : 
     463           50 :                     let max_next_record_lsn = match max_next_record_lsn {
     464           41 :                         Some(lsn) => lsn,
     465              :                         None => {
     466            9 :                             continue;
     467              :                         }
     468              :                     };
     469              : 
     470              :                     // Update the current position such that new receivers can decide
     471              :                     // whether to attach to us or spawn a new WAL reader.
     472           41 :                     let batch_wal_start_lsn = {
     473           41 :                         let mut guard = self.state.write().unwrap();
     474           41 :                         guard.update_current_position(max_next_record_lsn);
     475           41 :                         guard.replace_current_batch_wal_start(max_end_record_lsn.unwrap())
     476           41 :                     };
     477           41 : 
     478           41 :                     // Send interpreted records downstream. Anything that has already been seen
     479           41 :                     // by a shard is filtered out.
     480           41 :                     let mut shard_senders_to_remove = Vec::new();
     481           96 :                     for (shard, states) in &mut self.shard_senders {
     482          136 :                         for state in states {
     483           81 :                             let shard_sender_id = ShardSenderId::new(*shard, state.sender_id);
     484              : 
     485           81 :                             let batch = if max_next_record_lsn > state.next_record_lsn {
     486              :                                 // This batch contains at least one record that this shard has not
     487              :                                 // seen yet.
     488           67 :                                 let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default();
     489           67 : 
     490           67 :                                 InterpretedWalRecords {
     491           67 :                                     records,
     492           67 :                                     next_record_lsn: max_next_record_lsn,
     493           67 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     494           67 :                                 }
     495           14 :                             } else if wal_end_lsn > state.next_record_lsn {
     496              :                                 // All the records in this batch were seen by the shard
     497              :                                 // However, the batch maps to a chunk of WAL that the
     498              :                                 // shard has not yet seen. Notify it of the start LSN
     499              :                                 // of the PG WAL chunk such that it doesn't look like a gap.
     500            0 :                                 InterpretedWalRecords {
     501            0 :                                     records: Vec::default(),
     502            0 :                                     next_record_lsn: state.next_record_lsn,
     503            0 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     504            0 :                                 }
     505              :                             } else {
     506              :                                 // The shard has seen this chunk of WAL before. Skip it.
     507           14 :                                 continue;
     508              :                             };
     509              : 
     510           67 :                             let res = state.tx.send(Batch {
     511           67 :                                 wal_end_lsn,
     512           67 :                                 available_wal_end_lsn,
     513           67 :                                 records: batch,
     514           67 :                             }).await;
     515              : 
     516           67 :                             if res.is_err() {
     517            0 :                                 shard_senders_to_remove.push(shard_sender_id);
     518           67 :                             } else {
     519           67 :                                 state.next_record_lsn = std::cmp::max(state.next_record_lsn, max_next_record_lsn);
     520           67 :                             }
     521              :                         }
     522              :                     }
     523              : 
     524              :                     // Clean up any shard senders that have dropped out.
     525              :                     // This is inefficient, but such events are rare (connection to PS termination)
     526              :                     // and the number of subscriptions on the same shards very small (only one
     527              :                     // for the steady state).
     528           41 :                     for to_remove in shard_senders_to_remove {
     529            0 :                         let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above");
     530            0 :                         if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) {
     531            0 :                             shard_senders.remove(idx);
     532            0 :                             tracing::info!("Removed shard sender {}", to_remove);
     533            0 :                         }
     534              : 
     535            0 :                         if shard_senders.is_empty() {
     536            0 :                             self.shard_senders.remove(&to_remove.shard());
     537            0 :                         }
     538              :                     }
     539              :                 },
     540              :                 // Listen for new shards that want to attach to this reader.
     541              :                 // If the reader is not running as a task, then this is not supported
     542              :                 // (see the pending branch below).
     543           58 :                 notification = match self.shard_notification_rx.as_mut() {
     544           58 :                         Some(rx) => Either::Left(rx.recv()),
     545            0 :                         None => Either::Right(std::future::pending())
     546              :                     } => {
     547            4 :                     if let Some(n) = notification {
     548            4 :                         let AttachShardNotification { shard_id, sender, start_pos } = n;
     549            4 : 
     550            4 :                         // Update internal and external state, then reset the WAL stream
     551            4 :                         // if required.
     552            4 :                         let senders = self.shard_senders.entry(shard_id).or_default();
     553            4 :                         let new_sender_id = match senders.last() {
     554            2 :                             Some(sender) => sender.sender_id.next(),
     555            2 :                             None => SenderId::first()
     556              :                         };
     557              : 
     558            4 :                         senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos});
     559            4 : 
     560            4 :                         // If the shard is subscribing below the current position the we need
     561            4 :                         // to update the cursor that tracks where we are at in the WAL
     562            4 :                         // ([`Self::state`]) and reset the WAL stream itself
     563            4 :                         // (`[Self::wal_stream`]). This must be done atomically from the POV of
     564            4 :                         // anything outside the select statement.
     565            4 :                         let position_reset = self.state.write().unwrap().maybe_reset(start_pos);
     566            4 :                         match position_reset {
     567            3 :                             CurrentPositionUpdate::Reset { from: _, to } => {
     568            3 :                                 self.wal_stream.reset(to).await;
     569            3 :                                 wal_decoder = WalStreamDecoder::new(to, self.pg_version);
     570              :                             },
     571            1 :                             CurrentPositionUpdate::NotReset(_) => {}
     572              :                         };
     573              : 
     574            4 :                         tracing::info!(
     575            0 :                             "Added shard sender {} with start_pos={} previous_pos={} current_pos={}",
     576            0 :                             ShardSenderId::new(shard_id, new_sender_id),
     577            0 :                             start_pos,
     578            0 :                             position_reset.previous_position(),
     579            0 :                             position_reset.current_position(),
     580              :                         );
     581            0 :                     }
     582              :                 }
     583              :             }
     584              :         }
     585            0 :     }
     586              : 
     587              :     #[cfg(test)]
     588            1 :     fn state(&self) -> Arc<std::sync::RwLock<InterpretedWalReaderState>> {
     589            1 :         self.state.clone()
     590            1 :     }
     591              : }
     592              : 
     593              : impl InterpretedWalReaderHandle {
     594              :     /// Fan-out the reader by attaching a new shard to it
     595            3 :     pub(crate) fn fanout(
     596            3 :         &self,
     597            3 :         shard_id: ShardIdentity,
     598            3 :         sender: tokio::sync::mpsc::Sender<Batch>,
     599            3 :         start_pos: Lsn,
     600            3 :     ) -> Result<(), SendError<AttachShardNotification>> {
     601            3 :         self.shard_notification_tx.send(AttachShardNotification {
     602            3 :             shard_id,
     603            3 :             sender,
     604            3 :             start_pos,
     605            3 :         })
     606            3 :     }
     607              : 
     608              :     /// Get the current WAL position of the reader
     609            6 :     pub(crate) fn current_position(&self) -> Option<Lsn> {
     610            6 :         self.state.read().unwrap().current_position()
     611            6 :     }
     612              : 
     613            6 :     pub(crate) fn abort(&self) {
     614            6 :         self.join_handle.abort()
     615            6 :     }
     616              : }
     617              : 
     618              : impl Drop for InterpretedWalReaderHandle {
     619            3 :     fn drop(&mut self) {
     620            3 :         tracing::info!("Aborting interpreted wal reader");
     621            3 :         self.abort()
     622            3 :     }
     623              : }
     624              : 
     625              : pub(crate) struct InterpretedWalSender<'a, IO> {
     626              :     pub(crate) format: InterpretedFormat,
     627              :     pub(crate) compression: Option<Compression>,
     628              :     pub(crate) appname: Option<String>,
     629              : 
     630              :     pub(crate) tli: WalResidentTimeline,
     631              :     pub(crate) start_lsn: Lsn,
     632              : 
     633              :     pub(crate) pgb: &'a mut PostgresBackend<IO>,
     634              :     pub(crate) end_watch_view: EndWatchView,
     635              :     pub(crate) wal_sender_guard: Arc<WalSenderGuard>,
     636              :     pub(crate) rx: tokio::sync::mpsc::Receiver<Batch>,
     637              : }
     638              : 
     639              : impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
     640              :     /// Send interpreted WAL records over the network.
     641              :     /// Also manages keep-alives if nothing was sent for a while.
     642            0 :     pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> {
     643            0 :         let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
     644            0 :         keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
     645            0 :         keepalive_ticker.reset();
     646            0 : 
     647            0 :         let mut wal_position = self.start_lsn;
     648              : 
     649              :         loop {
     650            0 :             tokio::select! {
     651            0 :                 batch = self.rx.recv() => {
     652            0 :                     let batch = match batch {
     653            0 :                         Some(b) => b,
     654              :                         None => {
     655            0 :                             return Result::Err(
     656            0 :                                 CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early"))
     657            0 :                             );
     658              :                         }
     659              :                     };
     660              : 
     661            0 :                     wal_position = batch.wal_end_lsn;
     662              : 
     663            0 :                     let buf = batch
     664            0 :                         .records
     665            0 :                         .to_wire(self.format, self.compression)
     666            0 :                         .await
     667            0 :                         .with_context(|| "Failed to serialize interpreted WAL")
     668            0 :                         .map_err(CopyStreamHandlerEnd::from)?;
     669              : 
     670              :                     // Reset the keep alive ticker since we are sending something
     671              :                     // over the wire now.
     672            0 :                     keepalive_ticker.reset();
     673            0 : 
     674            0 :                     self.pgb
     675            0 :                         .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
     676            0 :                             streaming_lsn: batch.wal_end_lsn.0,
     677            0 :                             commit_lsn: batch.available_wal_end_lsn.0,
     678            0 :                             data: &buf,
     679            0 :                         })).await?;
     680              :                 }
     681              :                 // Send a periodic keep alive when the connection has been idle for a while.
     682              :                 // Since we've been idle, also check if we can stop streaming.
     683            0 :                 _ = keepalive_ticker.tick() => {
     684            0 :                     if let Some(remote_consistent_lsn) = self.wal_sender_guard
     685            0 :                         .walsenders()
     686            0 :                         .get_ws_remote_consistent_lsn(self.wal_sender_guard.id())
     687              :                     {
     688            0 :                         if self.tli.should_walsender_stop(remote_consistent_lsn).await {
     689              :                             // Stop streaming if the receivers are caught up and
     690              :                             // there's no active compute. This causes the loop in
     691              :                             // [`crate::send_interpreted_wal::InterpretedWalSender::run`]
     692              :                             // to exit and terminate the WAL stream.
     693            0 :                             break;
     694            0 :                         }
     695            0 :                     }
     696              : 
     697            0 :                     self.pgb
     698            0 :                         .write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
     699            0 :                             wal_end: self.end_watch_view.get().0,
     700            0 :                             timestamp: get_current_timestamp(),
     701            0 :                             request_reply: true,
     702            0 :                         }))
     703            0 :                         .await?;
     704              :                 },
     705              :             }
     706              :         }
     707              : 
     708            0 :         Err(CopyStreamHandlerEnd::ServerInitiated(format!(
     709            0 :             "ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
     710            0 :             self.appname, wal_position,
     711            0 :         )))
     712            0 :     }
     713              : }
     714              : #[cfg(test)]
     715              : mod tests {
     716              :     use std::collections::HashMap;
     717              :     use std::str::FromStr;
     718              :     use std::time::Duration;
     719              : 
     720              :     use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
     721              :     use postgres_ffi::MAX_SEND_SIZE;
     722              :     use tokio::sync::mpsc::error::TryRecvError;
     723              :     use utils::id::{NodeId, TenantTimelineId};
     724              :     use utils::lsn::Lsn;
     725              :     use utils::shard::{ShardCount, ShardNumber};
     726              : 
     727              :     use crate::send_interpreted_wal::{AttachShardNotification, Batch, InterpretedWalReader};
     728              :     use crate::test_utils::Env;
     729              :     use crate::wal_reader_stream::StreamingWalReader;
     730              : 
     731              :     #[tokio::test]
     732            1 :     async fn test_interpreted_wal_reader_fanout() {
     733            1 :         let _ = env_logger::builder().is_test(true).try_init();
     734            1 : 
     735            1 :         const SIZE: usize = 8 * 1024;
     736            1 :         const MSG_COUNT: usize = 200;
     737            1 :         const PG_VERSION: u32 = 17;
     738            1 :         const SHARD_COUNT: u8 = 2;
     739            1 : 
     740            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     741            1 :         let env = Env::new(true).unwrap();
     742            1 :         let tli = env
     743            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     744            1 :             .await
     745            1 :             .unwrap();
     746            1 : 
     747            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     748            1 :         let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, c"neon-file:", None)
     749            1 :             .await
     750            1 :             .unwrap();
     751            1 :         let end_pos = end_watch.get();
     752            1 : 
     753            1 :         tracing::info!("Doing first round of reads ...");
     754            1 : 
     755            1 :         let streaming_wal_reader = StreamingWalReader::new(
     756            1 :             resident_tli,
     757            1 :             None,
     758            1 :             start_lsn,
     759            1 :             end_pos,
     760            1 :             end_watch,
     761            1 :             MAX_SEND_SIZE,
     762            1 :         );
     763            1 : 
     764            1 :         let shard_0 = ShardIdentity::new(
     765            1 :             ShardNumber(0),
     766            1 :             ShardCount(SHARD_COUNT),
     767            1 :             ShardStripeSize::default(),
     768            1 :         )
     769            1 :         .unwrap();
     770            1 : 
     771            1 :         let shard_1 = ShardIdentity::new(
     772            1 :             ShardNumber(1),
     773            1 :             ShardCount(SHARD_COUNT),
     774            1 :             ShardStripeSize::default(),
     775            1 :         )
     776            1 :         .unwrap();
     777            1 : 
     778            1 :         let mut shards = HashMap::new();
     779            1 : 
     780            3 :         for shard_number in 0..SHARD_COUNT {
     781            2 :             let shard_id = ShardIdentity::new(
     782            2 :                 ShardNumber(shard_number),
     783            2 :                 ShardCount(SHARD_COUNT),
     784            2 :                 ShardStripeSize::default(),
     785            2 :             )
     786            2 :             .unwrap();
     787            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     788            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
     789            2 :         }
     790            1 : 
     791            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
     792            1 :         let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap();
     793            1 : 
     794            1 :         let handle = InterpretedWalReader::spawn(
     795            1 :             streaming_wal_reader,
     796            1 :             start_lsn,
     797            1 :             shard_0_tx,
     798            1 :             shard_0,
     799            1 :             PG_VERSION,
     800            1 :             &Some("pageserver".to_string()),
     801            1 :         );
     802            1 : 
     803            1 :         tracing::info!("Reading all WAL with only shard 0 attached ...");
     804            1 : 
     805            1 :         let mut shard_0_interpreted_records = Vec::new();
     806           13 :         while let Some(batch) = shard_0_rx.recv().await {
     807           13 :             shard_0_interpreted_records.push(batch.records);
     808           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     809            1 :                 break;
     810           12 :             }
     811            1 :         }
     812            1 : 
     813            1 :         let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap();
     814            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
     815            1 : 
     816            1 :         tracing::info!("Attaching shard 1 to the reader at start of WAL");
     817            1 :         handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap();
     818            1 : 
     819            1 :         tracing::info!("Reading all WAL with shard 0 and shard 1 attached ...");
     820            1 : 
     821            1 :         let mut shard_1_interpreted_records = Vec::new();
     822           13 :         while let Some(batch) = shard_1_rx.recv().await {
     823           13 :             shard_1_interpreted_records.push(batch.records);
     824           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     825            1 :                 break;
     826           12 :             }
     827            1 :         }
     828            1 : 
     829            1 :         // This test uses logical messages. Those only go to shard 0. Check that the
     830            1 :         // filtering worked and shard 1 did not get any.
     831            1 :         assert!(
     832            1 :             shard_1_interpreted_records
     833            1 :                 .iter()
     834           13 :                 .all(|recs| recs.records.is_empty())
     835            1 :         );
     836            1 : 
     837            1 :         // Shard 0 should not receive anything more since the reader is
     838            1 :         // going through wal that it has already processed.
     839            1 :         let res = shard_0_rx.try_recv();
     840            1 :         if let Ok(ref ok) = res {
     841            1 :             tracing::error!(
     842            1 :                 "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}",
     843            1 :                 ok.wal_end_lsn,
     844            1 :                 ok.available_wal_end_lsn
     845            1 :             );
     846            1 :         }
     847            1 :         assert!(matches!(res, Err(TryRecvError::Empty)));
     848            1 : 
     849            1 :         // Check that the next records lsns received by the two shards match up.
     850            1 :         let shard_0_next_lsns = shard_0_interpreted_records
     851            1 :             .iter()
     852           13 :             .map(|recs| recs.next_record_lsn)
     853            1 :             .collect::<Vec<_>>();
     854            1 :         let shard_1_next_lsns = shard_1_interpreted_records
     855            1 :             .iter()
     856           13 :             .map(|recs| recs.next_record_lsn)
     857            1 :             .collect::<Vec<_>>();
     858            1 :         assert_eq!(shard_0_next_lsns, shard_1_next_lsns);
     859            1 : 
     860            1 :         handle.abort();
     861            1 :         let mut done = false;
     862            2 :         for _ in 0..5 {
     863            2 :             if handle.current_position().is_none() {
     864            1 :                 done = true;
     865            1 :                 break;
     866            1 :             }
     867            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     868            1 :         }
     869            1 : 
     870            1 :         assert!(done);
     871            1 :     }
     872              : 
     873              :     #[tokio::test]
     874            1 :     async fn test_interpreted_wal_reader_same_shard_fanout() {
     875            1 :         let _ = env_logger::builder().is_test(true).try_init();
     876            1 : 
     877            1 :         const SIZE: usize = 8 * 1024;
     878            1 :         const MSG_COUNT: usize = 200;
     879            1 :         const PG_VERSION: u32 = 17;
     880            1 :         const SHARD_COUNT: u8 = 2;
     881            1 : 
     882            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     883            1 :         let env = Env::new(true).unwrap();
     884            1 :         let tli = env
     885            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     886            1 :             .await
     887            1 :             .unwrap();
     888            1 : 
     889            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     890            1 :         let mut next_record_lsns = Vec::default();
     891            1 :         let end_watch = Env::write_wal(
     892            1 :             tli,
     893            1 :             start_lsn,
     894            1 :             SIZE,
     895            1 :             MSG_COUNT,
     896            1 :             c"neon-file:",
     897            1 :             Some(&mut next_record_lsns),
     898            1 :         )
     899            1 :         .await
     900            1 :         .unwrap();
     901            1 :         let end_pos = end_watch.get();
     902            1 : 
     903            1 :         let streaming_wal_reader = StreamingWalReader::new(
     904            1 :             resident_tli,
     905            1 :             None,
     906            1 :             start_lsn,
     907            1 :             end_pos,
     908            1 :             end_watch,
     909            1 :             MAX_SEND_SIZE,
     910            1 :         );
     911            1 : 
     912            1 :         let shard_0 = ShardIdentity::new(
     913            1 :             ShardNumber(0),
     914            1 :             ShardCount(SHARD_COUNT),
     915            1 :             ShardStripeSize::default(),
     916            1 :         )
     917            1 :         .unwrap();
     918            1 : 
     919            1 :         struct Sender {
     920            1 :             tx: Option<tokio::sync::mpsc::Sender<Batch>>,
     921            1 :             rx: tokio::sync::mpsc::Receiver<Batch>,
     922            1 :             shard: ShardIdentity,
     923            1 :             start_lsn: Lsn,
     924            1 :             received_next_record_lsns: Vec<Lsn>,
     925            1 :         }
     926            1 : 
     927            1 :         impl Sender {
     928            3 :             fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self {
     929            3 :                 let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     930            3 :                 Self {
     931            3 :                     tx: Some(tx),
     932            3 :                     rx,
     933            3 :                     shard,
     934            3 :                     start_lsn,
     935            3 :                     received_next_record_lsns: Vec::default(),
     936            3 :                 }
     937            3 :             }
     938            1 :         }
     939            1 : 
     940            1 :         assert!(next_record_lsns.len() > 7);
     941            1 :         let start_lsns = vec![
     942            1 :             next_record_lsns[5],
     943            1 :             next_record_lsns[1],
     944            1 :             next_record_lsns[3],
     945            1 :         ];
     946            1 :         let mut senders = start_lsns
     947            1 :             .into_iter()
     948            3 :             .map(|lsn| Sender::new(lsn, shard_0))
     949            1 :             .collect::<Vec<_>>();
     950            1 : 
     951            1 :         let first_sender = senders.first_mut().unwrap();
     952            1 :         let handle = InterpretedWalReader::spawn(
     953            1 :             streaming_wal_reader,
     954            1 :             first_sender.start_lsn,
     955            1 :             first_sender.tx.take().unwrap(),
     956            1 :             first_sender.shard,
     957            1 :             PG_VERSION,
     958            1 :             &Some("pageserver".to_string()),
     959            1 :         );
     960            1 : 
     961            2 :         for sender in senders.iter_mut().skip(1) {
     962            2 :             handle
     963            2 :                 .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn)
     964            2 :                 .unwrap();
     965            2 :         }
     966            1 : 
     967            3 :         for sender in senders.iter_mut() {
     968            1 :             loop {
     969           39 :                 let batch = sender.rx.recv().await.unwrap();
     970           39 :                 tracing::info!(
     971            1 :                     "Sender with start_lsn={} received batch ending at {} with {} records",
     972            0 :                     sender.start_lsn,
     973            0 :                     batch.wal_end_lsn,
     974            0 :                     batch.records.records.len()
     975            1 :                 );
     976            1 : 
     977          627 :                 for rec in batch.records.records {
     978          588 :                     sender.received_next_record_lsns.push(rec.next_record_lsn);
     979          588 :                 }
     980            1 : 
     981           39 :                 if batch.wal_end_lsn == batch.available_wal_end_lsn {
     982            3 :                     break;
     983           36 :                 }
     984            1 :             }
     985            1 :         }
     986            1 : 
     987            1 :         handle.abort();
     988            1 :         let mut done = false;
     989            2 :         for _ in 0..5 {
     990            2 :             if handle.current_position().is_none() {
     991            1 :                 done = true;
     992            1 :                 break;
     993            1 :             }
     994            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     995            1 :         }
     996            1 : 
     997            1 :         assert!(done);
     998            1 : 
     999            4 :         for sender in senders {
    1000            3 :             tracing::info!(
    1001            1 :                 "Validating records received by sender with start_lsn={}",
    1002            1 :                 sender.start_lsn
    1003            1 :             );
    1004            1 : 
    1005            3 :             assert!(sender.received_next_record_lsns.is_sorted());
    1006            3 :             let expected = next_record_lsns
    1007            3 :                 .iter()
    1008          600 :                 .filter(|lsn| **lsn > sender.start_lsn)
    1009            3 :                 .copied()
    1010            3 :                 .collect::<Vec<_>>();
    1011            3 :             assert_eq!(sender.received_next_record_lsns, expected);
    1012            1 :         }
    1013            1 :     }
    1014              : 
    1015              :     #[tokio::test]
    1016            1 :     async fn test_batch_start_tracking_on_reset() {
    1017            1 :         // When the WAL stream is reset to an older LSN,
    1018            1 :         // the current batch start LSN should be invalidated.
    1019            1 :         // This test constructs such a scenario:
    1020            1 :         // 1. Shard 0 is reading somewhere ahead
    1021            1 :         // 2. Reader reads some WAL, but does not decode a full record (partial read)
    1022            1 :         // 3. Shard 1 attaches to the reader and resets it to an older LSN
    1023            1 :         // 4. Shard 1 should get the correct batch WAL start LSN
    1024            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1025            1 : 
    1026            1 :         const SIZE: usize = 64 * 1024;
    1027            1 :         const MSG_COUNT: usize = 10;
    1028            1 :         const PG_VERSION: u32 = 17;
    1029            1 :         const SHARD_COUNT: u8 = 2;
    1030            1 :         const WAL_READER_BATCH_SIZE: usize = 8192;
    1031            1 : 
    1032            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1033            1 :         let env = Env::new(true).unwrap();
    1034            1 :         let mut next_record_lsns = Vec::default();
    1035            1 :         let tli = env
    1036            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1037            1 :             .await
    1038            1 :             .unwrap();
    1039            1 : 
    1040            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1041            1 :         let end_watch = Env::write_wal(
    1042            1 :             tli,
    1043            1 :             start_lsn,
    1044            1 :             SIZE,
    1045            1 :             MSG_COUNT,
    1046            1 :             c"neon-file:",
    1047            1 :             Some(&mut next_record_lsns),
    1048            1 :         )
    1049            1 :         .await
    1050            1 :         .unwrap();
    1051            1 : 
    1052            1 :         assert!(next_record_lsns.len() > 3);
    1053            1 :         let shard_0_start_lsn = next_record_lsns[3];
    1054            1 : 
    1055            1 :         let end_pos = end_watch.get();
    1056            1 : 
    1057            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1058            1 :             resident_tli,
    1059            1 :             None,
    1060            1 :             shard_0_start_lsn,
    1061            1 :             end_pos,
    1062            1 :             end_watch,
    1063            1 :             WAL_READER_BATCH_SIZE,
    1064            1 :         );
    1065            1 : 
    1066            1 :         let shard_0 = ShardIdentity::new(
    1067            1 :             ShardNumber(0),
    1068            1 :             ShardCount(SHARD_COUNT),
    1069            1 :             ShardStripeSize::default(),
    1070            1 :         )
    1071            1 :         .unwrap();
    1072            1 : 
    1073            1 :         let shard_1 = ShardIdentity::new(
    1074            1 :             ShardNumber(1),
    1075            1 :             ShardCount(SHARD_COUNT),
    1076            1 :             ShardStripeSize::default(),
    1077            1 :         )
    1078            1 :         .unwrap();
    1079            1 : 
    1080            1 :         let mut shards = HashMap::new();
    1081            1 : 
    1082            3 :         for shard_number in 0..SHARD_COUNT {
    1083            2 :             let shard_id = ShardIdentity::new(
    1084            2 :                 ShardNumber(shard_number),
    1085            2 :                 ShardCount(SHARD_COUNT),
    1086            2 :                 ShardStripeSize::default(),
    1087            2 :             )
    1088            2 :             .unwrap();
    1089            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1090            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
    1091            2 :         }
    1092            1 : 
    1093            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
    1094            1 : 
    1095            1 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
    1096            1 : 
    1097            1 :         let reader = InterpretedWalReader::new(
    1098            1 :             streaming_wal_reader,
    1099            1 :             shard_0_start_lsn,
    1100            1 :             shard_0_tx,
    1101            1 :             shard_0,
    1102            1 :             PG_VERSION,
    1103            1 :             Some(shard_notification_rx),
    1104            1 :         );
    1105            1 : 
    1106            1 :         let reader_state = reader.state();
    1107            1 :         let mut reader_fut = std::pin::pin!(reader.run(shard_0_start_lsn, &None));
    1108            1 :         loop {
    1109           57 :             let poll = futures::poll!(reader_fut.as_mut());
    1110           57 :             assert!(poll.is_pending());
    1111            1 : 
    1112           57 :             let guard = reader_state.read().unwrap();
    1113           57 :             if guard.current_batch_wal_start().is_some() {
    1114            1 :                 break;
    1115           56 :             }
    1116            1 :         }
    1117            1 : 
    1118            1 :         shard_notification_tx
    1119            1 :             .send(AttachShardNotification {
    1120            1 :                 shard_id: shard_1,
    1121            1 :                 sender: shards.get_mut(&shard_1).unwrap().0.take().unwrap(),
    1122            1 :                 start_pos: start_lsn,
    1123            1 :             })
    1124            1 :             .unwrap();
    1125            1 : 
    1126            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
    1127            1 :         loop {
    1128           64 :             let poll = futures::poll!(reader_fut.as_mut());
    1129           64 :             assert!(poll.is_pending());
    1130            1 : 
    1131           64 :             let try_recv_res = shard_1_rx.try_recv();
    1132           63 :             match try_recv_res {
    1133            1 :                 Ok(batch) => {
    1134            1 :                     assert_eq!(batch.records.raw_wal_start_lsn.unwrap(), start_lsn);
    1135            1 :                     break;
    1136            1 :                 }
    1137           63 :                 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
    1138            1 :                 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
    1139            1 :                     unreachable!();
    1140            1 :                 }
    1141            1 :             }
    1142            1 :         }
    1143            1 :     }
    1144              : 
    1145              :     #[tokio::test]
    1146            1 :     async fn test_shard_zero_does_not_skip_empty_records() {
    1147            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1148            1 : 
    1149            1 :         const SIZE: usize = 8 * 1024;
    1150            1 :         const MSG_COUNT: usize = 10;
    1151            1 :         const PG_VERSION: u32 = 17;
    1152            1 : 
    1153            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1154            1 :         let env = Env::new(true).unwrap();
    1155            1 :         let tli = env
    1156            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1157            1 :             .await
    1158            1 :             .unwrap();
    1159            1 : 
    1160            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1161            1 :         let mut next_record_lsns = Vec::new();
    1162            1 :         let end_watch = Env::write_wal(
    1163            1 :             tli,
    1164            1 :             start_lsn,
    1165            1 :             SIZE,
    1166            1 :             MSG_COUNT,
    1167            1 :             // This is a logical message prefix that is not persisted to key value storage.
    1168            1 :             // We use it in order to validate that shard zero receives emtpy interpreted records.
    1169            1 :             c"test:",
    1170            1 :             Some(&mut next_record_lsns),
    1171            1 :         )
    1172            1 :         .await
    1173            1 :         .unwrap();
    1174            1 :         let end_pos = end_watch.get();
    1175            1 : 
    1176            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1177            1 :             resident_tli,
    1178            1 :             None,
    1179            1 :             start_lsn,
    1180            1 :             end_pos,
    1181            1 :             end_watch,
    1182            1 :             MAX_SEND_SIZE,
    1183            1 :         );
    1184            1 : 
    1185            1 :         let shard = ShardIdentity::unsharded();
    1186            1 :         let (records_tx, mut records_rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1187            1 : 
    1188            1 :         let handle = InterpretedWalReader::spawn(
    1189            1 :             streaming_wal_reader,
    1190            1 :             start_lsn,
    1191            1 :             records_tx,
    1192            1 :             shard,
    1193            1 :             PG_VERSION,
    1194            1 :             &Some("pageserver".to_string()),
    1195            1 :         );
    1196            1 : 
    1197            1 :         let mut interpreted_records = Vec::new();
    1198            1 :         while let Some(batch) = records_rx.recv().await {
    1199            1 :             interpreted_records.push(batch.records);
    1200            1 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
    1201            1 :                 break;
    1202            1 :             }
    1203            1 :         }
    1204            1 : 
    1205            1 :         let received_next_record_lsns = interpreted_records
    1206            1 :             .into_iter()
    1207            1 :             .flat_map(|b| b.records)
    1208            9 :             .map(|rec| rec.next_record_lsn)
    1209            1 :             .collect::<Vec<_>>();
    1210            1 : 
    1211            1 :         // By default this also includes the start LSN. Trim it since it shouldn't be received.
    1212            1 :         let next_record_lsns = next_record_lsns.into_iter().skip(1).collect::<Vec<_>>();
    1213            1 : 
    1214            1 :         assert_eq!(received_next_record_lsns, next_record_lsns);
    1215            1 : 
    1216            1 :         handle.abort();
    1217            1 :         let mut done = false;
    1218            2 :         for _ in 0..5 {
    1219            2 :             if handle.current_position().is_none() {
    1220            1 :                 done = true;
    1221            1 :                 break;
    1222            1 :             }
    1223            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
    1224            1 :         }
    1225            1 : 
    1226            1 :         assert!(done);
    1227            1 :     }
    1228              : }
        

Generated by: LCOV version 2.1-beta