Line data Source code
1 : use anyhow::Result;
2 : use futures::{Stream, StreamExt as _, TryStreamExt as _};
3 : use tokio::io::AsyncRead;
4 : use tokio_util::io::StreamReader;
5 : use tonic::metadata::AsciiMetadataValue;
6 : use tonic::metadata::errors::InvalidMetadataValue;
7 : use tonic::transport::Channel;
8 : use tonic::{Request, Streaming};
9 :
10 : use utils::id::TenantId;
11 : use utils::id::TimelineId;
12 : use utils::shard::ShardIndex;
13 :
14 : use crate::model;
15 : use crate::proto;
16 :
17 : ///
18 : /// AuthInterceptor adds tenant, timeline, and auth header to the channel. These
19 : /// headers are required at the pageserver.
20 : ///
21 : #[derive(Clone)]
22 : struct AuthInterceptor {
23 : tenant_id: AsciiMetadataValue,
24 : timeline_id: AsciiMetadataValue,
25 : shard_id: AsciiMetadataValue,
26 : auth_header: Option<AsciiMetadataValue>, // including "Bearer " prefix
27 : }
28 :
29 : impl AuthInterceptor {
30 0 : fn new(
31 0 : tenant_id: TenantId,
32 0 : timeline_id: TimelineId,
33 0 : auth_token: Option<String>,
34 0 : shard_id: ShardIndex,
35 0 : ) -> Result<Self, InvalidMetadataValue> {
36 0 : let tenant_ascii: AsciiMetadataValue = tenant_id.to_string().try_into()?;
37 0 : let timeline_ascii: AsciiMetadataValue = timeline_id.to_string().try_into()?;
38 0 : let shard_ascii: AsciiMetadataValue = shard_id.to_string().try_into()?;
39 :
40 0 : let auth_header: Option<AsciiMetadataValue> = match auth_token {
41 0 : Some(token) => Some(format!("Bearer {token}").try_into()?),
42 0 : None => None,
43 : };
44 :
45 0 : Ok(Self {
46 0 : tenant_id: tenant_ascii,
47 0 : shard_id: shard_ascii,
48 0 : timeline_id: timeline_ascii,
49 0 : auth_header,
50 0 : })
51 0 : }
52 : }
53 :
54 : impl tonic::service::Interceptor for AuthInterceptor {
55 0 : fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
56 0 : req.metadata_mut()
57 0 : .insert("neon-tenant-id", self.tenant_id.clone());
58 0 : req.metadata_mut()
59 0 : .insert("neon-shard-id", self.shard_id.clone());
60 0 : req.metadata_mut()
61 0 : .insert("neon-timeline-id", self.timeline_id.clone());
62 0 : if let Some(auth_header) = &self.auth_header {
63 0 : req.metadata_mut()
64 0 : .insert("authorization", auth_header.clone());
65 0 : }
66 0 : Ok(req)
67 0 : }
68 : }
69 :
70 : #[derive(Clone)]
71 : pub struct Client {
72 : client: proto::PageServiceClient<
73 : tonic::service::interceptor::InterceptedService<Channel, AuthInterceptor>,
74 : >,
75 : }
76 :
77 : impl Client {
78 0 : pub async fn new<T: TryInto<tonic::transport::Endpoint> + Send + Sync + 'static>(
79 0 : into_endpoint: T,
80 0 : tenant_id: TenantId,
81 0 : timeline_id: TimelineId,
82 0 : shard_id: ShardIndex,
83 0 : auth_header: Option<String>,
84 0 : compression: Option<tonic::codec::CompressionEncoding>,
85 0 : ) -> anyhow::Result<Self> {
86 0 : let endpoint: tonic::transport::Endpoint = into_endpoint
87 0 : .try_into()
88 0 : .map_err(|_e| anyhow::anyhow!("failed to convert endpoint"))?;
89 0 : let channel = endpoint.connect().await?;
90 0 : let auth = AuthInterceptor::new(tenant_id, timeline_id, auth_header, shard_id)
91 0 : .map_err(|e| anyhow::anyhow!(e.to_string()))?;
92 0 : let mut client = proto::PageServiceClient::with_interceptor(channel, auth);
93 :
94 0 : if let Some(compression) = compression {
95 0 : // TODO: benchmark this (including network latency).
96 0 : client = client
97 0 : .accept_compressed(compression)
98 0 : .send_compressed(compression);
99 0 : }
100 :
101 0 : Ok(Self { client })
102 0 : }
103 :
104 : /// Returns whether a relation exists.
105 0 : pub async fn check_rel_exists(
106 0 : &mut self,
107 0 : req: model::CheckRelExistsRequest,
108 0 : ) -> Result<model::CheckRelExistsResponse, tonic::Status> {
109 0 : let proto_req = proto::CheckRelExistsRequest::from(req);
110 :
111 0 : let response = self.client.check_rel_exists(proto_req).await?;
112 :
113 0 : let proto_resp = response.into_inner();
114 0 : Ok(proto_resp.into())
115 0 : }
116 :
117 : /// Fetches a base backup.
118 0 : pub async fn get_base_backup(
119 0 : &mut self,
120 0 : req: model::GetBaseBackupRequest,
121 0 : ) -> Result<impl AsyncRead + use<>, tonic::Status> {
122 0 : let req = proto::GetBaseBackupRequest::from(req);
123 0 : let chunks = self.client.get_base_backup(req).await?.into_inner();
124 0 : let reader = StreamReader::new(
125 0 : chunks
126 0 : .map_ok(|resp| resp.chunk)
127 0 : .map_err(std::io::Error::other),
128 : );
129 0 : Ok(reader)
130 0 : }
131 :
132 : /// Returns the total size of a database, as # of bytes.
133 0 : pub async fn get_db_size(
134 0 : &mut self,
135 0 : req: model::GetDbSizeRequest,
136 0 : ) -> Result<u64, tonic::Status> {
137 0 : let proto_req = proto::GetDbSizeRequest::from(req);
138 :
139 0 : let response = self.client.get_db_size(proto_req).await?;
140 0 : Ok(response.into_inner().into())
141 0 : }
142 :
143 : /// Fetches pages.
144 : ///
145 : /// This is implemented as a bidirectional streaming RPC for performance.
146 : /// Per-request errors are often returned as status_code instead of errors,
147 : /// to avoid tearing down the entire stream via tonic::Status.
148 0 : pub async fn get_pages<ReqSt>(
149 0 : &mut self,
150 0 : inbound: ReqSt,
151 0 : ) -> Result<
152 0 : impl Stream<Item = Result<model::GetPageResponse, tonic::Status>> + Send + 'static,
153 0 : tonic::Status,
154 0 : >
155 0 : where
156 0 : ReqSt: Stream<Item = model::GetPageRequest> + Send + 'static,
157 0 : {
158 0 : let outbound_proto = inbound.map(|domain_req| domain_req.into());
159 :
160 0 : let req_new = Request::new(outbound_proto);
161 :
162 0 : let response_stream: Streaming<proto::GetPageResponse> =
163 0 : self.client.get_pages(req_new).await?.into_inner();
164 :
165 0 : let domain_stream = response_stream.map_ok(model::GetPageResponse::from);
166 :
167 0 : Ok(domain_stream)
168 0 : }
169 :
170 : /// Returns the size of a relation, as # of blocks.
171 0 : pub async fn get_rel_size(
172 0 : &mut self,
173 0 : req: model::GetRelSizeRequest,
174 0 : ) -> Result<model::GetRelSizeResponse, tonic::Status> {
175 0 : let proto_req = proto::GetRelSizeRequest::from(req);
176 0 : let response = self.client.get_rel_size(proto_req).await?;
177 0 : let proto_resp = response.into_inner();
178 0 : Ok(proto_resp.into())
179 0 : }
180 :
181 : /// Fetches an SLRU segment.
182 0 : pub async fn get_slru_segment(
183 0 : &mut self,
184 0 : req: model::GetSlruSegmentRequest,
185 0 : ) -> Result<model::GetSlruSegmentResponse, tonic::Status> {
186 0 : let proto_req = proto::GetSlruSegmentRequest::from(req);
187 0 : let response = self.client.get_slru_segment(proto_req).await?;
188 0 : Ok(response.into_inner().try_into()?)
189 0 : }
190 :
191 : /// Acquires or extends a lease on the given LSN. This guarantees that the Pageserver won't
192 : /// garbage collect the LSN until the lease expires. Must be acquired on all relevant shards.
193 : ///
194 : /// Returns the lease expiration time, or a FailedPrecondition status if the lease could not be
195 : /// acquired because the LSN has already been garbage collected.
196 0 : pub async fn lease_lsn(
197 0 : &mut self,
198 0 : req: model::LeaseLsnRequest,
199 0 : ) -> Result<model::LeaseLsnResponse, tonic::Status> {
200 0 : let req = proto::LeaseLsnRequest::from(req);
201 0 : Ok(self.client.lease_lsn(req).await?.into_inner().try_into()?)
202 0 : }
203 : }
|