Line data Source code
1 : //! Frontend message serialization.
2 : #![allow(missing_docs)]
3 :
4 : use std::error::Error;
5 : use std::{io, marker};
6 :
7 : use byteorder::{BigEndian, ByteOrder};
8 : use bytes::{Buf, BufMut, BytesMut};
9 :
10 : use crate::{FromUsize, IsNull, Oid, write_nullable};
11 :
12 : #[inline]
13 69 : fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
14 69 : where
15 69 : F: FnOnce(&mut BytesMut) -> Result<(), E>,
16 69 : E: From<io::Error>,
17 : {
18 69 : let base = buf.len();
19 69 : buf.extend_from_slice(&[0; 4]);
20 :
21 69 : f(buf)?;
22 :
23 69 : let size = i32::from_usize(buf.len() - base)?;
24 69 : BigEndian::write_i32(&mut buf[base..], size);
25 69 : Ok(())
26 69 : }
27 :
28 : #[derive(Debug)]
29 : pub enum BindError {
30 : Conversion(Box<dyn Error + marker::Sync + Send>),
31 : Serialization(io::Error),
32 : }
33 :
34 : impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
35 : #[inline]
36 0 : fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
37 0 : BindError::Conversion(e)
38 0 : }
39 : }
40 :
41 : impl From<io::Error> for BindError {
42 : #[inline]
43 0 : fn from(e: io::Error) -> BindError {
44 0 : BindError::Serialization(e)
45 0 : }
46 : }
47 :
48 : #[inline]
49 0 : pub fn bind<I, J, F, T, K>(
50 0 : portal: &str,
51 0 : statement: &str,
52 0 : formats: I,
53 0 : values: J,
54 0 : mut serializer: F,
55 0 : result_formats: K,
56 0 : buf: &mut BytesMut,
57 0 : ) -> Result<(), BindError>
58 0 : where
59 0 : I: IntoIterator<Item = i16>,
60 0 : J: IntoIterator<Item = T>,
61 0 : F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
62 0 : K: IntoIterator<Item = i16>,
63 : {
64 0 : buf.put_u8(b'B');
65 :
66 0 : write_body(buf, |buf| {
67 0 : write_cstr(portal.as_bytes(), buf)?;
68 0 : write_cstr(statement.as_bytes(), buf)?;
69 0 : write_counted(
70 0 : formats,
71 0 : |f, buf| {
72 0 : buf.put_i16(f);
73 0 : Ok::<_, io::Error>(())
74 0 : },
75 0 : buf,
76 0 : )?;
77 0 : write_counted(
78 0 : values,
79 0 : |v, buf| write_nullable(|buf| serializer(v, buf), buf),
80 0 : buf,
81 0 : )?;
82 0 : write_counted(
83 0 : result_formats,
84 0 : |f, buf| {
85 0 : buf.put_i16(f);
86 0 : Ok::<_, io::Error>(())
87 0 : },
88 0 : buf,
89 0 : )?;
90 :
91 0 : Ok(())
92 0 : })
93 0 : }
94 :
95 : #[inline]
96 0 : fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
97 0 : where
98 0 : I: IntoIterator<Item = T>,
99 0 : F: FnMut(T, &mut BytesMut) -> Result<(), E>,
100 0 : E: From<io::Error>,
101 : {
102 0 : let base = buf.len();
103 0 : buf.extend_from_slice(&[0; 2]);
104 0 : let mut count = 0;
105 0 : for item in items {
106 0 : serializer(item, buf)?;
107 0 : count += 1;
108 : }
109 0 : let count = i16::from_usize(count)?;
110 0 : BigEndian::write_i16(&mut buf[base..], count);
111 :
112 0 : Ok(())
113 0 : }
114 :
115 : #[inline]
116 0 : pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
117 0 : write_body(buf, |buf| {
118 0 : buf.put_i32(80_877_102);
119 0 : buf.put_i32(process_id);
120 0 : buf.put_i32(secret_key);
121 0 : Ok::<_, io::Error>(())
122 0 : })
123 0 : .unwrap();
124 0 : }
125 :
126 : #[inline]
127 0 : pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
128 0 : buf.put_u8(b'C');
129 0 : write_body(buf, |buf| {
130 0 : buf.put_u8(variant);
131 0 : write_cstr(name.as_bytes(), buf)
132 0 : })
133 0 : }
134 :
135 : pub struct CopyData<T> {
136 : buf: T,
137 : len: i32,
138 : }
139 :
140 : impl<T> CopyData<T>
141 : where
142 : T: Buf,
143 : {
144 0 : pub fn new(buf: T) -> io::Result<CopyData<T>> {
145 0 : let len = buf
146 0 : .remaining()
147 0 : .checked_add(4)
148 0 : .and_then(|l| i32::try_from(l).ok())
149 0 : .ok_or_else(|| {
150 0 : io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
151 0 : })?;
152 :
153 0 : Ok(CopyData { buf, len })
154 0 : }
155 :
156 0 : pub fn write(self, out: &mut BytesMut) {
157 0 : out.put_u8(b'd');
158 0 : out.put_i32(self.len);
159 0 : out.put(self.buf);
160 0 : }
161 : }
162 :
163 : #[inline]
164 0 : pub fn copy_done(buf: &mut BytesMut) {
165 0 : buf.put_u8(b'c');
166 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
167 0 : }
168 :
169 : #[inline]
170 0 : pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
171 0 : buf.put_u8(b'f');
172 0 : write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
173 0 : }
174 :
175 : #[inline]
176 0 : pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
177 0 : buf.put_u8(b'D');
178 0 : write_body(buf, |buf| {
179 0 : buf.put_u8(variant);
180 0 : write_cstr(name.as_bytes(), buf)
181 0 : })
182 0 : }
183 :
184 : #[inline]
185 0 : pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
186 0 : buf.put_u8(b'E');
187 0 : write_body(buf, |buf| {
188 0 : write_cstr(portal.as_bytes(), buf)?;
189 0 : buf.put_i32(max_rows);
190 0 : Ok(())
191 0 : })
192 0 : }
193 :
194 : #[inline]
195 0 : pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
196 0 : where
197 0 : I: IntoIterator<Item = Oid>,
198 : {
199 0 : buf.put_u8(b'P');
200 0 : write_body(buf, |buf| {
201 0 : write_cstr(name.as_bytes(), buf)?;
202 0 : write_cstr(query.as_bytes(), buf)?;
203 0 : write_counted(
204 0 : param_types,
205 0 : |t, buf| {
206 0 : buf.put_u32(t);
207 0 : Ok::<_, io::Error>(())
208 0 : },
209 0 : buf,
210 0 : )?;
211 0 : Ok(())
212 0 : })
213 0 : }
214 :
215 : #[inline]
216 2 : pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
217 2 : buf.put_u8(b'p');
218 2 : write_body(buf, |buf| write_cstr(password, buf))
219 0 : }
220 :
221 : #[inline]
222 0 : pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
223 0 : buf.put_u8(b'Q');
224 0 : write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
225 0 : }
226 :
227 : #[inline]
228 14 : pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
229 14 : buf.put_u8(b'p');
230 14 : write_body(buf, |buf| {
231 14 : write_cstr(mechanism.as_bytes(), buf)?;
232 14 : let len = i32::from_usize(data.len())?;
233 14 : buf.put_i32(len);
234 14 : buf.put_slice(data);
235 14 : Ok(())
236 14 : })
237 0 : }
238 :
239 : #[inline]
240 11 : pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
241 11 : buf.put_u8(b'p');
242 11 : write_body(buf, |buf| {
243 11 : buf.put_slice(data);
244 11 : Ok(())
245 11 : })
246 0 : }
247 :
248 : #[inline]
249 20 : pub fn ssl_request(buf: &mut BytesMut) {
250 20 : write_body(buf, |buf| {
251 20 : buf.put_i32(80_877_103);
252 20 : Ok::<_, io::Error>(())
253 20 : })
254 20 : .unwrap();
255 0 : }
256 :
257 : #[inline]
258 22 : pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
259 22 : write_body(buf, |buf| {
260 : // postgres protocol version 3.0(196608) in bigger-endian
261 22 : buf.put_i32(0x00_03_00_00);
262 22 : buf.put_slice(¶meters.params);
263 22 : buf.put_u8(0);
264 22 : Ok(())
265 22 : })
266 0 : }
267 :
268 : #[derive(Debug, Clone, Default, PartialEq, Eq)]
269 : pub struct StartupMessageParams {
270 : pub params: BytesMut,
271 : }
272 :
273 : impl StartupMessageParams {
274 : /// Set parameter's value by its name.
275 31 : pub fn insert(&mut self, name: &str, value: &str) {
276 31 : if name.contains('\0') || value.contains('\0') {
277 0 : panic!("startup parameter name or value contained a null")
278 31 : }
279 31 : self.params.put_slice(name.as_bytes());
280 31 : self.params.put_u8(0);
281 31 : self.params.put_slice(value.as_bytes());
282 31 : self.params.put_u8(0);
283 31 : }
284 : }
285 :
286 : #[inline]
287 0 : pub fn sync(buf: &mut BytesMut) {
288 0 : buf.put_u8(b'S');
289 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
290 0 : }
291 :
292 : #[inline]
293 0 : pub fn flush(buf: &mut BytesMut) {
294 0 : buf.put_u8(b'H');
295 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
296 0 : }
297 :
298 : #[inline]
299 0 : pub fn terminate(buf: &mut BytesMut) {
300 0 : buf.put_u8(b'X');
301 0 : write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
302 0 : }
303 :
304 : #[inline]
305 16 : fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
306 16 : if s.contains(&0) {
307 0 : return Err(io::Error::new(
308 0 : io::ErrorKind::InvalidInput,
309 0 : "string contains embedded null",
310 0 : ));
311 16 : }
312 16 : buf.put_slice(s);
313 16 : buf.put_u8(0);
314 16 : Ok(())
315 16 : }
|