Line data Source code
1 : use std::collections::HashMap;
2 : use std::num::NonZeroUsize;
3 : use std::ops::Range;
4 : use std::pin::Pin;
5 : use std::sync::atomic::{AtomicU64, Ordering};
6 : use std::sync::{Arc, Mutex};
7 : use std::time::Instant;
8 :
9 : use anyhow::anyhow;
10 : use futures::TryStreamExt as _;
11 : use pageserver_api::shard::TenantShardId;
12 : use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
13 : use pageserver_client::page_service::BasebackupRequest;
14 : use pageserver_page_api as page_api;
15 : use rand::prelude::*;
16 : use tokio::io::AsyncRead;
17 : use tokio::sync::Barrier;
18 : use tokio::task::JoinSet;
19 : use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
20 : use tokio_util::io::StreamReader;
21 : use tonic::async_trait;
22 : use tracing::{info, instrument};
23 : use url::Url;
24 : use utils::id::TenantTimelineId;
25 : use utils::lsn::Lsn;
26 : use utils::shard::ShardIndex;
27 :
28 : use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
29 : use crate::util::{request_stats, tokio_thread_local_stats};
30 :
31 : /// basebackup@LatestLSN
32 : #[derive(clap::Parser)]
33 : pub(crate) struct Args {
34 : #[clap(long, default_value = "http://localhost:9898")]
35 : mgmt_api_endpoint: String,
36 : /// The Pageserver to connect to. Use postgresql:// for libpq, or grpc:// for gRPC.
37 : #[clap(long, default_value = "postgresql://postgres@localhost:64000")]
38 : page_service_connstring: String,
39 : #[clap(long)]
40 : pageserver_jwt: Option<String>,
41 : #[clap(long, default_value = "1")]
42 : num_clients: NonZeroUsize,
43 : #[clap(long)]
44 : no_compression: bool,
45 : #[clap(long)]
46 : runtime: Option<humantime::Duration>,
47 : #[clap(long)]
48 : limit_to_first_n_targets: Option<usize>,
49 : targets: Option<Vec<TenantTimelineId>>,
50 : }
51 :
52 : #[derive(Debug, Default)]
53 : struct LiveStats {
54 : completed_requests: AtomicU64,
55 : }
56 :
57 : impl LiveStats {
58 0 : fn inc(&self) {
59 0 : self.completed_requests.fetch_add(1, Ordering::Relaxed);
60 0 : }
61 : }
62 :
63 : struct Target {
64 : timeline: TenantTimelineId,
65 : lsn_range: Option<Range<Lsn>>,
66 : }
67 :
68 : #[derive(serde::Serialize)]
69 : struct Output {
70 : total: request_stats::Output,
71 : }
72 :
73 : tokio_thread_local_stats::declare!(STATS: request_stats::Stats);
74 :
75 0 : pub(crate) fn main(args: Args) -> anyhow::Result<()> {
76 0 : tokio_thread_local_stats::main!(STATS, move |thread_local_stats| {
77 0 : main_impl(args, thread_local_stats)
78 0 : })
79 0 : }
80 :
81 0 : async fn main_impl(
82 0 : args: Args,
83 0 : all_thread_local_stats: AllThreadLocalStats<request_stats::Stats>,
84 0 : ) -> anyhow::Result<()> {
85 0 : let args: &'static Args = Box::leak(Box::new(args));
86 :
87 0 : let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
88 0 : reqwest::Client::new(), // TODO: support ssl_ca_file for https APIs in pagebench.
89 0 : args.mgmt_api_endpoint.clone(),
90 0 : args.pageserver_jwt.as_deref(),
91 : ));
92 :
93 : // discover targets
94 0 : let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
95 0 : &mgmt_api_client,
96 0 : crate::util::cli::targets::Spec {
97 0 : limit_to_first_n_targets: args.limit_to_first_n_targets,
98 0 : targets: args.targets.clone(),
99 0 : },
100 0 : )
101 0 : .await?;
102 0 : let mut js = JoinSet::new();
103 0 : for timeline in &timelines {
104 0 : js.spawn({
105 0 : let timeline = *timeline;
106 0 : let info = mgmt_api_client
107 0 : .timeline_info(
108 0 : TenantShardId::unsharded(timeline.tenant_id),
109 0 : timeline.timeline_id,
110 0 : ForceAwaitLogicalSize::No,
111 0 : )
112 0 : .await
113 0 : .unwrap();
114 0 : async move {
115 0 : anyhow::Ok(Target {
116 0 : timeline,
117 0 : // TODO: support lsn_range != latest LSN
118 0 : lsn_range: Some(info.last_record_lsn..(info.last_record_lsn + 1)),
119 0 : })
120 0 : }
121 : });
122 : }
123 0 : let mut all_targets: Vec<Target> = Vec::new();
124 0 : while let Some(res) = js.join_next().await {
125 0 : all_targets.push(res.unwrap().unwrap());
126 0 : }
127 :
128 0 : let live_stats = Arc::new(LiveStats::default());
129 :
130 0 : let num_client_tasks = timelines.len();
131 0 : let num_live_stats_dump = 1;
132 0 : let num_work_sender_tasks = 1;
133 :
134 0 : let start_work_barrier = Arc::new(tokio::sync::Barrier::new(
135 0 : num_client_tasks + num_live_stats_dump + num_work_sender_tasks,
136 : ));
137 0 : let all_work_done_barrier = Arc::new(tokio::sync::Barrier::new(num_client_tasks));
138 :
139 0 : tokio::spawn({
140 0 : let stats = Arc::clone(&live_stats);
141 0 : let start_work_barrier = Arc::clone(&start_work_barrier);
142 0 : async move {
143 0 : start_work_barrier.wait().await;
144 : loop {
145 0 : let start = std::time::Instant::now();
146 0 : tokio::time::sleep(std::time::Duration::from_secs(1)).await;
147 0 : let completed_requests = stats.completed_requests.swap(0, Ordering::Relaxed);
148 0 : let elapsed = start.elapsed();
149 0 : info!(
150 0 : "RPS: {:.0}",
151 0 : completed_requests as f64 / elapsed.as_secs_f64()
152 : );
153 : }
154 : }
155 : });
156 :
157 0 : let mut work_senders = HashMap::new();
158 0 : let mut tasks = Vec::new();
159 0 : let scheme = match Url::parse(&args.page_service_connstring) {
160 0 : Ok(url) => url.scheme().to_lowercase().to_string(),
161 0 : Err(url::ParseError::RelativeUrlWithoutBase) => "postgresql".to_string(),
162 0 : Err(err) => return Err(anyhow!("invalid connstring: {err}")),
163 : };
164 0 : for &tl in &timelines {
165 0 : let (sender, receiver) = tokio::sync::mpsc::channel(1); // TODO: not sure what the implications of this are
166 0 : work_senders.insert(tl, sender);
167 :
168 0 : let client: Box<dyn Client> = match scheme.as_str() {
169 0 : "postgresql" | "postgres" => Box::new(
170 0 : LibpqClient::new(&args.page_service_connstring, tl, !args.no_compression).await?,
171 : ),
172 0 : "grpc" => Box::new(
173 0 : GrpcClient::new(&args.page_service_connstring, tl, !args.no_compression).await?,
174 : ),
175 0 : scheme => return Err(anyhow!("invalid scheme {scheme}")),
176 : };
177 :
178 0 : tasks.push(tokio::spawn(run_worker(
179 0 : client,
180 0 : Arc::clone(&start_work_barrier),
181 0 : receiver,
182 0 : Arc::clone(&all_work_done_barrier),
183 0 : Arc::clone(&live_stats),
184 : )));
185 : }
186 :
187 0 : let work_sender = async move {
188 0 : start_work_barrier.wait().await;
189 : loop {
190 0 : let (timeline, work) = {
191 0 : let mut rng = rand::thread_rng();
192 0 : let target = all_targets.choose(&mut rng).unwrap();
193 0 : let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
194 0 : (target.timeline, Work { lsn })
195 : };
196 0 : let sender = work_senders.get(&timeline).unwrap();
197 : // TODO: what if this blocks?
198 0 : sender.send(work).await.ok().unwrap();
199 : }
200 : };
201 :
202 0 : if let Some(runtime) = args.runtime {
203 0 : match tokio::time::timeout(runtime.into(), work_sender).await {
204 0 : Ok(()) => unreachable!("work sender never terminates"),
205 0 : Err(_timeout) => {
206 0 : // this implicitly drops the work_senders, making all the clients exit
207 0 : }
208 : }
209 : } else {
210 0 : work_sender.await;
211 0 : unreachable!("work sender never terminates");
212 : }
213 :
214 0 : for t in tasks {
215 0 : t.await.unwrap();
216 : }
217 :
218 0 : let output = Output {
219 : total: {
220 0 : let mut agg_stats = request_stats::Stats::new();
221 0 : for stats in all_thread_local_stats.lock().unwrap().iter() {
222 0 : let stats = stats.lock().unwrap();
223 0 : agg_stats.add(&stats);
224 0 : }
225 0 : agg_stats.output()
226 : },
227 : };
228 :
229 0 : let output = serde_json::to_string_pretty(&output).unwrap();
230 0 : println!("{output}");
231 :
232 0 : anyhow::Ok(())
233 0 : }
234 :
235 : #[derive(Copy, Clone)]
236 : struct Work {
237 : lsn: Option<Lsn>,
238 : }
239 :
240 : #[instrument(skip_all)]
241 : async fn run_worker(
242 : mut client: Box<dyn Client>,
243 : start_work_barrier: Arc<Barrier>,
244 : mut work: tokio::sync::mpsc::Receiver<Work>,
245 : all_work_done_barrier: Arc<Barrier>,
246 : live_stats: Arc<LiveStats>,
247 : ) {
248 : start_work_barrier.wait().await;
249 :
250 : while let Some(Work { lsn }) = work.recv().await {
251 : let start = Instant::now();
252 : let stream = client.basebackup(lsn).await.unwrap();
253 :
254 : let size = futures::io::copy(stream.compat(), &mut tokio::io::sink().compat_write())
255 : .await
256 : .unwrap();
257 : info!("basebackup size is {size} bytes");
258 : let elapsed = start.elapsed();
259 : live_stats.inc();
260 0 : STATS.with(|stats| {
261 0 : stats.borrow().lock().unwrap().observe(elapsed).unwrap();
262 0 : });
263 : }
264 :
265 : all_work_done_barrier.wait().await;
266 : }
267 :
268 : /// A basebackup client. This allows switching out the client protocol implementation.
269 : #[async_trait]
270 : trait Client: Send {
271 : async fn basebackup(
272 : &mut self,
273 : lsn: Option<Lsn>,
274 : ) -> anyhow::Result<Pin<Box<dyn AsyncRead + Send>>>;
275 : }
276 :
277 : /// A libpq-based Pageserver client.
278 : struct LibpqClient {
279 : inner: pageserver_client::page_service::Client,
280 : ttid: TenantTimelineId,
281 : compression: bool,
282 : }
283 :
284 : impl LibpqClient {
285 0 : async fn new(
286 0 : connstring: &str,
287 0 : ttid: TenantTimelineId,
288 0 : compression: bool,
289 0 : ) -> anyhow::Result<Self> {
290 : Ok(Self {
291 0 : inner: pageserver_client::page_service::Client::new(connstring.to_string()).await?,
292 0 : ttid,
293 0 : compression,
294 : })
295 0 : }
296 : }
297 :
298 : #[async_trait]
299 : impl Client for LibpqClient {
300 0 : async fn basebackup(
301 : &mut self,
302 : lsn: Option<Lsn>,
303 0 : ) -> anyhow::Result<Pin<Box<dyn AsyncRead + Send + 'static>>> {
304 0 : let req = BasebackupRequest {
305 0 : tenant_id: self.ttid.tenant_id,
306 0 : timeline_id: self.ttid.timeline_id,
307 0 : lsn,
308 0 : gzip: self.compression,
309 0 : };
310 0 : let stream = self.inner.basebackup(&req).await?;
311 0 : Ok(Box::pin(StreamReader::new(
312 0 : stream.map_err(std::io::Error::other),
313 0 : )))
314 0 : }
315 : }
316 :
317 : /// A gRPC Pageserver client.
318 : struct GrpcClient {
319 : inner: page_api::Client,
320 : compression: page_api::BaseBackupCompression,
321 : }
322 :
323 : impl GrpcClient {
324 0 : async fn new(
325 0 : connstring: &str,
326 0 : ttid: TenantTimelineId,
327 0 : compression: bool,
328 0 : ) -> anyhow::Result<Self> {
329 0 : let inner = page_api::Client::connect(
330 0 : connstring.to_string(),
331 0 : ttid.tenant_id,
332 0 : ttid.timeline_id,
333 0 : ShardIndex::unsharded(),
334 0 : None,
335 0 : None, // NB: uses payload compression
336 0 : )
337 0 : .await?;
338 0 : let compression = match compression {
339 0 : true => page_api::BaseBackupCompression::Gzip,
340 0 : false => page_api::BaseBackupCompression::None,
341 : };
342 0 : Ok(Self { inner, compression })
343 0 : }
344 : }
345 :
346 : #[async_trait]
347 : impl Client for GrpcClient {
348 0 : async fn basebackup(
349 : &mut self,
350 : lsn: Option<Lsn>,
351 0 : ) -> anyhow::Result<Pin<Box<dyn AsyncRead + Send + 'static>>> {
352 0 : let req = page_api::GetBaseBackupRequest {
353 0 : lsn,
354 0 : replica: false,
355 0 : full: false,
356 0 : compression: self.compression,
357 0 : };
358 0 : Ok(Box::pin(self.inner.get_base_backup(req).await?))
359 0 : }
360 : }
|