Line data Source code
1 : //! FIXME: most of this is copy-paste from mgmt_api.rs ; dedupe into a `reqwest_utils::Client` crate.
2 : use pageserver_client::mgmt_api::{Error, ResponseErrorMessageExt};
3 : use reqwest::Method;
4 : use serde::{Deserialize, Serialize};
5 : use tokio_util::sync::CancellationToken;
6 : use tracing::error;
7 :
8 : use super::importbucket_format::Spec;
9 : use crate::config::PageServerConf;
10 :
11 : pub struct Client {
12 : base_url: String,
13 : authorization_header: Option<String>,
14 : client: reqwest::Client,
15 : cancel: CancellationToken,
16 : }
17 :
18 : pub type Result<T> = std::result::Result<T, Error>;
19 :
20 0 : #[derive(Serialize, Deserialize, Debug)]
21 : struct ImportProgressRequest {
22 : // no fields yet, not sure if there every will be any
23 : }
24 :
25 0 : #[derive(Serialize, Deserialize, Debug)]
26 : struct ImportProgressResponse {
27 : // we don't care
28 : }
29 :
30 : impl Client {
31 0 : pub fn new(conf: &PageServerConf, cancel: CancellationToken) -> anyhow::Result<Self> {
32 0 : let Some(ref base_url) = conf.import_pgdata_upcall_api else {
33 0 : anyhow::bail!("import_pgdata_upcall_api is not configured")
34 : };
35 0 : let mut http_client = reqwest::Client::builder();
36 0 : for cert in &conf.ssl_ca_certs {
37 0 : http_client = http_client.add_root_certificate(cert.clone());
38 0 : }
39 0 : let http_client = http_client.build()?;
40 :
41 0 : Ok(Self {
42 0 : base_url: base_url.to_string(),
43 0 : client: http_client,
44 0 : cancel,
45 0 : authorization_header: conf
46 0 : .import_pgdata_upcall_api_token
47 0 : .as_ref()
48 0 : .map(|secret_string| secret_string.get_contents())
49 0 : .map(|jwt| format!("Bearer {jwt}")),
50 0 : })
51 0 : }
52 :
53 0 : fn start_request<U: reqwest::IntoUrl>(
54 0 : &self,
55 0 : method: Method,
56 0 : uri: U,
57 0 : ) -> reqwest::RequestBuilder {
58 0 : let req = self.client.request(method, uri);
59 0 : if let Some(value) = &self.authorization_header {
60 0 : req.header(reqwest::header::AUTHORIZATION, value)
61 : } else {
62 0 : req
63 : }
64 0 : }
65 :
66 0 : async fn request_noerror<B: serde::Serialize, U: reqwest::IntoUrl>(
67 0 : &self,
68 0 : method: Method,
69 0 : uri: U,
70 0 : body: B,
71 0 : ) -> Result<reqwest::Response> {
72 0 : self.start_request(method, uri)
73 0 : .json(&body)
74 0 : .send()
75 0 : .await
76 0 : .map_err(Error::ReceiveBody)
77 0 : }
78 :
79 0 : async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
80 0 : &self,
81 0 : method: Method,
82 0 : uri: U,
83 0 : body: B,
84 0 : ) -> Result<reqwest::Response> {
85 0 : let res = self.request_noerror(method, uri, body).await?;
86 0 : let response = res.error_from_body().await?;
87 0 : Ok(response)
88 0 : }
89 :
90 0 : pub async fn send_progress_once(&self, spec: &Spec) -> Result<()> {
91 0 : let url = format!(
92 0 : "{}/projects/{}/branches/{}/import_progress",
93 0 : self.base_url, spec.project_id, spec.branch_id
94 0 : );
95 0 : let ImportProgressResponse {} = self
96 0 : .request(Method::POST, url, &ImportProgressRequest {})
97 0 : .await?
98 0 : .json()
99 0 : .await
100 0 : .map_err(Error::ReceiveBody)?;
101 0 : Ok(())
102 0 : }
103 :
104 0 : pub async fn send_progress_until_success(&self, spec: &Spec) -> anyhow::Result<()> {
105 : loop {
106 0 : match self.send_progress_once(spec).await {
107 0 : Ok(()) => return Ok(()),
108 0 : Err(Error::Cancelled) => return Err(anyhow::anyhow!("cancelled")),
109 0 : Err(err) => {
110 0 : error!(?err, "error sending progress, retrying");
111 0 : if tokio::time::timeout(
112 0 : std::time::Duration::from_secs(10),
113 0 : self.cancel.cancelled(),
114 0 : )
115 0 : .await
116 0 : .is_ok()
117 : {
118 0 : anyhow::bail!("cancelled while sending early progress update");
119 0 : }
120 : }
121 : }
122 : }
123 0 : }
124 : }
|