LCOV - code coverage report
Current view: top level - safekeeper/src - send_interpreted_wal.rs (source / functions) Coverage Total Hit
Test: 07bee600374ccd486c69370d0972d9035964fe68.info Lines: 78.1 % 599 468
Test Date: 2025-02-20 13:11:02 Functions: 62.5 % 40 25

            Line data    Source code
       1              : use std::collections::HashMap;
       2              : use std::fmt::Display;
       3              : use std::sync::Arc;
       4              : use std::time::Duration;
       5              : 
       6              : use anyhow::{anyhow, Context};
       7              : use futures::future::Either;
       8              : use futures::StreamExt;
       9              : use pageserver_api::shard::ShardIdentity;
      10              : use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend};
      11              : use postgres_ffi::waldecoder::WalDecodeError;
      12              : use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder};
      13              : use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
      14              : use tokio::io::{AsyncRead, AsyncWrite};
      15              : use tokio::sync::mpsc::error::SendError;
      16              : use tokio::task::JoinHandle;
      17              : use tokio::time::MissedTickBehavior;
      18              : use tracing::{error, info, info_span, Instrument};
      19              : use utils::critical;
      20              : use utils::lsn::Lsn;
      21              : use utils::postgres_client::Compression;
      22              : use utils::postgres_client::InterpretedFormat;
      23              : use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
      24              : use wal_decoder::wire_format::ToWireFormat;
      25              : 
      26              : use crate::metrics::WAL_READERS;
      27              : use crate::send_wal::{EndWatchView, WalSenderGuard};
      28              : use crate::timeline::WalResidentTimeline;
      29              : use crate::wal_reader_stream::{StreamingWalReader, WalBytes};
      30              : 
      31              : /// Identifier used to differentiate between senders of the same
      32              : /// shard.
      33              : ///
      34              : /// In the steady state there's only one, but two pageservers may
      35              : /// temporarily have the same shard attached and attempt to ingest
      36              : /// WAL for it. See also [`ShardSenderId`].
      37              : #[derive(Hash, Eq, PartialEq, Copy, Clone)]
      38              : struct SenderId(u8);
      39              : 
      40              : impl SenderId {
      41            3 :     fn first() -> Self {
      42            3 :         SenderId(0)
      43            3 :     }
      44              : 
      45            2 :     fn next(&self) -> Self {
      46            2 :         SenderId(self.0.checked_add(1).expect("few senders"))
      47            2 :     }
      48              : }
      49              : 
      50              : #[derive(Hash, Eq, PartialEq)]
      51              : struct ShardSenderId {
      52              :     shard: ShardIdentity,
      53              :     sender_id: SenderId,
      54              : }
      55              : 
      56              : impl Display for ShardSenderId {
      57            0 :     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
      58            0 :         write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug())
      59            0 :     }
      60              : }
      61              : 
      62              : impl ShardSenderId {
      63          852 :     fn new(shard: ShardIdentity, sender_id: SenderId) -> Self {
      64          852 :         ShardSenderId { shard, sender_id }
      65          852 :     }
      66              : 
      67            0 :     fn shard(&self) -> ShardIdentity {
      68            0 :         self.shard
      69            0 :     }
      70              : }
      71              : 
      72              : /// Shard-aware fan-out interpreted record reader.
      73              : /// Reads WAL from disk, decodes it, intepretets it, and sends
      74              : /// it to any [`InterpretedWalSender`] connected to it.
      75              : /// Each [`InterpretedWalSender`] corresponds to one shard
      76              : /// and gets interpreted records concerning that shard only.
      77              : pub(crate) struct InterpretedWalReader {
      78              :     wal_stream: StreamingWalReader,
      79              :     shard_senders: HashMap<ShardIdentity, smallvec::SmallVec<[ShardSenderState; 1]>>,
      80              :     shard_notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>>,
      81              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      82              :     pg_version: u32,
      83              : }
      84              : 
      85              : /// A handle for [`InterpretedWalReader`] which allows for interacting with it
      86              : /// when it runs as a separate tokio task.
      87              : #[derive(Debug)]
      88              : pub(crate) struct InterpretedWalReaderHandle {
      89              :     join_handle: JoinHandle<Result<(), InterpretedWalReaderError>>,
      90              :     state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
      91              :     shard_notification_tx: tokio::sync::mpsc::UnboundedSender<AttachShardNotification>,
      92              : }
      93              : 
      94              : struct ShardSenderState {
      95              :     sender_id: SenderId,
      96              :     tx: tokio::sync::mpsc::Sender<Batch>,
      97              :     next_record_lsn: Lsn,
      98              : }
      99              : 
     100              : /// State of [`InterpretedWalReader`] visible outside of the task running it.
     101              : #[derive(Debug)]
     102              : pub(crate) enum InterpretedWalReaderState {
     103              :     Running { current_position: Lsn },
     104              :     Done,
     105              : }
     106              : 
     107              : pub(crate) struct Batch {
     108              :     wal_end_lsn: Lsn,
     109              :     available_wal_end_lsn: Lsn,
     110              :     records: InterpretedWalRecords,
     111              : }
     112              : 
     113              : #[derive(thiserror::Error, Debug)]
     114              : pub enum InterpretedWalReaderError {
     115              :     /// Handler initiates the end of streaming.
     116              :     #[error("decode error: {0}")]
     117              :     Decode(#[from] WalDecodeError),
     118              :     #[error("read or interpret error: {0}")]
     119              :     ReadOrInterpret(#[from] anyhow::Error),
     120              :     #[error("wal stream closed")]
     121              :     WalStreamClosed,
     122              : }
     123              : 
     124              : enum CurrentPositionUpdate {
     125              :     Reset(Lsn),
     126              :     NotReset(Lsn),
     127              : }
     128              : 
     129              : impl CurrentPositionUpdate {
     130            0 :     fn current_position(&self) -> Lsn {
     131            0 :         match self {
     132            0 :             CurrentPositionUpdate::Reset(lsn) => *lsn,
     133            0 :             CurrentPositionUpdate::NotReset(lsn) => *lsn,
     134              :         }
     135            0 :     }
     136              : }
     137              : 
     138              : impl InterpretedWalReaderState {
     139            4 :     fn current_position(&self) -> Option<Lsn> {
     140            4 :         match self {
     141              :             InterpretedWalReaderState::Running {
     142            2 :                 current_position, ..
     143            2 :             } => Some(*current_position),
     144            2 :             InterpretedWalReaderState::Done => None,
     145              :         }
     146            4 :     }
     147              : 
     148              :     // Reset the current position of the WAL reader if the requested starting position
     149              :     // of the new shard is smaller than the current value.
     150            3 :     fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate {
     151            3 :         match self {
     152              :             InterpretedWalReaderState::Running {
     153            3 :                 current_position, ..
     154            3 :             } => {
     155            3 :                 if new_shard_start_pos < *current_position {
     156            2 :                     *current_position = new_shard_start_pos;
     157            2 :                     CurrentPositionUpdate::Reset(*current_position)
     158              :                 } else {
     159            1 :                     CurrentPositionUpdate::NotReset(*current_position)
     160              :                 }
     161              :             }
     162              :             InterpretedWalReaderState::Done => {
     163            0 :                 panic!("maybe_reset called on finished reader")
     164              :             }
     165              :         }
     166            3 :     }
     167              : }
     168              : 
     169              : pub(crate) struct AttachShardNotification {
     170              :     shard_id: ShardIdentity,
     171              :     sender: tokio::sync::mpsc::Sender<Batch>,
     172              :     start_pos: Lsn,
     173              : }
     174              : 
     175              : impl InterpretedWalReader {
     176              :     /// Spawn the reader in a separate tokio task and return a handle
     177            2 :     pub(crate) fn spawn(
     178            2 :         wal_stream: StreamingWalReader,
     179            2 :         start_pos: Lsn,
     180            2 :         tx: tokio::sync::mpsc::Sender<Batch>,
     181            2 :         shard: ShardIdentity,
     182            2 :         pg_version: u32,
     183            2 :         appname: &Option<String>,
     184            2 :     ) -> InterpretedWalReaderHandle {
     185            2 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     186            2 :             current_position: start_pos,
     187            2 :         }));
     188            2 : 
     189            2 :         let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
     190              : 
     191            2 :         let reader = InterpretedWalReader {
     192            2 :             wal_stream,
     193            2 :             shard_senders: HashMap::from([(
     194            2 :                 shard,
     195            2 :                 smallvec::smallvec![ShardSenderState {
     196            0 :                     sender_id: SenderId::first(),
     197            0 :                     tx,
     198            0 :                     next_record_lsn: start_pos,
     199            0 :                 }],
     200              :             )]),
     201            2 :             shard_notification_rx: Some(shard_notification_rx),
     202            2 :             state: state.clone(),
     203            2 :             pg_version,
     204            2 :         };
     205            2 : 
     206            2 :         let metric = WAL_READERS
     207            2 :             .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")])
     208            2 :             .unwrap();
     209              : 
     210            2 :         let join_handle = tokio::task::spawn(
     211            2 :             async move {
     212            2 :                 metric.inc();
     213            2 :                 scopeguard::defer! {
     214            2 :                     metric.dec();
     215            2 :                 }
     216            2 : 
     217            2 :                 reader
     218            2 :                     .run_impl(start_pos)
     219            2 :                     .await
     220            0 :                     .inspect_err(|err| critical!("failed to read WAL record: {err:?}"))
     221            0 :             }
     222            2 :             .instrument(info_span!("interpreted wal reader")),
     223              :         );
     224              : 
     225            2 :         InterpretedWalReaderHandle {
     226            2 :             join_handle,
     227            2 :             state,
     228            2 :             shard_notification_tx,
     229            2 :         }
     230            2 :     }
     231              : 
     232              :     /// Construct the reader without spawning anything
     233              :     /// Callers should drive the future returned by [`Self::run`].
     234            0 :     pub(crate) fn new(
     235            0 :         wal_stream: StreamingWalReader,
     236            0 :         start_pos: Lsn,
     237            0 :         tx: tokio::sync::mpsc::Sender<Batch>,
     238            0 :         shard: ShardIdentity,
     239            0 :         pg_version: u32,
     240            0 :     ) -> InterpretedWalReader {
     241            0 :         let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
     242            0 :             current_position: start_pos,
     243            0 :         }));
     244            0 : 
     245            0 :         InterpretedWalReader {
     246            0 :             wal_stream,
     247            0 :             shard_senders: HashMap::from([(
     248            0 :                 shard,
     249            0 :                 smallvec::smallvec![ShardSenderState {
     250            0 :                     sender_id: SenderId::first(),
     251            0 :                     tx,
     252            0 :                     next_record_lsn: start_pos,
     253            0 :                 }],
     254              :             )]),
     255            0 :             shard_notification_rx: None,
     256            0 :             state: state.clone(),
     257            0 :             pg_version,
     258            0 :         }
     259            0 :     }
     260              : 
     261              :     /// Entry point for future (polling) based wal reader.
     262            0 :     pub(crate) async fn run(
     263            0 :         self,
     264            0 :         start_pos: Lsn,
     265            0 :         appname: &Option<String>,
     266            0 :     ) -> Result<(), CopyStreamHandlerEnd> {
     267            0 :         let metric = WAL_READERS
     268            0 :             .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")])
     269            0 :             .unwrap();
     270            0 : 
     271            0 :         metric.inc();
     272            0 :         scopeguard::defer! {
     273            0 :             metric.dec();
     274            0 :         }
     275              : 
     276            0 :         if let Err(err) = self.run_impl(start_pos).await {
     277            0 :             critical!("failed to read WAL record: {err:?}");
     278              :         } else {
     279            0 :             info!("interpreted wal reader exiting");
     280              :         }
     281              : 
     282            0 :         Err(CopyStreamHandlerEnd::Other(anyhow!(
     283            0 :             "interpreted wal reader finished"
     284            0 :         )))
     285            0 :     }
     286              : 
     287              :     /// Send interpreted WAL to one or more [`InterpretedWalSender`]s
     288              :     /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`]
     289              :     /// goes out of scope.
     290            2 :     async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> {
     291            2 :         let defer_state = self.state.clone();
     292            2 :         scopeguard::defer! {
     293            2 :             *defer_state.write().unwrap() = InterpretedWalReaderState::Done;
     294            2 :         }
     295            2 : 
     296            2 :         let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version);
     297              : 
     298              :         loop {
     299           44 :             tokio::select! {
     300              :                 // Main branch for reading WAL and forwarding it
     301           44 :                 wal_or_reset = self.wal_stream.next() => {
     302           39 :                     let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below"));
     303              :                     let WalBytes {
     304           39 :                         wal,
     305           39 :                         wal_start_lsn: _,
     306           39 :                         wal_end_lsn,
     307           39 :                         available_wal_end_lsn,
     308           39 :                     } = match wal {
     309           39 :                         Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?,
     310              :                         None => {
     311              :                             // [`StreamingWalReader::next`] is an endless stream of WAL.
     312              :                             // It shouldn't ever finish unless it panicked or became internally
     313              :                             // inconsistent.
     314            0 :                             return Result::Err(InterpretedWalReaderError::WalStreamClosed);
     315              :                         }
     316              :                     };
     317              : 
     318           39 :                     wal_decoder.feed_bytes(&wal);
     319           39 : 
     320           39 :                     // Deserialize and interpret WAL records from this batch of WAL.
     321           39 :                     // Interpreted records for each shard are collected separately.
     322           39 :                     let shard_ids = self.shard_senders.keys().copied().collect::<Vec<_>>();
     323           39 :                     let mut records_by_sender: HashMap<ShardSenderId, Vec<InterpretedWalRecord>> = HashMap::new();
     324           39 :                     let mut max_next_record_lsn = None;
     325          635 :                     while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()?
     326              :                     {
     327          596 :                         assert!(next_record_lsn.is_aligned());
     328          596 :                         max_next_record_lsn = Some(next_record_lsn);
     329              : 
     330          596 :                         let interpreted = InterpretedWalRecord::from_bytes_filtered(
     331          596 :                             recdata,
     332          596 :                             &shard_ids,
     333          596 :                             next_record_lsn,
     334          596 :                             self.pg_version,
     335          596 :                         )
     336          596 :                         .with_context(|| "Failed to interpret WAL")?;
     337              : 
     338         1391 :                         for (shard, record) in interpreted {
     339          795 :                             if record.is_empty() {
     340          199 :                                 continue;
     341          596 :                             }
     342          596 : 
     343          596 :                             let mut states_iter = self.shard_senders
     344          596 :                                 .get(&shard)
     345          596 :                                 .expect("keys collected above")
     346          596 :                                 .iter()
     347          992 :                                 .filter(|state| record.next_record_lsn > state.next_record_lsn)
     348          596 :                                 .peekable();
     349          986 :                             while let Some(state) = states_iter.next() {
     350          787 :                                 let shard_sender_id = ShardSenderId::new(shard, state.sender_id);
     351          787 : 
     352          787 :                                 // The most commont case is one sender per shard. Peek and break to avoid the
     353          787 :                                 // clone in that situation.
     354          787 :                                 if states_iter.peek().is_none() {
     355          397 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record);
     356          397 :                                     break;
     357          390 :                                 } else {
     358          390 :                                     records_by_sender.entry(shard_sender_id).or_default().push(record.clone());
     359          390 :                                 }
     360              :                             }
     361              :                         }
     362              :                     }
     363              : 
     364           39 :                     let max_next_record_lsn = match max_next_record_lsn {
     365           39 :                         Some(lsn) => lsn,
     366            0 :                         None => { continue; }
     367              :                     };
     368              : 
     369              :                     // Update the current position such that new receivers can decide
     370              :                     // whether to attach to us or spawn a new WAL reader.
     371           39 :                     match &mut *self.state.write().unwrap() {
     372           39 :                         InterpretedWalReaderState::Running { current_position, .. } => {
     373           39 :                             *current_position = max_next_record_lsn;
     374           39 :                         },
     375              :                         InterpretedWalReaderState::Done => {
     376            0 :                             unreachable!()
     377              :                         }
     378              :                     }
     379              : 
     380              :                     // Send interpreted records downstream. Anything that has already been seen
     381              :                     // by a shard is filtered out.
     382           39 :                     let mut shard_senders_to_remove = Vec::new();
     383           91 :                     for (shard, states) in &mut self.shard_senders {
     384          130 :                         for state in states {
     385           78 :                             if max_next_record_lsn <= state.next_record_lsn {
     386           13 :                                 continue;
     387           65 :                             }
     388           65 : 
     389           65 :                             let shard_sender_id = ShardSenderId::new(*shard, state.sender_id);
     390           65 :                             let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default();
     391           65 : 
     392           65 :                             let batch = InterpretedWalRecords {
     393           65 :                                 records,
     394           65 :                                 next_record_lsn: Some(max_next_record_lsn),
     395           65 :                             };
     396              : 
     397           65 :                             let res = state.tx.send(Batch {
     398           65 :                                 wal_end_lsn,
     399           65 :                                 available_wal_end_lsn,
     400           65 :                                 records: batch,
     401           65 :                             }).await;
     402              : 
     403           65 :                             if res.is_err() {
     404            0 :                                 shard_senders_to_remove.push(shard_sender_id);
     405           65 :                             } else {
     406           65 :                                 state.next_record_lsn = max_next_record_lsn;
     407           65 :                             }
     408              :                         }
     409              :                     }
     410              : 
     411              :                     // Clean up any shard senders that have dropped out.
     412              :                     // This is inefficient, but such events are rare (connection to PS termination)
     413              :                     // and the number of subscriptions on the same shards very small (only one
     414              :                     // for the steady state).
     415           39 :                     for to_remove in shard_senders_to_remove {
     416            0 :                         let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above");
     417            0 :                         if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) {
     418            0 :                             shard_senders.remove(idx);
     419            0 :                             tracing::info!("Removed shard sender {}", to_remove);
     420            0 :                         }
     421              : 
     422            0 :                         if shard_senders.is_empty() {
     423            0 :                             self.shard_senders.remove(&to_remove.shard());
     424            0 :                         }
     425              :                     }
     426              :                 },
     427              :                 // Listen for new shards that want to attach to this reader.
     428              :                 // If the reader is not running as a task, then this is not supported
     429              :                 // (see the pending branch below).
     430           44 :                 notification = match self.shard_notification_rx.as_mut() {
     431           44 :                         Some(rx) => Either::Left(rx.recv()),
     432            0 :                         None => Either::Right(std::future::pending())
     433              :                     } => {
     434            3 :                     if let Some(n) = notification {
     435            3 :                         let AttachShardNotification { shard_id, sender, start_pos } = n;
     436            3 : 
     437            3 :                         // Update internal and external state, then reset the WAL stream
     438            3 :                         // if required.
     439            3 :                         let senders = self.shard_senders.entry(shard_id).or_default();
     440            3 :                         let new_sender_id = match senders.last() {
     441            2 :                             Some(sender) => sender.sender_id.next(),
     442            1 :                             None => SenderId::first()
     443              :                         };
     444              : 
     445            3 :                         senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos});
     446            3 : 
     447            3 :                         // If the shard is subscribing below the current position the we need
     448            3 :                         // to update the cursor that tracks where we are at in the WAL
     449            3 :                         // ([`Self::state`]) and reset the WAL stream itself
     450            3 :                         // (`[Self::wal_stream`]). This must be done atomically from the POV of
     451            3 :                         // anything outside the select statement.
     452            3 :                         let position_reset = self.state.write().unwrap().maybe_reset(start_pos);
     453            3 :                         match position_reset {
     454            2 :                             CurrentPositionUpdate::Reset(to) => {
     455            2 :                                 self.wal_stream.reset(to).await;
     456            2 :                                 wal_decoder = WalStreamDecoder::new(to, self.pg_version);
     457              :                             },
     458            1 :                             CurrentPositionUpdate::NotReset(_) => {}
     459              :                         };
     460              : 
     461            3 :                         tracing::info!(
     462            0 :                             "Added shard sender {} with start_pos={} current_pos={}",
     463            0 :                             ShardSenderId::new(shard_id, new_sender_id), start_pos, position_reset.current_position()
     464              :                         );
     465            0 :                     }
     466              :                 }
     467              :             }
     468              :         }
     469            0 :     }
     470              : }
     471              : 
     472              : impl InterpretedWalReaderHandle {
     473              :     /// Fan-out the reader by attaching a new shard to it
     474            3 :     pub(crate) fn fanout(
     475            3 :         &self,
     476            3 :         shard_id: ShardIdentity,
     477            3 :         sender: tokio::sync::mpsc::Sender<Batch>,
     478            3 :         start_pos: Lsn,
     479            3 :     ) -> Result<(), SendError<AttachShardNotification>> {
     480            3 :         self.shard_notification_tx.send(AttachShardNotification {
     481            3 :             shard_id,
     482            3 :             sender,
     483            3 :             start_pos,
     484            3 :         })
     485            3 :     }
     486              : 
     487              :     /// Get the current WAL position of the reader
     488            4 :     pub(crate) fn current_position(&self) -> Option<Lsn> {
     489            4 :         self.state.read().unwrap().current_position()
     490            4 :     }
     491              : 
     492            4 :     pub(crate) fn abort(&self) {
     493            4 :         self.join_handle.abort()
     494            4 :     }
     495              : }
     496              : 
     497              : impl Drop for InterpretedWalReaderHandle {
     498            2 :     fn drop(&mut self) {
     499            2 :         tracing::info!("Aborting interpreted wal reader");
     500            2 :         self.abort()
     501            2 :     }
     502              : }
     503              : 
     504              : pub(crate) struct InterpretedWalSender<'a, IO> {
     505              :     pub(crate) format: InterpretedFormat,
     506              :     pub(crate) compression: Option<Compression>,
     507              :     pub(crate) appname: Option<String>,
     508              : 
     509              :     pub(crate) tli: WalResidentTimeline,
     510              :     pub(crate) start_lsn: Lsn,
     511              : 
     512              :     pub(crate) pgb: &'a mut PostgresBackend<IO>,
     513              :     pub(crate) end_watch_view: EndWatchView,
     514              :     pub(crate) wal_sender_guard: Arc<WalSenderGuard>,
     515              :     pub(crate) rx: tokio::sync::mpsc::Receiver<Batch>,
     516              : }
     517              : 
     518              : impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
     519              :     /// Send interpreted WAL records over the network.
     520              :     /// Also manages keep-alives if nothing was sent for a while.
     521            0 :     pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> {
     522            0 :         let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
     523            0 :         keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
     524            0 :         keepalive_ticker.reset();
     525            0 : 
     526            0 :         let mut wal_position = self.start_lsn;
     527              : 
     528              :         loop {
     529            0 :             tokio::select! {
     530            0 :                 batch = self.rx.recv() => {
     531            0 :                     let batch = match batch {
     532            0 :                         Some(b) => b,
     533              :                         None => {
     534            0 :                             return Result::Err(
     535            0 :                                 CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early"))
     536            0 :                             );
     537              :                         }
     538              :                     };
     539              : 
     540            0 :                     wal_position = batch.wal_end_lsn;
     541              : 
     542            0 :                     let buf = batch
     543            0 :                         .records
     544            0 :                         .to_wire(self.format, self.compression)
     545            0 :                         .await
     546            0 :                         .with_context(|| "Failed to serialize interpreted WAL")
     547            0 :                         .map_err(CopyStreamHandlerEnd::from)?;
     548              : 
     549              :                     // Reset the keep alive ticker since we are sending something
     550              :                     // over the wire now.
     551            0 :                     keepalive_ticker.reset();
     552            0 : 
     553            0 :                     self.pgb
     554            0 :                         .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
     555            0 :                             streaming_lsn: batch.wal_end_lsn.0,
     556            0 :                             commit_lsn: batch.available_wal_end_lsn.0,
     557            0 :                             data: &buf,
     558            0 :                         })).await?;
     559              :                 }
     560              :                 // Send a periodic keep alive when the connection has been idle for a while.
     561              :                 // Since we've been idle, also check if we can stop streaming.
     562            0 :                 _ = keepalive_ticker.tick() => {
     563            0 :                     if let Some(remote_consistent_lsn) = self.wal_sender_guard
     564            0 :                         .walsenders()
     565            0 :                         .get_ws_remote_consistent_lsn(self.wal_sender_guard.id())
     566              :                     {
     567            0 :                         if self.tli.should_walsender_stop(remote_consistent_lsn).await {
     568              :                             // Stop streaming if the receivers are caught up and
     569              :                             // there's no active compute. This causes the loop in
     570              :                             // [`crate::send_interpreted_wal::InterpretedWalSender::run`]
     571              :                             // to exit and terminate the WAL stream.
     572            0 :                             break;
     573            0 :                         }
     574            0 :                     }
     575              : 
     576            0 :                     self.pgb
     577            0 :                         .write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
     578            0 :                             wal_end: self.end_watch_view.get().0,
     579            0 :                             timestamp: get_current_timestamp(),
     580            0 :                             request_reply: true,
     581            0 :                         }))
     582            0 :                         .await?;
     583              :                 },
     584              :             }
     585              :         }
     586              : 
     587            0 :         Err(CopyStreamHandlerEnd::ServerInitiated(format!(
     588            0 :             "ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
     589            0 :             self.appname, wal_position,
     590            0 :         )))
     591            0 :     }
     592              : }
     593              : #[cfg(test)]
     594              : mod tests {
     595              :     use std::{collections::HashMap, str::FromStr, time::Duration};
     596              : 
     597              :     use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
     598              :     use postgres_ffi::MAX_SEND_SIZE;
     599              :     use tokio::sync::mpsc::error::TryRecvError;
     600              :     use utils::{
     601              :         id::{NodeId, TenantTimelineId},
     602              :         lsn::Lsn,
     603              :         shard::{ShardCount, ShardNumber},
     604              :     };
     605              : 
     606              :     use crate::{
     607              :         send_interpreted_wal::{Batch, InterpretedWalReader},
     608              :         test_utils::Env,
     609              :         wal_reader_stream::StreamingWalReader,
     610              :     };
     611              : 
     612              :     #[tokio::test]
     613            1 :     async fn test_interpreted_wal_reader_fanout() {
     614            1 :         let _ = env_logger::builder().is_test(true).try_init();
     615            1 : 
     616            1 :         const SIZE: usize = 8 * 1024;
     617            1 :         const MSG_COUNT: usize = 200;
     618            1 :         const PG_VERSION: u32 = 17;
     619            1 :         const SHARD_COUNT: u8 = 2;
     620            1 : 
     621            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     622            1 :         let env = Env::new(true).unwrap();
     623            1 :         let tli = env
     624            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     625            1 :             .await
     626            1 :             .unwrap();
     627            1 : 
     628            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     629            1 :         let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None)
     630            1 :             .await
     631            1 :             .unwrap();
     632            1 :         let end_pos = end_watch.get();
     633            1 : 
     634            1 :         tracing::info!("Doing first round of reads ...");
     635            1 : 
     636            1 :         let streaming_wal_reader = StreamingWalReader::new(
     637            1 :             resident_tli,
     638            1 :             None,
     639            1 :             start_lsn,
     640            1 :             end_pos,
     641            1 :             end_watch,
     642            1 :             MAX_SEND_SIZE,
     643            1 :         );
     644            1 : 
     645            1 :         let shard_0 = ShardIdentity::new(
     646            1 :             ShardNumber(0),
     647            1 :             ShardCount(SHARD_COUNT),
     648            1 :             ShardStripeSize::default(),
     649            1 :         )
     650            1 :         .unwrap();
     651            1 : 
     652            1 :         let shard_1 = ShardIdentity::new(
     653            1 :             ShardNumber(1),
     654            1 :             ShardCount(SHARD_COUNT),
     655            1 :             ShardStripeSize::default(),
     656            1 :         )
     657            1 :         .unwrap();
     658            1 : 
     659            1 :         let mut shards = HashMap::new();
     660            1 : 
     661            3 :         for shard_number in 0..SHARD_COUNT {
     662            2 :             let shard_id = ShardIdentity::new(
     663            2 :                 ShardNumber(shard_number),
     664            2 :                 ShardCount(SHARD_COUNT),
     665            2 :                 ShardStripeSize::default(),
     666            2 :             )
     667            2 :             .unwrap();
     668            2 :             let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     669            2 :             shards.insert(shard_id, (Some(tx), Some(rx)));
     670            2 :         }
     671            1 : 
     672            1 :         let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
     673            1 :         let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap();
     674            1 : 
     675            1 :         let handle = InterpretedWalReader::spawn(
     676            1 :             streaming_wal_reader,
     677            1 :             start_lsn,
     678            1 :             shard_0_tx,
     679            1 :             shard_0,
     680            1 :             PG_VERSION,
     681            1 :             &Some("pageserver".to_string()),
     682            1 :         );
     683            1 : 
     684            1 :         tracing::info!("Reading all WAL with only shard 0 attached ...");
     685            1 : 
     686            1 :         let mut shard_0_interpreted_records = Vec::new();
     687           13 :         while let Some(batch) = shard_0_rx.recv().await {
     688           13 :             shard_0_interpreted_records.push(batch.records);
     689           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     690            1 :                 break;
     691           12 :             }
     692            1 :         }
     693            1 : 
     694            1 :         let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap();
     695            1 :         let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
     696            1 : 
     697            1 :         tracing::info!("Attaching shard 1 to the reader at start of WAL");
     698            1 :         handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap();
     699            1 : 
     700            1 :         tracing::info!("Reading all WAL with shard 0 and shard 1 attached ...");
     701            1 : 
     702            1 :         let mut shard_1_interpreted_records = Vec::new();
     703           13 :         while let Some(batch) = shard_1_rx.recv().await {
     704           13 :             shard_1_interpreted_records.push(batch.records);
     705           13 :             if batch.wal_end_lsn == batch.available_wal_end_lsn {
     706            1 :                 break;
     707           12 :             }
     708            1 :         }
     709            1 : 
     710            1 :         // This test uses logical messages. Those only go to shard 0. Check that the
     711            1 :         // filtering worked and shard 1 did not get any.
     712            1 :         assert!(shard_1_interpreted_records
     713            1 :             .iter()
     714           13 :             .all(|recs| recs.records.is_empty()));
     715            1 : 
     716            1 :         // Shard 0 should not receive anything more since the reader is
     717            1 :         // going through wal that it has already processed.
     718            1 :         let res = shard_0_rx.try_recv();
     719            1 :         if let Ok(ref ok) = res {
     720            1 :             tracing::error!(
     721            1 :                 "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}",
     722            1 :                 ok.wal_end_lsn,
     723            1 :                 ok.available_wal_end_lsn
     724            1 :             );
     725            1 :         }
     726            1 :         assert!(matches!(res, Err(TryRecvError::Empty)));
     727            1 : 
     728            1 :         // Check that the next records lsns received by the two shards match up.
     729            1 :         let shard_0_next_lsns = shard_0_interpreted_records
     730            1 :             .iter()
     731           13 :             .map(|recs| recs.next_record_lsn)
     732            1 :             .collect::<Vec<_>>();
     733            1 :         let shard_1_next_lsns = shard_1_interpreted_records
     734            1 :             .iter()
     735           13 :             .map(|recs| recs.next_record_lsn)
     736            1 :             .collect::<Vec<_>>();
     737            1 :         assert_eq!(shard_0_next_lsns, shard_1_next_lsns);
     738            1 : 
     739            1 :         handle.abort();
     740            1 :         let mut done = false;
     741            2 :         for _ in 0..5 {
     742            2 :             if handle.current_position().is_none() {
     743            1 :                 done = true;
     744            1 :                 break;
     745            1 :             }
     746            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     747            1 :         }
     748            1 : 
     749            1 :         assert!(done);
     750            1 :     }
     751              : 
     752              :     #[tokio::test]
     753            1 :     async fn test_interpreted_wal_reader_same_shard_fanout() {
     754            1 :         let _ = env_logger::builder().is_test(true).try_init();
     755            1 : 
     756            1 :         const SIZE: usize = 8 * 1024;
     757            1 :         const MSG_COUNT: usize = 200;
     758            1 :         const PG_VERSION: u32 = 17;
     759            1 :         const SHARD_COUNT: u8 = 2;
     760            1 : 
     761            1 :         let start_lsn = Lsn::from_str("0/149FD18").unwrap();
     762            1 :         let env = Env::new(true).unwrap();
     763            1 :         let tli = env
     764            1 :             .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
     765            1 :             .await
     766            1 :             .unwrap();
     767            1 : 
     768            1 :         let resident_tli = tli.wal_residence_guard().await.unwrap();
     769            1 :         let mut next_record_lsns = Vec::default();
     770            1 :         let end_watch =
     771            1 :             Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, Some(&mut next_record_lsns))
     772            1 :                 .await
     773            1 :                 .unwrap();
     774            1 :         let end_pos = end_watch.get();
     775            1 : 
     776            1 :         let streaming_wal_reader = StreamingWalReader::new(
     777            1 :             resident_tli,
     778            1 :             None,
     779            1 :             start_lsn,
     780            1 :             end_pos,
     781            1 :             end_watch,
     782            1 :             MAX_SEND_SIZE,
     783            1 :         );
     784            1 : 
     785            1 :         let shard_0 = ShardIdentity::new(
     786            1 :             ShardNumber(0),
     787            1 :             ShardCount(SHARD_COUNT),
     788            1 :             ShardStripeSize::default(),
     789            1 :         )
     790            1 :         .unwrap();
     791            1 : 
     792            1 :         struct Sender {
     793            1 :             tx: Option<tokio::sync::mpsc::Sender<Batch>>,
     794            1 :             rx: tokio::sync::mpsc::Receiver<Batch>,
     795            1 :             shard: ShardIdentity,
     796            1 :             start_lsn: Lsn,
     797            1 :             received_next_record_lsns: Vec<Lsn>,
     798            1 :         }
     799            1 : 
     800            1 :         impl Sender {
     801            3 :             fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self {
     802            3 :                 let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
     803            3 :                 Self {
     804            3 :                     tx: Some(tx),
     805            3 :                     rx,
     806            3 :                     shard,
     807            3 :                     start_lsn,
     808            3 :                     received_next_record_lsns: Vec::default(),
     809            3 :                 }
     810            3 :             }
     811            1 :         }
     812            1 : 
     813            1 :         assert!(next_record_lsns.len() > 7);
     814            1 :         let start_lsns = vec![
     815            1 :             next_record_lsns[5],
     816            1 :             next_record_lsns[1],
     817            1 :             next_record_lsns[3],
     818            1 :         ];
     819            1 :         let mut senders = start_lsns
     820            1 :             .into_iter()
     821            3 :             .map(|lsn| Sender::new(lsn, shard_0))
     822            1 :             .collect::<Vec<_>>();
     823            1 : 
     824            1 :         let first_sender = senders.first_mut().unwrap();
     825            1 :         let handle = InterpretedWalReader::spawn(
     826            1 :             streaming_wal_reader,
     827            1 :             first_sender.start_lsn,
     828            1 :             first_sender.tx.take().unwrap(),
     829            1 :             first_sender.shard,
     830            1 :             PG_VERSION,
     831            1 :             &Some("pageserver".to_string()),
     832            1 :         );
     833            1 : 
     834            2 :         for sender in senders.iter_mut().skip(1) {
     835            2 :             handle
     836            2 :                 .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn)
     837            2 :                 .unwrap();
     838            2 :         }
     839            1 : 
     840            3 :         for sender in senders.iter_mut() {
     841            1 :             loop {
     842           39 :                 let batch = sender.rx.recv().await.unwrap();
     843           39 :                 tracing::info!(
     844            1 :                     "Sender with start_lsn={} received batch ending at {} with {} records",
     845            0 :                     sender.start_lsn,
     846            0 :                     batch.wal_end_lsn,
     847            0 :                     batch.records.records.len()
     848            1 :                 );
     849            1 : 
     850          627 :                 for rec in batch.records.records {
     851          588 :                     sender.received_next_record_lsns.push(rec.next_record_lsn);
     852          588 :                 }
     853            1 : 
     854           39 :                 if batch.wal_end_lsn == batch.available_wal_end_lsn {
     855            3 :                     break;
     856           36 :                 }
     857            1 :             }
     858            1 :         }
     859            1 : 
     860            1 :         handle.abort();
     861            1 :         let mut done = false;
     862            2 :         for _ in 0..5 {
     863            2 :             if handle.current_position().is_none() {
     864            1 :                 done = true;
     865            1 :                 break;
     866            1 :             }
     867            1 :             tokio::time::sleep(Duration::from_millis(1)).await;
     868            1 :         }
     869            1 : 
     870            1 :         assert!(done);
     871            1 : 
     872            4 :         for sender in senders {
     873            3 :             tracing::info!(
     874            1 :                 "Validating records received by sender with start_lsn={}",
     875            1 :                 sender.start_lsn
     876            1 :             );
     877            1 : 
     878            3 :             assert!(sender.received_next_record_lsns.is_sorted());
     879            3 :             let expected = next_record_lsns
     880            3 :                 .iter()
     881          600 :                 .filter(|lsn| **lsn > sender.start_lsn)
     882            3 :                 .copied()
     883            3 :                 .collect::<Vec<_>>();
     884            3 :             assert_eq!(sender.received_next_record_lsns, expected);
     885            1 :         }
     886            1 :     }
     887              : }
        

Generated by: LCOV version 2.1-beta