Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt::Display;
3 : use std::sync::Arc;
4 : use std::time::Duration;
5 :
6 : use anyhow::{Context, anyhow};
7 : use futures::StreamExt;
8 : use futures::future::Either;
9 : use pageserver_api::shard::ShardIdentity;
10 : use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend};
11 : use postgres_ffi::get_current_timestamp;
12 : use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder};
13 : use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
14 : use tokio::io::{AsyncRead, AsyncWrite};
15 : use tokio::sync::mpsc::error::SendError;
16 : use tokio::task::JoinHandle;
17 : use tokio::time::MissedTickBehavior;
18 : use tracing::{Instrument, error, info, info_span};
19 : use utils::critical;
20 : use utils::lsn::Lsn;
21 : use utils::postgres_client::{Compression, InterpretedFormat};
22 : use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
23 : use wal_decoder::wire_format::ToWireFormat;
24 :
25 : use crate::metrics::WAL_READERS;
26 : use crate::send_wal::{EndWatchView, WalSenderGuard};
27 : use crate::timeline::WalResidentTimeline;
28 : use crate::wal_reader_stream::{StreamingWalReader, WalBytes};
29 :
30 : /// Identifier used to differentiate between senders of the same
31 : /// shard.
32 : ///
33 : /// In the steady state there's only one, but two pageservers may
34 : /// temporarily have the same shard attached and attempt to ingest
35 : /// WAL for it. See also [`ShardSenderId`].
36 : #[derive(Hash, Eq, PartialEq, Copy, Clone)]
37 : struct SenderId(u8);
38 :
39 : impl SenderId {
40 5 : fn first() -> Self {
41 5 : SenderId(0)
42 5 : }
43 :
44 2 : fn next(&self) -> Self {
45 2 : SenderId(self.0.checked_add(1).expect("few senders"))
46 2 : }
47 : }
48 :
49 : #[derive(Hash, Eq, PartialEq)]
50 : struct ShardSenderId {
51 : shard: ShardIdentity,
52 : sender_id: SenderId,
53 : }
54 :
55 : impl Display for ShardSenderId {
56 0 : fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 0 : write!(f, "{}{}", self.sender_id.0, self.shard.shard_slug())
58 0 : }
59 : }
60 :
61 : impl ShardSenderId {
62 867 : fn new(shard: ShardIdentity, sender_id: SenderId) -> Self {
63 867 : ShardSenderId { shard, sender_id }
64 867 : }
65 :
66 0 : fn shard(&self) -> ShardIdentity {
67 0 : self.shard
68 0 : }
69 : }
70 :
71 : /// Shard-aware fan-out interpreted record reader.
72 : /// Reads WAL from disk, decodes it, intepretets it, and sends
73 : /// it to any [`InterpretedWalSender`] connected to it.
74 : /// Each [`InterpretedWalSender`] corresponds to one shard
75 : /// and gets interpreted records concerning that shard only.
76 : pub(crate) struct InterpretedWalReader {
77 : wal_stream: StreamingWalReader,
78 : shard_senders: HashMap<ShardIdentity, smallvec::SmallVec<[ShardSenderState; 1]>>,
79 : shard_notification_rx: Option<tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>>,
80 : state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
81 : pg_version: u32,
82 : }
83 :
84 : /// A handle for [`InterpretedWalReader`] which allows for interacting with it
85 : /// when it runs as a separate tokio task.
86 : #[derive(Debug)]
87 : pub(crate) struct InterpretedWalReaderHandle {
88 : join_handle: JoinHandle<Result<(), InterpretedWalReaderError>>,
89 : state: Arc<std::sync::RwLock<InterpretedWalReaderState>>,
90 : shard_notification_tx: tokio::sync::mpsc::UnboundedSender<AttachShardNotification>,
91 : }
92 :
93 : struct ShardSenderState {
94 : sender_id: SenderId,
95 : tx: tokio::sync::mpsc::Sender<Batch>,
96 : next_record_lsn: Lsn,
97 : }
98 :
99 : /// State of [`InterpretedWalReader`] visible outside of the task running it.
100 : #[derive(Debug)]
101 : pub(crate) enum InterpretedWalReaderState {
102 : Running {
103 : current_position: Lsn,
104 : /// Tracks the start of the PG WAL LSN from which the current batch of
105 : /// interpreted records originated.
106 : current_batch_wal_start: Option<Lsn>,
107 : },
108 : Done,
109 : }
110 :
111 : pub(crate) struct Batch {
112 : wal_end_lsn: Lsn,
113 : available_wal_end_lsn: Lsn,
114 : records: InterpretedWalRecords,
115 : }
116 :
117 : #[derive(thiserror::Error, Debug)]
118 : pub enum InterpretedWalReaderError {
119 : /// Handler initiates the end of streaming.
120 : #[error("decode error: {0}")]
121 : Decode(#[from] WalDecodeError),
122 : #[error("read or interpret error: {0}")]
123 : ReadOrInterpret(#[from] anyhow::Error),
124 : #[error("wal stream closed")]
125 : WalStreamClosed,
126 : }
127 :
128 : enum CurrentPositionUpdate {
129 : Reset { from: Lsn, to: Lsn },
130 : NotReset(Lsn),
131 : }
132 :
133 : impl CurrentPositionUpdate {
134 0 : fn current_position(&self) -> Lsn {
135 0 : match self {
136 0 : CurrentPositionUpdate::Reset { from: _, to } => *to,
137 0 : CurrentPositionUpdate::NotReset(lsn) => *lsn,
138 : }
139 0 : }
140 :
141 0 : fn previous_position(&self) -> Lsn {
142 0 : match self {
143 0 : CurrentPositionUpdate::Reset { from, to: _ } => *from,
144 0 : CurrentPositionUpdate::NotReset(lsn) => *lsn,
145 : }
146 0 : }
147 : }
148 :
149 : impl InterpretedWalReaderState {
150 4 : fn current_position(&self) -> Option<Lsn> {
151 4 : match self {
152 : InterpretedWalReaderState::Running {
153 2 : current_position, ..
154 2 : } => Some(*current_position),
155 2 : InterpretedWalReaderState::Done => None,
156 : }
157 4 : }
158 :
159 : #[cfg(test)]
160 362 : fn current_batch_wal_start(&self) -> Option<Lsn> {
161 362 : match self {
162 : InterpretedWalReaderState::Running {
163 362 : current_batch_wal_start,
164 362 : ..
165 362 : } => *current_batch_wal_start,
166 0 : InterpretedWalReaderState::Done => None,
167 : }
168 362 : }
169 :
170 : // Reset the current position of the WAL reader if the requested starting position
171 : // of the new shard is smaller than the current value.
172 4 : fn maybe_reset(&mut self, new_shard_start_pos: Lsn) -> CurrentPositionUpdate {
173 4 : match self {
174 : InterpretedWalReaderState::Running {
175 4 : current_position,
176 4 : current_batch_wal_start,
177 4 : } => {
178 4 : if new_shard_start_pos < *current_position {
179 3 : let from = *current_position;
180 3 : *current_position = new_shard_start_pos;
181 3 : *current_batch_wal_start = None;
182 3 : CurrentPositionUpdate::Reset {
183 3 : from,
184 3 : to: *current_position,
185 3 : }
186 : } else {
187 1 : CurrentPositionUpdate::NotReset(*current_position)
188 : }
189 : }
190 : InterpretedWalReaderState::Done => {
191 0 : panic!("maybe_reset called on finished reader")
192 : }
193 : }
194 4 : }
195 :
196 49 : fn update_current_batch_wal_start(&mut self, lsn: Lsn) {
197 49 : match self {
198 : InterpretedWalReaderState::Running {
199 49 : current_batch_wal_start,
200 49 : ..
201 49 : } => {
202 49 : if current_batch_wal_start.is_none() {
203 41 : *current_batch_wal_start = Some(lsn);
204 41 : }
205 : }
206 : InterpretedWalReaderState::Done => {
207 0 : panic!("update_current_batch_wal_start called on finished reader")
208 : }
209 : }
210 49 : }
211 :
212 40 : fn take_current_batch_wal_start(&mut self) -> Lsn {
213 40 : match self {
214 : InterpretedWalReaderState::Running {
215 40 : current_batch_wal_start,
216 40 : ..
217 40 : } => current_batch_wal_start.take().unwrap(),
218 : InterpretedWalReaderState::Done => {
219 0 : panic!("take_current_batch_wal_start called on finished reader")
220 : }
221 : }
222 40 : }
223 :
224 40 : fn update_current_position(&mut self, lsn: Lsn) {
225 40 : match self {
226 : InterpretedWalReaderState::Running {
227 40 : current_position, ..
228 40 : } => {
229 40 : *current_position = lsn;
230 40 : }
231 : InterpretedWalReaderState::Done => {
232 0 : panic!("update_current_position called on finished reader")
233 : }
234 : }
235 40 : }
236 : }
237 :
238 : pub(crate) struct AttachShardNotification {
239 : shard_id: ShardIdentity,
240 : sender: tokio::sync::mpsc::Sender<Batch>,
241 : start_pos: Lsn,
242 : }
243 :
244 : impl InterpretedWalReader {
245 : /// Spawn the reader in a separate tokio task and return a handle
246 2 : pub(crate) fn spawn(
247 2 : wal_stream: StreamingWalReader,
248 2 : start_pos: Lsn,
249 2 : tx: tokio::sync::mpsc::Sender<Batch>,
250 2 : shard: ShardIdentity,
251 2 : pg_version: u32,
252 2 : appname: &Option<String>,
253 2 : ) -> InterpretedWalReaderHandle {
254 2 : let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
255 2 : current_position: start_pos,
256 2 : current_batch_wal_start: None,
257 2 : }));
258 2 :
259 2 : let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
260 :
261 2 : let reader = InterpretedWalReader {
262 2 : wal_stream,
263 2 : shard_senders: HashMap::from([(
264 2 : shard,
265 2 : smallvec::smallvec![ShardSenderState {
266 0 : sender_id: SenderId::first(),
267 0 : tx,
268 0 : next_record_lsn: start_pos,
269 0 : }],
270 : )]),
271 2 : shard_notification_rx: Some(shard_notification_rx),
272 2 : state: state.clone(),
273 2 : pg_version,
274 2 : };
275 2 :
276 2 : let metric = WAL_READERS
277 2 : .get_metric_with_label_values(&["task", appname.as_deref().unwrap_or("safekeeper")])
278 2 : .unwrap();
279 :
280 2 : let join_handle = tokio::task::spawn(
281 2 : async move {
282 2 : metric.inc();
283 2 : scopeguard::defer! {
284 2 : metric.dec();
285 2 : }
286 2 :
287 2 : reader
288 2 : .run_impl(start_pos)
289 2 : .await
290 0 : .inspect_err(|err| critical!("failed to read WAL record: {err:?}"))
291 0 : }
292 2 : .instrument(info_span!("interpreted wal reader")),
293 : );
294 :
295 2 : InterpretedWalReaderHandle {
296 2 : join_handle,
297 2 : state,
298 2 : shard_notification_tx,
299 2 : }
300 2 : }
301 :
302 : /// Construct the reader without spawning anything
303 : /// Callers should drive the future returned by [`Self::run`].
304 1 : pub(crate) fn new(
305 1 : wal_stream: StreamingWalReader,
306 1 : start_pos: Lsn,
307 1 : tx: tokio::sync::mpsc::Sender<Batch>,
308 1 : shard: ShardIdentity,
309 1 : pg_version: u32,
310 1 : shard_notification_rx: Option<
311 1 : tokio::sync::mpsc::UnboundedReceiver<AttachShardNotification>,
312 1 : >,
313 1 : ) -> InterpretedWalReader {
314 1 : let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running {
315 1 : current_position: start_pos,
316 1 : current_batch_wal_start: None,
317 1 : }));
318 1 :
319 1 : InterpretedWalReader {
320 1 : wal_stream,
321 1 : shard_senders: HashMap::from([(
322 1 : shard,
323 1 : smallvec::smallvec![ShardSenderState {
324 0 : sender_id: SenderId::first(),
325 0 : tx,
326 0 : next_record_lsn: start_pos,
327 0 : }],
328 : )]),
329 1 : shard_notification_rx,
330 1 : state: state.clone(),
331 1 : pg_version,
332 1 : }
333 1 : }
334 :
335 : /// Entry point for future (polling) based wal reader.
336 1 : pub(crate) async fn run(
337 1 : self,
338 1 : start_pos: Lsn,
339 1 : appname: &Option<String>,
340 1 : ) -> Result<(), CopyStreamHandlerEnd> {
341 1 : let metric = WAL_READERS
342 1 : .get_metric_with_label_values(&["future", appname.as_deref().unwrap_or("safekeeper")])
343 1 : .unwrap();
344 1 :
345 1 : metric.inc();
346 1 : scopeguard::defer! {
347 1 : metric.dec();
348 1 : }
349 :
350 1 : if let Err(err) = self.run_impl(start_pos).await {
351 0 : critical!("failed to read WAL record: {err:?}");
352 : } else {
353 0 : info!("interpreted wal reader exiting");
354 : }
355 :
356 0 : Err(CopyStreamHandlerEnd::Other(anyhow!(
357 0 : "interpreted wal reader finished"
358 0 : )))
359 0 : }
360 :
361 : /// Send interpreted WAL to one or more [`InterpretedWalSender`]s
362 : /// Stops when an error is encountered or when the [`InterpretedWalReaderHandle`]
363 : /// goes out of scope.
364 3 : async fn run_impl(mut self, start_pos: Lsn) -> Result<(), InterpretedWalReaderError> {
365 3 : let defer_state = self.state.clone();
366 3 : scopeguard::defer! {
367 3 : *defer_state.write().unwrap() = InterpretedWalReaderState::Done;
368 3 : }
369 3 :
370 3 : let mut wal_decoder = WalStreamDecoder::new(start_pos, self.pg_version);
371 :
372 : loop {
373 56 : tokio::select! {
374 : // Main branch for reading WAL and forwarding it
375 56 : wal_or_reset = self.wal_stream.next() => {
376 49 : let wal = wal_or_reset.map(|wor| wor.get_wal().expect("reset handled in select branch below"));
377 : let WalBytes {
378 49 : wal,
379 49 : wal_start_lsn,
380 49 : wal_end_lsn,
381 49 : available_wal_end_lsn,
382 49 : } = match wal {
383 49 : Some(some) => some.map_err(InterpretedWalReaderError::ReadOrInterpret)?,
384 : None => {
385 : // [`StreamingWalReader::next`] is an endless stream of WAL.
386 : // It shouldn't ever finish unless it panicked or became internally
387 : // inconsistent.
388 0 : return Result::Err(InterpretedWalReaderError::WalStreamClosed);
389 : }
390 : };
391 :
392 49 : self.state.write().unwrap().update_current_batch_wal_start(wal_start_lsn);
393 49 :
394 49 : wal_decoder.feed_bytes(&wal);
395 49 :
396 49 : // Deserialize and interpret WAL records from this batch of WAL.
397 49 : // Interpreted records for each shard are collected separately.
398 49 : let shard_ids = self.shard_senders.keys().copied().collect::<Vec<_>>();
399 49 : let mut records_by_sender: HashMap<ShardSenderId, Vec<InterpretedWalRecord>> = HashMap::new();
400 49 : let mut max_next_record_lsn = None;
401 646 : while let Some((next_record_lsn, recdata)) = wal_decoder.poll_decode()?
402 : {
403 597 : assert!(next_record_lsn.is_aligned());
404 597 : max_next_record_lsn = Some(next_record_lsn);
405 :
406 597 : let interpreted = InterpretedWalRecord::from_bytes_filtered(
407 597 : recdata,
408 597 : &shard_ids,
409 597 : next_record_lsn,
410 597 : self.pg_version,
411 597 : )
412 597 : .with_context(|| "Failed to interpret WAL")?;
413 :
414 1394 : for (shard, record) in interpreted {
415 797 : if record.is_empty() {
416 200 : continue;
417 597 : }
418 597 :
419 597 : let mut states_iter = self.shard_senders
420 597 : .get(&shard)
421 597 : .expect("keys collected above")
422 597 : .iter()
423 993 : .filter(|state| record.next_record_lsn > state.next_record_lsn)
424 597 : .peekable();
425 987 : while let Some(state) = states_iter.next() {
426 787 : let shard_sender_id = ShardSenderId::new(shard, state.sender_id);
427 787 :
428 787 : // The most commont case is one sender per shard. Peek and break to avoid the
429 787 : // clone in that situation.
430 787 : if states_iter.peek().is_none() {
431 397 : records_by_sender.entry(shard_sender_id).or_default().push(record);
432 397 : break;
433 390 : } else {
434 390 : records_by_sender.entry(shard_sender_id).or_default().push(record.clone());
435 390 : }
436 : }
437 : }
438 : }
439 :
440 49 : let max_next_record_lsn = match max_next_record_lsn {
441 40 : Some(lsn) => lsn,
442 : None => {
443 9 : continue;
444 : }
445 : };
446 :
447 : // Update the current position such that new receivers can decide
448 : // whether to attach to us or spawn a new WAL reader.
449 40 : let batch_wal_start_lsn = {
450 40 : let mut guard = self.state.write().unwrap();
451 40 : guard.update_current_position(max_next_record_lsn);
452 40 : guard.take_current_batch_wal_start()
453 40 : };
454 40 :
455 40 : // Send interpreted records downstream. Anything that has already been seen
456 40 : // by a shard is filtered out.
457 40 : let mut shard_senders_to_remove = Vec::new();
458 94 : for (shard, states) in &mut self.shard_senders {
459 134 : for state in states {
460 80 : let shard_sender_id = ShardSenderId::new(*shard, state.sender_id);
461 :
462 80 : let batch = if max_next_record_lsn > state.next_record_lsn {
463 : // This batch contains at least one record that this shard has not
464 : // seen yet.
465 66 : let records = records_by_sender.remove(&shard_sender_id).unwrap_or_default();
466 66 :
467 66 : InterpretedWalRecords {
468 66 : records,
469 66 : next_record_lsn: max_next_record_lsn,
470 66 : raw_wal_start_lsn: Some(batch_wal_start_lsn),
471 66 : }
472 14 : } else if wal_end_lsn > state.next_record_lsn {
473 : // All the records in this batch were seen by the shard
474 : // However, the batch maps to a chunk of WAL that the
475 : // shard has not yet seen. Notify it of the start LSN
476 : // of the PG WAL chunk such that it doesn't look like a gap.
477 0 : InterpretedWalRecords {
478 0 : records: Vec::default(),
479 0 : next_record_lsn: state.next_record_lsn,
480 0 : raw_wal_start_lsn: Some(batch_wal_start_lsn),
481 0 : }
482 : } else {
483 : // The shard has seen this chunk of WAL before. Skip it.
484 14 : continue;
485 : };
486 :
487 66 : let res = state.tx.send(Batch {
488 66 : wal_end_lsn,
489 66 : available_wal_end_lsn,
490 66 : records: batch,
491 66 : }).await;
492 :
493 66 : if res.is_err() {
494 0 : shard_senders_to_remove.push(shard_sender_id);
495 66 : } else {
496 66 : state.next_record_lsn = std::cmp::max(state.next_record_lsn, max_next_record_lsn);
497 66 : }
498 : }
499 : }
500 :
501 : // Clean up any shard senders that have dropped out.
502 : // This is inefficient, but such events are rare (connection to PS termination)
503 : // and the number of subscriptions on the same shards very small (only one
504 : // for the steady state).
505 40 : for to_remove in shard_senders_to_remove {
506 0 : let shard_senders = self.shard_senders.get_mut(&to_remove.shard()).expect("saw it above");
507 0 : if let Some(idx) = shard_senders.iter().position(|s| s.sender_id == to_remove.sender_id) {
508 0 : shard_senders.remove(idx);
509 0 : tracing::info!("Removed shard sender {}", to_remove);
510 0 : }
511 :
512 0 : if shard_senders.is_empty() {
513 0 : self.shard_senders.remove(&to_remove.shard());
514 0 : }
515 : }
516 : },
517 : // Listen for new shards that want to attach to this reader.
518 : // If the reader is not running as a task, then this is not supported
519 : // (see the pending branch below).
520 56 : notification = match self.shard_notification_rx.as_mut() {
521 56 : Some(rx) => Either::Left(rx.recv()),
522 0 : None => Either::Right(std::future::pending())
523 : } => {
524 4 : if let Some(n) = notification {
525 4 : let AttachShardNotification { shard_id, sender, start_pos } = n;
526 4 :
527 4 : // Update internal and external state, then reset the WAL stream
528 4 : // if required.
529 4 : let senders = self.shard_senders.entry(shard_id).or_default();
530 4 : let new_sender_id = match senders.last() {
531 2 : Some(sender) => sender.sender_id.next(),
532 2 : None => SenderId::first()
533 : };
534 :
535 4 : senders.push(ShardSenderState { sender_id: new_sender_id, tx: sender, next_record_lsn: start_pos});
536 4 :
537 4 : // If the shard is subscribing below the current position the we need
538 4 : // to update the cursor that tracks where we are at in the WAL
539 4 : // ([`Self::state`]) and reset the WAL stream itself
540 4 : // (`[Self::wal_stream`]). This must be done atomically from the POV of
541 4 : // anything outside the select statement.
542 4 : let position_reset = self.state.write().unwrap().maybe_reset(start_pos);
543 4 : match position_reset {
544 3 : CurrentPositionUpdate::Reset { from: _, to } => {
545 3 : self.wal_stream.reset(to).await;
546 3 : wal_decoder = WalStreamDecoder::new(to, self.pg_version);
547 : },
548 1 : CurrentPositionUpdate::NotReset(_) => {}
549 : };
550 :
551 4 : tracing::info!(
552 0 : "Added shard sender {} with start_pos={} previous_pos={} current_pos={}",
553 0 : ShardSenderId::new(shard_id, new_sender_id),
554 0 : start_pos,
555 0 : position_reset.previous_position(),
556 0 : position_reset.current_position(),
557 : );
558 0 : }
559 : }
560 : }
561 : }
562 0 : }
563 :
564 : #[cfg(test)]
565 1 : fn state(&self) -> Arc<std::sync::RwLock<InterpretedWalReaderState>> {
566 1 : self.state.clone()
567 1 : }
568 : }
569 :
570 : impl InterpretedWalReaderHandle {
571 : /// Fan-out the reader by attaching a new shard to it
572 3 : pub(crate) fn fanout(
573 3 : &self,
574 3 : shard_id: ShardIdentity,
575 3 : sender: tokio::sync::mpsc::Sender<Batch>,
576 3 : start_pos: Lsn,
577 3 : ) -> Result<(), SendError<AttachShardNotification>> {
578 3 : self.shard_notification_tx.send(AttachShardNotification {
579 3 : shard_id,
580 3 : sender,
581 3 : start_pos,
582 3 : })
583 3 : }
584 :
585 : /// Get the current WAL position of the reader
586 4 : pub(crate) fn current_position(&self) -> Option<Lsn> {
587 4 : self.state.read().unwrap().current_position()
588 4 : }
589 :
590 4 : pub(crate) fn abort(&self) {
591 4 : self.join_handle.abort()
592 4 : }
593 : }
594 :
595 : impl Drop for InterpretedWalReaderHandle {
596 2 : fn drop(&mut self) {
597 2 : tracing::info!("Aborting interpreted wal reader");
598 2 : self.abort()
599 2 : }
600 : }
601 :
602 : pub(crate) struct InterpretedWalSender<'a, IO> {
603 : pub(crate) format: InterpretedFormat,
604 : pub(crate) compression: Option<Compression>,
605 : pub(crate) appname: Option<String>,
606 :
607 : pub(crate) tli: WalResidentTimeline,
608 : pub(crate) start_lsn: Lsn,
609 :
610 : pub(crate) pgb: &'a mut PostgresBackend<IO>,
611 : pub(crate) end_watch_view: EndWatchView,
612 : pub(crate) wal_sender_guard: Arc<WalSenderGuard>,
613 : pub(crate) rx: tokio::sync::mpsc::Receiver<Batch>,
614 : }
615 :
616 : impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
617 : /// Send interpreted WAL records over the network.
618 : /// Also manages keep-alives if nothing was sent for a while.
619 0 : pub(crate) async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> {
620 0 : let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
621 0 : keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
622 0 : keepalive_ticker.reset();
623 0 :
624 0 : let mut wal_position = self.start_lsn;
625 :
626 : loop {
627 0 : tokio::select! {
628 0 : batch = self.rx.recv() => {
629 0 : let batch = match batch {
630 0 : Some(b) => b,
631 : None => {
632 0 : return Result::Err(
633 0 : CopyStreamHandlerEnd::Other(anyhow!("Interpreted WAL reader exited early"))
634 0 : );
635 : }
636 : };
637 :
638 0 : wal_position = batch.wal_end_lsn;
639 :
640 0 : let buf = batch
641 0 : .records
642 0 : .to_wire(self.format, self.compression)
643 0 : .await
644 0 : .with_context(|| "Failed to serialize interpreted WAL")
645 0 : .map_err(CopyStreamHandlerEnd::from)?;
646 :
647 : // Reset the keep alive ticker since we are sending something
648 : // over the wire now.
649 0 : keepalive_ticker.reset();
650 0 :
651 0 : self.pgb
652 0 : .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
653 0 : streaming_lsn: batch.wal_end_lsn.0,
654 0 : commit_lsn: batch.available_wal_end_lsn.0,
655 0 : data: &buf,
656 0 : })).await?;
657 : }
658 : // Send a periodic keep alive when the connection has been idle for a while.
659 : // Since we've been idle, also check if we can stop streaming.
660 0 : _ = keepalive_ticker.tick() => {
661 0 : if let Some(remote_consistent_lsn) = self.wal_sender_guard
662 0 : .walsenders()
663 0 : .get_ws_remote_consistent_lsn(self.wal_sender_guard.id())
664 : {
665 0 : if self.tli.should_walsender_stop(remote_consistent_lsn).await {
666 : // Stop streaming if the receivers are caught up and
667 : // there's no active compute. This causes the loop in
668 : // [`crate::send_interpreted_wal::InterpretedWalSender::run`]
669 : // to exit and terminate the WAL stream.
670 0 : break;
671 0 : }
672 0 : }
673 :
674 0 : self.pgb
675 0 : .write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
676 0 : wal_end: self.end_watch_view.get().0,
677 0 : timestamp: get_current_timestamp(),
678 0 : request_reply: true,
679 0 : }))
680 0 : .await?;
681 : },
682 : }
683 : }
684 :
685 0 : Err(CopyStreamHandlerEnd::ServerInitiated(format!(
686 0 : "ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
687 0 : self.appname, wal_position,
688 0 : )))
689 0 : }
690 : }
691 : #[cfg(test)]
692 : mod tests {
693 : use std::collections::HashMap;
694 : use std::str::FromStr;
695 : use std::time::Duration;
696 :
697 : use pageserver_api::shard::{ShardIdentity, ShardStripeSize};
698 : use postgres_ffi::MAX_SEND_SIZE;
699 : use tokio::sync::mpsc::error::TryRecvError;
700 : use utils::id::{NodeId, TenantTimelineId};
701 : use utils::lsn::Lsn;
702 : use utils::shard::{ShardCount, ShardNumber};
703 :
704 : use crate::send_interpreted_wal::{AttachShardNotification, Batch, InterpretedWalReader};
705 : use crate::test_utils::Env;
706 : use crate::wal_reader_stream::StreamingWalReader;
707 :
708 : #[tokio::test]
709 1 : async fn test_interpreted_wal_reader_fanout() {
710 1 : let _ = env_logger::builder().is_test(true).try_init();
711 1 :
712 1 : const SIZE: usize = 8 * 1024;
713 1 : const MSG_COUNT: usize = 200;
714 1 : const PG_VERSION: u32 = 17;
715 1 : const SHARD_COUNT: u8 = 2;
716 1 :
717 1 : let start_lsn = Lsn::from_str("0/149FD18").unwrap();
718 1 : let env = Env::new(true).unwrap();
719 1 : let tli = env
720 1 : .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
721 1 : .await
722 1 : .unwrap();
723 1 :
724 1 : let resident_tli = tli.wal_residence_guard().await.unwrap();
725 1 : let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None)
726 1 : .await
727 1 : .unwrap();
728 1 : let end_pos = end_watch.get();
729 1 :
730 1 : tracing::info!("Doing first round of reads ...");
731 1 :
732 1 : let streaming_wal_reader = StreamingWalReader::new(
733 1 : resident_tli,
734 1 : None,
735 1 : start_lsn,
736 1 : end_pos,
737 1 : end_watch,
738 1 : MAX_SEND_SIZE,
739 1 : );
740 1 :
741 1 : let shard_0 = ShardIdentity::new(
742 1 : ShardNumber(0),
743 1 : ShardCount(SHARD_COUNT),
744 1 : ShardStripeSize::default(),
745 1 : )
746 1 : .unwrap();
747 1 :
748 1 : let shard_1 = ShardIdentity::new(
749 1 : ShardNumber(1),
750 1 : ShardCount(SHARD_COUNT),
751 1 : ShardStripeSize::default(),
752 1 : )
753 1 : .unwrap();
754 1 :
755 1 : let mut shards = HashMap::new();
756 1 :
757 3 : for shard_number in 0..SHARD_COUNT {
758 2 : let shard_id = ShardIdentity::new(
759 2 : ShardNumber(shard_number),
760 2 : ShardCount(SHARD_COUNT),
761 2 : ShardStripeSize::default(),
762 2 : )
763 2 : .unwrap();
764 2 : let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
765 2 : shards.insert(shard_id, (Some(tx), Some(rx)));
766 2 : }
767 1 :
768 1 : let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
769 1 : let mut shard_0_rx = shards.get_mut(&shard_0).unwrap().1.take().unwrap();
770 1 :
771 1 : let handle = InterpretedWalReader::spawn(
772 1 : streaming_wal_reader,
773 1 : start_lsn,
774 1 : shard_0_tx,
775 1 : shard_0,
776 1 : PG_VERSION,
777 1 : &Some("pageserver".to_string()),
778 1 : );
779 1 :
780 1 : tracing::info!("Reading all WAL with only shard 0 attached ...");
781 1 :
782 1 : let mut shard_0_interpreted_records = Vec::new();
783 13 : while let Some(batch) = shard_0_rx.recv().await {
784 13 : shard_0_interpreted_records.push(batch.records);
785 13 : if batch.wal_end_lsn == batch.available_wal_end_lsn {
786 1 : break;
787 12 : }
788 1 : }
789 1 :
790 1 : let shard_1_tx = shards.get_mut(&shard_1).unwrap().0.take().unwrap();
791 1 : let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
792 1 :
793 1 : tracing::info!("Attaching shard 1 to the reader at start of WAL");
794 1 : handle.fanout(shard_1, shard_1_tx, start_lsn).unwrap();
795 1 :
796 1 : tracing::info!("Reading all WAL with shard 0 and shard 1 attached ...");
797 1 :
798 1 : let mut shard_1_interpreted_records = Vec::new();
799 13 : while let Some(batch) = shard_1_rx.recv().await {
800 13 : shard_1_interpreted_records.push(batch.records);
801 13 : if batch.wal_end_lsn == batch.available_wal_end_lsn {
802 1 : break;
803 12 : }
804 1 : }
805 1 :
806 1 : // This test uses logical messages. Those only go to shard 0. Check that the
807 1 : // filtering worked and shard 1 did not get any.
808 1 : assert!(
809 1 : shard_1_interpreted_records
810 1 : .iter()
811 13 : .all(|recs| recs.records.is_empty())
812 1 : );
813 1 :
814 1 : // Shard 0 should not receive anything more since the reader is
815 1 : // going through wal that it has already processed.
816 1 : let res = shard_0_rx.try_recv();
817 1 : if let Ok(ref ok) = res {
818 1 : tracing::error!(
819 1 : "Shard 0 received batch: wal_end_lsn={} available_wal_end_lsn={}",
820 1 : ok.wal_end_lsn,
821 1 : ok.available_wal_end_lsn
822 1 : );
823 1 : }
824 1 : assert!(matches!(res, Err(TryRecvError::Empty)));
825 1 :
826 1 : // Check that the next records lsns received by the two shards match up.
827 1 : let shard_0_next_lsns = shard_0_interpreted_records
828 1 : .iter()
829 13 : .map(|recs| recs.next_record_lsn)
830 1 : .collect::<Vec<_>>();
831 1 : let shard_1_next_lsns = shard_1_interpreted_records
832 1 : .iter()
833 13 : .map(|recs| recs.next_record_lsn)
834 1 : .collect::<Vec<_>>();
835 1 : assert_eq!(shard_0_next_lsns, shard_1_next_lsns);
836 1 :
837 1 : handle.abort();
838 1 : let mut done = false;
839 2 : for _ in 0..5 {
840 2 : if handle.current_position().is_none() {
841 1 : done = true;
842 1 : break;
843 1 : }
844 1 : tokio::time::sleep(Duration::from_millis(1)).await;
845 1 : }
846 1 :
847 1 : assert!(done);
848 1 : }
849 :
850 : #[tokio::test]
851 1 : async fn test_interpreted_wal_reader_same_shard_fanout() {
852 1 : let _ = env_logger::builder().is_test(true).try_init();
853 1 :
854 1 : const SIZE: usize = 8 * 1024;
855 1 : const MSG_COUNT: usize = 200;
856 1 : const PG_VERSION: u32 = 17;
857 1 : const SHARD_COUNT: u8 = 2;
858 1 :
859 1 : let start_lsn = Lsn::from_str("0/149FD18").unwrap();
860 1 : let env = Env::new(true).unwrap();
861 1 : let tli = env
862 1 : .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
863 1 : .await
864 1 : .unwrap();
865 1 :
866 1 : let resident_tli = tli.wal_residence_guard().await.unwrap();
867 1 : let mut next_record_lsns = Vec::default();
868 1 : let end_watch =
869 1 : Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, Some(&mut next_record_lsns))
870 1 : .await
871 1 : .unwrap();
872 1 : let end_pos = end_watch.get();
873 1 :
874 1 : let streaming_wal_reader = StreamingWalReader::new(
875 1 : resident_tli,
876 1 : None,
877 1 : start_lsn,
878 1 : end_pos,
879 1 : end_watch,
880 1 : MAX_SEND_SIZE,
881 1 : );
882 1 :
883 1 : let shard_0 = ShardIdentity::new(
884 1 : ShardNumber(0),
885 1 : ShardCount(SHARD_COUNT),
886 1 : ShardStripeSize::default(),
887 1 : )
888 1 : .unwrap();
889 1 :
890 1 : struct Sender {
891 1 : tx: Option<tokio::sync::mpsc::Sender<Batch>>,
892 1 : rx: tokio::sync::mpsc::Receiver<Batch>,
893 1 : shard: ShardIdentity,
894 1 : start_lsn: Lsn,
895 1 : received_next_record_lsns: Vec<Lsn>,
896 1 : }
897 1 :
898 1 : impl Sender {
899 3 : fn new(start_lsn: Lsn, shard: ShardIdentity) -> Self {
900 3 : let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
901 3 : Self {
902 3 : tx: Some(tx),
903 3 : rx,
904 3 : shard,
905 3 : start_lsn,
906 3 : received_next_record_lsns: Vec::default(),
907 3 : }
908 3 : }
909 1 : }
910 1 :
911 1 : assert!(next_record_lsns.len() > 7);
912 1 : let start_lsns = vec![
913 1 : next_record_lsns[5],
914 1 : next_record_lsns[1],
915 1 : next_record_lsns[3],
916 1 : ];
917 1 : let mut senders = start_lsns
918 1 : .into_iter()
919 3 : .map(|lsn| Sender::new(lsn, shard_0))
920 1 : .collect::<Vec<_>>();
921 1 :
922 1 : let first_sender = senders.first_mut().unwrap();
923 1 : let handle = InterpretedWalReader::spawn(
924 1 : streaming_wal_reader,
925 1 : first_sender.start_lsn,
926 1 : first_sender.tx.take().unwrap(),
927 1 : first_sender.shard,
928 1 : PG_VERSION,
929 1 : &Some("pageserver".to_string()),
930 1 : );
931 1 :
932 2 : for sender in senders.iter_mut().skip(1) {
933 2 : handle
934 2 : .fanout(sender.shard, sender.tx.take().unwrap(), sender.start_lsn)
935 2 : .unwrap();
936 2 : }
937 1 :
938 3 : for sender in senders.iter_mut() {
939 1 : loop {
940 39 : let batch = sender.rx.recv().await.unwrap();
941 39 : tracing::info!(
942 1 : "Sender with start_lsn={} received batch ending at {} with {} records",
943 0 : sender.start_lsn,
944 0 : batch.wal_end_lsn,
945 0 : batch.records.records.len()
946 1 : );
947 1 :
948 627 : for rec in batch.records.records {
949 588 : sender.received_next_record_lsns.push(rec.next_record_lsn);
950 588 : }
951 1 :
952 39 : if batch.wal_end_lsn == batch.available_wal_end_lsn {
953 3 : break;
954 36 : }
955 1 : }
956 1 : }
957 1 :
958 1 : handle.abort();
959 1 : let mut done = false;
960 2 : for _ in 0..5 {
961 2 : if handle.current_position().is_none() {
962 1 : done = true;
963 1 : break;
964 1 : }
965 1 : tokio::time::sleep(Duration::from_millis(1)).await;
966 1 : }
967 1 :
968 1 : assert!(done);
969 1 :
970 4 : for sender in senders {
971 3 : tracing::info!(
972 1 : "Validating records received by sender with start_lsn={}",
973 1 : sender.start_lsn
974 1 : );
975 1 :
976 3 : assert!(sender.received_next_record_lsns.is_sorted());
977 3 : let expected = next_record_lsns
978 3 : .iter()
979 600 : .filter(|lsn| **lsn > sender.start_lsn)
980 3 : .copied()
981 3 : .collect::<Vec<_>>();
982 3 : assert_eq!(sender.received_next_record_lsns, expected);
983 1 : }
984 1 : }
985 :
986 : #[tokio::test]
987 1 : async fn test_batch_start_tracking_on_reset() {
988 1 : // When the WAL stream is reset to an older LSN,
989 1 : // the current batch start LSN should be invalidated.
990 1 : // This test constructs such a scenario:
991 1 : // 1. Shard 0 is reading somewhere ahead
992 1 : // 2. Reader reads some WAL, but does not decode a full record (partial read)
993 1 : // 3. Shard 1 attaches to the reader and resets it to an older LSN
994 1 : // 4. Shard 1 should get the correct batch WAL start LSN
995 1 : let _ = env_logger::builder().is_test(true).try_init();
996 1 :
997 1 : const SIZE: usize = 64 * 1024;
998 1 : const MSG_COUNT: usize = 10;
999 1 : const PG_VERSION: u32 = 17;
1000 1 : const SHARD_COUNT: u8 = 2;
1001 1 : const WAL_READER_BATCH_SIZE: usize = 8192;
1002 1 :
1003 1 : let start_lsn = Lsn::from_str("0/149FD18").unwrap();
1004 1 : let env = Env::new(true).unwrap();
1005 1 : let mut next_record_lsns = Vec::default();
1006 1 : let tli = env
1007 1 : .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
1008 1 : .await
1009 1 : .unwrap();
1010 1 :
1011 1 : let resident_tli = tli.wal_residence_guard().await.unwrap();
1012 1 : let end_watch =
1013 1 : Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, Some(&mut next_record_lsns))
1014 1 : .await
1015 1 : .unwrap();
1016 1 :
1017 1 : assert!(next_record_lsns.len() > 3);
1018 1 : let shard_0_start_lsn = next_record_lsns[3];
1019 1 :
1020 1 : let end_pos = end_watch.get();
1021 1 :
1022 1 : let streaming_wal_reader = StreamingWalReader::new(
1023 1 : resident_tli,
1024 1 : None,
1025 1 : shard_0_start_lsn,
1026 1 : end_pos,
1027 1 : end_watch,
1028 1 : WAL_READER_BATCH_SIZE,
1029 1 : );
1030 1 :
1031 1 : let shard_0 = ShardIdentity::new(
1032 1 : ShardNumber(0),
1033 1 : ShardCount(SHARD_COUNT),
1034 1 : ShardStripeSize::default(),
1035 1 : )
1036 1 : .unwrap();
1037 1 :
1038 1 : let shard_1 = ShardIdentity::new(
1039 1 : ShardNumber(1),
1040 1 : ShardCount(SHARD_COUNT),
1041 1 : ShardStripeSize::default(),
1042 1 : )
1043 1 : .unwrap();
1044 1 :
1045 1 : let mut shards = HashMap::new();
1046 1 :
1047 3 : for shard_number in 0..SHARD_COUNT {
1048 2 : let shard_id = ShardIdentity::new(
1049 2 : ShardNumber(shard_number),
1050 2 : ShardCount(SHARD_COUNT),
1051 2 : ShardStripeSize::default(),
1052 2 : )
1053 2 : .unwrap();
1054 2 : let (tx, rx) = tokio::sync::mpsc::channel::<Batch>(MSG_COUNT * 2);
1055 2 : shards.insert(shard_id, (Some(tx), Some(rx)));
1056 2 : }
1057 1 :
1058 1 : let shard_0_tx = shards.get_mut(&shard_0).unwrap().0.take().unwrap();
1059 1 :
1060 1 : let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
1061 1 :
1062 1 : let reader = InterpretedWalReader::new(
1063 1 : streaming_wal_reader,
1064 1 : shard_0_start_lsn,
1065 1 : shard_0_tx,
1066 1 : shard_0,
1067 1 : PG_VERSION,
1068 1 : Some(shard_notification_rx),
1069 1 : );
1070 1 :
1071 1 : let reader_state = reader.state();
1072 1 : let mut reader_fut = std::pin::pin!(reader.run(shard_0_start_lsn, &None));
1073 1 : loop {
1074 362 : let poll = futures::poll!(reader_fut.as_mut());
1075 362 : assert!(poll.is_pending());
1076 1 :
1077 362 : let guard = reader_state.read().unwrap();
1078 362 : if guard.current_batch_wal_start().is_some() {
1079 1 : break;
1080 361 : }
1081 1 : }
1082 1 :
1083 1 : shard_notification_tx
1084 1 : .send(AttachShardNotification {
1085 1 : shard_id: shard_1,
1086 1 : sender: shards.get_mut(&shard_1).unwrap().0.take().unwrap(),
1087 1 : start_pos: start_lsn,
1088 1 : })
1089 1 : .unwrap();
1090 1 :
1091 1 : let mut shard_1_rx = shards.get_mut(&shard_1).unwrap().1.take().unwrap();
1092 1 : loop {
1093 117 : let poll = futures::poll!(reader_fut.as_mut());
1094 117 : assert!(poll.is_pending());
1095 1 :
1096 117 : let try_recv_res = shard_1_rx.try_recv();
1097 116 : match try_recv_res {
1098 1 : Ok(batch) => {
1099 1 : assert_eq!(batch.records.raw_wal_start_lsn.unwrap(), start_lsn);
1100 1 : break;
1101 1 : }
1102 116 : Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
1103 1 : Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
1104 1 : unreachable!();
1105 1 : }
1106 1 : }
1107 1 : }
1108 1 : }
1109 : }
|