Line data Source code
1 : use serde::Deserialize;
2 : use std::fmt;
3 :
4 : use crate::auth::IpPattern;
5 :
6 : use crate::{BranchId, EndpointId, ProjectId};
7 :
8 : /// Generic error response with human-readable description.
9 : /// Note that we can't always present it to user as is.
10 0 : #[derive(Debug, Deserialize)]
11 : pub struct ConsoleError {
12 : pub error: Box<str>,
13 : }
14 :
15 : /// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
16 : /// Returned by the `/proxy_get_role_secret` API method.
17 30 : #[derive(Deserialize)]
18 : pub struct GetRoleSecret {
19 : pub role_secret: Box<str>,
20 : pub allowed_ips: Option<Vec<IpPattern>>,
21 : pub project_id: Option<ProjectId>,
22 : }
23 :
24 : // Manually implement debug to omit sensitive info.
25 : impl fmt::Debug for GetRoleSecret {
26 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 0 : f.debug_struct("GetRoleSecret").finish_non_exhaustive()
28 0 : }
29 : }
30 :
31 : /// Response which holds compute node's `host:port` pair.
32 : /// Returned by the `/proxy_wake_compute` API method.
33 10 : #[derive(Debug, Deserialize)]
34 : pub struct WakeCompute {
35 : pub address: Box<str>,
36 : pub aux: MetricsAuxInfo,
37 : }
38 :
39 : /// Async response which concludes the link auth flow.
40 : /// Also known as `kickResponse` in the console.
41 12 : #[derive(Debug, Deserialize)]
42 : pub struct KickSession<'a> {
43 : /// Session ID is assigned by the proxy.
44 : pub session_id: &'a str,
45 :
46 : /// Compute node connection params.
47 : #[serde(deserialize_with = "KickSession::parse_db_info")]
48 : pub result: DatabaseInfo,
49 : }
50 :
51 : impl KickSession<'_> {
52 2 : fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
53 2 : where
54 2 : D: serde::Deserializer<'de>,
55 2 : {
56 4 : #[derive(Deserialize)]
57 2 : enum Wrapper {
58 2 : // Currently, console only reports `Success`.
59 2 : // `Failure(String)` used to be here... RIP.
60 2 : Success(DatabaseInfo),
61 2 : }
62 2 :
63 2 : Wrapper::deserialize(des).map(|x| match x {
64 2 : Wrapper::Success(info) => info,
65 2 : })
66 2 : }
67 : }
68 :
69 : /// Compute node connection params.
70 100 : #[derive(Deserialize)]
71 : pub struct DatabaseInfo {
72 : pub host: Box<str>,
73 : pub port: u16,
74 : pub dbname: Box<str>,
75 : pub user: Box<str>,
76 : /// Console always provides a password, but it might
77 : /// be inconvenient for debug with local PG instance.
78 : pub password: Option<Box<str>>,
79 : pub aux: MetricsAuxInfo,
80 : }
81 :
82 : // Manually implement debug to omit sensitive info.
83 : impl fmt::Debug for DatabaseInfo {
84 0 : fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 0 : f.debug_struct("DatabaseInfo")
86 0 : .field("host", &self.host)
87 0 : .field("port", &self.port)
88 0 : .field("dbname", &self.dbname)
89 0 : .field("user", &self.user)
90 0 : .finish_non_exhaustive()
91 0 : }
92 : }
93 :
94 : /// Various labels for prometheus metrics.
95 : /// Also known as `ProxyMetricsAuxInfo` in the console.
96 70 : #[derive(Debug, Deserialize, Clone, Default)]
97 : pub struct MetricsAuxInfo {
98 : pub endpoint_id: EndpointId,
99 : pub project_id: ProjectId,
100 : pub branch_id: BranchId,
101 : pub is_cold_start: Option<bool>,
102 : }
103 :
104 : #[cfg(test)]
105 : mod tests {
106 : use super::*;
107 : use serde_json::json;
108 :
109 10 : fn dummy_aux() -> serde_json::Value {
110 10 : json!({
111 10 : "endpoint_id": "endpoint",
112 10 : "project_id": "project",
113 10 : "branch_id": "branch",
114 10 : })
115 10 : }
116 :
117 2 : #[test]
118 2 : fn parse_kick_session() -> anyhow::Result<()> {
119 2 : // This is what the console's kickResponse looks like.
120 2 : let json = json!({
121 2 : "session_id": "deadbeef",
122 2 : "result": {
123 2 : "Success": {
124 2 : "host": "localhost",
125 2 : "port": 5432,
126 2 : "dbname": "postgres",
127 2 : "user": "john_doe",
128 2 : "password": "password",
129 2 : "aux": dummy_aux(),
130 2 : }
131 2 : }
132 2 : });
133 2 : let _: KickSession = serde_json::from_str(&json.to_string())?;
134 :
135 2 : Ok(())
136 2 : }
137 :
138 2 : #[test]
139 2 : fn parse_db_info() -> anyhow::Result<()> {
140 : // with password
141 2 : let _: DatabaseInfo = serde_json::from_value(json!({
142 2 : "host": "localhost",
143 2 : "port": 5432,
144 2 : "dbname": "postgres",
145 2 : "user": "john_doe",
146 2 : "password": "password",
147 2 : "aux": dummy_aux(),
148 2 : }))?;
149 :
150 : // without password
151 2 : let _: DatabaseInfo = serde_json::from_value(json!({
152 2 : "host": "localhost",
153 2 : "port": 5432,
154 2 : "dbname": "postgres",
155 2 : "user": "john_doe",
156 2 : "aux": dummy_aux(),
157 2 : }))?;
158 :
159 : // new field (forward compatibility)
160 2 : let _: DatabaseInfo = serde_json::from_value(json!({
161 2 : "host": "localhost",
162 2 : "port": 5432,
163 2 : "dbname": "postgres",
164 2 : "user": "john_doe",
165 2 : "project": "hello_world",
166 2 : "N.E.W": "forward compatibility check",
167 2 : "aux": dummy_aux(),
168 2 : }))?;
169 :
170 2 : Ok(())
171 2 : }
172 :
173 2 : #[test]
174 2 : fn parse_wake_compute() -> anyhow::Result<()> {
175 2 : let json = json!({
176 2 : "address": "0.0.0.0",
177 2 : "aux": dummy_aux(),
178 2 : });
179 2 : let _: WakeCompute = serde_json::from_str(&json.to_string())?;
180 2 : Ok(())
181 2 : }
182 :
183 2 : #[test]
184 2 : fn parse_get_role_secret() -> anyhow::Result<()> {
185 2 : // Empty `allowed_ips` field.
186 2 : let json = json!({
187 2 : "role_secret": "secret",
188 2 : });
189 2 : let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
190 2 : let json = json!({
191 2 : "role_secret": "secret",
192 2 : "allowed_ips": ["8.8.8.8"],
193 2 : });
194 2 : let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
195 2 : let json = json!({
196 2 : "role_secret": "secret",
197 2 : "allowed_ips": ["8.8.8.8"],
198 2 : "project_id": "project",
199 2 : });
200 2 : let _: GetRoleSecret = serde_json::from_str(&json.to_string())?;
201 :
202 2 : Ok(())
203 2 : }
204 : }
|