LCOV - code coverage report
Current view: top level - safekeeper/src - send_interpreted_wal.rs (source / functions) Coverage Total Hit
Test: 1d5975439f3c9882b18414799141ebf9a3922c58.info Lines: 82.3 % 712 586
Test Date: 2025-07-31 15:59:03 Functions: 81.8 % 44 36

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::fmt::Display;
       3              : use std::sync::Arc;
       4              : use std::sync::atomic::AtomicBool;
       5              : use std::time::Duration;
       6              : 
       7              : use anyhow::{Context, anyhow};
       8              : use futures::StreamExt;
       9              : use futures::future::Either;
      10              : use pageserver_api::shard::ShardIdentity;
      11              : use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend};
      12              : use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder};
      13              : use postgres_ffi::{PgMajorVersion, get_current_timestamp};
      14              : use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
      15              : use tokio::io::{AsyncRead, AsyncWrite};
      16              : use tokio::sync::mpsc::error::SendError;
      17              : use tokio::task::JoinHandle;
      18              : use tokio::time::MissedTickBehavior;
      19              : use tracing::{Instrument, error, info, info_span};
      20              : use utils::critical_timeline;
      21              : use utils::lsn::Lsn;
      22              : use utils::postgres_client::{Compression, InterpretedFormat};
      23              : use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
      24              : use wal_decoder::wire_format::ToWireFormat;
      25              : 
      26              : use crate::metrics::WAL_READERS;
      27              : use crate::send_wal::{EndWatchView, WalSenderGuard};
      28              : use crate::timeline::WalResidentTimeline;
      29              : use crate::wal_reader_stream::{StreamingWalReader, WalBytes};
      30              : 
      31              : /// Identifier used to differentiate between senders of the same
      32              : /// shard.
      33              : ///
      34              : /// In the steady state there's only one, but two pageservers may
      35              : /// temporarily have the same shard attached and attempt to ingest
      36              : /// WAL for it. See also [`ShardSenderId`].
      37              : #[derive(Hash, Eq, PartialEq, Copy, Clone)]
      38              : struct SenderId(u8);
      39              : 
      40              : impl SenderId {
      41            6 :     fn first() -> Self {
      42            6 :         SenderId(0)
      43            6 :     }
      44              : 
      45            2 :     fn next(&self) -> Self {
      46            2 :         SenderId(self.0.checked_add(1).expect("few senders"))
      47            2 :     }
      48              : }
      49              : 
      50              : #[derive(Hash, Eq, PartialEq)]
      51              : struct ShardSenderId {
      52              :     shard: ShardIdentity,
      53              :     sender_id: SenderId,
      54              : }
      55              : 
      56              : impl Display for ShardSenderId {
      57            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      58            0 :         write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug())
      59            0 :     }
      60              : }
      61              : 
      62              : impl ShardSenderId {
      63          877 :     fn new(shard: ShardIdentity, sender_id: SenderId) -> Self {
      64          877 :         ShardSenderId { shard, sender_id }
      65          877 :     }
      66              : 
      67            0 :     fn shard(&self) -> ShardIdentity {
      68            0 :         self.shard
      69            0 :     }
      70              : }
      71              : 
      72              : /// Shard-aware fan-out interpreted record reader.
      73              : /// Reads WAL from disk, decodes it, intepretets it, and sends
      74              : /// it to any [`InterpretedWalSender`] connected to it.
      75              : /// Each [`InterpretedWalSender`] corresponds to one shard
      76              : /// and gets interpreted records concerning that shard only.
      77              : pub(crate) struct InterpretedWalReader {
      78              :     wal_stream: StreamingWalReader,
      79              :     shard_senders: HashMap<ShardIdentity, smallvec::SmallVec<[ShardSenderState; 1]>>,
      80              :     shard_notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>>,
      81              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      82              :     pg_version: PgMajorVersion,
      83              : }
      84              : 
      85              : /// A handle for [`InterpretedWalReader`] which allows for interacting with it
      86              : /// when it runs as a separate tokio task.
      87              : #[derive(Debug)]
      88              : pub(crate) struct InterpretedWalReaderHandle {
      89              :     join_handle: JoinHandle<Result<(), InterpretedWalReaderError>>,
      90              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      91              :     shard_notification_tx: tokio::sync::mpsc::UnboundedSender<AttachShardNotification>,
      92              : }
      93              : 
      94              : struct ShardSenderState {
      95              :     sender_id: SenderId,
      96              :     tx: tokio::sync::mpsc::Sender<Batch>,
      97              :     next_record_lsn: Lsn,
      98              : }
      99              : 
     100              : /// State of [`InterpretedWalReader`] visible outside of the task running it.
     101              : #[derive(Debug)]
     102              : pub(crate) enum InterpretedWalReaderState {
     103              :     Running {
     104              :         current_position: Lsn,
     105              :         /// Tracks the start of the PG WAL LSN from which the current batch of
     106              :         /// interpreted records originated.
     107              :         current_batch_wal_start: Option<Lsn>,
     108              :     },
     109              :     Done,
     110              : }
     111              : 
     112              : pub(crate) struct Batch {
     113              :     wal_end_lsn: Lsn,
     114              :     available_wal_end_lsn: Lsn,
     115              :     records: InterpretedWalRecords,
     116              : }
     117              : 
     118              : #[derive(thiserror::Error, Debug)]
     119              : pub enum InterpretedWalReaderError {
     120              :     /// Handler initiates the end of streaming.
     121              :     #[error("decode error: {0}")]
     122              :     Decode(#[from] WalDecodeError),
     123              :     #[error("read or interpret error: {0}")]
     124              :     ReadOrInterpret(#[from] anyhow::Error),
     125              :     #[error("wal stream closed")]
     126              :     WalStreamClosed,
     127              : }
     128              : 
     129              : enum CurrentPositionUpdate {
     130              :     Reset { from: Lsn, to: Lsn },
     131              :     NotReset(Lsn),
     132              : }
     133              : 
     134              : impl CurrentPositionUpdate {
     135            0 :     fn current_position(&self) -> Lsn {
     136            0 :         match self {
     137            0 :             CurrentPositionUpdate::Reset { from: _, to } => *to,
     138            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     139              :         }
     140            0 :     }
     141              : 
     142            0 :     fn previous_position(&self) -> Lsn {
     143            0 :         match self {
     144            0 :             CurrentPositionUpdate::Reset { from, to: _ } => *from,
     145            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     146              :         }
     147            0 :     }
     148              : }
     149              : 
     150              : impl InterpretedWalReaderState {
     151            6 :     fn current_position(&self) -> Option<Lsn> {
     152            6 :         match self {
     153              :             InterpretedWalReaderState::Running {
     154            3 :                 current_position, ..
     155            3 :             } => Some(*current_position),
     156            3 :             InterpretedWalReaderState::Done => None,
     157              :         }
     158            6 :     }
     159              : 
     160              :     #[cfg(test)]
     161           64 :     fn current_batch_wal_start(&self) -> Option<Lsn> {
     162           64 :         match self {
     163              :             InterpretedWalReaderState::Running {
     164           64 :                 current_batch_wal_start,
     165              :                 ..
     166           64 :             } => *current_batch_wal_start,
     167            0 :             InterpretedWalReaderState::Done => None,
     168              :         }
     169           64 :     }
     170              : 
     171              :     // Reset the current position of the WAL reader if the requested starting position
     172              :     // of the new shard is smaller than the current value.
     173            4 :     fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate {
     174            4 :         match self {
     175              :             InterpretedWalReaderState::Running {
     176            4 :                 current_position,
     177            4 :                 current_batch_wal_start,
     178              :             } => {
     179            4 :                 if new_shard_start_pos < *current_position {
     180            3 :                     let from = *current_position;
     181            3 :                     *current_position = new_shard_start_pos;
     182            3 :                     *current_batch_wal_start = None;
     183            3 :                     CurrentPositionUpdate::Reset {
     184            3 :                         from,
     185            3 :                         to: *current_position,
     186            3 :                     }
     187              :                 } else {
     188              :                     // Edge case: The new shard is at the same current position as
     189              :                     // the reader. Note that the current position is WAL record aligned,
     190              :                     // so the reader might have done some partial reads and updated the
     191              :                     // batch start. If that's the case, adjust the batch start to match
     192              :                     // starting position of the new shard. It can lead to some shards
     193              :                     // seeing overlaps, but in that case the actual record LSNs are checked
     194              :                     // which should be fine based on the filtering logic.
     195            1 :                     if let Some(start) = current_batch_wal_start {
     196            0 :                         *start = std::cmp::min(*start, new_shard_start_pos);
     197            1 :                     }
     198            1 :                     CurrentPositionUpdate::NotReset(*current_position)
     199              :                 }
     200              :             }
     201              :             InterpretedWalReaderState::Done => {
     202            0 :                 panic!("maybe_reset called on finished reader")
     203              :             }
     204              :         }
     205            4 :     }
     206              : 
     207           51 :     fn update_current_batch_wal_start(&mut self, lsn: Lsn) {
     208           51 :         match self {
     209              :             InterpretedWalReaderState::Running {
     210           51 :                 current_batch_wal_start,
     211              :                 ..
     212              :             } => {
     213           51 :                 if current_batch_wal_start.is_none() {
     214            6 :                     *current_batch_wal_start = Some(lsn);
     215           45 :                 }
     216              :             }
     217              :             InterpretedWalReaderState::Done => {
     218            0 :                 panic!("update_current_batch_wal_start called on finished reader")
     219              :             }
     220              :         }
     221           51 :     }
     222              : 
     223           41 :     fn replace_current_batch_wal_start(&mut self, with: Lsn) -> Lsn {
     224           41 :         match self {
     225              :             InterpretedWalReaderState::Running {
     226           41 :                 current_batch_wal_start,
     227              :                 ..
     228           41 :             } => current_batch_wal_start.replace(with).unwrap(),
     229              :             InterpretedWalReaderState::Done => {
     230            0 :                 panic!("take_current_batch_wal_start called on finished reader")
     231              :             }
     232              :         }
     233           41 :     }
     234              : 
     235           41 :     fn update_current_position(&mut self, lsn: Lsn) {
     236           41 :         match self {
     237              :             InterpretedWalReaderState::Running {
     238           41 :                 current_position, ..
     239           41 :             } => {
     240           41 :                 *current_position = lsn;
     241           41 :             }
     242              :             InterpretedWalReaderState::Done => {
     243            0 :                 panic!("update_current_position called on finished reader")
     244              :             }
     245              :         }
     246           41 :     }
     247              : }
     248              : 
     249              : pub(crate) struct AttachShardNotification {
     250              :     shard_id: ShardIdentity,
     251              :     sender: tokio::sync::mpsc::Sender<Batch>,
     252              :     start_pos: Lsn,
     253              : }
     254              : 
     255              : impl InterpretedWalReader {
     256              :     /// Spawn the reader in a separate tokio task and return a handle
     257            3 :     pub(crate) fn spawn(
     258            3 :         wal_stream: StreamingWalReader,
     259            3 :         start_pos: Lsn,
     260            3 :         tx: tokio::sync::mpsc::Sender<Batch>,
     261            3 :         shard: ShardIdentity,
     262            3 :         pg_version: PgMajorVersion,
     263            3 :         appname: &Option<String>,
     264            3 :     ) -> InterpretedWalReaderHandle {
     265            3 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     266            3 :             current_position: start_pos,
     267            3 :             current_batch_wal_start: None,
     268            3 :         }));
     269              : 
     270            3 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
     271              : 
     272            3 :         let ttid = wal_stream.ttid;
     273              : 
     274            3 :         let reader = InterpretedWalReader {
     275            3 :             wal_stream,
     276            3 :             shard_senders: HashMap::from([(
     277            3 :                 shard,
     278            3 :                 smallvec::smallvec![ShardSenderState {
     279            0 :                     sender_id: SenderId::first(),
     280            0 :                     tx,
     281            0 :                     next_record_lsn: start_pos,
     282            0 :                 }],
     283              :             )]),
     284            3 :             shard_notification_rx: Some(shard_notification_rx),
     285            3 :             state: state.clone(),
     286            3 :             pg_version,
     287              :         };
     288              : 
     289            3 :         let metric = WAL_READERS
     290            3 :             .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")])
     291            3 :             .unwrap();
     292              : 
     293            3 :         let join_handle = tokio::task::spawn(
     294            3 :             async move {
     295            3 :                 metric.inc();
     296            3 :                 scopeguard::defer! {
     297              :                     metric.dec();
     298              :                 }
     299              : 
     300            3 :                 reader
     301            3 :                     .run_impl(start_pos)
     302            3 :                     .await
     303            0 :                     .inspect_err(|err| match err {
     304              :                         // TODO: we may want to differentiate these errors further.
     305              :                         InterpretedWalReaderError::Decode(_) => {
     306            0 :                             critical_timeline!(
     307            0 :                                 ttid.tenant_id,
     308            0 :                                 ttid.timeline_id,
     309              :                                 // Hadron: The corruption flag is only used in PS so that it can feed this information back to SKs.
     310              :                                 // We do not use these flags in SKs.
     311            0 :                                 None::<&AtomicBool>,
     312            0 :                                 "failed to read WAL record: {err:?}"
     313              :                             );
     314              :                         }
     315            0 :                         err => error!("failed to read WAL record: {err}"),
     316            0 :                     })
     317            0 :             }
     318            3 :             .instrument(info_span!("interpreted wal reader")),
     319              :         );
     320              : 
     321            3 :         InterpretedWalReaderHandle {
     322            3 :             join_handle,
     323            3 :             state,
     324            3 :             shard_notification_tx,
     325            3 :         }
     326            3 :     }
     327              : 
     328              :     /// Construct the reader without spawning anything
     329              :     /// Callers should drive the future returned by [`Self::run`].
     330            1 :     pub(crate) fn new(
     331            1 :         wal_stream: StreamingWalReader,
     332            1 :         start_pos: Lsn,
     333            1 :         tx: tokio::sync::mpsc::Sender<Batch>,
     334            1 :         shard: ShardIdentity,
     335            1 :         pg_version: PgMajorVersion,
     336            1 :         shard_notification_rx: Option<
     337            1 :             tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>,
     338            1 :         >,
     339            1 :     ) -> InterpretedWalReader {
     340            1 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     341            1 :             current_position: start_pos,
     342            1 :             current_batch_wal_start: None,
     343            1 :         }));
     344              : 
     345              :         InterpretedWalReader {
     346            1 :             wal_stream,
     347            1 :             shard_senders: HashMap::from([(
     348            1 :                 shard,
     349            1 :                 smallvec::smallvec![ShardSenderState {
     350            0 :                     sender_id: SenderId::first(),
     351            0 :                     tx,
     352            0 :                     next_record_lsn: start_pos,
     353            0 :                 }],
     354              :             )]),
     355            1 :             shard_notification_rx,
     356            1 :             state: state.clone(),
     357            1 :             pg_version,
     358              :         }
     359            1 :     }
     360              : 
     361              :     /// Entry point for future (polling) based wal reader.
     362            1 :     pub(crate) async fn run(
     363            1 :         self,
     364            1 :         start_pos: Lsn,
     365            1 :         appname: &Option<String>,
     366            1 :     ) -> Result<(), CopyStreamHandlerEnd> {
     367            1 :         let metric = WAL_READERS
     368            1 :             .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")])
     369            1 :             .unwrap();
     370              : 
     371            1 :         metric.inc();
     372            1 :         scopeguard::defer! {
     373              :             metric.dec();
     374              :         }
     375              : 
     376            1 :         let ttid = self.wal_stream.ttid;
     377            1 :         match self.run_impl(start_pos).await {
     378            0 :             Err(err @ InterpretedWalReaderError::Decode(_)) => {
     379            0 :                 critical_timeline!(
     380            0 :                     ttid.tenant_id,
     381            0 :                     ttid.timeline_id,
     382              :                     // Hadron: The corruption flag is only used in PS so that it can feed this information back to SKs.
     383              :                     // We do not use these flags in SKs.
     384            0 :                     None::<&AtomicBool>,
     385            0 :                     "failed to decode WAL record: {err:?}"
     386              :                 );
     387              :             }
     388            0 :             Err(err) => error!("failed to read WAL record: {err}"),
     389            0 :             Ok(()) => info!("interpreted wal reader exiting"),
     390              :         }
     391              : 
     392            0 :         Err(CopyStreamHandlerEnd::Other(anyhow!(
     393            0 :             "interpreted wal reader finished"
     394            0 :         )))
     395            0 :     }
     396              : 
     397              :     /// Send interpreted WAL to one or more [`InterpretedWalSender`]s
     398              :     /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`]
     399              :     /// goes out of scope.
     400            4 :     async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> {
     401            4 :         let defer_state = self.state.clone();
     402            4 :         scopeguard::defer! {
     403              :             *defer_state.write().unwrap() = InterpretedWalReaderState::Done;
     404              :         }
     405              : 
     406            4 :         let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version);
     407              : 
     408              :         loop {
     409           59 :             tokio::select! {
     410              :                 // Main branch for reading WAL and forwarding it
     411           59 :                 wal_or_reset = self.wal_stream.next() => {
     412           51 :                     let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below"));
     413              :                     let WalBytes {
     414           51 :                         wal,
     415           51 :                         wal_start_lsn,
     416           51 :                         wal_end_lsn,
     417           51 :                         available_wal_end_lsn,
     418           51 :                     } = match wal {
     419           51 :                         Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?,
     420              :                         None => {
     421              :                             // [`StreamingWalReader::next`] is an endless stream of WAL.
     422              :                             // It shouldn't ever finish unless it panicked or became internally
     423              :                             // inconsistent.
     424            0 :                             return Result::Err(InterpretedWalReaderError::WalStreamClosed);
     425              :                         }
     426              :                     };
     427              : 
     428           51 :                     self.state.write().unwrap().update_current_batch_wal_start(wal_start_lsn);
     429              : 
     430           51 :                     wal_decoder.feed_bytes(&wal);
     431              : 
     432              :                     // Deserialize and interpret WAL records from this batch of WAL.
     433              :                     // Interpreted records for each shard are collected separately.
     434           51 :                     let shard_ids = self.shard_senders.keys().copied().collect::<Vec<_>>();
     435           51 :                     let mut records_by_sender: HashMap<ShardSenderId, Vec<InterpretedWalRecord>> = HashMap::new();
     436           51 :                     let mut max_next_record_lsn = None;
     437           51 :                     let mut max_end_record_lsn = None;
     438          657 :                     while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()?
     439              :                     {
     440          606 :                         assert!(next_record_lsn.is_aligned());
     441          606 :                         max_next_record_lsn = Some(next_record_lsn);
     442          606 :                         max_end_record_lsn = Some(wal_decoder.lsn());
     443              : 
     444          606 :                         let interpreted = InterpretedWalRecord::from_bytes_filtered(
     445          606 :                             recdata,
     446          606 :                             &shard_ids,
     447          606 :                             next_record_lsn,
     448          606 :                             self.pg_version,
     449              :                         )
     450          606 :                         .with_context(|| "Failed to interpret WAL")?;
     451              : 
     452         1412 :                         for (shard, record) in interpreted {
     453              :                             // Shard zero needs to track the start LSN of the latest record
     454              :                             // in adition to the LSN of the next record to ingest. The former
     455              :                             // is included in basebackup persisted by the compute in WAL.
     456          806 :                             if !shard.is_shard_zero() && record.is_empty() {
     457          200 :                                 continue;
     458          606 :                             }
     459              : 
     460          606 :                             let mut states_iter = self.shard_senders
     461          606 :                                 .get(&shard)
     462          606 :                                 .expect("keys collected above")
     463          606 :                                 .iter()
     464         1002 :                                 .filter(|state| record.next_record_lsn > state.next_record_lsn)
     465          606 :                                 .peekable();
     466          996 :                             while let Some(state) = states_iter.next() {
     467          796 :                                 let shard_sender_id = ShardSenderId::new(shard, state.sender_id);
     468              : 
     469              :                                 // The most commont case is one sender per shard. Peek and break to avoid the
     470              :                                 // clone in that situation.
     471          796 :                                 if states_iter.peek().is_none() {
     472          406 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record);
     473          406 :                                     break;
     474          390 :                                 } else {
     475          390 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record.clone());
     476          390 :                                 }
     477              :                             }
     478              :                         }
     479              :                     }
     480              : 
     481           51 :                     let max_next_record_lsn = match max_next_record_lsn {
     482           41 :                         Some(lsn) => lsn,
     483              :                         None => {
     484           10 :                             continue;
     485              :                         }
     486              :                     };
     487              : 
     488              :                     // Update the current position such that new receivers can decide
     489              :                     // whether to attach to us or spawn a new WAL reader.
     490           41 :                     let batch_wal_start_lsn = {
     491           41 :                         let mut guard = self.state.write().unwrap();
     492           41 :                         guard.update_current_position(max_next_record_lsn);
     493           41 :                         guard.replace_current_batch_wal_start(max_end_record_lsn.unwrap())
     494              :                     };
     495              : 
     496              :                     // Send interpreted records downstream. Anything that has already been seen
     497              :                     // by a shard is filtered out.
     498           41 :                     let mut shard_senders_to_remove = Vec::new();
     499           96 :                     for (shard, states) in &mut self.shard_senders {
     500          136 :                         for state in states {
     501           81 :                             let shard_sender_id = ShardSenderId::new(*shard, state.sender_id);
     502              : 
     503           81 :                             let batch = if max_next_record_lsn > state.next_record_lsn {
     504              :                                 // This batch contains at least one record that this shard has not
     505              :                                 // seen yet.
     506           67 :                                 let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default();
     507              : 
     508           67 :                                 InterpretedWalRecords {
     509           67 :                                     records,
     510           67 :                                     next_record_lsn: max_next_record_lsn,
     511           67 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     512           67 :                                 }
     513           14 :                             } else if wal_end_lsn > state.next_record_lsn {
     514              :                                 // All the records in this batch were seen by the shard
     515              :                                 // However, the batch maps to a chunk of WAL that the
     516              :                                 // shard has not yet seen. Notify it of the start LSN
     517              :                                 // of the PG WAL chunk such that it doesn't look like a gap.
     518            0 :                                 InterpretedWalRecords {
     519            0 :                                     records: Vec::default(),
     520            0 :                                     next_record_lsn: state.next_record_lsn,
     521            0 :                                     raw_wal_start_lsn: Some(batch_wal_start_lsn),
     522            0 :                                 }
     523              :                             } else {
     524              :                                 // The shard has seen this chunk of WAL before. Skip it.
     525           14 :                                 continue;
     526              :                             };
     527              : 
     528           67 :                             let res = state.tx.send(Batch {
     529           67 :                                 wal_end_lsn,
     530           67 :                                 available_wal_end_lsn,
     531           67 :                                 records: batch,
     532           67 :                             }).await;
     533              : 
     534           67 :                             if res.is_err() {
     535            0 :                                 shard_senders_to_remove.push(shard_sender_id);
     536           67 :                             } else {
     537           67 :                                 state.next_record_lsn = std::cmp::max(state.next_record_lsn, max_next_record_lsn);
     538           67 :                             }
     539              :                         }
     540              :                     }
     541              : 
     542              :                     // Clean up any shard senders that have dropped out.
     543              :                     // This is inefficient, but such events are rare (connection to PS termination)
     544              :                     // and the number of subscriptions on the same shards very small (only one
     545              :                     // for the steady state).
     546           41 :                     for to_remove in shard_senders_to_remove {
     547            0 :                         let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above");
     548            0 :                         if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) {
     549            0 :                             shard_senders.remove(idx);
     550            0 :                             tracing::info!("Removed shard sender {}", to_remove);
     551            0 :                         }
     552              : 
     553            0 :                         if shard_senders.is_empty() {
     554            0 :                             self.shard_senders.remove(&to_remove.shard());
     555            0 :                         }
     556              :                     }
     557              :                 },
     558              :                 // Listen for new shards that want to attach to this reader.
     559              :                 // If the reader is not running as a task, then this is not supported
     560              :                 // (see the pending branch below).
     561           59 :                 notification = match self.shard_notification_rx.as_mut() {
     562           59 :                         Some(rx) => Either::Left(rx.recv()),
     563            0 :                         None => Either::Right(std::future::pending())
     564              :                     } => {
     565            4 :                     if let Some(n) = notification {
     566            4 :                         let AttachShardNotification { shard_id, sender, start_pos } = n;
     567              : 
     568              :                         // Update internal and external state, then reset the WAL stream
     569              :                         // if required.
     570            4 :                         let senders = self.shard_senders.entry(shard_id).or_default();
     571              : 
     572              :                         // Clean up any shard senders that have dropped out before adding the new
     573              :                         // one. This avoids a build up of dead senders.
     574            4 :                         senders.retain(|sender| {
     575            3 :                             let closed = sender.tx.is_closed();
     576              : 
     577            3 :                             if closed {
     578            0 :                                 let sender_id = ShardSenderId::new(shard_id, sender.sender_id);
     579            0 :                                 tracing::info!("Removed shard sender {}", sender_id);
     580            3 :                             }
     581              : 
     582            3 :                             !closed
     583            3 :                         });
     584              : 
     585            4 :                         let new_sender_id = match senders.last() {
     586            2 :                             Some(sender) => sender.sender_id.next(),
     587            2 :                             None => SenderId::first()
     588              :                         };
     589              : 
     590            4 :                         senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos});
     591              : 
     592              :                         // If the shard is subscribing below the current position the we need
     593              :                         // to update the cursor that tracks where we are at in the WAL
     594              :                         // ([`Self::state`]) and reset the WAL stream itself
     595              :                         // (`[Self::wal_stream`]). This must be done atomically from the POV of
     596              :                         // anything outside the select statement.
     597            4 :                         let position_reset = self.state.write().unwrap().maybe_reset(start_pos);
     598            4 :                         match position_reset {
     599            3 :                             CurrentPositionUpdate::Reset { from: _, to } => {
     600            3 :                                 self.wal_stream.reset(to).await;
     601            3 :                                 wal_decoder = WalStreamDecoder::new(to, self.pg_version);
     602              :                             },
     603            1 :                             CurrentPositionUpdate::NotReset(_) => {}
     604              :                         };
     605              : 
     606            4 :                         tracing::info!(
     607            0 :                             "Added shard sender {} with start_pos={} previous_pos={} current_pos={}",
     608            0 :                             ShardSenderId::new(shard_id, new_sender_id),
     609              :                             start_pos,
     610            0 :                             position_reset.previous_position(),
     611            0 :                             position_reset.current_position(),
     612              :                         );
     613            0 :                     }
     614              :                 }
     615              :             }
     616              :         }
     617            0 :     }
     618              : 
     619              :     #[cfg(test)]
     620            1 :     fn state(&self) -> Arc<std::sync::RwLock<InterpretedWalReaderState>> {
     621            1 :         self.state.clone()
     622            1 :     }
     623              : }
     624              : 
     625              : impl InterpretedWalReaderHandle {
     626              :     /// Fan-out the reader by attaching a new shard to it
     627            3 :     pub(crate) fn fanout(
     628            3 :         &self,
     629            3 :         shard_id: ShardIdentity,
     630            3 :         sender: tokio::sync::mpsc::Sender<Batch>,
     631            3 :         start_pos: Lsn,
     632            3 :     ) -> Result<(), SendError<AttachShardNotification>> {
     633            3 :         self.shard_notification_tx.send(AttachShardNotification {
     634            3 :             shard_id,
     635            3 :             sender,
     636            3 :             start_pos,
     637            3 :         })
     638            3 :     }
     639              : 
     640              :     /// Get the current WAL position of the reader
     641            6 :     pub(crate) fn current_position(&self) -> Option<Lsn> {
     642            6 :         self.state.read().unwrap().current_position()
     643            6 :     }
     644              : 
     645            6 :     pub(crate) fn abort(&self) {
     646            6 :         self.join_handle.abort()
     647            6 :     }
     648              : }
     649              : 
     650              : impl Drop for InterpretedWalReaderHandle {
     651            3 :     fn drop(&mut self) {
     652            3 :         tracing::info!("Aborting interpreted wal reader");
     653            3 :         self.abort()
     654            3 :     }
     655              : }
     656              : 
     657              : pub(crate) struct InterpretedWalSender<'a, IO> {
     658              :     pub(crate) format: InterpretedFormat,
     659              :     pub(crate) compression: Option<Compression>,
     660              :     pub(crate) appname: Option<String>,
     661              : 
     662              :     pub(crate) tli: WalResidentTimeline,
     663              :     pub(crate) start_lsn: Lsn,
     664              : 
     665              :     pub(crate) pgb: &'a mut PostgresBackend<IO>,
     666              :     pub(crate) end_watch_view: EndWatchView,
     667              :     pub(crate) wal_sender_guard: Arc<WalSenderGuard>,
     668              :     pub(crate) rx: tokio::sync::mpsc::Receiver<Batch>,
     669              : }
     670              : 
     671              : impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
     672              :     /// Send interpreted WAL records over the network.
     673              :     /// Also manages keep-alives if nothing was sent for a while.
     674            0 :     pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> {
     675            0 :         let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
     676            0 :         keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
     677            0 :         keepalive_ticker.reset();
     678              : 
     679            0 :         let mut wal_position = self.start_lsn;
     680              : 
     681              :         loop {
     682            0 :             tokio::select! {
     683            0 :                 batch = self.rx.recv() => {
     684            0 :                     let batch = match batch {
     685            0 :                         Some(b) => b,
     686              :                         None => {
     687            0 :                             return Result::Err(
     688            0 :                                 CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early"))
     689            0 :                             );
     690              :                         }
     691              :                     };
     692              : 
     693            0 :                     wal_position = batch.wal_end_lsn;
     694              : 
     695            0 :                     let buf = batch
     696            0 :                         .records
     697            0 :                         .to_wire(self.format, self.compression)
     698            0 :                         .await
     699            0 :                         .with_context(|| "Failed to serialize interpreted WAL")
     700            0 :                         .map_err(CopyStreamHandlerEnd::from)?;
     701              : 
     702              :                     // Reset the keep alive ticker since we are sending something
     703              :                     // over the wire now.
     704            0 :                     keepalive_ticker.reset();
     705              : 
     706            0 :                     self.pgb
     707            0 :                         .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
     708            0 :                             streaming_lsn: batch.wal_end_lsn.0,
     709            0 :                             commit_lsn: batch.available_wal_end_lsn.0,
     710            0 :                             data: &buf,
     711            0 :                         })).await?;
     712              :                 }
     713              :                 // Send a periodic keep alive when the connection has been idle for a while.
     714              :                 // Since we've been idle, also check if we can stop streaming.
     715            0 :                 _ = keepalive_ticker.tick() => {
     716            0 :                     if let Some(remote_consistent_lsn) = self.wal_sender_guard
     717            0 :                         .walsenders()
     718            0 :                         .get_ws_remote_consistent_lsn(self.wal_sender_guard.id())
     719              :                     {
     720            0 :                         if self.tli.should_walsender_stop(remote_consistent_lsn).await {
     721              :                             // Stop streaming if the receivers are caught up and
     722              :                             // there's no active compute. This causes the loop in
     723              :                             // [`crate::send_interpreted_wal::InterpretedWalSender::run`]
     724              :                             // to exit and terminate the WAL stream.
     725            0 :                             break;
     726            0 :                         }
     727            0 :                     }
     728              : 
     729            0 :                     self.pgb
     730            0 :                         .write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
     731            0 :                             wal_end: self.end_watch_view.get().0,
     732            0 :                             timestamp: get_current_timestamp(),
     733            0 :                             request_reply: true,
     734            0 :                         }))
     735            0 :                         .await?;
     736              :                 },
     737              :             }
     738              :         }
     739              : 
     740            0 :         Err(CopyStreamHandlerEnd::ServerInitiated(format!(
     741            0 :             "ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
     742            0 :             self.appname, wal_position,
     743            0 :         )))
     744            0 :     }
     745              : }
     746              : #[cfg(test)]
     747              : mod tests {
     748              :     use std::collections::HashMap;
     749              :     use std::str::FromStr;
     750              :     use std::time::Duration;
     751              : 
     752              :     use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardIdentity};
     753              :     use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion};
     754              :     use tokio::sync::mpsc::error::TryRecvError;
     755              :     use utils::id::{NodeId, TenantTimelineId};
     756              :     use utils::lsn::Lsn;
     757              :     use utils::shard::{ShardCount, ShardNumber};
     758              : 
     759              :     use crate::send_interpreted_wal::{AttachShardNotification, Batch, InterpretedWalReader};
     760              :     use crate::test_utils::Env;
     761              :     use crate::wal_reader_stream::StreamingWalReader;
     762              : 
     763              :     #[tokio::test]
     764            1 :     async fn test_interpreted_wal_reader_fanout() {
     765            1 :         let _ = env_logger::builder().is_test(true).try_init();
     766              : 
     767              :         const SIZE: usize = 8 * 1024;
     768              :         const MSG_COUNT: usize = 200;
     769              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
     770              :         const SHARD_COUNT: u8 = 2;
     771              : 
     772            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     773            1 :         let env = Env::new(true).unwrap();
     774            1 :         let tli = env
     775            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     776            1 :             .await
     777            1 :             .unwrap();
     778              : 
     779            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     780            1 :         let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, c"neon-file:", None)
     781            1 :             .await
     782            1 :             .unwrap();
     783            1 :         let end_pos = end_watch.get();
     784              : 
     785            1 :         tracing::info!("Doing first round of reads ...");
     786              : 
     787            1 :         let streaming_wal_reader = StreamingWalReader::new(
     788            1 :             resident_tli,
     789            1 :             None,
     790            1 :             start_lsn,
     791            1 :             end_pos,
     792            1 :             end_watch,
     793              :             MAX_SEND_SIZE,
     794              :         );
     795              : 
     796            1 :         let shard_0 =
     797            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     798            1 :                 .unwrap();
     799              : 
     800            1 :         let shard_1 =
     801            1 :             ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     802            1 :                 .unwrap();
     803              : 
     804            1 :         let mut shards = HashMap::new();
     805              : 
     806            3 :         for shard_number in 0..SHARD_COUNT {
     807            2 :             let shard_id = ShardIdentity::new(
     808            2 :                 ShardNumber(shard_number),
     809            2 :                 ShardCount(SHARD_COUNT),
     810            2 :                 DEFAULT_STRIPE_SIZE,
     811            2 :             )
     812            2 :             .unwrap();
     813            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     814            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
     815            2 :         }
     816              : 
     817            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
     818            1 :         let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap();
     819              : 
     820            1 :         let handle = InterpretedWalReader::spawn(
     821            1 :             streaming_wal_reader,
     822            1 :             start_lsn,
     823            1 :             shard_0_tx,
     824            1 :             shard_0,
     825              :             PG_VERSION,
     826            1 :             &Some("pageserver".to_string()),
     827              :         );
     828              : 
     829            1 :         tracing::info!("Reading all WAL with only shard 0 attached ...");
     830              : 
     831            1 :         let mut shard_0_interpreted_records = Vec::new();
     832           13 :         while let Some(batch) = shard_0_rx.recv().await {
     833           13 :             shard_0_interpreted_records.push(batch.records);
     834           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     835            1 :                 break;
     836           12 :             }
     837              :         }
     838              : 
     839            1 :         let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap();
     840            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
     841              : 
     842            1 :         tracing::info!("Attaching shard 1 to the reader at start of WAL");
     843            1 :         handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap();
     844              : 
     845            1 :         tracing::info!("Reading all WAL with shard 0 and shard 1 attached ...");
     846              : 
     847            1 :         let mut shard_1_interpreted_records = Vec::new();
     848           13 :         while let Some(batch) = shard_1_rx.recv().await {
     849           13 :             shard_1_interpreted_records.push(batch.records);
     850           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     851            1 :                 break;
     852           12 :             }
     853              :         }
     854              : 
     855              :         // This test uses logical messages. Those only go to shard 0. Check that the
     856              :         // filtering worked and shard 1 did not get any.
     857            1 :         assert!(
     858            1 :             shard_1_interpreted_records
     859            1 :                 .iter()
     860           13 :                 .all(|recs| recs.records.is_empty())
     861              :         );
     862              : 
     863              :         // Shard 0 should not receive anything more since the reader is
     864              :         // going through wal that it has already processed.
     865            1 :         let res = shard_0_rx.try_recv();
     866            1 :         if let Ok(ref ok) = res {
     867            0 :             tracing::error!(
     868            0 :                 "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}",
     869              :                 ok.wal_end_lsn,
     870              :                 ok.available_wal_end_lsn
     871              :             );
     872            1 :         }
     873            1 :         assert!(matches!(res, Err(TryRecvError::Empty)));
     874              : 
     875              :         // Check that the next records lsns received by the two shards match up.
     876            1 :         let shard_0_next_lsns = shard_0_interpreted_records
     877            1 :             .iter()
     878            1 :             .map(|recs| recs.next_record_lsn)
     879            1 :             .collect::<Vec<_>>();
     880            1 :         let shard_1_next_lsns = shard_1_interpreted_records
     881            1 :             .iter()
     882            1 :             .map(|recs| recs.next_record_lsn)
     883            1 :             .collect::<Vec<_>>();
     884            1 :         assert_eq!(shard_0_next_lsns, shard_1_next_lsns);
     885              : 
     886            1 :         handle.abort();
     887            1 :         let mut done = false;
     888            2 :         for _ in 0..5 {
     889            2 :             if handle.current_position().is_none() {
     890            1 :                 done = true;
     891            1 :                 break;
     892            1 :             }
     893            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     894            1 :         }
     895            1 : 
     896            1 :         assert!(done);
     897            1 :     }
     898              : 
     899              :     #[tokio::test]
     900            1 :     async fn test_interpreted_wal_reader_same_shard_fanout() {
     901            1 :         let _ = env_logger::builder().is_test(true).try_init();
     902              : 
     903              :         const SIZE: usize = 8 * 1024;
     904              :         const MSG_COUNT: usize = 200;
     905              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
     906              :         const SHARD_COUNT: u8 = 2;
     907              : 
     908            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     909            1 :         let env = Env::new(true).unwrap();
     910            1 :         let tli = env
     911            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     912            1 :             .await
     913            1 :             .unwrap();
     914              : 
     915            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     916            1 :         let mut next_record_lsns = Vec::default();
     917            1 :         let end_watch = Env::write_wal(
     918            1 :             tli,
     919            1 :             start_lsn,
     920            1 :             SIZE,
     921            1 :             MSG_COUNT,
     922            1 :             c"neon-file:",
     923            1 :             Some(&mut next_record_lsns),
     924            1 :         )
     925            1 :         .await
     926            1 :         .unwrap();
     927            1 :         let end_pos = end_watch.get();
     928              : 
     929            1 :         let streaming_wal_reader = StreamingWalReader::new(
     930            1 :             resident_tli,
     931            1 :             None,
     932            1 :             start_lsn,
     933            1 :             end_pos,
     934            1 :             end_watch,
     935              :             MAX_SEND_SIZE,
     936              :         );
     937              : 
     938            1 :         let shard_0 =
     939            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
     940            1 :                 .unwrap();
     941              : 
     942              :         struct Sender {
     943              :             tx: Option<tokio::sync::mpsc::Sender<Batch>>,
     944              :             rx: tokio::sync::mpsc::Receiver<Batch>,
     945              :             shard: ShardIdentity,
     946              :             start_lsn: Lsn,
     947              :             received_next_record_lsns: Vec<Lsn>,
     948              :         }
     949              : 
     950              :         impl Sender {
     951            3 :             fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self {
     952            3 :                 let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     953            3 :                 Self {
     954            3 :                     tx: Some(tx),
     955            3 :                     rx,
     956            3 :                     shard,
     957            3 :                     start_lsn,
     958            3 :                     received_next_record_lsns: Vec::default(),
     959            3 :                 }
     960            3 :             }
     961              :         }
     962              : 
     963            1 :         assert!(next_record_lsns.len() > 7);
     964            1 :         let start_lsns = vec![
     965            1 :             next_record_lsns[5],
     966            1 :             next_record_lsns[1],
     967            1 :             next_record_lsns[3],
     968              :         ];
     969            1 :         let mut senders = start_lsns
     970            1 :             .into_iter()
     971            3 :             .map(|lsn| Sender::new(lsn, shard_0))
     972            1 :             .collect::<Vec<_>>();
     973              : 
     974            1 :         let first_sender = senders.first_mut().unwrap();
     975            1 :         let handle = InterpretedWalReader::spawn(
     976            1 :             streaming_wal_reader,
     977            1 :             first_sender.start_lsn,
     978            1 :             first_sender.tx.take().unwrap(),
     979            1 :             first_sender.shard,
     980              :             PG_VERSION,
     981            1 :             &Some("pageserver".to_string()),
     982              :         );
     983              : 
     984            2 :         for sender in senders.iter_mut().skip(1) {
     985            2 :             handle
     986            2 :                 .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn)
     987            2 :                 .unwrap();
     988            2 :         }
     989              : 
     990            3 :         for sender in senders.iter_mut() {
     991              :             loop {
     992           39 :                 let batch = sender.rx.recv().await.unwrap();
     993           39 :                 tracing::info!(
     994            0 :                     "Sender with start_lsn={} received batch ending at {} with {} records",
     995              :                     sender.start_lsn,
     996              :                     batch.wal_end_lsn,
     997            0 :                     batch.records.records.len()
     998              :                 );
     999              : 
    1000          627 :                 for rec in batch.records.records {
    1001          588 :                     sender.received_next_record_lsns.push(rec.next_record_lsn);
    1002          588 :                 }
    1003              : 
    1004           39 :                 if batch.wal_end_lsn == batch.available_wal_end_lsn {
    1005            3 :                     break;
    1006           36 :                 }
    1007              :             }
    1008              :         }
    1009              : 
    1010            1 :         handle.abort();
    1011            1 :         let mut done = false;
    1012            2 :         for _ in 0..5 {
    1013            2 :             if handle.current_position().is_none() {
    1014            1 :                 done = true;
    1015            1 :                 break;
    1016            1 :             }
    1017            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
    1018              :         }
    1019              : 
    1020            1 :         assert!(done);
    1021              : 
    1022            4 :         for sender in senders {
    1023            3 :             tracing::info!(
    1024            1 :                 "Validating records received by sender with start_lsn={}",
    1025            1 :                 sender.start_lsn
    1026            1 :             );
    1027            1 : 
    1028            3 :             assert!(sender.received_next_record_lsns.is_sorted());
    1029            3 :             let expected = next_record_lsns
    1030            3 :                 .iter()
    1031          600 :                 .filter(|lsn| **lsn > sender.start_lsn)
    1032            3 :                 .copied()
    1033            3 :                 .collect::<Vec<_>>();
    1034            3 :             assert_eq!(sender.received_next_record_lsns, expected);
    1035            1 :         }
    1036            1 :     }
    1037              : 
    1038              :     #[tokio::test]
    1039            1 :     async fn test_batch_start_tracking_on_reset() {
    1040              :         // When the WAL stream is reset to an older LSN,
    1041              :         // the current batch start LSN should be invalidated.
    1042              :         // This test constructs such a scenario:
    1043              :         // 1. Shard 0 is reading somewhere ahead
    1044              :         // 2. Reader reads some WAL, but does not decode a full record (partial read)
    1045              :         // 3. Shard 1 attaches to the reader and resets it to an older LSN
    1046              :         // 4. Shard 1 should get the correct batch WAL start LSN
    1047            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1048              : 
    1049              :         const SIZE: usize = 64 * 1024;
    1050              :         const MSG_COUNT: usize = 10;
    1051              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
    1052              :         const SHARD_COUNT: u8 = 2;
    1053              :         const WAL_READER_BATCH_SIZE: usize = 8192;
    1054              : 
    1055            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1056            1 :         let env = Env::new(true).unwrap();
    1057            1 :         let mut next_record_lsns = Vec::default();
    1058            1 :         let tli = env
    1059            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1060            1 :             .await
    1061            1 :             .unwrap();
    1062              : 
    1063            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1064            1 :         let end_watch = Env::write_wal(
    1065            1 :             tli,
    1066            1 :             start_lsn,
    1067            1 :             SIZE,
    1068            1 :             MSG_COUNT,
    1069            1 :             c"neon-file:",
    1070            1 :             Some(&mut next_record_lsns),
    1071            1 :         )
    1072            1 :         .await
    1073            1 :         .unwrap();
    1074              : 
    1075            1 :         assert!(next_record_lsns.len() > 3);
    1076            1 :         let shard_0_start_lsn = next_record_lsns[3];
    1077              : 
    1078            1 :         let end_pos = end_watch.get();
    1079              : 
    1080            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1081            1 :             resident_tli,
    1082            1 :             None,
    1083            1 :             shard_0_start_lsn,
    1084            1 :             end_pos,
    1085            1 :             end_watch,
    1086              :             WAL_READER_BATCH_SIZE,
    1087              :         );
    1088              : 
    1089            1 :         let shard_0 =
    1090            1 :             ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
    1091            1 :                 .unwrap();
    1092              : 
    1093            1 :         let shard_1 =
    1094            1 :             ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE)
    1095            1 :                 .unwrap();
    1096              : 
    1097            1 :         let mut shards = HashMap::new();
    1098              : 
    1099            3 :         for shard_number in 0..SHARD_COUNT {
    1100            2 :             let shard_id = ShardIdentity::new(
    1101            2 :                 ShardNumber(shard_number),
    1102            2 :                 ShardCount(SHARD_COUNT),
    1103            2 :                 DEFAULT_STRIPE_SIZE,
    1104            2 :             )
    1105            2 :             .unwrap();
    1106            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1107            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
    1108            2 :         }
    1109              : 
    1110            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
    1111              : 
    1112            1 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
    1113              : 
    1114            1 :         let reader = InterpretedWalReader::new(
    1115            1 :             streaming_wal_reader,
    1116            1 :             shard_0_start_lsn,
    1117            1 :             shard_0_tx,
    1118            1 :             shard_0,
    1119              :             PG_VERSION,
    1120            1 :             Some(shard_notification_rx),
    1121              :         );
    1122              : 
    1123            1 :         let reader_state = reader.state();
    1124            1 :         let mut reader_fut = std::pin::pin!(reader.run(shard_0_start_lsn, &None));
    1125              :         loop {
    1126           64 :             let poll = futures::poll!(reader_fut.as_mut());
    1127           64 :             assert!(poll.is_pending());
    1128              : 
    1129           64 :             let guard = reader_state.read().unwrap();
    1130           64 :             if guard.current_batch_wal_start().is_some() {
    1131            1 :                 break;
    1132           63 :             }
    1133              :         }
    1134              : 
    1135            1 :         shard_notification_tx
    1136            1 :             .send(AttachShardNotification {
    1137            1 :                 shard_id: shard_1,
    1138            1 :                 sender: shards.get_mut(&shard_1).unwrap().0.take().unwrap(),
    1139            1 :                 start_pos: start_lsn,
    1140            1 :             })
    1141            1 :             .unwrap();
    1142              : 
    1143            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
    1144            1 :         loop {
    1145           45 :             let poll = futures::poll!(reader_fut.as_mut());
    1146           45 :             assert!(poll.is_pending());
    1147            1 : 
    1148           45 :             let try_recv_res = shard_1_rx.try_recv();
    1149           44 :             match try_recv_res {
    1150            1 :                 Ok(batch) => {
    1151            1 :                     assert_eq!(batch.records.raw_wal_start_lsn.unwrap(), start_lsn);
    1152            1 :                     break;
    1153            1 :                 }
    1154           44 :                 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
    1155            1 :                 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
    1156            1 :                     unreachable!();
    1157            1 :                 }
    1158            1 :             }
    1159            1 :         }
    1160            1 :     }
    1161              : 
    1162              :     #[tokio::test]
    1163            1 :     async fn test_shard_zero_does_not_skip_empty_records() {
    1164            1 :         let _ = env_logger::builder().is_test(true).try_init();
    1165              : 
    1166              :         const SIZE: usize = 8 * 1024;
    1167              :         const MSG_COUNT: usize = 10;
    1168              :         const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
    1169              : 
    1170            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
    1171            1 :         let env = Env::new(true).unwrap();
    1172            1 :         let tli = env
    1173            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
    1174            1 :             .await
    1175            1 :             .unwrap();
    1176              : 
    1177            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
    1178            1 :         let mut next_record_lsns = Vec::new();
    1179            1 :         let end_watch = Env::write_wal(
    1180            1 :             tli,
    1181            1 :             start_lsn,
    1182            1 :             SIZE,
    1183            1 :             MSG_COUNT,
    1184            1 :             // This is a logical message prefix that is not persisted to key value storage.
    1185            1 :             // We use it in order to validate that shard zero receives emtpy interpreted records.
    1186            1 :             c"test:",
    1187            1 :             Some(&mut next_record_lsns),
    1188            1 :         )
    1189            1 :         .await
    1190            1 :         .unwrap();
    1191            1 :         let end_pos = end_watch.get();
    1192              : 
    1193            1 :         let streaming_wal_reader = StreamingWalReader::new(
    1194            1 :             resident_tli,
    1195            1 :             None,
    1196            1 :             start_lsn,
    1197            1 :             end_pos,
    1198            1 :             end_watch,
    1199              :             MAX_SEND_SIZE,
    1200              :         );
    1201              : 
    1202            1 :         let shard = ShardIdentity::unsharded();
    1203            1 :         let (records_tx, mut records_rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
    1204              : 
    1205            1 :         let handle = InterpretedWalReader::spawn(
    1206            1 :             streaming_wal_reader,
    1207            1 :             start_lsn,
    1208            1 :             records_tx,
    1209            1 :             shard,
    1210              :             PG_VERSION,
    1211            1 :             &Some("pageserver".to_string()),
    1212              :         );
    1213              : 
    1214            1 :         let mut interpreted_records = Vec::new();
    1215            1 :         while let Some(batch) = records_rx.recv().await {
    1216            1 :             interpreted_records.push(batch.records);
    1217            1 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
    1218            1 :                 break;
    1219            0 :             }
    1220              :         }
    1221              : 
    1222            1 :         let received_next_record_lsns = interpreted_records
    1223            1 :             .into_iter()
    1224            1 :             .flat_map(|b| b.records)
    1225            1 :             .map(|rec| rec.next_record_lsn)
    1226            1 :             .collect::<Vec<_>>();
    1227              : 
    1228              :         // By default this also includes the start LSN. Trim it since it shouldn't be received.
    1229            1 :         let next_record_lsns = next_record_lsns.into_iter().skip(1).collect::<Vec<_>>();
    1230              : 
    1231            1 :         assert_eq!(received_next_record_lsns, next_record_lsns);
    1232              : 
    1233            1 :         handle.abort();
    1234            1 :         let mut done = false;
    1235            2 :         for _ in 0..5 {
    1236            2 :             if handle.current_position().is_none() {
    1237            1 :                 done = true;
    1238            1 :                 break;
    1239            1 :             }
    1240            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
    1241            1 :         }
    1242            1 : 
    1243            1 :         assert!(done);
    1244            1 :     }
    1245              : }
        

Generated by: LCOV version 2.1-beta