Line data Source code
1 : //! FIXME: most of this is copy-paste from mgmt_api.rs ; dedupe into a `reqwest_utils::Client` crate.
2 : use pageserver_client::mgmt_api::{Error, ResponseErrorMessageExt};
3 : use serde::{Deserialize, Serialize};
4 : use tokio_util::sync::CancellationToken;
5 : use tracing::error;
6 :
7 : use crate::config::PageServerConf;
8 : use reqwest::Method;
9 :
10 : use super::importbucket_format::Spec;
11 :
12 : pub struct Client {
13 : base_url: String,
14 : authorization_header: Option<String>,
15 : client: reqwest::Client,
16 : cancel: CancellationToken,
17 : }
18 :
19 : pub type Result<T> = std::result::Result<T, Error>;
20 :
21 0 : #[derive(Serialize, Deserialize, Debug)]
22 : struct ImportProgressRequest {
23 : // no fields yet, not sure if there every will be any
24 : }
25 :
26 0 : #[derive(Serialize, Deserialize, Debug)]
27 : struct ImportProgressResponse {
28 : // we don't care
29 : }
30 :
31 : impl Client {
32 0 : pub fn new(conf: &PageServerConf, cancel: CancellationToken) -> anyhow::Result<Self> {
33 0 : let Some(ref base_url) = conf.import_pgdata_upcall_api else {
34 0 : anyhow::bail!("import_pgdata_upcall_api is not configured")
35 : };
36 0 : Ok(Self {
37 0 : base_url: base_url.to_string(),
38 0 : client: reqwest::Client::new(),
39 0 : cancel,
40 0 : authorization_header: conf
41 0 : .import_pgdata_upcall_api_token
42 0 : .as_ref()
43 0 : .map(|secret_string| secret_string.get_contents())
44 0 : .map(|jwt| format!("Bearer {jwt}")),
45 0 : })
46 0 : }
47 :
48 0 : fn start_request<U: reqwest::IntoUrl>(
49 0 : &self,
50 0 : method: Method,
51 0 : uri: U,
52 0 : ) -> reqwest::RequestBuilder {
53 0 : let req = self.client.request(method, uri);
54 0 : if let Some(value) = &self.authorization_header {
55 0 : req.header(reqwest::header::AUTHORIZATION, value)
56 : } else {
57 0 : req
58 : }
59 0 : }
60 :
61 0 : async fn request_noerror<B: serde::Serialize, U: reqwest::IntoUrl>(
62 0 : &self,
63 0 : method: Method,
64 0 : uri: U,
65 0 : body: B,
66 0 : ) -> Result<reqwest::Response> {
67 0 : self.start_request(method, uri)
68 0 : .json(&body)
69 0 : .send()
70 0 : .await
71 0 : .map_err(Error::ReceiveBody)
72 0 : }
73 :
74 0 : async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
75 0 : &self,
76 0 : method: Method,
77 0 : uri: U,
78 0 : body: B,
79 0 : ) -> Result<reqwest::Response> {
80 0 : let res = self.request_noerror(method, uri, body).await?;
81 0 : let response = res.error_from_body().await?;
82 0 : Ok(response)
83 0 : }
84 :
85 0 : pub async fn send_progress_once(&self, spec: &Spec) -> Result<()> {
86 0 : let url = format!(
87 0 : "{}/projects/{}/branches/{}/import_progress",
88 0 : self.base_url, spec.project_id, spec.branch_id
89 0 : );
90 0 : let ImportProgressResponse {} = self
91 0 : .request(Method::POST, url, &ImportProgressRequest {})
92 0 : .await?
93 0 : .json()
94 0 : .await
95 0 : .map_err(Error::ReceiveBody)?;
96 0 : Ok(())
97 0 : }
98 :
99 0 : pub async fn send_progress_until_success(&self, spec: &Spec) -> anyhow::Result<()> {
100 : loop {
101 0 : match self.send_progress_once(spec).await {
102 0 : Ok(()) => return Ok(()),
103 0 : Err(Error::Cancelled) => return Err(anyhow::anyhow!("cancelled")),
104 0 : Err(err) => {
105 0 : error!(?err, "error sending progress, retrying");
106 0 : if tokio::time::timeout(
107 0 : std::time::Duration::from_secs(10),
108 0 : self.cancel.cancelled(),
109 0 : )
110 0 : .await
111 0 : .is_ok()
112 : {
113 0 : anyhow::bail!("cancelled while sending early progress update");
114 0 : }
115 : }
116 : }
117 : }
118 0 : }
119 : }
|