Line data Source code
1 : use std::net::SocketAddr;
2 :
3 : use arc_swap::ArcSwapOption;
4 : use postgres_client::config::SslMode;
5 : use tokio::sync::Semaphore;
6 :
7 : use super::jwt::{AuthRule, FetchAuthRules};
8 : use crate::auth::backend::jwt::FetchAuthRulesError;
9 : use crate::compute::ConnectInfo;
10 : use crate::compute_ctl::ComputeCtlApi;
11 : use crate::context::RequestContext;
12 : use crate::control_plane::NodeInfo;
13 : use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, MetricsAuxInfo};
14 : use crate::http;
15 : use crate::intern::{BranchIdTag, EndpointIdTag, InternId, ProjectIdTag};
16 : use crate::types::EndpointId;
17 : use crate::url::ApiUrl;
18 :
19 : pub struct LocalBackend {
20 : pub(crate) initialize: Semaphore,
21 : pub(crate) compute_ctl: ComputeCtlApi,
22 : pub(crate) node_info: NodeInfo,
23 : }
24 :
25 : impl LocalBackend {
26 0 : pub fn new(postgres_addr: SocketAddr, compute_ctl: ApiUrl) -> Self {
27 0 : LocalBackend {
28 0 : initialize: Semaphore::new(1),
29 0 : compute_ctl: ComputeCtlApi {
30 0 : api: http::Endpoint::new(compute_ctl, http::new_client()),
31 0 : },
32 0 : node_info: NodeInfo {
33 0 : conn_info: ConnectInfo {
34 0 : host_addr: Some(postgres_addr.ip()),
35 0 : host: postgres_addr.ip().to_string().into(),
36 0 : port: postgres_addr.port(),
37 0 : ssl_mode: SslMode::Disable,
38 0 : },
39 0 : // TODO(conrad): make this better reflect compute info rather than endpoint info.
40 0 : aux: MetricsAuxInfo {
41 0 : endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"),
42 0 : project_id: ProjectIdTag::get_interner().get_or_intern("local"),
43 0 : branch_id: BranchIdTag::get_interner().get_or_intern("local"),
44 0 : compute_id: "local".into(),
45 0 : cold_start_info: ColdStartInfo::WarmCached,
46 0 : },
47 0 : },
48 0 : }
49 0 : }
50 : }
51 :
52 : #[derive(Clone, Copy)]
53 : pub(crate) struct StaticAuthRules;
54 :
55 : pub static JWKS_ROLE_MAP: ArcSwapOption<EndpointJwksResponse> = ArcSwapOption::const_empty();
56 :
57 : impl FetchAuthRules for StaticAuthRules {
58 0 : async fn fetch_auth_rules(
59 0 : &self,
60 0 : _ctx: &RequestContext,
61 0 : _endpoint: EndpointId,
62 0 : ) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
63 0 : let mappings = JWKS_ROLE_MAP.load();
64 0 : let role_mappings = mappings
65 0 : .as_deref()
66 0 : .ok_or(FetchAuthRulesError::RoleJwksNotConfigured)?;
67 0 : let mut rules = vec![];
68 0 : for setting in &role_mappings.jwks {
69 0 : rules.push(AuthRule {
70 0 : id: setting.id.clone(),
71 0 : jwks_url: setting.jwks_url.clone(),
72 0 : audience: setting.jwt_audience.clone(),
73 0 : role_names: setting.role_names.clone(),
74 0 : });
75 0 : }
76 :
77 0 : Ok(rules)
78 0 : }
79 : }
|