Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use byteorder::{BigEndian, ReadBytesExt};
9 : use bytes::{Buf, BufMut, Bytes, BytesMut};
10 : use serde::{Deserialize, Serialize};
11 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
12 :
13 : // re-export for use in utils pageserver_feedback.rs
14 : pub use postgres_protocol::PG_EPOCH;
15 :
16 : pub type Oid = u32;
17 : pub type SystemId = u64;
18 :
19 : pub const INT8_OID: Oid = 20;
20 : pub const INT4_OID: Oid = 23;
21 : pub const TEXT_OID: Oid = 25;
22 :
23 0 : #[derive(Debug)]
24 : pub enum FeMessage {
25 : // Simple query.
26 : Query(Bytes),
27 : // Extended query protocol.
28 : Parse(FeParseMessage),
29 : Describe(FeDescribeMessage),
30 : Bind(FeBindMessage),
31 : Execute(FeExecuteMessage),
32 : Close(FeCloseMessage),
33 : Sync,
34 : Terminate,
35 : CopyData(Bytes),
36 : CopyDone,
37 : CopyFail,
38 : PasswordMessage(Bytes),
39 : }
40 :
41 0 : #[derive(Debug)]
42 : pub enum FeStartupPacket {
43 : CancelRequest(CancelKeyData),
44 : SslRequest,
45 : GssEncRequest,
46 : StartupMessage {
47 : major_version: u32,
48 : minor_version: u32,
49 : params: StartupMessageParams,
50 : },
51 : }
52 :
53 0 : #[derive(Debug)]
54 : pub struct StartupMessageParams {
55 : params: HashMap<String, String>,
56 : }
57 :
58 : impl StartupMessageParams {
59 : /// Get parameter's value by its name.
60 148 : pub fn get(&self, name: &str) -> Option<&str> {
61 148 : self.params.get(name).map(|s| s.as_str())
62 148 : }
63 :
64 : /// Split command-line options according to PostgreSQL's logic,
65 : /// taking into account all escape sequences but leaving them as-is.
66 : /// [`None`] means that there's no `options` in [`Self`].
67 60 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
68 60 : self.get("options").map(Self::parse_options_raw)
69 60 : }
70 :
71 : /// Split command-line options according to PostgreSQL's logic,
72 : /// applying all escape sequences (using owned strings as needed).
73 : /// [`None`] means that there's no `options` in [`Self`].
74 10 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
75 10 : self.get("options").map(Self::parse_options_escaped)
76 10 : }
77 :
78 : /// Split command-line options according to PostgreSQL's logic,
79 : /// taking into account all escape sequences but leaving them as-is.
80 60 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
81 60 : // See `postgres: pg_split_opts`.
82 60 : let mut last_was_escape = false;
83 60 : input
84 1128 : .split(move |c: char| {
85 : // We split by non-escaped whitespace symbols.
86 1128 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
87 1128 : last_was_escape = c == '\\' && !last_was_escape;
88 1128 : should_split
89 1128 : })
90 152 : .filter(|s| !s.is_empty())
91 60 : }
92 :
93 : /// Split command-line options according to PostgreSQL's logic,
94 : /// applying all escape sequences (using owned strings as needed).
95 8 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
96 8 : // See `postgres: pg_split_opts`.
97 14 : Self::parse_options_raw(input).map(|s| {
98 14 : let mut preserve_next_escape = false;
99 34 : let escape = |c| {
100 : // We should remove '\\' unless it's preceded by '\\'.
101 34 : let should_remove = c == '\\' && !preserve_next_escape;
102 34 : preserve_next_escape = should_remove;
103 34 : should_remove
104 34 : };
105 :
106 14 : match s.contains('\\') {
107 6 : true => Cow::Owned(s.replace(escape, "")),
108 8 : false => Cow::Borrowed(s),
109 : }
110 14 : })
111 8 : }
112 :
113 : /// Iterate through key-value pairs in an arbitrary order.
114 14 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
115 42 : self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
116 14 : }
117 :
118 : // This function is mostly useful in tests.
119 : #[doc(hidden)]
120 46 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
121 46 : Self {
122 62 : params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
123 46 : }
124 46 : }
125 : }
126 :
127 20 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
128 : pub struct CancelKeyData {
129 : pub backend_pid: i32,
130 : pub cancel_key: i32,
131 : }
132 :
133 : impl fmt::Display for CancelKeyData {
134 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 0 : let hi = (self.backend_pid as u64) << 32;
136 0 : let lo = self.cancel_key as u64;
137 0 : let id = hi | lo;
138 0 :
139 0 : // This format is more compact and might work better for logs.
140 0 : f.debug_tuple("CancelKeyData")
141 0 : .field(&format_args!("{:x}", id))
142 0 : .finish()
143 0 : }
144 : }
145 :
146 : use rand::distributions::{Distribution, Standard};
147 : impl Distribution<CancelKeyData> for Standard {
148 2 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
149 2 : CancelKeyData {
150 2 : backend_pid: rng.gen(),
151 2 : cancel_key: rng.gen(),
152 2 : }
153 2 : }
154 : }
155 :
156 : // We only support the simple case of Parse on unnamed prepared statement and
157 : // no params
158 0 : #[derive(Debug)]
159 : pub struct FeParseMessage {
160 : pub query_string: Bytes,
161 : }
162 :
163 0 : #[derive(Debug)]
164 : pub struct FeDescribeMessage {
165 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
166 : // we only support unnamed prepared stmt or portal
167 : }
168 :
169 : // we only support unnamed prepared stmt and portal
170 0 : #[derive(Debug)]
171 : pub struct FeBindMessage;
172 :
173 : // we only support unnamed prepared stmt or portal
174 0 : #[derive(Debug)]
175 : pub struct FeExecuteMessage {
176 : /// max # of rows
177 : pub maxrows: i32,
178 : }
179 :
180 : // we only support unnamed prepared stmt and portal
181 0 : #[derive(Debug)]
182 : pub struct FeCloseMessage;
183 :
184 : /// An error occurred while parsing or serializing raw stream into Postgres
185 : /// messages.
186 0 : #[derive(thiserror::Error, Debug)]
187 : pub enum ProtocolError {
188 : /// Invalid packet was received from the client (e.g. unexpected message
189 : /// type or broken len).
190 : #[error("Protocol error: {0}")]
191 : Protocol(String),
192 : /// Failed to parse or, (unlikely), serialize a protocol message.
193 : #[error("Message parse error: {0}")]
194 : BadMessage(String),
195 : }
196 :
197 : impl ProtocolError {
198 : /// Proxy stream.rs uses only io::Error; provide it.
199 0 : pub fn into_io_error(self) -> io::Error {
200 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
201 0 : }
202 : }
203 :
204 : impl FeMessage {
205 : /// Read and parse one message from the `buf` input buffer. If there is at
206 : /// least one valid message, returns it, advancing `buf`; redundant copies
207 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
208 : /// directly into the `buf` (processed data is garbage collected after
209 : /// parsed message is dropped).
210 : ///
211 : /// Returns None if `buf` doesn't contain enough data for a single message.
212 : /// For efficiency, tries to reserve large enough space in `buf` for the
213 : /// next message in this case to save the repeated calls.
214 : ///
215 : /// Returns Error if message is malformed, the only possible ErrorKind is
216 : /// InvalidInput.
217 : //
218 : // Inspired by rust-postgres Message::parse.
219 98 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
220 98 : // Every message contains message type byte and 4 bytes len; can't do
221 98 : // much without them.
222 98 : if buf.len() < 5 {
223 52 : let to_read = 5 - buf.len();
224 52 : buf.reserve(to_read);
225 52 : return Ok(None);
226 46 : }
227 46 :
228 46 : // We shouldn't advance `buf` as probably full message is not there yet,
229 46 : // so can't directly use Bytes::get_u32 etc.
230 46 : let tag = buf[0];
231 46 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
232 46 : if len < 4 {
233 0 : return Err(ProtocolError::Protocol(format!(
234 0 : "invalid message length {}",
235 0 : len
236 0 : )));
237 46 : }
238 46 :
239 46 : // length field includes itself, but not message type.
240 46 : let total_len = len as usize + 1;
241 46 : if buf.len() < total_len {
242 : // Don't have full message yet.
243 0 : let to_read = total_len - buf.len();
244 0 : buf.reserve(to_read);
245 0 : return Ok(None);
246 46 : }
247 46 :
248 46 : // got the message, advance buffer
249 46 : let mut msg = buf.split_to(total_len).freeze();
250 46 : msg.advance(5); // consume message type and len
251 46 :
252 46 : match tag {
253 4 : b'Q' => Ok(Some(FeMessage::Query(msg))),
254 0 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
255 0 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
256 0 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
257 0 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
258 0 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
259 0 : b'S' => Ok(Some(FeMessage::Sync)),
260 0 : b'X' => Ok(Some(FeMessage::Terminate)),
261 0 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
262 0 : b'c' => Ok(Some(FeMessage::CopyDone)),
263 0 : b'f' => Ok(Some(FeMessage::CopyFail)),
264 42 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
265 0 : tag => Err(ProtocolError::Protocol(format!(
266 0 : "unknown message tag: {tag},'{msg:?}'"
267 0 : ))),
268 : }
269 98 : }
270 : }
271 :
272 : impl FeStartupPacket {
273 : /// Read and parse startup message from the `buf` input buffer. It is
274 : /// different from [`FeMessage::parse`] because startup messages don't have
275 : /// message type byte; otherwise, its comments apply.
276 182 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
277 182 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
278 182 : const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
279 182 : const CANCEL_REQUEST_CODE: u32 = 5678;
280 182 : const NEGOTIATE_SSL_CODE: u32 = 5679;
281 182 : const NEGOTIATE_GSS_CODE: u32 = 5680;
282 182 :
283 182 : // need at least 4 bytes with packet len
284 182 : if buf.len() < 4 {
285 90 : let to_read = 4 - buf.len();
286 90 : buf.reserve(to_read);
287 90 : return Ok(None);
288 92 : }
289 92 :
290 92 : // We shouldn't advance `buf` as probably full message is not there yet,
291 92 : // so can't directly use Bytes::get_u32 etc.
292 92 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
293 92 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
294 92 : // which is less readable
295 92 : #[allow(clippy::manual_range_contains)]
296 92 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
297 2 : return Err(ProtocolError::Protocol(format!(
298 2 : "invalid startup packet message length {}",
299 2 : len
300 2 : )));
301 90 : }
302 90 :
303 90 : if buf.len() < len {
304 : // Don't have full message yet.
305 0 : let to_read = len - buf.len();
306 0 : buf.reserve(to_read);
307 0 : return Ok(None);
308 90 : }
309 90 :
310 90 : // got the message, advance buffer
311 90 : let mut msg = buf.split_to(len).freeze();
312 90 : msg.advance(4); // consume len
313 90 :
314 90 : let request_code = msg.get_u32();
315 90 : let req_hi = request_code >> 16;
316 90 : let req_lo = request_code & ((1 << 16) - 1);
317 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
318 90 : let message = match (req_hi, req_lo) {
319 : (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
320 0 : if msg.remaining() != 8 {
321 0 : return Err(ProtocolError::BadMessage(
322 0 : "CancelRequest message is malformed, backend PID / secret key missing"
323 0 : .to_owned(),
324 0 : ));
325 0 : }
326 0 : FeStartupPacket::CancelRequest(CancelKeyData {
327 0 : backend_pid: msg.get_i32(),
328 0 : cancel_key: msg.get_i32(),
329 0 : })
330 : }
331 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
332 : // Requested upgrade to SSL (aka TLS)
333 42 : FeStartupPacket::SslRequest
334 : }
335 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
336 : // Requested upgrade to GSSAPI
337 0 : FeStartupPacket::GssEncRequest
338 : }
339 0 : (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
340 0 : return Err(ProtocolError::Protocol(format!(
341 0 : "Unrecognized request code {unrecognized_code}"
342 0 : )));
343 : }
344 : // TODO bail if protocol major_version is not 3?
345 48 : (major_version, minor_version) => {
346 : // StartupMessage
347 :
348 : // Parse pairs of null-terminated strings (key, value).
349 : // See `postgres: ProcessStartupPacket, build_startup_packet`.
350 48 : let mut tokens = str::from_utf8(&msg)
351 48 : .map_err(|_e| {
352 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
353 48 : })?
354 48 : .strip_suffix('\0') // drop packet's own null
355 48 : .ok_or_else(|| {
356 0 : ProtocolError::Protocol(
357 0 : "StartupMessage params: missing null terminator".to_string(),
358 0 : )
359 48 : })?
360 48 : .split_terminator('\0');
361 48 :
362 48 : let mut params = HashMap::new();
363 186 : while let Some(name) = tokens.next() {
364 138 : let value = tokens.next().ok_or_else(|| {
365 0 : ProtocolError::Protocol(
366 0 : "StartupMessage params: key without value".to_string(),
367 0 : )
368 138 : })?;
369 :
370 138 : params.insert(name.to_owned(), value.to_owned());
371 : }
372 :
373 48 : FeStartupPacket::StartupMessage {
374 48 : major_version,
375 48 : minor_version,
376 48 : params: StartupMessageParams { params },
377 48 : }
378 : }
379 : };
380 90 : Ok(Some(message))
381 182 : }
382 : }
383 :
384 : impl FeParseMessage {
385 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
386 : // FIXME: the rust-postgres driver uses a named prepared statement
387 : // for copy_out(). We're not prepared to handle that correctly. For
388 : // now, just ignore the statement name, assuming that the client never
389 : // uses more than one prepared statement at a time.
390 :
391 0 : let _pstmt_name = read_cstr(&mut buf)?;
392 0 : let query_string = read_cstr(&mut buf)?;
393 0 : if buf.remaining() < 2 {
394 0 : return Err(ProtocolError::BadMessage(
395 0 : "Parse message is malformed, nparams missing".to_string(),
396 0 : ));
397 0 : }
398 0 : let nparams = buf.get_i16();
399 0 :
400 0 : if nparams != 0 {
401 0 : return Err(ProtocolError::BadMessage(
402 0 : "query params not implemented".to_string(),
403 0 : ));
404 0 : }
405 0 :
406 0 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
407 0 : }
408 : }
409 :
410 : impl FeDescribeMessage {
411 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
412 0 : let kind = buf.get_u8();
413 0 : let _pstmt_name = read_cstr(&mut buf)?;
414 :
415 : // FIXME: see FeParseMessage::parse
416 0 : if kind != b'S' {
417 0 : return Err(ProtocolError::BadMessage(
418 0 : "only prepared statemement Describe is implemented".to_string(),
419 0 : ));
420 0 : }
421 0 :
422 0 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
423 0 : }
424 : }
425 :
426 : impl FeExecuteMessage {
427 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
428 0 : let portal_name = read_cstr(&mut buf)?;
429 0 : if buf.remaining() < 4 {
430 0 : return Err(ProtocolError::BadMessage(
431 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
432 0 : ));
433 0 : }
434 0 : let maxrows = buf.get_i32();
435 0 :
436 0 : if !portal_name.is_empty() {
437 0 : return Err(ProtocolError::BadMessage(
438 0 : "named portals not implemented".to_string(),
439 0 : ));
440 0 : }
441 0 : if maxrows != 0 {
442 0 : return Err(ProtocolError::BadMessage(
443 0 : "row limit in Execute message not implemented".to_string(),
444 0 : ));
445 0 : }
446 0 :
447 0 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
448 0 : }
449 : }
450 :
451 : impl FeBindMessage {
452 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
453 0 : let portal_name = read_cstr(&mut buf)?;
454 0 : let _pstmt_name = read_cstr(&mut buf)?;
455 :
456 : // FIXME: see FeParseMessage::parse
457 0 : if !portal_name.is_empty() {
458 0 : return Err(ProtocolError::BadMessage(
459 0 : "named portals not implemented".to_string(),
460 0 : ));
461 0 : }
462 0 :
463 0 : Ok(FeMessage::Bind(FeBindMessage))
464 0 : }
465 : }
466 :
467 : impl FeCloseMessage {
468 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
469 0 : let _kind = buf.get_u8();
470 0 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
471 :
472 : // FIXME: we do nothing with Close
473 0 : Ok(FeMessage::Close(FeCloseMessage))
474 0 : }
475 : }
476 :
477 : // Backend
478 :
479 0 : #[derive(Debug)]
480 : pub enum BeMessage<'a> {
481 : AuthenticationOk,
482 : AuthenticationMD5Password([u8; 4]),
483 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
484 : AuthenticationCleartextPassword,
485 : BackendKeyData(CancelKeyData),
486 : BindComplete,
487 : CommandComplete(&'a [u8]),
488 : CopyData(&'a [u8]),
489 : CopyDone,
490 : CopyFail,
491 : CopyInResponse,
492 : CopyOutResponse,
493 : CopyBothResponse,
494 : CloseComplete,
495 : // None means column is NULL
496 : DataRow(&'a [Option<&'a [u8]>]),
497 : // None errcode means internal_error will be sent.
498 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
499 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
500 : EncryptionResponse(bool),
501 : NoData,
502 : ParameterDescription,
503 : ParameterStatus {
504 : name: &'a [u8],
505 : value: &'a [u8],
506 : },
507 : ParseComplete,
508 : ReadyForQuery,
509 : RowDescription(&'a [RowDescriptor<'a>]),
510 : XLogData(XLogDataBody<'a>),
511 : NoticeResponse(&'a str),
512 : KeepAlive(WalSndKeepAlive),
513 : }
514 :
515 : /// Common shorthands.
516 : impl<'a> BeMessage<'a> {
517 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
518 : /// This is a sensible default, given that:
519 : /// * rust strings only support this encoding out of the box.
520 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
521 : ///
522 : /// TODO: do we need to report `server_encoding` as well?
523 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
524 : name: b"client_encoding",
525 : value: b"UTF8",
526 : };
527 :
528 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
529 : name: b"integer_datetimes",
530 : value: b"on",
531 : };
532 :
533 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
534 4 : pub fn server_version(version: &'a str) -> Self {
535 4 : Self::ParameterStatus {
536 4 : name: b"server_version",
537 4 : value: version.as_bytes(),
538 4 : }
539 4 : }
540 : }
541 :
542 0 : #[derive(Debug)]
543 : pub enum BeAuthenticationSaslMessage<'a> {
544 : Methods(&'a [&'a str]),
545 : Continue(&'a [u8]),
546 : Final(&'a [u8]),
547 : }
548 :
549 0 : #[derive(Debug)]
550 : pub enum BeParameterStatusMessage<'a> {
551 : Encoding(&'a str),
552 : ServerVersion(&'a str),
553 : }
554 :
555 : // One row description in RowDescription packet.
556 0 : #[derive(Debug)]
557 : pub struct RowDescriptor<'a> {
558 : pub name: &'a [u8],
559 : pub tableoid: Oid,
560 : pub attnum: i16,
561 : pub typoid: Oid,
562 : pub typlen: i16,
563 : pub typmod: i32,
564 : pub formatcode: i16,
565 : }
566 :
567 : impl Default for RowDescriptor<'_> {
568 0 : fn default() -> RowDescriptor<'static> {
569 0 : RowDescriptor {
570 0 : name: b"",
571 0 : tableoid: 0,
572 0 : attnum: 0,
573 0 : typoid: 0,
574 0 : typlen: 0,
575 0 : typmod: 0,
576 0 : formatcode: 0,
577 0 : }
578 0 : }
579 : }
580 :
581 : impl RowDescriptor<'_> {
582 : /// Convenience function to create a RowDescriptor message for an int8 column
583 0 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
584 0 : RowDescriptor {
585 0 : name,
586 0 : tableoid: 0,
587 0 : attnum: 0,
588 0 : typoid: INT8_OID,
589 0 : typlen: 8,
590 0 : typmod: 0,
591 0 : formatcode: 0,
592 0 : }
593 0 : }
594 :
595 4 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
596 4 : RowDescriptor {
597 4 : name,
598 4 : tableoid: 0,
599 4 : attnum: 0,
600 4 : typoid: TEXT_OID,
601 4 : typlen: -1,
602 4 : typmod: 0,
603 4 : formatcode: 0,
604 4 : }
605 4 : }
606 : }
607 :
608 0 : #[derive(Debug)]
609 : pub struct XLogDataBody<'a> {
610 : pub wal_start: u64,
611 : pub wal_end: u64, // current end of WAL on the server
612 : pub timestamp: i64,
613 : pub data: &'a [u8],
614 : }
615 :
616 0 : #[derive(Debug)]
617 : pub struct WalSndKeepAlive {
618 : pub wal_end: u64, // current end of WAL on the server
619 : pub timestamp: i64,
620 : pub request_reply: bool,
621 : }
622 :
623 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
624 :
625 : // single text column
626 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
627 : name: b"data",
628 : tableoid: 0,
629 : attnum: 0,
630 : typoid: TEXT_OID,
631 : typlen: -1,
632 : typmod: 0,
633 : formatcode: 0,
634 : }]);
635 :
636 : /// Call f() to write body of the message and prepend it with 4-byte len as
637 : /// prescribed by the protocol.
638 134 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
639 134 : let base = buf.len();
640 134 : buf.extend_from_slice(&[0; 4]);
641 134 :
642 134 : let res = f(buf);
643 134 :
644 134 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
645 134 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
646 134 :
647 134 : res
648 134 : }
649 :
650 : /// Safe write of s into buf as cstring (String in the protocol).
651 110 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
652 110 : let bytes = s.as_ref();
653 110 : if bytes.contains(&0) {
654 0 : return Err(ProtocolError::BadMessage(
655 0 : "string contains embedded null".to_owned(),
656 0 : ));
657 110 : }
658 110 : buf.put_slice(bytes);
659 110 : buf.put_u8(0);
660 110 : Ok(())
661 110 : }
662 :
663 : /// Read cstring from buf, advancing it.
664 22 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
665 22 : let pos = buf
666 22 : .iter()
667 312 : .position(|x| *x == 0)
668 22 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
669 22 : let result = buf.split_to(pos);
670 22 : buf.advance(1); // drop the null terminator
671 22 : Ok(result)
672 22 : }
673 :
674 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
675 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
676 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
677 :
678 : impl<'a> BeMessage<'a> {
679 : /// Serialize `message` to the given `buf`.
680 : /// Apart from smart memory managemet, BytesMut is good here as msg len
681 : /// precedes its body and it is handy to write it down first and then fill
682 : /// the length. With Write we would have to either calc it manually or have
683 : /// one more buffer.
684 176 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
685 176 : match message {
686 18 : BeMessage::AuthenticationOk => {
687 18 : buf.put_u8(b'R');
688 18 : write_body(buf, |buf| {
689 18 : buf.put_i32(0); // Specifies that the authentication was successful.
690 18 : });
691 18 : }
692 :
693 0 : BeMessage::AuthenticationCleartextPassword => {
694 0 : buf.put_u8(b'R');
695 0 : write_body(buf, |buf| {
696 0 : buf.put_i32(3); // Specifies that clear text password is required.
697 0 : });
698 0 : }
699 :
700 0 : BeMessage::AuthenticationMD5Password(salt) => {
701 0 : buf.put_u8(b'R');
702 0 : write_body(buf, |buf| {
703 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
704 0 : buf.put_slice(&salt[..]);
705 0 : });
706 0 : }
707 :
708 54 : BeMessage::AuthenticationSasl(msg) => {
709 54 : buf.put_u8(b'R');
710 54 : write_body(buf, |buf| {
711 54 : use BeAuthenticationSaslMessage::*;
712 54 : match msg {
713 24 : Methods(methods) => {
714 24 : buf.put_i32(10); // Specifies that SASL auth method is used.
715 48 : for method in methods.iter() {
716 48 : write_cstr(method, buf)?;
717 : }
718 24 : buf.put_u8(0); // zero terminator for the list
719 : }
720 20 : Continue(extra) => {
721 20 : buf.put_i32(11); // Continue SASL auth.
722 20 : buf.put_slice(extra);
723 20 : }
724 10 : Final(extra) => {
725 10 : buf.put_i32(12); // Send final SASL message.
726 10 : buf.put_slice(extra);
727 10 : }
728 : }
729 54 : Ok(())
730 54 : })?;
731 : }
732 :
733 0 : BeMessage::BackendKeyData(key_data) => {
734 0 : buf.put_u8(b'K');
735 0 : write_body(buf, |buf| {
736 0 : buf.put_i32(key_data.backend_pid);
737 0 : buf.put_i32(key_data.cancel_key);
738 0 : });
739 0 : }
740 :
741 0 : BeMessage::BindComplete => {
742 0 : buf.put_u8(b'2');
743 0 : write_body(buf, |_| {});
744 0 : }
745 :
746 0 : BeMessage::CloseComplete => {
747 0 : buf.put_u8(b'3');
748 0 : write_body(buf, |_| {});
749 0 : }
750 :
751 4 : BeMessage::CommandComplete(cmd) => {
752 4 : buf.put_u8(b'C');
753 4 : write_body(buf, |buf| write_cstr(cmd, buf))?;
754 : }
755 :
756 0 : BeMessage::CopyData(data) => {
757 0 : buf.put_u8(b'd');
758 0 : write_body(buf, |buf| {
759 0 : buf.put_slice(data);
760 0 : });
761 0 : }
762 :
763 0 : BeMessage::CopyDone => {
764 0 : buf.put_u8(b'c');
765 0 : write_body(buf, |_| {});
766 0 : }
767 :
768 0 : BeMessage::CopyFail => {
769 0 : buf.put_u8(b'f');
770 0 : write_body(buf, |_| {});
771 0 : }
772 :
773 0 : BeMessage::CopyInResponse => {
774 0 : buf.put_u8(b'G');
775 0 : write_body(buf, |buf| {
776 0 : buf.put_u8(1); // copy_is_binary
777 0 : buf.put_i16(0); // numAttributes
778 0 : });
779 0 : }
780 :
781 0 : BeMessage::CopyOutResponse => {
782 0 : buf.put_u8(b'H');
783 0 : write_body(buf, |buf| {
784 0 : buf.put_u8(0); // copy_is_binary
785 0 : buf.put_i16(0); // numAttributes
786 0 : });
787 0 : }
788 :
789 0 : BeMessage::CopyBothResponse => {
790 0 : buf.put_u8(b'W');
791 0 : write_body(buf, |buf| {
792 0 : // doesn't matter, used only for replication
793 0 : buf.put_u8(0); // copy_is_binary
794 0 : buf.put_i16(0); // numAttributes
795 0 : });
796 0 : }
797 :
798 4 : BeMessage::DataRow(vals) => {
799 4 : buf.put_u8(b'D');
800 4 : write_body(buf, |buf| {
801 4 : buf.put_u16(vals.len() as u16); // num of cols
802 4 : for val_opt in vals.iter() {
803 4 : if let Some(val) = val_opt {
804 4 : buf.put_u32(val.len() as u32);
805 4 : buf.put_slice(val);
806 4 : } else {
807 0 : buf.put_i32(-1);
808 0 : }
809 : }
810 4 : });
811 4 : }
812 :
813 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
814 : // First byte of each field represents type of this field. Set just enough fields
815 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
816 : // message text.
817 2 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
818 2 : // 'E' signalizes ErrorResponse messages
819 2 : buf.put_u8(b'E');
820 2 : write_body(buf, |buf| {
821 2 : buf.put_u8(b'S'); // severity
822 2 : buf.put_slice(b"ERROR\0");
823 2 :
824 2 : buf.put_u8(b'C'); // SQLSTATE error code
825 2 : buf.put_slice(&terminate_code(
826 2 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
827 2 : ));
828 2 :
829 2 : buf.put_u8(b'M'); // the message
830 2 : write_cstr(error_msg, buf)?;
831 :
832 2 : buf.put_u8(0); // terminator
833 2 : Ok(())
834 2 : })?;
835 : }
836 :
837 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
838 : // message but continue listening for ReadyForQuery or ErrorResponse"
839 0 : BeMessage::NoticeResponse(error_msg) => {
840 0 : // For all the errors set Severity to Error and error code to
841 0 : // 'internal error'.
842 0 :
843 0 : // 'N' signalizes NoticeResponse messages
844 0 : buf.put_u8(b'N');
845 0 : write_body(buf, |buf| {
846 0 : buf.put_u8(b'S'); // severity
847 0 : buf.put_slice(b"NOTICE\0");
848 0 :
849 0 : buf.put_u8(b'C'); // SQLSTATE error code
850 0 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
851 0 :
852 0 : buf.put_u8(b'M'); // the message
853 0 : write_cstr(error_msg.as_bytes(), buf)?;
854 :
855 0 : buf.put_u8(0); // terminator
856 0 : Ok(())
857 0 : })?;
858 : }
859 :
860 0 : BeMessage::NoData => {
861 0 : buf.put_u8(b'n');
862 0 : write_body(buf, |_| {});
863 0 : }
864 :
865 42 : BeMessage::EncryptionResponse(should_negotiate) => {
866 42 : let response = if *should_negotiate { b'S' } else { b'N' };
867 42 : buf.put_u8(response);
868 : }
869 :
870 26 : BeMessage::ParameterStatus { name, value } => {
871 26 : buf.put_u8(b'S');
872 26 : write_body(buf, |buf| {
873 26 : write_cstr(name, buf)?;
874 26 : write_cstr(value, buf)
875 26 : })?;
876 : }
877 :
878 0 : BeMessage::ParameterDescription => {
879 0 : buf.put_u8(b't');
880 0 : write_body(buf, |buf| {
881 0 : // we don't support params, so always 0
882 0 : buf.put_i16(0);
883 0 : });
884 0 : }
885 :
886 0 : BeMessage::ParseComplete => {
887 0 : buf.put_u8(b'1');
888 0 : write_body(buf, |_| {});
889 0 : }
890 :
891 22 : BeMessage::ReadyForQuery => {
892 22 : buf.put_u8(b'Z');
893 22 : write_body(buf, |buf| {
894 22 : buf.put_u8(b'I');
895 22 : });
896 22 : }
897 :
898 4 : BeMessage::RowDescription(rows) => {
899 4 : buf.put_u8(b'T');
900 4 : write_body(buf, |buf| {
901 4 : buf.put_i16(rows.len() as i16); // # of fields
902 4 : for row in rows.iter() {
903 4 : write_cstr(row.name, buf)?;
904 4 : buf.put_i32(0); /* table oid */
905 4 : buf.put_i16(0); /* attnum */
906 4 : buf.put_u32(row.typoid);
907 4 : buf.put_i16(row.typlen);
908 4 : buf.put_i32(-1); /* typmod */
909 4 : buf.put_i16(0); /* format code */
910 : }
911 4 : Ok(())
912 4 : })?;
913 : }
914 :
915 0 : BeMessage::XLogData(body) => {
916 0 : buf.put_u8(b'd');
917 0 : write_body(buf, |buf| {
918 0 : buf.put_u8(b'w');
919 0 : buf.put_u64(body.wal_start);
920 0 : buf.put_u64(body.wal_end);
921 0 : buf.put_i64(body.timestamp);
922 0 : buf.put_slice(body.data);
923 0 : });
924 0 : }
925 :
926 0 : BeMessage::KeepAlive(req) => {
927 0 : buf.put_u8(b'd');
928 0 : write_body(buf, |buf| {
929 0 : buf.put_u8(b'k');
930 0 : buf.put_u64(req.wal_end);
931 0 : buf.put_i64(req.timestamp);
932 0 : buf.put_u8(u8::from(req.request_reply));
933 0 : });
934 0 : }
935 : }
936 176 : Ok(())
937 176 : }
938 : }
939 :
940 2 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
941 2 : let mut terminated = [0; 6];
942 10 : for (i, &elem) in code.iter().enumerate() {
943 10 : terminated[i] = elem;
944 10 : }
945 :
946 2 : terminated
947 2 : }
948 :
949 : #[cfg(test)]
950 : mod tests {
951 : use super::*;
952 :
953 2 : #[test]
954 2 : fn test_startup_message_params_options_escaped() {
955 8 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
956 8 : params
957 8 : .options_escaped()
958 8 : .expect("options are None")
959 8 : .collect()
960 8 : }
961 2 :
962 8 : let make_params = |options| StartupMessageParams::new([("options", options)]);
963 :
964 2 : let params = StartupMessageParams::new([]);
965 2 : assert!(params.options_escaped().is_none());
966 :
967 2 : let params = make_params("");
968 2 : assert!(split_options(¶ms).is_empty());
969 :
970 2 : let params = make_params("foo");
971 2 : assert_eq!(split_options(¶ms), ["foo"]);
972 :
973 2 : let params = make_params(" foo bar ");
974 2 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
975 :
976 2 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
977 2 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
978 2 : }
979 :
980 2 : #[test]
981 2 : fn parse_fe_startup_packet_regression() {
982 2 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
983 2 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
984 2 : }
985 : }
|