Line data Source code
1 : use std::{
2 : pin::Pin,
3 : task::{Context, Poll},
4 : };
5 :
6 : use bytes::Bytes;
7 : use futures::{stream::BoxStream, Stream, StreamExt};
8 : use utils::lsn::Lsn;
9 :
10 : use crate::{send_wal::EndWatch, timeline::WalResidentTimeline, wal_storage::WalReader};
11 : use safekeeper_api::Term;
12 :
13 : #[derive(PartialEq, Eq, Debug)]
14 : pub(crate) struct WalBytes {
15 : /// Raw PG WAL
16 : pub(crate) wal: Bytes,
17 : /// Start LSN of [`Self::wal`]
18 : #[allow(dead_code)]
19 : pub(crate) wal_start_lsn: Lsn,
20 : /// End LSN of [`Self::wal`]
21 : pub(crate) wal_end_lsn: Lsn,
22 : /// End LSN of WAL available on the safekeeper.
23 : ///
24 : /// For pagservers this will be commit LSN,
25 : /// while for the compute it will be the flush LSN.
26 : pub(crate) available_wal_end_lsn: Lsn,
27 : }
28 :
29 : struct PositionedWalReader {
30 : start: Lsn,
31 : end: Lsn,
32 : reader: Option<WalReader>,
33 : }
34 :
35 : /// A streaming WAL reader wrapper which can be reset while running
36 : pub(crate) struct StreamingWalReader {
37 : stream: BoxStream<'static, WalOrReset>,
38 : start_changed_tx: tokio::sync::watch::Sender<Lsn>,
39 : }
40 :
41 : pub(crate) enum WalOrReset {
42 : Wal(anyhow::Result<WalBytes>),
43 : Reset(Lsn),
44 : }
45 :
46 : impl WalOrReset {
47 65 : pub(crate) fn get_wal(self) -> Option<anyhow::Result<WalBytes>> {
48 65 : match self {
49 65 : WalOrReset::Wal(wal) => Some(wal),
50 0 : WalOrReset::Reset(_) => None,
51 : }
52 65 : }
53 : }
54 :
55 : impl StreamingWalReader {
56 3 : pub(crate) fn new(
57 3 : tli: WalResidentTimeline,
58 3 : term: Option<Term>,
59 3 : start: Lsn,
60 3 : end: Lsn,
61 3 : end_watch: EndWatch,
62 3 : buffer_size: usize,
63 3 : ) -> Self {
64 3 : let (start_changed_tx, start_changed_rx) = tokio::sync::watch::channel(start);
65 3 :
66 3 : let state = WalReaderStreamState {
67 3 : tli,
68 3 : wal_reader: PositionedWalReader {
69 3 : start,
70 3 : end,
71 3 : reader: None,
72 3 : },
73 3 : term,
74 3 : end_watch,
75 3 : buffer: vec![0; buffer_size],
76 3 : buffer_size,
77 3 : };
78 3 :
79 3 : // When a change notification is received while polling the internal
80 3 : // reader, stop polling the read future and service the change.
81 3 : let stream = futures::stream::unfold(
82 3 : (state, start_changed_rx),
83 70 : |(mut state, mut rx)| async move {
84 70 : let wal_or_reset = tokio::select! {
85 70 : read_res = state.read() => { WalOrReset::Wal(read_res) },
86 70 : changed_res = rx.changed() => {
87 3 : if changed_res.is_err() {
88 0 : return None;
89 3 : }
90 3 :
91 3 : let new_start_pos = rx.borrow_and_update();
92 3 : WalOrReset::Reset(*new_start_pos)
93 : }
94 : };
95 :
96 68 : if let WalOrReset::Reset(lsn) = wal_or_reset {
97 3 : state.wal_reader.start = lsn;
98 3 : state.wal_reader.reader = None;
99 65 : }
100 :
101 68 : Some((wal_or_reset, (state, rx)))
102 138 : },
103 3 : )
104 3 : .boxed();
105 3 :
106 3 : Self {
107 3 : stream,
108 3 : start_changed_tx,
109 3 : }
110 3 : }
111 :
112 : /// Reset the stream to a given position.
113 3 : pub(crate) async fn reset(&mut self, start: Lsn) {
114 3 : self.start_changed_tx.send(start).unwrap();
115 3 : while let Some(wal_or_reset) = self.stream.next().await {
116 3 : match wal_or_reset {
117 3 : WalOrReset::Reset(at) => {
118 3 : // Stream confirmed the reset.
119 3 : // There may only one ongoing reset at any given time,
120 3 : // hence the assertion.
121 3 : assert_eq!(at, start);
122 3 : break;
123 : }
124 0 : WalOrReset::Wal(_) => {
125 0 : // Ignore wal generated before reset was handled
126 0 : }
127 : }
128 : }
129 3 : }
130 : }
131 :
132 : impl Stream for StreamingWalReader {
133 : type Item = WalOrReset;
134 :
135 144 : fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
136 144 : Pin::new(&mut self.stream).poll_next(cx)
137 144 : }
138 : }
139 :
140 : struct WalReaderStreamState {
141 : tli: WalResidentTimeline,
142 : wal_reader: PositionedWalReader,
143 : term: Option<Term>,
144 : end_watch: EndWatch,
145 : buffer: Vec<u8>,
146 : buffer_size: usize,
147 : }
148 :
149 : impl WalReaderStreamState {
150 70 : async fn read(&mut self) -> anyhow::Result<WalBytes> {
151 69 : // Create reader if needed
152 69 : if self.wal_reader.reader.is_none() {
153 5 : self.wal_reader.reader = Some(self.tli.get_walreader(self.wal_reader.start).await?);
154 64 : }
155 :
156 69 : let have_something_to_send = self.wal_reader.end > self.wal_reader.start;
157 69 : if !have_something_to_send {
158 4 : tracing::debug!(
159 0 : "Waiting for wal: start={}, end={}",
160 : self.wal_reader.end,
161 : self.wal_reader.start
162 : );
163 4 : self.wal_reader.end = self
164 4 : .end_watch
165 4 : .wait_for_lsn(self.wal_reader.start, self.term)
166 4 : .await?;
167 0 : tracing::debug!(
168 0 : "Done waiting for wal: start={}, end={}",
169 : self.wal_reader.end,
170 : self.wal_reader.start
171 : );
172 65 : }
173 :
174 65 : assert!(
175 65 : self.wal_reader.end > self.wal_reader.start,
176 0 : "nothing to send after waiting for WAL"
177 : );
178 :
179 : // Calculate chunk size
180 65 : let mut chunk_end_pos = self.wal_reader.start + self.buffer_size as u64;
181 65 : if chunk_end_pos >= self.wal_reader.end {
182 5 : chunk_end_pos = self.wal_reader.end;
183 60 : } else {
184 60 : chunk_end_pos = chunk_end_pos
185 60 : .checked_sub(chunk_end_pos.block_offset())
186 60 : .unwrap();
187 60 : }
188 :
189 65 : let send_size = (chunk_end_pos.0 - self.wal_reader.start.0) as usize;
190 65 : let buffer = &mut self.buffer[..send_size];
191 :
192 : // Read WAL
193 65 : let send_size = {
194 65 : let _term_guard = if let Some(t) = self.term {
195 0 : Some(self.tli.acquire_term(t).await?)
196 : } else {
197 65 : None
198 : };
199 65 : self.wal_reader
200 65 : .reader
201 65 : .as_mut()
202 65 : .unwrap()
203 65 : .read(buffer)
204 65 : .await?
205 : };
206 :
207 65 : let wal = Bytes::copy_from_slice(&buffer[..send_size]);
208 65 : let result = WalBytes {
209 65 : wal,
210 65 : wal_start_lsn: self.wal_reader.start,
211 65 : wal_end_lsn: self.wal_reader.start + send_size as u64,
212 65 : available_wal_end_lsn: self.wal_reader.end,
213 65 : };
214 65 :
215 65 : self.wal_reader.start += send_size as u64;
216 65 :
217 65 : Ok(result)
218 65 : }
219 : }
220 :
221 : #[cfg(test)]
222 : mod tests {
223 : use std::str::FromStr;
224 :
225 : use futures::StreamExt;
226 : use postgres_ffi::MAX_SEND_SIZE;
227 : use utils::{
228 : id::{NodeId, TenantTimelineId},
229 : lsn::Lsn,
230 : };
231 :
232 : use crate::{test_utils::Env, wal_reader_stream::StreamingWalReader};
233 :
234 : #[tokio::test]
235 1 : async fn test_streaming_wal_reader_reset() {
236 1 : let _ = env_logger::builder().is_test(true).try_init();
237 1 :
238 1 : const SIZE: usize = 8 * 1024;
239 1 : const MSG_COUNT: usize = 200;
240 1 :
241 1 : let start_lsn = Lsn::from_str("0/149FD18").unwrap();
242 1 : let env = Env::new(true).unwrap();
243 1 : let tli = env
244 1 : .make_timeline(NodeId(1), TenantTimelineId::generate(), start_lsn)
245 1 : .await
246 1 : .unwrap();
247 1 :
248 1 : let resident_tli = tli.wal_residence_guard().await.unwrap();
249 1 : let end_watch = Env::write_wal(tli, start_lsn, SIZE, MSG_COUNT, None)
250 1 : .await
251 1 : .unwrap();
252 1 : let end_pos = end_watch.get();
253 1 :
254 1 : tracing::info!("Doing first round of reads ...");
255 1 :
256 1 : let mut streaming_wal_reader = StreamingWalReader::new(
257 1 : resident_tli,
258 1 : None,
259 1 : start_lsn,
260 1 : end_pos,
261 1 : end_watch,
262 1 : MAX_SEND_SIZE,
263 1 : );
264 1 :
265 1 : let mut before_reset = Vec::new();
266 13 : while let Some(wor) = streaming_wal_reader.next().await {
267 13 : let wal = wor.get_wal().unwrap().unwrap();
268 13 : let stop = wal.available_wal_end_lsn == wal.wal_end_lsn;
269 13 : before_reset.push(wal);
270 13 :
271 13 : if stop {
272 1 : break;
273 12 : }
274 1 : }
275 1 :
276 1 : tracing::info!("Resetting the WAL stream ...");
277 1 :
278 1 : streaming_wal_reader.reset(start_lsn).await;
279 1 :
280 1 : tracing::info!("Doing second round of reads ...");
281 1 :
282 1 : let mut after_reset = Vec::new();
283 13 : while let Some(wor) = streaming_wal_reader.next().await {
284 13 : let wal = wor.get_wal().unwrap().unwrap();
285 13 : let stop = wal.available_wal_end_lsn == wal.wal_end_lsn;
286 13 : after_reset.push(wal);
287 13 :
288 13 : if stop {
289 1 : break;
290 12 : }
291 1 : }
292 1 :
293 1 : assert_eq!(before_reset, after_reset);
294 1 : }
295 : }
|