Line data Source code
1 : //! Frontend message serialization.
2 : #![allow(missing_docs)]
3 :
4 : use byteorder::{BigEndian, ByteOrder};
5 : use bytes::{Buf, BufMut, BytesMut};
6 : use std::error::Error;
7 : use std::io;
8 : use std::marker;
9 :
10 : use crate::{write_nullable, FromUsize, IsNull, Oid};
11 :
12 : #[inline]
13 69 : fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
14 69 : where
15 69 : F: FnOnce(&mut BytesMut) -> Result<(), E>,
16 69 : E: From<io::Error>,
17 69 : {
18 69 : let base = buf.len();
19 69 : buf.extend_from_slice(&[0; 4]);
20 69 :
21 69 : f(buf)?;
22 :
23 69 : let size = i32::from_usize(buf.len() - base)?;
24 69 : BigEndian::write_i32(&mut buf[base..], size);
25 69 : Ok(())
26 69 : }
27 :
28 : pub enum BindError {
29 : Conversion(Box<dyn Error + marker::Sync + Send>),
30 : Serialization(io::Error),
31 : }
32 :
33 : impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
34 : #[inline]
35 0 : fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
36 0 : BindError::Conversion(e)
37 0 : }
38 : }
39 :
40 : impl From<io::Error> for BindError {
41 : #[inline]
42 0 : fn from(e: io::Error) -> BindError {
43 0 : BindError::Serialization(e)
44 0 : }
45 : }
46 :
47 : #[inline]
48 0 : pub fn bind<I, J, F, T, K>(
49 0 : portal: &str,
50 0 : statement: &str,
51 0 : formats: I,
52 0 : values: J,
53 0 : mut serializer: F,
54 0 : result_formats: K,
55 0 : buf: &mut BytesMut,
56 0 : ) -> Result<(), BindError>
57 0 : where
58 0 : I: IntoIterator<Item = i16>,
59 0 : J: IntoIterator<Item = T>,
60 0 : F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
61 0 : K: IntoIterator<Item = i16>,
62 0 : {
63 0 : buf.put_u8(b'B');
64 0 :
65 0 : write_body(buf, |buf| {
66 0 : write_cstr(portal.as_bytes(), buf)?;
67 0 : write_cstr(statement.as_bytes(), buf)?;
68 0 : write_counted(
69 0 : formats,
70 0 : |f, buf| {
71 0 : buf.put_i16(f);
72 0 : Ok::<_, io::Error>(())
73 0 : },
74 0 : buf,
75 0 : )?;
76 0 : write_counted(
77 0 : values,
78 0 : |v, buf| write_nullable(|buf| serializer(v, buf), buf),
79 0 : buf,
80 0 : )?;
81 0 : write_counted(
82 0 : result_formats,
83 0 : |f, buf| {
84 0 : buf.put_i16(f);
85 0 : Ok::<_, io::Error>(())
86 0 : },
87 0 : buf,
88 0 : )?;
89 :
90 0 : Ok(())
91 0 : })
92 0 : }
93 :
94 : #[inline]
95 0 : fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
96 0 : where
97 0 : I: IntoIterator<Item = T>,
98 0 : F: FnMut(T, &mut BytesMut) -> Result<(), E>,
99 0 : E: From<io::Error>,
100 0 : {
101 0 : let base = buf.len();
102 0 : buf.extend_from_slice(&[0; 2]);
103 0 : let mut count = 0;
104 0 : for item in items {
105 0 : serializer(item, buf)?;
106 0 : count += 1;
107 : }
108 0 : let count = i16::from_usize(count)?;
109 0 : BigEndian::write_i16(&mut buf[base..], count);
110 0 :
111 0 : Ok(())
112 0 : }
113 :
114 : #[inline]
115 0 : pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
116 0 : write_body(buf, |buf| {
117 0 : buf.put_i32(80_877_102);
118 0 : buf.put_i32(process_id);
119 0 : buf.put_i32(secret_key);
120 0 : Ok::<_, io::Error>(())
121 0 : })
122 0 : .unwrap();
123 0 : }
124 :
125 : #[inline]
126 0 : pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
127 0 : buf.put_u8(b'C');
128 0 : write_body(buf, |buf| {
129 0 : buf.put_u8(variant);
130 0 : write_cstr(name.as_bytes(), buf)
131 0 : })
132 0 : }
133 :
134 : pub struct CopyData<T> {
135 : buf: T,
136 : len: i32,
137 : }
138 :
139 : impl<T> CopyData<T>
140 : where
141 : T: Buf,
142 : {
143 0 : pub fn new(buf: T) -> io::Result<CopyData<T>> {
144 0 : let len = buf
145 0 : .remaining()
146 0 : .checked_add(4)
147 0 : .and_then(|l| i32::try_from(l).ok())
148 0 : .ok_or_else(|| {
149 0 : io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
150 0 : })?;
151 :
152 0 : Ok(CopyData { buf, len })
153 0 : }
154 :
155 0 : pub fn write(self, out: &mut BytesMut) {
156 0 : out.put_u8(b'd');
157 0 : out.put_i32(self.len);
158 0 : out.put(self.buf);
159 0 : }
160 : }
161 :
162 : #[inline]
163 0 : pub fn copy_done(buf: &mut BytesMut) {
164 0 : buf.put_u8(b'c');
165 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
166 0 : }
167 :
168 : #[inline]
169 0 : pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
170 0 : buf.put_u8(b'f');
171 0 : write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
172 0 : }
173 :
174 : #[inline]
175 0 : pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
176 0 : buf.put_u8(b'D');
177 0 : write_body(buf, |buf| {
178 0 : buf.put_u8(variant);
179 0 : write_cstr(name.as_bytes(), buf)
180 0 : })
181 0 : }
182 :
183 : #[inline]
184 0 : pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
185 0 : buf.put_u8(b'E');
186 0 : write_body(buf, |buf| {
187 0 : write_cstr(portal.as_bytes(), buf)?;
188 0 : buf.put_i32(max_rows);
189 0 : Ok(())
190 0 : })
191 0 : }
192 :
193 : #[inline]
194 0 : pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
195 0 : where
196 0 : I: IntoIterator<Item = Oid>,
197 0 : {
198 0 : buf.put_u8(b'P');
199 0 : write_body(buf, |buf| {
200 0 : write_cstr(name.as_bytes(), buf)?;
201 0 : write_cstr(query.as_bytes(), buf)?;
202 0 : write_counted(
203 0 : param_types,
204 0 : |t, buf| {
205 0 : buf.put_u32(t);
206 0 : Ok::<_, io::Error>(())
207 0 : },
208 0 : buf,
209 0 : )?;
210 0 : Ok(())
211 0 : })
212 0 : }
213 :
214 : #[inline]
215 2 : pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
216 2 : buf.put_u8(b'p');
217 2 : write_body(buf, |buf| write_cstr(password, buf))
218 2 : }
219 :
220 : #[inline]
221 0 : pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
222 0 : buf.put_u8(b'Q');
223 0 : write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
224 0 : }
225 :
226 : #[inline]
227 14 : pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
228 14 : buf.put_u8(b'p');
229 14 : write_body(buf, |buf| {
230 14 : write_cstr(mechanism.as_bytes(), buf)?;
231 14 : let len = i32::from_usize(data.len())?;
232 14 : buf.put_i32(len);
233 14 : buf.put_slice(data);
234 14 : Ok(())
235 14 : })
236 14 : }
237 :
238 : #[inline]
239 11 : pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
240 11 : buf.put_u8(b'p');
241 11 : write_body(buf, |buf| {
242 11 : buf.put_slice(data);
243 11 : Ok(())
244 11 : })
245 11 : }
246 :
247 : #[inline]
248 20 : pub fn ssl_request(buf: &mut BytesMut) {
249 20 : write_body(buf, |buf| {
250 20 : buf.put_i32(80_877_103);
251 20 : Ok::<_, io::Error>(())
252 20 : })
253 20 : .unwrap();
254 20 : }
255 :
256 : #[inline]
257 22 : pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
258 22 : write_body(buf, |buf| {
259 22 : // postgres protocol version 3.0(196608) in bigger-endian
260 22 : buf.put_i32(0x00_03_00_00);
261 22 : buf.put_slice(¶meters.params);
262 22 : buf.put_u8(0);
263 22 : Ok(())
264 22 : })
265 22 : }
266 :
267 : #[derive(Debug, Clone, Default, PartialEq, Eq)]
268 : pub struct StartupMessageParams {
269 : pub params: BytesMut,
270 : }
271 :
272 : impl StartupMessageParams {
273 : /// Set parameter's value by its name.
274 31 : pub fn insert(&mut self, name: &str, value: &str) {
275 31 : if name.contains('\0') || value.contains('\0') {
276 0 : panic!("startup parameter name or value contained a null")
277 31 : }
278 31 : self.params.put_slice(name.as_bytes());
279 31 : self.params.put_u8(0);
280 31 : self.params.put_slice(value.as_bytes());
281 31 : self.params.put_u8(0);
282 31 : }
283 : }
284 :
285 : #[inline]
286 0 : pub fn sync(buf: &mut BytesMut) {
287 0 : buf.put_u8(b'S');
288 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
289 0 : }
290 :
291 : #[inline]
292 0 : pub fn terminate(buf: &mut BytesMut) {
293 0 : buf.put_u8(b'X');
294 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
295 0 : }
296 :
297 : #[inline]
298 16 : fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
299 16 : if s.contains(&0) {
300 0 : return Err(io::Error::new(
301 0 : io::ErrorKind::InvalidInput,
302 0 : "string contains embedded null",
303 0 : ));
304 16 : }
305 16 : buf.put_slice(s);
306 16 : buf.put_u8(0);
307 16 : Ok(())
308 16 : }
|