|             Line data    Source code 
       1              : use std::collections::HashMap;
       2              : use std::fmt::Display;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use anyhow::{Context, anyhow};
       7              : use futures::StreamExt;
       8              : use futures::future::Either;
       9              : use pageserver_api::shard::ShardIdentity;
      10              : use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend};
      11              : use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder};
      12              : use postgres_ffi::{PgMajorVersion, get_current_timestamp};
      13              : use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
      14              : use tokio::io::{AsyncRead, AsyncWrite};
      15              : use tokio::sync::mpsc::error::SendError;
      16              : use tokio::task::JoinHandle;
      17              : use tokio::time::MissedTickBehavior;
      18              : use tracing::{Instrument, error, info, info_span};
      19              : use utils::critical_timeline;
      20              : use utils::lsn::Lsn;
      21              : use utils::postgres_client::{Compression, InterpretedFormat};
      22              : use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
      23              : use wal_decoder::wire_format::ToWireFormat;
      24              : 
      25              : use crate::metrics::WAL_READERS;
      26              : use crate::send_wal::{EndWatchView, WalSenderGuard};
      27              : use crate::timeline::WalResidentTimeline;
      28              : use crate::wal_reader_stream::{StreamingWalReader, WalBytes};
      29              : 
      30              : /// Identifier used to differentiate between senders of the same
      31              : /// shard.
      32              : ///
      33              : /// In the steady state there's only one, but two pageservers may
      34              : /// temporarily have the same shard attached and attempt to ingest
      35              : /// WAL for it. See also [`ShardSenderId`].
      36              : #[derive(Hash, Eq, PartialEq, Copy, Clone)]
      37              : struct SenderId(u8);
      38              : 
      39              : impl SenderId {
      40            6 :     fn first() -> Self {
      41            6 :         SenderId(0)
      42            6 :     }
      43              : 
      44            2 :     fn next(&self) -> Self {
      45            2 :         SenderId(self.0.checked_add(1).expect("few senders"))
      46            2 :     }
      47              : }
      48              : 
      49              : #[derive(Hash, Eq, PartialEq)]
      50              : struct ShardSenderId {
      51              :     shard: ShardIdentity,
      52              :     sender_id: SenderId,
      53              : }
      54              : 
      55              : impl Display for ShardSenderId {
      56            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      57            0 :         write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug())
      58            0 :     }
      59              : }
      60              : 
      61              : impl ShardSenderId {
      62          877 :     fn new(shard: ShardIdentity, sender_id: SenderId) -> Self {
      63          877 :         ShardSenderId { shard, sender_id }
      64          877 :     }
      65              : 
      66            0 :     fn shard(&self) -> ShardIdentity {
      67            0 :         self.shard
      68            0 :     }
      69              : }
      70              : 
      71              : /// Shard-aware fan-out interpreted record reader.
      72              : /// Reads WAL from disk, decodes it, intepretets it, and sends
      73              : /// it to any [`InterpretedWalSender`] connected to it.
      74              : /// Each [`InterpretedWalSender`] corresponds to one shard
      75              : /// and gets interpreted records concerning that shard only.
      76              : pub(crate) struct InterpretedWalReader {
      77              :     wal_stream: StreamingWalReader,
      78              :     shard_senders: HashMap<ShardIdentity, smallvec::SmallVec<[ShardSenderState; 1]>>,
      79              :     shard_notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>>,
      80              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      81              :     pg_version: PgMajorVersion,
      82              : }
      83              : 
      84              : /// A handle for [`InterpretedWalReader`] which allows for interacting with it
      85              : /// when it runs as a separate tokio task.
      86              : #[derive(Debug)]
      87              : pub(crate) struct InterpretedWalReaderHandle {
      88              :     join_handle: JoinHandle<Result<(), InterpretedWalReaderError>>,
      89              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      90              :     shard_notification_tx: tokio::sync::mpsc::UnboundedSender<AttachShardNotification>,
      91              : }
      92              : 
      93              : struct ShardSenderState {
      94              :     sender_id: SenderId,
      95              :     tx: tokio::sync::mpsc::Sender<Batch>,
      96              :     next_record_lsn: Lsn,
      97              : }
      98              : 
      99              : /// State of [`InterpretedWalReader`] visible outside of the task running it.
     100              : #[derive(Debug)]
     101              : pub(crate) enum InterpretedWalReaderState {
     102              :     Running {
     103              :         current_position: Lsn,
     104              :         /// Tracks the start of the PG WAL LSN from which the current batch of
     105              :         /// interpreted records originated.
     106              :         current_batch_wal_start: Option<Lsn>,
     107              :     },
     108              :     Done,
     109              : }
     110              : 
     111              : pub(crate) struct Batch {
     112              :     wal_end_lsn: Lsn,
     113              :     available_wal_end_lsn: Lsn,
     114              :     records: InterpretedWalRecords,
     115              : }
     116              : 
     117              : #[derive(thiserror::Error, Debug)]
     118              : pub enum InterpretedWalReaderError {
     119              :     /// Handler initiates the end of streaming.
     120              :     #[error("decode error: {0}")]
     121              :     Decode(#[from] WalDecodeError),
     122              :     #[error("read or interpret error: {0}")]
     123              :     ReadOrInterpret(#[from] anyhow::Error),
     124              :     #[error("wal stream closed")]
     125              :     WalStreamClosed,
     126              : }
     127              : 
     128              : enum CurrentPositionUpdate {
     129              :     Reset { from: Lsn, to: Lsn },
     130              :     NotReset(Lsn),
     131              : }
     132              : 
     133              : impl CurrentPositionUpdate {
     134            0 :     fn current_position(&self) -> Lsn {
     135            0 :         match self {
     136            0 :             CurrentPositionUpdate::Reset { from: _, to } => *to,
     137            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     138              :         }
     139            0 :     }
     140              : 
     141            0 :     fn previous_position(&self) -> Lsn {
     142            0 :         match self {
     143            0 :             CurrentPositionUpdate::Reset { from, to: _ } => *from,
     144            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     145              :         }
     146            0 :     }
     147              : }
     148              : 
     149              : impl InterpretedWalReaderState {
     150            6 :     fn current_position(&self) -> Option<Lsn> {
     151            6 :         match self {
     152              :             InterpretedWalReaderState::Running {
     153            3 :                 current_position, ..
     154            3 :             } => Some(*current_position),
     155            3 :             InterpretedWalReaderState::Done => None,
     156              :         }
     157            6 :     }
     158              : 
     159              :     #[cfg(test)]
     160           39 :     fn current_batch_wal_start(&self) -> Option<Lsn> {
     161           39 :         match self {
     162              :             InterpretedWalReaderState::Running {
     163           39 :                 current_batch_wal_start,
     164              :                 ..
     165           39 :             } => *current_batch_wal_start,
     166            0 :             InterpretedWalReaderState::Done => None,
     167              :         }
     168           39 :     }
     169              : 
     170              :     // Reset the current position of the WAL reader if the requested starting position
     171              :     // of the new shard is smaller than the current value.
     172            4 :     fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate {
     173            4 :         match self {
     174              :             InterpretedWalReaderState::Running {
     175            4 :                 current_position,
     176            4 :                 current_batch_wal_start,
     177              :             } => {
     178            4 :                 if new_shard_start_pos < *current_position {
     179            3 :                     let from = *current_position;
     180            3 :                     *current_position = new_shard_start_pos;
     181            3 :                     *current_batch_wal_start = None;
     182            3 :                     CurrentPositionUpdate::Reset {
     183            3 :                         from,
     184            3 :                         to: *current_position,
     185            3 :                     }
     186              :                 } else {
     187              :                     // Edge case: The new shard is at the same current position as
     188              :                     // the reader. Note that the current position is WAL record aligned,
     189              :                     // so the reader might have done some partial reads and updated the
     190              :                     // batch start. If that's the case, adjust the batch start to match
     191              :                     // starting position of the new shard. It can lead to some shards
     192              :                     // seeing overlaps, but in that case the actual record LSNs are checked
     193              :                     // which should be fine based on the filtering logic.
     194            1 :                     if let Some(start) = current_batch_wal_start {
     195            0 :                         *start = std::cmp::min(*start, new_shard_start_pos);
     196            1 :                     }
     197            1 :                     CurrentPositionUpdate::NotReset(*current_position)
     198              :                 }
     199              :             }
     200              :             InterpretedWalReaderState::Done => {
     201            0 :                 panic!("maybe_reset called on finished reader")
     202              :             }
     203              :         }
     204            4 :     }
     205              : 
     206           50 :     fn update_current_batch_wal_start(&mut self, lsn: Lsn) {
     207           50 :         match self {
     208              :             InterpretedWalReaderState::Running {
     209           50 :                 current_batch_wal_start,
     210              :                 ..
     211              :             } => {
     212           50 :                 if current_batch_wal_start.is_none() {
     213            6 :                     *current_batch_wal_start = Some(lsn);
     214           44 :                 }
     215              :             }
     216              :             InterpretedWalReaderState::Done => {
     217            0 :                 panic!("update_current_batch_wal_start called on finished reader")
     218              :             }
     219              :         }
     220           50 :     }
     221              : 
     222           41 :     fn replace_current_batch_wal_start(&mut self, with: Lsn) -> Lsn {
     223           41 :         match self {
     224              :             InterpretedWalReaderState::Running {
     225           41 :                 current_batch_wal_start,
     226              :                 ..
     227           41 :             } => current_batch_wal_start.replace(with).unwrap(),
     228              :             InterpretedWalReaderState::Done => {
     229            0 :                 panic!("take_current_batch_wal_start called on finished reader")
     230              :             }
     231              :         }
     232           41 :     }
     233              : 
     234           41 :     fn update_current_position(&mut self, lsn: Lsn) {
     235           41 :         match self {
     236              :             InterpretedWalReaderState::Running {
     237           41 :                 current_position, ..
     238           41 :             } => {
     239           41 :                 *current_position = lsn;
     240           41 :             }
     241              :             InterpretedWalReaderState::Done => {
     242            0 :                 panic!("update_current_position called on finished reader")
     243              :             }
     244              :         }
     245           41 :     }
     246              : }
     247              : 
     248              : pub(crate) struct AttachShardNotification {
     249              :     shard_id: ShardIdentity,
     250              :     sender: tokio::sync::mpsc::Sender<Batch>,
     251              :     start_pos: Lsn,
     252              : }
     253              : 
     254              : impl InterpretedWalReader {
     255              :     /// Spawn the reader in a separate tokio task and return a handle
     256            3 :     pub(crate) fn spawn(
     257            3 :         wal_stream: StreamingWalReader,
     258            3 :         start_pos: Lsn,
     259            3 :         tx: tokio::sync::mpsc::Sender<Batch>,
     260            3 :         shard: ShardIdentity,
     261            3 :         pg_version: PgMajorVersion,
     262            3 :         appname: &Option<String>,
     263            3 :     ) -> InterpretedWalReaderHandle {
     264            3 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     265            3 :             current_position: start_pos,
     266            3 :             current_batch_wal_start: None,
     267            3 :         }));
     268              : 
     269            3 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
     270              : 
     271            3 :         let ttid = wal_stream.ttid;
     272              : 
     273            3 :         let reader = InterpretedWalReader {
     274            3 :             wal_stream,
     275            3 :             shard_senders: HashMap::from([(
     276            3 :                 shard,
     277            3 :                 smallvec::smallvec![ShardSenderState {
     278            0 :                     sender_id: SenderId::first(),
     279            0 :                     tx,
     280            0 :                     next_record_lsn: start_pos,
     281            0 :                 }],
     282              :             )]),
     283            3 :             shard_notification_rx: Some(shard_notification_rx),
     284            3 :             state: state.clone(),
     285            3 :             pg_version,
     286              :         };
     287              : 
     288            3 :         let metric = WAL_READERS
     289            3 :             .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")])
     290            3 :             .unwrap();
     291              : 
     292            3 :         let join_handle = tokio::task::spawn(
     293            3 :             async move {
     294            3 :                 metric.inc();
     295            3 :                 scopeguard::defer! {
     296              :                     metric.dec();
     297              :                 }
     298              : 
     299            3 :                 reader
     300            3 :                     .run_impl(start_pos)
     301            3 :                     .await
     302            0 :                     .inspect_err(|err| match err {
     303              :                         // TODO: we may want to differentiate these errors further.
     304              :                         InterpretedWalReaderError::Decode(_) => {
     305            0 :                             critical_timeline!(
     306            0 :                                 ttid.tenant_id,
     307            0 :                                 ttid.timeline_id,
     308            0 :                                 "failed to read WAL record: {err:?}"
     309              :                             );
     310              :                         }
     311            0 :                         err => error!("failed to read WAL record: {err}"),
     312            0 :                     })
     313            0 :             }
     314            3 :             .instrument(info_span!("interpreted wal reader")),
     315              :         );
     316              : 
     317            3 :         InterpretedWalReaderHandle {
     318            3 :             join_handle,
     319            3 :             state,
     320            3 :             shard_notification_tx,
     321            3 :         }
     322            3 :     }
     323              : 
     324              :     /// Construct the reader without spawning anything
     325              :     /// Callers should drive the future returned by [`Self::run`].
     326            1 :     pub(crate) fn new(
     327            1 :         wal_stream: StreamingWalReader,
     328            1 :         start_pos: Lsn,
     329            1 :         tx: tokio::sync::mpsc::Sender<Batch>,
     330            1 :         shard: ShardIdentity,
     331            1 :         pg_version: PgMajorVersion,
     332            1 :         shard_notification_rx: Option<
     333            1 :             tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>,
     334            1 :         >,
     335            1 :     ) -> InterpretedWalReader {
     336            1 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     337            1 :             current_position: start_pos,
     338            1 :             current_batch_wal_start: None,
     339            1 :         }));
     340              : 
     341              :         InterpretedWalReader {
     342            1 :             wal_stream,
     343            1 :             shard_senders: HashMap::from([(
     344            1 :                 shard,
     345            1 :                 smallvec::smallvec![ShardSenderState {
     346            0 :                     sender_id: SenderId::first(),
     347            0 :                     tx,
     348            0 :                     next_record_lsn: start_pos,
     349            0 :                 }],
     350              :             )]),
     351            1 :             shard_notification_rx,
     352            1 :             state: state.clone(),
     353            1 :             pg_version,
     354              :         }
     355            1 :     }
     356              : 
     357              :     /// Entry point for future (polling) based wal reader.
     358            1 :     pub(crate) async fn run(
     359            1 :         self,
     360            1 :         start_pos: Lsn,
     361            1 :         appname: &Option<String>,
     362            1 :     ) -> Result<(), CopyStreamHandlerEnd> {
     363            1 :         let metric = WAL_READERS
     364            1 :             .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")])
     365            1 :             .unwrap();
     366              : 
     367            1 :         metric.inc();
     368            1 :         scopeguard::defer! {
     369              :             metric.dec();
     370              :         }
     371              : 
     372            1 :         let ttid = self.wal_stream.ttid;
     373            1 :         match self.run_impl(start_pos).await {
     374            0 :             Err(err @ InterpretedWalReaderError::Decode(_)) => {
     375            0 :                 critical_timeline!(
     376            0 :                     ttid.tenant_id,
     377            0 :                     ttid.timeline_id,
     378            0 :                     "failed to decode WAL record: {err:?}"
     379              :                 );
     380              :             }
     381            0 :             Err(err) => error!("failed to read WAL record: {err}"),
     382            0 :             Ok(()) => info!("interpreted wal reader exiting"),
     383              :         }
     384              : 
     385            0 :         Err(CopyStreamHandlerEnd::Other(anyhow!(
     386            0 :             "interpreted wal reader finished"
     387            0 :         )))
     388            0 :     }
     389              : 
     390              :     /// Send interpreted WAL to one or more [`InterpretedWalSender`]s
     391              :     /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`]
     392              :     /// goes out of scope.
     393            4 :     async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> {
     394            4 :         let defer_state = self.state.clone();
     395            4 :         scopeguard::defer! {
     396              :             *defer_state.write().unwrap() = InterpretedWalReaderState::Done;
     397              :         }
     398              : 
     399            4 :         let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version);
     400              : 
     401              :         loop {
     402           58 :             tokio::select! {
     403              :                 // Main branch for reading WAL and forwarding it
     404           58 :                 wal_or_reset = self.wal_stream.next() => {
     405           50 :                     let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below"));
     406              :                     let WalBytes {
     407           50 :                         wal,
     408           50 :                         wal_start_lsn,
     409           50 :                         wal_end_lsn,
     410           50 :                         available_wal_end_lsn,
     411           50 :                     } = match wal {
     412           50 :                         Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?,
     413              :                         None => {
     414              :                             // [`StreamingWalReader::next`] is an endless stream of WAL.
     415              :                             // It shouldn't ever finish unless it panicked or became internally
     416              :                             // inconsistent.
     417            0 :                             return Result::Err(InterpretedWalReaderError::WalStreamClosed);
     418              :                         }
     419              :                     };
     420              : 
     421           50 :                     self.state.write().unwrap().update_current_batch_wal_start(wal_start_lsn);
     422              : 
     423           50 :                     wal_decoder.feed_bytes(&wal);
     424              : 
     425              :                     // Deserialize and interpret WAL records from this batch of WAL.
     426              :                     // Interpreted records for each shard are collected separately.
     427           50 :                     let shard_ids = self.shard_senders.keys().copied().collect::<Vec<_>>();
     428           50 :                     let mut records_by_sender: HashMap<ShardSenderId, Vec<InterpretedWalRecord>> = HashMap::new();
     429           50 :                     let mut max_next_record_lsn = None;
     430           50 :                     let mut max_end_record_lsn = None;
     431          656 :                     while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()?
     432              :                     {
     433          606 :                         assert!(next_record_lsn.is_aligned());
     434          606 :                         max_next_record_lsn = Some(next_record_lsn);
     435          606 :                         max_end_record_lsn = Some(wal_decoder.lsn());
     436              : 
     437          606 :                         let interpreted = InterpretedWalRecord::from_bytes_filtered(
     438          606 :                             recdata,
     439          606 :                             &shard_ids,
     440          606 :                             next_record_lsn,
     441          606 :                             self.pg_version,
     442              :                         )
     443          606 :                         .with_context(|| "Failed to interpret WAL")?;
     444              : 
     445         1412 :                         for (shard, record) in interpreted {
     446              :                             // Shard zero needs to track the start LSN of the latest record
     447              :                             // in adition to the LSN of the next record to ingest. The former
     448              :                             // is included in basebackup persisted by the compute in WAL.
     449          806 :                             if !shard.is_shard_zero() && record.is_empty() {
     450          200 :                                 continue;
     451          606 :                             }
     452              : 
     453          606 :                             let mut states_iter = self.shard_senders
     454          606 :                                 .get(&shard)
     455          606 :                                 .expect("keys collected above")
     456          606 :                                 .iter()
     457         1002 :                                 .filter(|state| record.next_record_lsn > state.next_record_lsn)
     458          606 :                                 .peekable();
     459          996 :                             while let Some(state) = states_iter.next() {
     460          796 :                                 let shard_sender_id = ShardSenderId::new(shard, state.sender_id);
     461              : 
     462              :                                 // The most commont case is one sender per shard. Peek and break to avoid the
     463              :                                 // clone in that situation.
     464          796 :                                 if states_iter.peek().is_none() {
     465          406 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record);
     466          406 :                                     break;
     467          390 :                                 } else {
     468          390 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record.clone());
     469          390 :                                 }
     470              :                             }
     471              :                         }
     472              :                     }
     473              : 
     474           50 :                     let max_next_record_lsn = match max_next_record_lsn {
     475           41 :                         Some(lsn) => lsn,
     476              :                         None => {
     477            9 :                             continue;
     478              :                         }
     479              :                     };
     480              : 
     481              :                     // Update the current position such that new receivers can decide
     482              :                     // whether to attach to us or spawn a new WAL reader.
     483           41 :                     let batch_wal_start_lsn = {
     484           41 :                         let mut guard = self.state.write().unwrap();
     485           41 :                         guard.update_current_position(max_next_record_lsn);
     486           41 :                         guard.replace_current_batch_wal_start(max_end_record_lsn.unwrap())
     487              :                     };
     488              : 
     489              :                     // Send interpreted records downstream. Anything that has already been seen
     490              :                     // by a shard is filtered out.
     491           41 :                     let mut shard_senders_to_remove = Vec::new();
     492           96 :                     for (shard, states) in &mut self.shard_senders {
     493          136 :                         for state in states {
     494           81 :                             let shard_sender_id = ShardSenderId::new(*shard, state.sender_id);
     495              : 
     496           81 :                             let batch = if max_next_record_lsn > state.next_record_lsn {
     497              :                                 // This batch contains at least one record that this shard has not
     498              :                                 // seen yet.
     499           67 :                                 let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default();
     500              : 
     501           67 :                                 InterpretedWalRecords {
     502           67 :                                     records,
     503           67 :                                     next_record_lsn: max_next_record_lsn,
     504           67 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     505           67 :                                 }
     506           14 :                             } else if wal_end_lsn > state.next_record_lsn {
     507              :                                 // All the records in this batch were seen by the shard
     508              :                                 // However, the batch maps to a chunk of WAL that the
     509              :                                 // shard has not yet seen. Notify it of the start LSN
     510              :                                 // of the PG WAL chunk such that it doesn't look like a gap.
     511            0 :                                 InterpretedWalRecords {
     512            0 :                                     records: Vec::default(),
     513            0 :                                     next_record_lsn: state.next_record_lsn,
     514            0 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     515            0 :                                 }
     516              :                             } else {
     517              :                                 // The shard has seen this chunk of WAL before. Skip it.
     518           14 :                                 continue;
     519              :                             };
     520              : 
     521           67 :                             let res = state.tx.send(Batch {
     522           67 :                                 wal_end_lsn,
     523           67 :                                 available_wal_end_lsn,
     524           67 :                                 records: batch,
     525           67 :                             }).await;
     526              : 
     527           67 :                             if res.is_err() {
     528            0 :                                 shard_senders_to_remove.push(shard_sender_id);
     529           67 :                             } else {
     530           67 :                                 state.next_record_lsn = std::cmp::max(state.next_record_lsn, max_next_record_lsn);
     531           67 :                             }
     532              :                         }
     533              :                     }
     534              : 
     535              :                     // Clean up any shard senders that have dropped out.
     536              :                     // This is inefficient, but such events are rare (connection to PS termination)
     537              :                     // and the number of subscriptions on the same shards very small (only one
     538              :                     // for the steady state).
     539           41 :                     for to_remove in shard_senders_to_remove {
     540            0 :                         let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above");
     541            0 :                         if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) {
     542            0 :                             shard_senders.remove(idx);
     543            0 :                             tracing::info!("Removed shard sender {}", to_remove);
     544            0 :                         }
     545              : 
     546            0 :                         if shard_senders.is_empty() {
     547            0 :                             self.shard_senders.remove(&to_remove.shard());
     548            0 :                         }
     549              :                     }
     550              :                 },
     551              :                 // Listen for new shards that want to attach to this reader.
     552              :                 // If the reader is not running as a task, then this is not supported
     553              :                 // (see the pending branch below).
     554           58 :                 notification = match self.shard_notification_rx.as_mut() {
     555           58 :                         Some(rx) => Either::Left(rx.recv()),
     556            0 :                         None => Either::Right(std::future::pending())
     557              :                     } => {
     558            4 :                     if let Some(n) = notification {
     559            4 :                         let AttachShardNotification { shard_id, sender, start_pos } = n;
     560              : 
     561              :                         // Update internal and external state, then reset the WAL stream
     562              :                         // if required.
     563            4 :                         let senders = self.shard_senders.entry(shard_id).or_default();
     564              : 
     565              :                         // Clean up any shard senders that have dropped out before adding the new
     566              :                         // one. This avoids a build up of dead senders.
     567            4 :                         senders.retain(|sender| {
     568            3 :                             let closed = sender.tx.is_closed();
     569              : 
     570            3 :                             if closed {
     571            0 :                                 let sender_id = ShardSenderId::new(shard_id, sender.sender_id);
     572            0 :                                 tracing::info!("Removed shard sender {}", sender_id);
     573            3 :                             }
     574              : 
     575            3 :                             !closed
     576            3 :                         });
     577              : 
     578            4 :                         let new_sender_id = match senders.last() {
     579            2 :                             Some(sender) => sender.sender_id.next(),
     580            2 :                             None => SenderId::first()
     581              :                         };
     582              : 
     583            4 :                         senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos});
     584              : 
     585              :                         // If the shard is subscribing below the current position the we need
     586              :                         // to update the cursor that tracks where we are at in the WAL
     587              :                         // ([`Self::state`]) and reset the WAL stream itself
     588              :                         // (`[Self::wal_stream`]). This must be done atomically from the POV of
     589              :                         // anything outside the select statement.
     590            4 :                         let position_reset = self.state.write().unwrap().maybe_reset(start_pos);
     591            4 :                         match position_reset {
     592            3 :                             CurrentPositionUpdate::Reset { from: _, to } => {
     593            3 :                                 self.wal_stream.reset(to).await;
     594            3 :                                 wal_decoder = WalStreamDecoder::new(to, self.pg_version);
     595              :                             },
     596            1 :                             CurrentPositionUpdate::NotReset(_) => {}
     597              :                         };
     598              : 
     599            4 :                         tracing::info!(
     600            0 :                             "Added shard sender {} with start_pos={} previous_pos={} current_pos={}",
     601            0 :                             ShardSenderId::new(shard_id, new_sender_id),
     602              :                             start_pos,
     603            0 :                             position_reset.previous_position(),
     604            0 :                             position_reset.current_position(),
     605              :                         );
     606            0 :                     }
     607              :                 }
     608              :             }
     609              :         }
     610            0 :     }
     611              : 
     612              :     #[cfg(test)]
     613            1 :     fn state(&self) -> Arc<std::sync::RwLock<InterpretedWalReaderState>> {
     614            1 :         self.state.clone()
     615            1 :     }
     616              : }
     617              : 
     618              : impl InterpretedWalReaderHandle {
     619              :     /// Fan-out the reader by attaching a new shard to it
     620            3 :     pub(crate) fn fanout(
     621            3 :         &self,
     622            3 :         shard_id: ShardIdentity,
     623            3 :         sender: tokio::sync::mpsc::Sender<Batch>,
     624            3 :         start_pos: Lsn,
     625            3 :     ) -> Result<(), SendError<AttachShardNotification>> {
     626            3 :         self.shard_notification_tx.send(AttachShardNotification {
     627            3 :             shard_id,
     628            3 :             sender,
     629            3 :             start_pos,
     630            3 :         })
     631            3 :     }
     632              : 
     633              :     /// Get the current WAL position of the reader
     634            6 :     pub(crate) fn current_position(&self) -> Option<Lsn> {
     635            6 :         self.state.read().unwrap().current_position()
     636            6 :     }
     637              : 
     638            6 :     pub(crate) fn abort(&self) {
     639            6 :         self.join_handle.abort()
     640            6 :     }
     641              : }
     642              : 
     643              : impl Drop for InterpretedWalReaderHandle {
     644            3 :     fn drop(&mut self) {
     645            3 :         tracing::info!("Aborting interpreted wal reader");
     646            3 :         self.abort()
     647            3 :     }
     648              : }
     649              : 
     650              : pub(crate) struct InterpretedWalSender<'a, IO> {
     651              :     pub(crate) format: InterpretedFormat,
     652              :     pub(crate) compression: Option<Compression>,
     653              :     pub(crate) appname: Option<String>,
     654              : 
     655              :     pub(crate) tli: WalResidentTimeline,
     656              :     pub(crate) start_lsn: Lsn,
     657              : 
     658              :     pub(crate) pgb: &'a mut PostgresBackend<IO>,
     659              :     pub(crate) end_watch_view: EndWatchView,
     660              :     pub(crate) wal_sender_guard: Arc<WalSenderGuard>,
     661              :     pub(crate) rx: tokio::sync::mpsc::Receiver<Batch>,
     662              : }
     663              : 
     664              : impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
     665              :     /// Send interpreted WAL records over the network.
     666              :     /// Also manages keep-alives if nothing was sent for a while.
     667            0 :     pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> {
     668            0 :         let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
     669            0 :         keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
     670            0 :         keepalive_ticker.reset();
     671              : 
     672            0 :         let mut wal_position = self.start_lsn;
     673              : 
     674              :         loop {
     675            0 :             tokio::select! {
     676            0 :                 batch = self.rx.recv() => {
     677            0 :                     let batch = match batch {
     678            0 :                         Some(b) => b,
     679              :                         None => {
     680            0 :                             return Result::Err(
     681            0 :                                 CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early"))
     682            0 :                             );
     683              :                         }
     684              :                     };
     685              : 
     686            0 :                     wal_position = batch.wal_end_lsn;
     687              : 
     688            0 :                     let buf = batch
     689            0 :                         .records
     690            0 :                         .to_wire(self.format, self.compression)
     691            0 :                         .await
     692            0 :                         .with_context(|| "Failed to serialize interpreted WAL")
     693            0 :                         .map_err(CopyStreamHandlerEnd::from)?;
     694              : 
     695              :                     // Reset the keep alive ticker since we are sending something
     696              :                     // over the wire now.
     697            0 :                     keepalive_ticker.reset();
     698              : 
     699            0 :                     self.pgb
     700            0 :                         .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
     701            0 :                             streaming_lsn: batch.wal_end_lsn.0,
     702            0 :                             commit_lsn: batch.available_wal_end_lsn.0,
     703            0 :                             data: &buf,
     704            0 :                         })).await?;
     705              :                 }
     706              :                 // Send a periodic keep alive when the connection has been idle for a while.
     707              :                 // Since we've been idle, also check if we can stop streaming.
     708            0 :                 _ = keepalive_ticker.tick() => {
     709            0 :                     if let Some(remote_consistent_lsn) = self.wal_sender_guard
     710            0 :                         .walsenders()
     711            0 :                         .get_ws_remote_consistent_lsn(self.wal_sender_guard.id())
     712              :                     {
     713            0 :                         if self.tli.should_walsender_stop(remote_consistent_lsn).await {
     714              :                             // Stop streaming if the receivers are caught up and
     715              :                             // there's no active compute. This causes the loop in
     716              :                             // [`crate::send_interpreted_wal::InterpretedWalSender::run`]
     717              :                             // to exit and terminate the WAL stream.
     718            0 :                             break;
     719            0 :                         }
     720            0 :                     }
     721              : 
     722            0 :                     self.pgb
     723            0 :                         .write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
     724            0 :                             wal_end: self.end_watch_view.get().0,
     725            0 :                             timestamp: get_current_timestamp(),
     726            0 :                             request_reply: true,
     727            0 :                         }))
     728            0 :                         .await?;
     729              :                 },
     730              :             }
     731              :         }
     732              : 
     733            0 :         Err(CopyStreamHandlerEnd::ServerInitiated(format!(
     734            0 :             "ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
     735            0 :             self.appname, wal_position,
     736            0 :         )))
     737            0 :     }
     738              : }
     739              : #[cfg(test)]
     740              : mod tests {
     741              :     use std::collections::HashMap;
     742              :     use std::str::FromStr;
     743              :     use std::time::Duration;
     744              : 
     745              :     use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardIdentity};
     746              :     use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion};
     747              :     use tokio::sync::mpsc::error::TryRecvError;
     748              :     use utils::id::{NodeId, TenantTimelineId};
     749              :     use utils::lsn::Lsn;
     750              :     use utils::shard::{ShardCount, ShardNumber};
     751              : 
     752              :     use crate::send_interpreted_wal::{AttachShardNotification, Batch, InterpretedWalReader};
     753              :     use crate::test_utils::Env;
     754              :     use crate::wal_reader_stream::StreamingWalReader;
     755              : 
     756              :     #[tokio::test]
     757            1 :     async fn test_interpreted_wal_reader_fanout() {
     758            1 :         let _ = env_logger::builder().is_test(true).try_init();
     759              : 
     760              :         const SIZE: usize = 8 * 1024;
     761              :         const MSG_COUNT: usize = 200;
     762              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
     763              :         const SHARD_COUNT: u8 = 2;
     764              : 
     765            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     766            1 :         let env = Env::new(true).unwrap();
     767            1 :         let tli = env
     768            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     769            1 :             .await
     770            1 :             .unwrap();
     771              : 
     772            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     773            1 :         let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, c"neon-file:", None)
     774            1 :             .await
     775            1 :             .unwrap();
     776            1 :         let end_pos = end_watch.get();
     777              : 
     778            1 :         tracing::info!("Doing first round of reads ...");
     779              : 
     780            1 :         let streaming_wal_reader = StreamingWalReader::new(
     781            1 :             resident_tli,
     782            1 :             None,
     783            1 :             start_lsn,
     784            1 :             end_pos,
     785            1 :             end_watch,
     786              :             MAX_SEND_SIZE,
     787              :         );
     788              : 
     789            1 :         let shard_0 =
     790            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     791            1 :                 .unwrap();
     792              : 
     793            1 :         let shard_1 =
     794            1 :             ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     795            1 :                 .unwrap();
     796              : 
     797            1 :         let mut shards = HashMap::new();
     798              : 
     799            3 :         for shard_number in 0..SHARD_COUNT {
     800            2 :             let shard_id = ShardIdentity::new(
     801            2 :                 ShardNumber(shard_number),
     802            2 :                 ShardCount(SHARD_COUNT),
     803            2 :                 DEFAULT_STRIPE_SIZE,
     804            2 :             )
     805            2 :             .unwrap();
     806            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     807            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
     808            2 :         }
     809              : 
     810            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
     811            1 :         let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap();
     812              : 
     813            1 :         let handle = InterpretedWalReader::spawn(
     814            1 :             streaming_wal_reader,
     815            1 :             start_lsn,
     816            1 :             shard_0_tx,
     817            1 :             shard_0,
     818              :             PG_VERSION,
     819            1 :             &Some("pageserver".to_string()),
     820              :         );
     821              : 
     822            1 :         tracing::info!("Reading all WAL with only shard 0 attached ...");
     823              : 
     824            1 :         let mut shard_0_interpreted_records = Vec::new();
     825           13 :         while let Some(batch) = shard_0_rx.recv().await {
     826           13 :             shard_0_interpreted_records.push(batch.records);
     827           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     828            1 :                 break;
     829           12 :             }
     830              :         }
     831              : 
     832            1 :         let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap();
     833            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
     834              : 
     835            1 :         tracing::info!("Attaching shard 1 to the reader at start of WAL");
     836            1 :         handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap();
     837              : 
     838            1 :         tracing::info!("Reading all WAL with shard 0 and shard 1 attached ...");
     839              : 
     840            1 :         let mut shard_1_interpreted_records = Vec::new();
     841           13 :         while let Some(batch) = shard_1_rx.recv().await {
     842           13 :             shard_1_interpreted_records.push(batch.records);
     843           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     844            1 :                 break;
     845           12 :             }
     846              :         }
     847              : 
     848              :         // This test uses logical messages. Those only go to shard 0. Check that the
     849              :         // filtering worked and shard 1 did not get any.
     850            1 :         assert!(
     851            1 :             shard_1_interpreted_records
     852            1 :                 .iter()
     853           13 :                 .all(|recs| recs.records.is_empty())
     854              :         );
     855              : 
     856              :         // Shard 0 should not receive anything more since the reader is
     857              :         // going through wal that it has already processed.
     858            1 :         let res = shard_0_rx.try_recv();
     859            1 :         if let Ok(ref ok) = res {
     860            0 :             tracing::error!(
     861            0 :                 "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}",
     862              :                 ok.wal_end_lsn,
     863              :                 ok.available_wal_end_lsn
     864              :             );
     865            1 :         }
     866            1 :         assert!(matches!(res, Err(TryRecvError::Empty)));
     867              : 
     868              :         // Check that the next records lsns received by the two shards match up.
     869            1 :         let shard_0_next_lsns = shard_0_interpreted_records
     870            1 :             .iter()
     871            1 :             .map(|recs| recs.next_record_lsn)
     872            1 :             .collect::<Vec<_>>();
     873            1 :         let shard_1_next_lsns = shard_1_interpreted_records
     874            1 :             .iter()
     875            1 :             .map(|recs| recs.next_record_lsn)
     876            1 :             .collect::<Vec<_>>();
     877            1 :         assert_eq!(shard_0_next_lsns, shard_1_next_lsns);
     878              : 
     879            1 :         handle.abort();
     880            1 :         let mut done = false;
     881            2 :         for _ in 0..5 {
     882            2 :             if handle.current_position().is_none() {
     883            1 :                 done = true;
     884            1 :                 break;
     885            1 :             }
     886            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     887            1 :         }
     888            1 : 
     889            1 :         assert!(done);
     890            1 :     }
     891              : 
     892              :     #[tokio::test]
     893            1 :     async fn test_interpreted_wal_reader_same_shard_fanout() {
     894            1 :         let _ = env_logger::builder().is_test(true).try_init();
     895              : 
     896              :         const SIZE: usize = 8 * 1024;
     897              :         const MSG_COUNT: usize = 200;
     898              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
     899              :         const SHARD_COUNT: u8 = 2;
     900              : 
     901            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     902            1 :         let env = Env::new(true).unwrap();
     903            1 :         let tli = env
     904            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     905            1 :             .await
     906            1 :             .unwrap();
     907              : 
     908            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     909            1 :         let mut next_record_lsns = Vec::default();
     910            1 :         let end_watch = Env::write_wal(
     911            1 :             tli,
     912            1 :             start_lsn,
     913            1 :             SIZE,
     914            1 :             MSG_COUNT,
     915            1 :             c"neon-file:",
     916            1 :             Some(&mut next_record_lsns),
     917            1 :         )
     918            1 :         .await
     919            1 :         .unwrap();
     920            1 :         let end_pos = end_watch.get();
     921              : 
     922            1 :         let streaming_wal_reader = StreamingWalReader::new(
     923            1 :             resident_tli,
     924            1 :             None,
     925            1 :             start_lsn,
     926            1 :             end_pos,
     927            1 :             end_watch,
     928              :             MAX_SEND_SIZE,
     929              :         );
     930              : 
     931            1 :         let shard_0 =
     932            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     933            1 :                 .unwrap();
     934              : 
     935              :         struct Sender {
     936              :             tx: Option<tokio::sync::mpsc::Sender<Batch>>,
     937              :             rx: tokio::sync::mpsc::Receiver<Batch>,
     938              :             shard: ShardIdentity,
     939              :             start_lsn: Lsn,
     940              :             received_next_record_lsns: Vec<Lsn>,
     941              :         }
     942              : 
     943              :         impl Sender {
     944            3 :             fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self {
     945            3 :                 let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     946            3 :                 Self {
     947            3 :                     tx: Some(tx),
     948            3 :                     rx,
     949            3 :                     shard,
     950            3 :                     start_lsn,
     951            3 :                     received_next_record_lsns: Vec::default(),
     952            3 :                 }
     953            3 :             }
     954              :         }
     955              : 
     956            1 :         assert!(next_record_lsns.len() > 7);
     957            1 :         let start_lsns = vec![
     958            1 :             next_record_lsns[5],
     959            1 :             next_record_lsns[1],
     960            1 :             next_record_lsns[3],
     961              :         ];
     962            1 :         let mut senders = start_lsns
     963            1 :             .into_iter()
     964            3 :             .map(|lsn| Sender::new(lsn, shard_0))
     965            1 :             .collect::<Vec<_>>();
     966              : 
     967            1 :         let first_sender = senders.first_mut().unwrap();
     968            1 :         let handle = InterpretedWalReader::spawn(
     969            1 :             streaming_wal_reader,
     970            1 :             first_sender.start_lsn,
     971            1 :             first_sender.tx.take().unwrap(),
     972            1 :             first_sender.shard,
     973              :             PG_VERSION,
     974            1 :             &Some("pageserver".to_string()),
     975              :         );
     976              : 
     977            2 :         for sender in senders.iter_mut().skip(1) {
     978            2 :             handle
     979            2 :                 .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn)
     980            2 :                 .unwrap();
     981            2 :         }
     982              : 
     983            3 :         for sender in senders.iter_mut() {
     984              :             loop {
     985           39 :                 let batch = sender.rx.recv().await.unwrap();
     986           39 :                 tracing::info!(
     987            0 :                     "Sender with start_lsn={} received batch ending at {} with {} records",
     988              :                     sender.start_lsn,
     989              :                     batch.wal_end_lsn,
     990            0 :                     batch.records.records.len()
     991              :                 );
     992              : 
     993          627 :                 for rec in batch.records.records {
     994          588 :                     sender.received_next_record_lsns.push(rec.next_record_lsn);
     995          588 :                 }
     996              : 
     997           39 :                 if batch.wal_end_lsn == batch.available_wal_end_lsn {
     998            3 :                     break;
     999           36 :                 }
    1000              :             }
    1001              :         }
    1002              : 
    1003            1 :         handle.abort();
    1004            1 :         let mut done = false;
    1005            2 :         for _ in 0..5 {
    1006            2 :             if handle.current_position().is_none() {
    1007            1 :                 done = true;
    1008            1 :                 break;
    1009            1 :             }
    1010            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
    1011              :         }
    1012              : 
    1013            1 :         assert!(done);
    1014              : 
    1015            4 :         for sender in senders {
    1016            3 :             tracing::info!(
    1017            1 :                 "Validating records received by sender with start_lsn={}",
    1018            1 :                 sender.start_lsn
    1019            1 :             );
    1020            1 : 
    1021            3 :             assert!(sender.received_next_record_lsns.is_sorted());
    1022            3 :             let expected = next_record_lsns
    1023            3 :                 .iter()
    1024          600 :                 .filter(|lsn| **lsn > sender.start_lsn)
    1025            3 :                 .copied()
    1026            3 :                 .collect::<Vec<_>>();
    1027            3 :             assert_eq!(sender.received_next_record_lsns, expected);
    1028            1 :         }
    1029            1 :     }
    1030              : 
    1031              :     #[tokio::test]
    1032            1 :     async fn test_batch_start_tracking_on_reset() {
    1033              :         // When the WAL stream is reset to an older LSN,
    1034              :         // the current batch start LSN should be invalidated.
    1035              :         // This test constructs such a scenario:
    1036              :         // 1. Shard 0 is reading somewhere ahead
    1037              :         // 2. Reader reads some WAL, but does not decode a full record (partial read)
    1038              :         // 3. Shard 1 attaches to the reader and resets it to an older LSN
    1039              :         // 4. Shard 1 should get the correct batch WAL start LSN
    1040            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1041              : 
    1042              :         const SIZE: usize = 64 * 1024;
    1043              :         const MSG_COUNT: usize = 10;
    1044              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
    1045              :         const SHARD_COUNT: u8 = 2;
    1046              :         const WAL_READER_BATCH_SIZE: usize = 8192;
    1047              : 
    1048            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1049            1 :         let env = Env::new(true).unwrap();
    1050            1 :         let mut next_record_lsns = Vec::default();
    1051            1 :         let tli = env
    1052            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1053            1 :             .await
    1054            1 :             .unwrap();
    1055              : 
    1056            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1057            1 :         let end_watch = Env::write_wal(
    1058            1 :             tli,
    1059            1 :             start_lsn,
    1060            1 :             SIZE,
    1061            1 :             MSG_COUNT,
    1062            1 :             c"neon-file:",
    1063            1 :             Some(&mut next_record_lsns),
    1064            1 :         )
    1065            1 :         .await
    1066            1 :         .unwrap();
    1067              : 
    1068            1 :         assert!(next_record_lsns.len() > 3);
    1069            1 :         let shard_0_start_lsn = next_record_lsns[3];
    1070              : 
    1071            1 :         let end_pos = end_watch.get();
    1072              : 
    1073            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1074            1 :             resident_tli,
    1075            1 :             None,
    1076            1 :             shard_0_start_lsn,
    1077            1 :             end_pos,
    1078            1 :             end_watch,
    1079              :             WAL_READER_BATCH_SIZE,
    1080              :         );
    1081              : 
    1082            1 :         let shard_0 =
    1083            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
    1084            1 :                 .unwrap();
    1085              : 
    1086            1 :         let shard_1 =
    1087            1 :             ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
    1088            1 :                 .unwrap();
    1089              : 
    1090            1 :         let mut shards = HashMap::new();
    1091              : 
    1092            3 :         for shard_number in 0..SHARD_COUNT {
    1093            2 :             let shard_id = ShardIdentity::new(
    1094            2 :                 ShardNumber(shard_number),
    1095            2 :                 ShardCount(SHARD_COUNT),
    1096            2 :                 DEFAULT_STRIPE_SIZE,
    1097            2 :             )
    1098            2 :             .unwrap();
    1099            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1100            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
    1101            2 :         }
    1102              : 
    1103            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
    1104              : 
    1105            1 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
    1106              : 
    1107            1 :         let reader = InterpretedWalReader::new(
    1108            1 :             streaming_wal_reader,
    1109            1 :             shard_0_start_lsn,
    1110            1 :             shard_0_tx,
    1111            1 :             shard_0,
    1112              :             PG_VERSION,
    1113            1 :             Some(shard_notification_rx),
    1114              :         );
    1115              : 
    1116            1 :         let reader_state = reader.state();
    1117            1 :         let mut reader_fut = std::pin::pin!(reader.run(shard_0_start_lsn, &None));
    1118              :         loop {
    1119           39 :             let poll = futures::poll!(reader_fut.as_mut());
    1120           39 :             assert!(poll.is_pending());
    1121              : 
    1122           39 :             let guard = reader_state.read().unwrap();
    1123           39 :             if guard.current_batch_wal_start().is_some() {
    1124            1 :                 break;
    1125           38 :             }
    1126              :         }
    1127              : 
    1128            1 :         shard_notification_tx
    1129            1 :             .send(AttachShardNotification {
    1130            1 :                 shard_id: shard_1,
    1131            1 :                 sender: shards.get_mut(&shard_1).unwrap().0.take().unwrap(),
    1132            1 :                 start_pos: start_lsn,
    1133            1 :             })
    1134            1 :             .unwrap();
    1135              : 
    1136            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
    1137            1 :         loop {
    1138           91 :             let poll = futures::poll!(reader_fut.as_mut());
    1139           91 :             assert!(poll.is_pending());
    1140            1 : 
    1141           91 :             let try_recv_res = shard_1_rx.try_recv();
    1142           90 :             match try_recv_res {
    1143            1 :                 Ok(batch) => {
    1144            1 :                     assert_eq!(batch.records.raw_wal_start_lsn.unwrap(), start_lsn);
    1145            1 :                     break;
    1146            1 :                 }
    1147           90 :                 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
    1148            1 :                 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
    1149            1 :                     unreachable!();
    1150            1 :                 }
    1151            1 :             }
    1152            1 :         }
    1153            1 :     }
    1154              : 
    1155              :     #[tokio::test]
    1156            1 :     async fn test_shard_zero_does_not_skip_empty_records() {
    1157            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1158              : 
    1159              :         const SIZE: usize = 8 * 1024;
    1160              :         const MSG_COUNT: usize = 10;
    1161              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
    1162              : 
    1163            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1164            1 :         let env = Env::new(true).unwrap();
    1165            1 :         let tli = env
    1166            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1167            1 :             .await
    1168            1 :             .unwrap();
    1169              : 
    1170            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1171            1 :         let mut next_record_lsns = Vec::new();
    1172            1 :         let end_watch = Env::write_wal(
    1173            1 :             tli,
    1174            1 :             start_lsn,
    1175            1 :             SIZE,
    1176            1 :             MSG_COUNT,
    1177            1 :             // This is a logical message prefix that is not persisted to key value storage.
    1178            1 :             // We use it in order to validate that shard zero receives emtpy interpreted records.
    1179            1 :             c"test:",
    1180            1 :             Some(&mut next_record_lsns),
    1181            1 :         )
    1182            1 :         .await
    1183            1 :         .unwrap();
    1184            1 :         let end_pos = end_watch.get();
    1185              : 
    1186            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1187            1 :             resident_tli,
    1188            1 :             None,
    1189            1 :             start_lsn,
    1190            1 :             end_pos,
    1191            1 :             end_watch,
    1192              :             MAX_SEND_SIZE,
    1193              :         );
    1194              : 
    1195            1 :         let shard = ShardIdentity::unsharded();
    1196            1 :         let (records_tx, mut records_rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1197              : 
    1198            1 :         let handle = InterpretedWalReader::spawn(
    1199            1 :             streaming_wal_reader,
    1200            1 :             start_lsn,
    1201            1 :             records_tx,
    1202            1 :             shard,
    1203              :             PG_VERSION,
    1204            1 :             &Some("pageserver".to_string()),
    1205              :         );
    1206              : 
    1207            1 :         let mut interpreted_records = Vec::new();
    1208            1 :         while let Some(batch) = records_rx.recv().await {
    1209            1 :             interpreted_records.push(batch.records);
    1210            1 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
    1211            1 :                 break;
    1212            0 :             }
    1213              :         }
    1214              : 
    1215            1 :         let received_next_record_lsns = interpreted_records
    1216            1 :             .into_iter()
    1217            1 :             .flat_map(|b| b.records)
    1218            1 :             .map(|rec| rec.next_record_lsn)
    1219            1 :             .collect::<Vec<_>>();
    1220              : 
    1221              :         // By default this also includes the start LSN. Trim it since it shouldn't be received.
    1222            1 :         let next_record_lsns = next_record_lsns.into_iter().skip(1).collect::<Vec<_>>();
    1223              : 
    1224            1 :         assert_eq!(received_next_record_lsns, next_record_lsns);
    1225              : 
    1226            1 :         handle.abort();
    1227            1 :         let mut done = false;
    1228            2 :         for _ in 0..5 {
    1229            2 :             if handle.current_position().is_none() {
    1230            1 :                 done = true;
    1231            1 :                 break;
    1232            1 :             }
    1233            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
    1234            1 :         }
    1235            1 : 
    1236            1 :         assert!(done);
    1237            1 :     }
    1238              : }
         |