Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use byteorder::{BigEndian, ReadBytesExt};
9 : use bytes::{Buf, BufMut, Bytes, BytesMut};
10 : use itertools::Itertools;
11 : use serde::{Deserialize, Serialize};
12 : use std::{borrow::Cow, fmt, io, str};
13 :
14 : // re-export for use in utils pageserver_feedback.rs
15 : pub use postgres_protocol::PG_EPOCH;
16 :
17 : pub type Oid = u32;
18 : pub type SystemId = u64;
19 :
20 : pub const INT8_OID: Oid = 20;
21 : pub const INT4_OID: Oid = 23;
22 : pub const TEXT_OID: Oid = 25;
23 :
24 : #[derive(Debug)]
25 : pub enum FeMessage {
26 : // Simple query.
27 : Query(Bytes),
28 : // Extended query protocol.
29 : Parse(FeParseMessage),
30 : Describe(FeDescribeMessage),
31 : Bind(FeBindMessage),
32 : Execute(FeExecuteMessage),
33 : Close(FeCloseMessage),
34 : Sync,
35 : Terminate,
36 : CopyData(Bytes),
37 : CopyDone,
38 : CopyFail,
39 : PasswordMessage(Bytes),
40 : }
41 :
42 : #[derive(Debug)]
43 : pub enum FeStartupPacket {
44 : CancelRequest(CancelKeyData),
45 : SslRequest,
46 : GssEncRequest,
47 : StartupMessage {
48 : major_version: u32,
49 : minor_version: u32,
50 : params: StartupMessageParams,
51 : },
52 : }
53 :
54 : #[derive(Debug, Clone, Default)]
55 : pub struct StartupMessageParamsBuilder {
56 : params: BytesMut,
57 : }
58 :
59 : impl StartupMessageParamsBuilder {
60 : /// Set parameter's value by its name.
61 : /// name and value must not contain a \0 byte
62 62 : pub fn insert(&mut self, name: &str, value: &str) {
63 62 : self.params.put(name.as_bytes());
64 62 : self.params.put(&b"\0"[..]);
65 62 : self.params.put(value.as_bytes());
66 62 : self.params.put(&b"\0"[..]);
67 62 : }
68 :
69 46 : pub fn freeze(self) -> StartupMessageParams {
70 46 : StartupMessageParams {
71 46 : params: self.params.freeze(),
72 46 : }
73 46 : }
74 : }
75 :
76 : #[derive(Debug, Clone, Default)]
77 : pub struct StartupMessageParams {
78 : params: Bytes,
79 : }
80 :
81 : impl StartupMessageParams {
82 : /// Get parameter's value by its name.
83 96 : pub fn get(&self, name: &str) -> Option<&str> {
84 128 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
85 96 : }
86 :
87 : /// Split command-line options according to PostgreSQL's logic,
88 : /// taking into account all escape sequences but leaving them as-is.
89 : /// [`None`] means that there's no `options` in [`Self`].
90 60 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
91 60 : self.get("options").map(Self::parse_options_raw)
92 60 : }
93 :
94 : /// Split command-line options according to PostgreSQL's logic,
95 : /// applying all escape sequences (using owned strings as needed).
96 : /// [`None`] means that there's no `options` in [`Self`].
97 10 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
98 10 : self.get("options").map(Self::parse_options_escaped)
99 10 : }
100 :
101 : /// Split command-line options according to PostgreSQL's logic,
102 : /// taking into account all escape sequences but leaving them as-is.
103 60 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
104 60 : // See `postgres: pg_split_opts`.
105 60 : let mut last_was_escape = false;
106 60 : input
107 1128 : .split(move |c: char| {
108 : // We split by non-escaped whitespace symbols.
109 1128 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
110 1128 : last_was_escape = c == '\\' && !last_was_escape;
111 1128 : should_split
112 1128 : })
113 152 : .filter(|s| !s.is_empty())
114 60 : }
115 :
116 : /// Split command-line options according to PostgreSQL's logic,
117 : /// applying all escape sequences (using owned strings as needed).
118 8 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
119 8 : // See `postgres: pg_split_opts`.
120 14 : Self::parse_options_raw(input).map(|s| {
121 14 : let mut preserve_next_escape = false;
122 34 : let escape = |c| {
123 : // We should remove '\\' unless it's preceded by '\\'.
124 34 : let should_remove = c == '\\' && !preserve_next_escape;
125 34 : preserve_next_escape = should_remove;
126 34 : should_remove
127 34 : };
128 :
129 14 : match s.contains('\\') {
130 6 : true => Cow::Owned(s.replace(escape, "")),
131 8 : false => Cow::Borrowed(s),
132 : }
133 14 : })
134 8 : }
135 :
136 : /// Iterate through key-value pairs in an arbitrary order.
137 110 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
138 110 : let params =
139 110 : std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
140 110 : params.split_terminator('\0').tuples()
141 110 : }
142 :
143 : // This function is mostly useful in tests.
144 : #[doc(hidden)]
145 46 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
146 46 : let mut b = StartupMessageParamsBuilder::default();
147 108 : for (k, v) in pairs {
148 62 : b.insert(k, v)
149 : }
150 46 : b.freeze()
151 46 : }
152 : }
153 :
154 12 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
155 : pub struct CancelKeyData {
156 : pub backend_pid: i32,
157 : pub cancel_key: i32,
158 : }
159 :
160 : impl fmt::Display for CancelKeyData {
161 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 0 : let hi = (self.backend_pid as u64) << 32;
163 0 : let lo = self.cancel_key as u64;
164 0 : let id = hi | lo;
165 0 :
166 0 : // This format is more compact and might work better for logs.
167 0 : f.debug_tuple("CancelKeyData")
168 0 : .field(&format_args!("{:x}", id))
169 0 : .finish()
170 0 : }
171 : }
172 :
173 : use rand::distributions::{Distribution, Standard};
174 : impl Distribution<CancelKeyData> for Standard {
175 2 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
176 2 : CancelKeyData {
177 2 : backend_pid: rng.gen(),
178 2 : cancel_key: rng.gen(),
179 2 : }
180 2 : }
181 : }
182 :
183 : // We only support the simple case of Parse on unnamed prepared statement and
184 : // no params
185 : #[derive(Debug)]
186 : pub struct FeParseMessage {
187 : pub query_string: Bytes,
188 : }
189 :
190 : #[derive(Debug)]
191 : pub struct FeDescribeMessage {
192 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
193 : // we only support unnamed prepared stmt or portal
194 : }
195 :
196 : // we only support unnamed prepared stmt and portal
197 : #[derive(Debug)]
198 : pub struct FeBindMessage;
199 :
200 : // we only support unnamed prepared stmt or portal
201 : #[derive(Debug)]
202 : pub struct FeExecuteMessage {
203 : /// max # of rows
204 : pub maxrows: i32,
205 : }
206 :
207 : // we only support unnamed prepared stmt and portal
208 : #[derive(Debug)]
209 : pub struct FeCloseMessage;
210 :
211 : /// An error occurred while parsing or serializing raw stream into Postgres
212 : /// messages.
213 0 : #[derive(thiserror::Error, Debug)]
214 : pub enum ProtocolError {
215 : /// Invalid packet was received from the client (e.g. unexpected message
216 : /// type or broken len).
217 : #[error("Protocol error: {0}")]
218 : Protocol(String),
219 : /// Failed to parse or, (unlikely), serialize a protocol message.
220 : #[error("Message parse error: {0}")]
221 : BadMessage(String),
222 : }
223 :
224 : impl ProtocolError {
225 : /// Proxy stream.rs uses only io::Error; provide it.
226 0 : pub fn into_io_error(self) -> io::Error {
227 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
228 0 : }
229 : }
230 :
231 : impl FeMessage {
232 : /// Read and parse one message from the `buf` input buffer. If there is at
233 : /// least one valid message, returns it, advancing `buf`; redundant copies
234 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
235 : /// directly into the `buf` (processed data is garbage collected after
236 : /// parsed message is dropped).
237 : ///
238 : /// Returns None if `buf` doesn't contain enough data for a single message.
239 : /// For efficiency, tries to reserve large enough space in `buf` for the
240 : /// next message in this case to save the repeated calls.
241 : ///
242 : /// Returns Error if message is malformed, the only possible ErrorKind is
243 : /// InvalidInput.
244 : //
245 : // Inspired by rust-postgres Message::parse.
246 114 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
247 114 : // Every message contains message type byte and 4 bytes len; can't do
248 114 : // much without them.
249 114 : if buf.len() < 5 {
250 60 : let to_read = 5 - buf.len();
251 60 : buf.reserve(to_read);
252 60 : return Ok(None);
253 54 : }
254 54 :
255 54 : // We shouldn't advance `buf` as probably full message is not there yet,
256 54 : // so can't directly use Bytes::get_u32 etc.
257 54 : let tag = buf[0];
258 54 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
259 54 : if len < 4 {
260 0 : return Err(ProtocolError::Protocol(format!(
261 0 : "invalid message length {}",
262 0 : len
263 0 : )));
264 54 : }
265 54 :
266 54 : // length field includes itself, but not message type.
267 54 : let total_len = len as usize + 1;
268 54 : if buf.len() < total_len {
269 : // Don't have full message yet.
270 0 : let to_read = total_len - buf.len();
271 0 : buf.reserve(to_read);
272 0 : return Ok(None);
273 54 : }
274 54 :
275 54 : // got the message, advance buffer
276 54 : let mut msg = buf.split_to(total_len).freeze();
277 54 : msg.advance(5); // consume message type and len
278 54 :
279 54 : match tag {
280 4 : b'Q' => Ok(Some(FeMessage::Query(msg))),
281 0 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
282 0 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
283 0 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
284 0 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
285 0 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
286 0 : b'S' => Ok(Some(FeMessage::Sync)),
287 0 : b'X' => Ok(Some(FeMessage::Terminate)),
288 0 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
289 0 : b'c' => Ok(Some(FeMessage::CopyDone)),
290 0 : b'f' => Ok(Some(FeMessage::CopyFail)),
291 50 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
292 0 : tag => Err(ProtocolError::Protocol(format!(
293 0 : "unknown message tag: {tag},'{msg:?}'"
294 0 : ))),
295 : }
296 114 : }
297 : }
298 :
299 : impl FeStartupPacket {
300 : /// Read and parse startup message from the `buf` input buffer. It is
301 : /// different from [`FeMessage::parse`] because startup messages don't have
302 : /// message type byte; otherwise, its comments apply.
303 182 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
304 182 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
305 182 : const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
306 182 : const CANCEL_REQUEST_CODE: u32 = 5678;
307 182 : const NEGOTIATE_SSL_CODE: u32 = 5679;
308 182 : const NEGOTIATE_GSS_CODE: u32 = 5680;
309 182 :
310 182 : // need at least 4 bytes with packet len
311 182 : if buf.len() < 4 {
312 90 : let to_read = 4 - buf.len();
313 90 : buf.reserve(to_read);
314 90 : return Ok(None);
315 92 : }
316 92 :
317 92 : // We shouldn't advance `buf` as probably full message is not there yet,
318 92 : // so can't directly use Bytes::get_u32 etc.
319 92 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
320 92 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
321 92 : // which is less readable
322 92 : #[allow(clippy::manual_range_contains)]
323 92 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
324 2 : return Err(ProtocolError::Protocol(format!(
325 2 : "invalid startup packet message length {}",
326 2 : len
327 2 : )));
328 90 : }
329 90 :
330 90 : if buf.len() < len {
331 : // Don't have full message yet.
332 0 : let to_read = len - buf.len();
333 0 : buf.reserve(to_read);
334 0 : return Ok(None);
335 90 : }
336 90 :
337 90 : // got the message, advance buffer
338 90 : let mut msg = buf.split_to(len).freeze();
339 90 : msg.advance(4); // consume len
340 90 :
341 90 : let request_code = msg.get_u32();
342 90 : let req_hi = request_code >> 16;
343 90 : let req_lo = request_code & ((1 << 16) - 1);
344 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
345 90 : let message = match (req_hi, req_lo) {
346 : (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
347 0 : if msg.remaining() != 8 {
348 0 : return Err(ProtocolError::BadMessage(
349 0 : "CancelRequest message is malformed, backend PID / secret key missing"
350 0 : .to_owned(),
351 0 : ));
352 0 : }
353 0 : FeStartupPacket::CancelRequest(CancelKeyData {
354 0 : backend_pid: msg.get_i32(),
355 0 : cancel_key: msg.get_i32(),
356 0 : })
357 : }
358 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
359 : // Requested upgrade to SSL (aka TLS)
360 42 : FeStartupPacket::SslRequest
361 : }
362 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
363 : // Requested upgrade to GSSAPI
364 0 : FeStartupPacket::GssEncRequest
365 : }
366 0 : (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
367 0 : return Err(ProtocolError::Protocol(format!(
368 0 : "Unrecognized request code {unrecognized_code}"
369 0 : )));
370 : }
371 : // TODO bail if protocol major_version is not 3?
372 48 : (major_version, minor_version) => {
373 : // StartupMessage
374 :
375 48 : let s = str::from_utf8(&msg).map_err(|_e| {
376 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
377 48 : })?;
378 48 : let s = s.strip_suffix('\0').ok_or_else(|| {
379 0 : ProtocolError::Protocol(
380 0 : "StartupMessage params: missing null terminator".to_string(),
381 0 : )
382 48 : })?;
383 :
384 48 : FeStartupPacket::StartupMessage {
385 48 : major_version,
386 48 : minor_version,
387 48 : params: StartupMessageParams {
388 48 : params: msg.slice_ref(s.as_bytes()),
389 48 : },
390 48 : }
391 : }
392 : };
393 90 : Ok(Some(message))
394 182 : }
395 : }
396 :
397 : impl FeParseMessage {
398 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
399 : // FIXME: the rust-postgres driver uses a named prepared statement
400 : // for copy_out(). We're not prepared to handle that correctly. For
401 : // now, just ignore the statement name, assuming that the client never
402 : // uses more than one prepared statement at a time.
403 :
404 0 : let _pstmt_name = read_cstr(&mut buf)?;
405 0 : let query_string = read_cstr(&mut buf)?;
406 0 : if buf.remaining() < 2 {
407 0 : return Err(ProtocolError::BadMessage(
408 0 : "Parse message is malformed, nparams missing".to_string(),
409 0 : ));
410 0 : }
411 0 : let nparams = buf.get_i16();
412 0 :
413 0 : if nparams != 0 {
414 0 : return Err(ProtocolError::BadMessage(
415 0 : "query params not implemented".to_string(),
416 0 : ));
417 0 : }
418 0 :
419 0 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
420 0 : }
421 : }
422 :
423 : impl FeDescribeMessage {
424 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
425 0 : let kind = buf.get_u8();
426 0 : let _pstmt_name = read_cstr(&mut buf)?;
427 :
428 : // FIXME: see FeParseMessage::parse
429 0 : if kind != b'S' {
430 0 : return Err(ProtocolError::BadMessage(
431 0 : "only prepared statemement Describe is implemented".to_string(),
432 0 : ));
433 0 : }
434 0 :
435 0 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
436 0 : }
437 : }
438 :
439 : impl FeExecuteMessage {
440 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
441 0 : let portal_name = read_cstr(&mut buf)?;
442 0 : if buf.remaining() < 4 {
443 0 : return Err(ProtocolError::BadMessage(
444 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
445 0 : ));
446 0 : }
447 0 : let maxrows = buf.get_i32();
448 0 :
449 0 : if !portal_name.is_empty() {
450 0 : return Err(ProtocolError::BadMessage(
451 0 : "named portals not implemented".to_string(),
452 0 : ));
453 0 : }
454 0 : if maxrows != 0 {
455 0 : return Err(ProtocolError::BadMessage(
456 0 : "row limit in Execute message not implemented".to_string(),
457 0 : ));
458 0 : }
459 0 :
460 0 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
461 0 : }
462 : }
463 :
464 : impl FeBindMessage {
465 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
466 0 : let portal_name = read_cstr(&mut buf)?;
467 0 : let _pstmt_name = read_cstr(&mut buf)?;
468 :
469 : // FIXME: see FeParseMessage::parse
470 0 : if !portal_name.is_empty() {
471 0 : return Err(ProtocolError::BadMessage(
472 0 : "named portals not implemented".to_string(),
473 0 : ));
474 0 : }
475 0 :
476 0 : Ok(FeMessage::Bind(FeBindMessage))
477 0 : }
478 : }
479 :
480 : impl FeCloseMessage {
481 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
482 0 : let _kind = buf.get_u8();
483 0 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
484 :
485 : // FIXME: we do nothing with Close
486 0 : Ok(FeMessage::Close(FeCloseMessage))
487 0 : }
488 : }
489 :
490 : // Backend
491 :
492 : #[derive(Debug)]
493 : pub enum BeMessage<'a> {
494 : AuthenticationOk,
495 : AuthenticationMD5Password([u8; 4]),
496 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
497 : AuthenticationCleartextPassword,
498 : BackendKeyData(CancelKeyData),
499 : BindComplete,
500 : CommandComplete(&'a [u8]),
501 : CopyData(&'a [u8]),
502 : CopyDone,
503 : CopyFail,
504 : CopyInResponse,
505 : CopyOutResponse,
506 : CopyBothResponse,
507 : CloseComplete,
508 : // None means column is NULL
509 : DataRow(&'a [Option<&'a [u8]>]),
510 : // None errcode means internal_error will be sent.
511 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
512 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
513 : EncryptionResponse(bool),
514 : NoData,
515 : ParameterDescription,
516 : ParameterStatus {
517 : name: &'a [u8],
518 : value: &'a [u8],
519 : },
520 : ParseComplete,
521 : ReadyForQuery,
522 : RowDescription(&'a [RowDescriptor<'a>]),
523 : XLogData(XLogDataBody<'a>),
524 : NoticeResponse(&'a str),
525 : KeepAlive(WalSndKeepAlive),
526 : }
527 :
528 : /// Common shorthands.
529 : impl<'a> BeMessage<'a> {
530 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
531 : /// This is a sensible default, given that:
532 : /// * rust strings only support this encoding out of the box.
533 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
534 : ///
535 : /// TODO: do we need to report `server_encoding` as well?
536 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
537 : name: b"client_encoding",
538 : value: b"UTF8",
539 : };
540 :
541 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
542 : name: b"integer_datetimes",
543 : value: b"on",
544 : };
545 :
546 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
547 4 : pub fn server_version(version: &'a str) -> Self {
548 4 : Self::ParameterStatus {
549 4 : name: b"server_version",
550 4 : value: version.as_bytes(),
551 4 : }
552 4 : }
553 : }
554 :
555 : #[derive(Debug)]
556 : pub enum BeAuthenticationSaslMessage<'a> {
557 : Methods(&'a [&'a str]),
558 : Continue(&'a [u8]),
559 : Final(&'a [u8]),
560 : }
561 :
562 : #[derive(Debug)]
563 : pub enum BeParameterStatusMessage<'a> {
564 : Encoding(&'a str),
565 : ServerVersion(&'a str),
566 : }
567 :
568 : // One row description in RowDescription packet.
569 : #[derive(Debug)]
570 : pub struct RowDescriptor<'a> {
571 : pub name: &'a [u8],
572 : pub tableoid: Oid,
573 : pub attnum: i16,
574 : pub typoid: Oid,
575 : pub typlen: i16,
576 : pub typmod: i32,
577 : pub formatcode: i16,
578 : }
579 :
580 : impl Default for RowDescriptor<'_> {
581 0 : fn default() -> RowDescriptor<'static> {
582 0 : RowDescriptor {
583 0 : name: b"",
584 0 : tableoid: 0,
585 0 : attnum: 0,
586 0 : typoid: 0,
587 0 : typlen: 0,
588 0 : typmod: 0,
589 0 : formatcode: 0,
590 0 : }
591 0 : }
592 : }
593 :
594 : impl RowDescriptor<'_> {
595 : /// Convenience function to create a RowDescriptor message for an int8 column
596 0 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
597 0 : RowDescriptor {
598 0 : name,
599 0 : tableoid: 0,
600 0 : attnum: 0,
601 0 : typoid: INT8_OID,
602 0 : typlen: 8,
603 0 : typmod: 0,
604 0 : formatcode: 0,
605 0 : }
606 0 : }
607 :
608 4 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
609 4 : RowDescriptor {
610 4 : name,
611 4 : tableoid: 0,
612 4 : attnum: 0,
613 4 : typoid: TEXT_OID,
614 4 : typlen: -1,
615 4 : typmod: 0,
616 4 : formatcode: 0,
617 4 : }
618 4 : }
619 : }
620 :
621 : #[derive(Debug)]
622 : pub struct XLogDataBody<'a> {
623 : pub wal_start: u64,
624 : pub wal_end: u64, // current end of WAL on the server
625 : pub timestamp: i64,
626 : pub data: &'a [u8],
627 : }
628 :
629 : #[derive(Debug)]
630 : pub struct WalSndKeepAlive {
631 : pub wal_end: u64, // current end of WAL on the server
632 : pub timestamp: i64,
633 : pub request_reply: bool,
634 : }
635 :
636 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
637 :
638 : // single text column
639 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
640 : name: b"data",
641 : tableoid: 0,
642 : attnum: 0,
643 : typoid: TEXT_OID,
644 : typlen: -1,
645 : typmod: 0,
646 : formatcode: 0,
647 : }]);
648 :
649 : /// Call f() to write body of the message and prepend it with 4-byte len as
650 : /// prescribed by the protocol.
651 150 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
652 150 : let base = buf.len();
653 150 : buf.extend_from_slice(&[0; 4]);
654 150 :
655 150 : let res = f(buf);
656 150 :
657 150 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
658 150 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
659 150 :
660 150 : res
661 150 : }
662 :
663 : /// Safe write of s into buf as cstring (String in the protocol).
664 112 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
665 112 : let bytes = s.as_ref();
666 112 : if bytes.contains(&0) {
667 0 : return Err(ProtocolError::BadMessage(
668 0 : "string contains embedded null".to_owned(),
669 0 : ));
670 112 : }
671 112 : buf.put_slice(bytes);
672 112 : buf.put_u8(0);
673 112 : Ok(())
674 112 : }
675 :
676 : /// Read cstring from buf, advancing it.
677 22 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
678 22 : let pos = buf
679 22 : .iter()
680 312 : .position(|x| *x == 0)
681 22 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
682 22 : let result = buf.split_to(pos);
683 22 : buf.advance(1); // drop the null terminator
684 22 : Ok(result)
685 22 : }
686 :
687 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
688 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
689 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
690 :
691 : impl<'a> BeMessage<'a> {
692 : /// Serialize `message` to the given `buf`.
693 : /// Apart from smart memory managemet, BytesMut is good here as msg len
694 : /// precedes its body and it is handy to write it down first and then fill
695 : /// the length. With Write we would have to either calc it manually or have
696 : /// one more buffer.
697 192 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
698 192 : match message {
699 24 : BeMessage::AuthenticationOk => {
700 24 : buf.put_u8(b'R');
701 24 : write_body(buf, |buf| {
702 24 : buf.put_i32(0); // Specifies that the authentication was successful.
703 24 : });
704 24 : }
705 :
706 4 : BeMessage::AuthenticationCleartextPassword => {
707 4 : buf.put_u8(b'R');
708 4 : write_body(buf, |buf| {
709 4 : buf.put_i32(3); // Specifies that clear text password is required.
710 4 : });
711 4 : }
712 :
713 0 : BeMessage::AuthenticationMD5Password(salt) => {
714 0 : buf.put_u8(b'R');
715 0 : write_body(buf, |buf| {
716 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
717 0 : buf.put_slice(&salt[..]);
718 0 : });
719 0 : }
720 :
721 60 : BeMessage::AuthenticationSasl(msg) => {
722 60 : buf.put_u8(b'R');
723 60 : write_body(buf, |buf| {
724 60 : use BeAuthenticationSaslMessage::*;
725 60 : match msg {
726 26 : Methods(methods) => {
727 26 : buf.put_i32(10); // Specifies that SASL auth method is used.
728 50 : for method in methods.iter() {
729 50 : write_cstr(method, buf)?;
730 : }
731 26 : buf.put_u8(0); // zero terminator for the list
732 : }
733 22 : Continue(extra) => {
734 22 : buf.put_i32(11); // Continue SASL auth.
735 22 : buf.put_slice(extra);
736 22 : }
737 12 : Final(extra) => {
738 12 : buf.put_i32(12); // Send final SASL message.
739 12 : buf.put_slice(extra);
740 12 : }
741 : }
742 60 : Ok(())
743 60 : })?;
744 : }
745 :
746 0 : BeMessage::BackendKeyData(key_data) => {
747 0 : buf.put_u8(b'K');
748 0 : write_body(buf, |buf| {
749 0 : buf.put_i32(key_data.backend_pid);
750 0 : buf.put_i32(key_data.cancel_key);
751 0 : });
752 0 : }
753 :
754 0 : BeMessage::BindComplete => {
755 0 : buf.put_u8(b'2');
756 0 : write_body(buf, |_| {});
757 0 : }
758 :
759 0 : BeMessage::CloseComplete => {
760 0 : buf.put_u8(b'3');
761 0 : write_body(buf, |_| {});
762 0 : }
763 :
764 4 : BeMessage::CommandComplete(cmd) => {
765 4 : buf.put_u8(b'C');
766 4 : write_body(buf, |buf| write_cstr(cmd, buf))?;
767 : }
768 :
769 0 : BeMessage::CopyData(data) => {
770 0 : buf.put_u8(b'd');
771 0 : write_body(buf, |buf| {
772 0 : buf.put_slice(data);
773 0 : });
774 0 : }
775 :
776 0 : BeMessage::CopyDone => {
777 0 : buf.put_u8(b'c');
778 0 : write_body(buf, |_| {});
779 0 : }
780 :
781 0 : BeMessage::CopyFail => {
782 0 : buf.put_u8(b'f');
783 0 : write_body(buf, |_| {});
784 0 : }
785 :
786 0 : BeMessage::CopyInResponse => {
787 0 : buf.put_u8(b'G');
788 0 : write_body(buf, |buf| {
789 0 : buf.put_u8(1); // copy_is_binary
790 0 : buf.put_i16(0); // numAttributes
791 0 : });
792 0 : }
793 :
794 0 : BeMessage::CopyOutResponse => {
795 0 : buf.put_u8(b'H');
796 0 : write_body(buf, |buf| {
797 0 : buf.put_u8(0); // copy_is_binary
798 0 : buf.put_i16(0); // numAttributes
799 0 : });
800 0 : }
801 :
802 0 : BeMessage::CopyBothResponse => {
803 0 : buf.put_u8(b'W');
804 0 : write_body(buf, |buf| {
805 0 : // doesn't matter, used only for replication
806 0 : buf.put_u8(0); // copy_is_binary
807 0 : buf.put_i16(0); // numAttributes
808 0 : });
809 0 : }
810 :
811 4 : BeMessage::DataRow(vals) => {
812 4 : buf.put_u8(b'D');
813 4 : write_body(buf, |buf| {
814 4 : buf.put_u16(vals.len() as u16); // num of cols
815 4 : for val_opt in vals.iter() {
816 4 : if let Some(val) = val_opt {
817 4 : buf.put_u32(val.len() as u32);
818 4 : buf.put_slice(val);
819 4 : } else {
820 0 : buf.put_i32(-1);
821 0 : }
822 : }
823 4 : });
824 4 : }
825 :
826 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
827 : // First byte of each field represents type of this field. Set just enough fields
828 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
829 : // message text.
830 2 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
831 2 : // 'E' signalizes ErrorResponse messages
832 2 : buf.put_u8(b'E');
833 2 : write_body(buf, |buf| {
834 2 : buf.put_u8(b'S'); // severity
835 2 : buf.put_slice(b"ERROR\0");
836 2 :
837 2 : buf.put_u8(b'C'); // SQLSTATE error code
838 2 : buf.put_slice(&terminate_code(
839 2 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
840 2 : ));
841 2 :
842 2 : buf.put_u8(b'M'); // the message
843 2 : write_cstr(error_msg, buf)?;
844 :
845 2 : buf.put_u8(0); // terminator
846 2 : Ok(())
847 2 : })?;
848 : }
849 :
850 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
851 : // message but continue listening for ReadyForQuery or ErrorResponse"
852 0 : BeMessage::NoticeResponse(error_msg) => {
853 0 : // For all the errors set Severity to Error and error code to
854 0 : // 'internal error'.
855 0 :
856 0 : // 'N' signalizes NoticeResponse messages
857 0 : buf.put_u8(b'N');
858 0 : write_body(buf, |buf| {
859 0 : buf.put_u8(b'S'); // severity
860 0 : buf.put_slice(b"NOTICE\0");
861 0 :
862 0 : buf.put_u8(b'C'); // SQLSTATE error code
863 0 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
864 0 :
865 0 : buf.put_u8(b'M'); // the message
866 0 : write_cstr(error_msg.as_bytes(), buf)?;
867 :
868 0 : buf.put_u8(0); // terminator
869 0 : Ok(())
870 0 : })?;
871 : }
872 :
873 0 : BeMessage::NoData => {
874 0 : buf.put_u8(b'n');
875 0 : write_body(buf, |_| {});
876 0 : }
877 :
878 42 : BeMessage::EncryptionResponse(should_negotiate) => {
879 42 : let response = if *should_negotiate { b'S' } else { b'N' };
880 42 : buf.put_u8(response);
881 : }
882 :
883 26 : BeMessage::ParameterStatus { name, value } => {
884 26 : buf.put_u8(b'S');
885 26 : write_body(buf, |buf| {
886 26 : write_cstr(name, buf)?;
887 26 : write_cstr(value, buf)
888 26 : })?;
889 : }
890 :
891 0 : BeMessage::ParameterDescription => {
892 0 : buf.put_u8(b't');
893 0 : write_body(buf, |buf| {
894 0 : // we don't support params, so always 0
895 0 : buf.put_i16(0);
896 0 : });
897 0 : }
898 :
899 0 : BeMessage::ParseComplete => {
900 0 : buf.put_u8(b'1');
901 0 : write_body(buf, |_| {});
902 0 : }
903 :
904 22 : BeMessage::ReadyForQuery => {
905 22 : buf.put_u8(b'Z');
906 22 : write_body(buf, |buf| {
907 22 : buf.put_u8(b'I');
908 22 : });
909 22 : }
910 :
911 4 : BeMessage::RowDescription(rows) => {
912 4 : buf.put_u8(b'T');
913 4 : write_body(buf, |buf| {
914 4 : buf.put_i16(rows.len() as i16); // # of fields
915 4 : for row in rows.iter() {
916 4 : write_cstr(row.name, buf)?;
917 4 : buf.put_i32(0); /* table oid */
918 4 : buf.put_i16(0); /* attnum */
919 4 : buf.put_u32(row.typoid);
920 4 : buf.put_i16(row.typlen);
921 4 : buf.put_i32(-1); /* typmod */
922 4 : buf.put_i16(0); /* format code */
923 : }
924 4 : Ok(())
925 4 : })?;
926 : }
927 :
928 0 : BeMessage::XLogData(body) => {
929 0 : buf.put_u8(b'd');
930 0 : write_body(buf, |buf| {
931 0 : buf.put_u8(b'w');
932 0 : buf.put_u64(body.wal_start);
933 0 : buf.put_u64(body.wal_end);
934 0 : buf.put_i64(body.timestamp);
935 0 : buf.put_slice(body.data);
936 0 : });
937 0 : }
938 :
939 0 : BeMessage::KeepAlive(req) => {
940 0 : buf.put_u8(b'd');
941 0 : write_body(buf, |buf| {
942 0 : buf.put_u8(b'k');
943 0 : buf.put_u64(req.wal_end);
944 0 : buf.put_i64(req.timestamp);
945 0 : buf.put_u8(u8::from(req.request_reply));
946 0 : });
947 0 : }
948 : }
949 192 : Ok(())
950 192 : }
951 : }
952 :
953 2 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
954 2 : let mut terminated = [0; 6];
955 10 : for (i, &elem) in code.iter().enumerate() {
956 10 : terminated[i] = elem;
957 10 : }
958 :
959 2 : terminated
960 2 : }
961 :
962 : #[cfg(test)]
963 : mod tests {
964 : use super::*;
965 :
966 : #[test]
967 2 : fn test_startup_message_params_options_escaped() {
968 8 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
969 8 : params
970 8 : .options_escaped()
971 8 : .expect("options are None")
972 8 : .collect()
973 8 : }
974 2 :
975 8 : let make_params = |options| StartupMessageParams::new([("options", options)]);
976 :
977 2 : let params = StartupMessageParams::new([]);
978 2 : assert!(params.options_escaped().is_none());
979 :
980 2 : let params = make_params("");
981 2 : assert!(split_options(¶ms).is_empty());
982 :
983 2 : let params = make_params("foo");
984 2 : assert_eq!(split_options(¶ms), ["foo"]);
985 :
986 2 : let params = make_params(" foo bar ");
987 2 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
988 :
989 2 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
990 2 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
991 2 : }
992 :
993 : #[test]
994 2 : fn parse_fe_startup_packet_regression() {
995 2 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
996 2 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
997 2 : }
998 : }
|