Line data Source code
1 : use core::fmt;
2 : use std::borrow::Cow;
3 : use std::str::FromStr;
4 :
5 : use anyhow::anyhow;
6 : use hyper::body::HttpBody;
7 : use hyper::{Body, Request};
8 : use routerify::ext::RequestExt;
9 :
10 : use super::error::ApiError;
11 :
12 0 : pub fn get_request_param<'a>(
13 0 : request: &'a Request<Body>,
14 0 : param_name: &str,
15 0 : ) -> Result<&'a str, ApiError> {
16 0 : match request.param(param_name) {
17 0 : Some(arg) => Ok(arg),
18 0 : None => Err(ApiError::BadRequest(anyhow!(
19 0 : "no {param_name} specified in path param",
20 0 : ))),
21 : }
22 0 : }
23 :
24 0 : pub fn parse_request_param<T: FromStr>(
25 0 : request: &Request<Body>,
26 0 : param_name: &str,
27 0 : ) -> Result<T, ApiError> {
28 0 : match get_request_param(request, param_name)?.parse() {
29 0 : Ok(v) => Ok(v),
30 0 : Err(_) => Err(ApiError::BadRequest(anyhow!(
31 0 : "failed to parse {param_name}",
32 0 : ))),
33 : }
34 0 : }
35 :
36 4 : pub fn get_query_param<'a>(
37 4 : request: &'a Request<Body>,
38 4 : param_name: &str,
39 4 : ) -> Result<Option<Cow<'a, str>>, ApiError> {
40 4 : let query = match request.uri().query() {
41 3 : Some(q) => q,
42 1 : None => return Ok(None),
43 : };
44 3 : let values = url::form_urlencoded::parse(query.as_bytes())
45 6 : .filter_map(|(k, v)| if k == param_name { Some(v) } else { None })
46 : // we call .next() twice below. If it's None the first time, .fuse() ensures it's None afterwards
47 3 : .fuse();
48 :
49 : // Work around an issue with Alloy's pyroscope scrape where the "seconds"
50 : // parameter is added several times. https://github.com/grafana/alloy/issues/3026
51 : // TODO: revert after Alloy is fixed.
52 3 : let value1 = values
53 3 : .map(Ok)
54 3 : .reduce(|acc, i| {
55 2 : match acc {
56 1 : Err(_) => acc,
57 :
58 : // It's okay to have duplicates as along as they have the same value.
59 2 : Ok(ref a) if a == &i.unwrap() => acc,
60 :
61 1 : _ => Err(ApiError::BadRequest(anyhow!(
62 1 : "param {param_name} specified more than once"
63 1 : ))),
64 : }
65 3 : })
66 3 : .transpose()?;
67 : // if values.next().is_some() {
68 : // return Err(ApiError::BadRequest(anyhow!(
69 : // "param {param_name} specified more than once"
70 : // )));
71 : // }
72 :
73 2 : Ok(value1)
74 4 : }
75 :
76 0 : pub fn must_get_query_param<'a>(
77 0 : request: &'a Request<Body>,
78 0 : param_name: &str,
79 0 : ) -> Result<Cow<'a, str>, ApiError> {
80 0 : get_query_param(request, param_name)?.ok_or_else(|| {
81 0 : ApiError::BadRequest(anyhow!("no {param_name} specified in query parameters"))
82 0 : })
83 0 : }
84 :
85 0 : pub fn parse_query_param<E: fmt::Display, T: FromStr<Err = E>>(
86 0 : request: &Request<Body>,
87 0 : param_name: &str,
88 0 : ) -> Result<Option<T>, ApiError> {
89 0 : get_query_param(request, param_name)?
90 0 : .map(|v| {
91 0 : v.parse().map_err(|e| {
92 0 : ApiError::BadRequest(anyhow!("cannot parse query param {param_name}: {e}"))
93 0 : })
94 0 : })
95 0 : .transpose()
96 0 : }
97 :
98 0 : pub fn must_parse_query_param<E: fmt::Display, T: FromStr<Err = E>>(
99 0 : request: &Request<Body>,
100 0 : param_name: &str,
101 0 : ) -> Result<T, ApiError> {
102 0 : parse_query_param(request, param_name)?.ok_or_else(|| {
103 0 : ApiError::BadRequest(anyhow!("no {param_name} specified in query parameters"))
104 0 : })
105 0 : }
106 :
107 0 : pub async fn ensure_no_body(request: &mut Request<Body>) -> Result<(), ApiError> {
108 0 : match request.body_mut().data().await {
109 0 : Some(_) => Err(ApiError::BadRequest(anyhow!("Unexpected request body"))),
110 0 : None => Ok(()),
111 : }
112 0 : }
113 :
114 : #[cfg(test)]
115 : mod tests {
116 : use super::*;
117 :
118 : #[test]
119 1 : fn test_get_query_param_duplicate() {
120 1 : let req = Request::builder()
121 1 : .uri("http://localhost:12345/testuri?testparam=1")
122 1 : .body(hyper::Body::empty())
123 1 : .unwrap();
124 1 : let value = get_query_param(&req, "testparam").unwrap();
125 1 : assert_eq!(value.unwrap(), "1");
126 :
127 1 : let req = Request::builder()
128 1 : .uri("http://localhost:12345/testuri?testparam=1&testparam=1")
129 1 : .body(hyper::Body::empty())
130 1 : .unwrap();
131 1 : let value = get_query_param(&req, "testparam").unwrap();
132 1 : assert_eq!(value.unwrap(), "1");
133 :
134 1 : let req = Request::builder()
135 1 : .uri("http://localhost:12345/testuri")
136 1 : .body(hyper::Body::empty())
137 1 : .unwrap();
138 1 : let value = get_query_param(&req, "testparam").unwrap();
139 1 : assert!(value.is_none());
140 :
141 1 : let req = Request::builder()
142 1 : .uri("http://localhost:12345/testuri?testparam=1&testparam=2&testparam=3")
143 1 : .body(hyper::Body::empty())
144 1 : .unwrap();
145 1 : let value = get_query_param(&req, "testparam");
146 1 : assert!(value.is_err());
147 1 : }
148 : }
|