Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use byteorder::{BigEndian, ReadBytesExt};
9 : use bytes::{Buf, BufMut, Bytes, BytesMut};
10 : use serde::{Deserialize, Serialize};
11 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
12 :
13 : // re-export for use in utils pageserver_feedback.rs
14 : pub use postgres_protocol::PG_EPOCH;
15 :
16 : pub type Oid = u32;
17 : pub type SystemId = u64;
18 :
19 : pub const INT8_OID: Oid = 20;
20 : pub const INT4_OID: Oid = 23;
21 : pub const TEXT_OID: Oid = 25;
22 :
23 46 : #[derive(Debug)]
24 : pub enum FeMessage {
25 : // Simple query.
26 : Query(Bytes),
27 : // Extended query protocol.
28 : Parse(FeParseMessage),
29 : Describe(FeDescribeMessage),
30 : Bind(FeBindMessage),
31 : Execute(FeExecuteMessage),
32 : Close(FeCloseMessage),
33 : Sync,
34 : Terminate,
35 : CopyData(Bytes),
36 : CopyDone,
37 : CopyFail,
38 : PasswordMessage(Bytes),
39 : }
40 :
41 230 : #[derive(Debug)]
42 : pub enum FeStartupPacket {
43 : CancelRequest(CancelKeyData),
44 : SslRequest,
45 : GssEncRequest,
46 : StartupMessage {
47 : major_version: u32,
48 : minor_version: u32,
49 : params: StartupMessageParams,
50 : },
51 : }
52 :
53 126 : #[derive(Debug)]
54 : pub struct StartupMessageParams {
55 : params: HashMap<String, String>,
56 : }
57 :
58 : impl StartupMessageParams {
59 : /// Get parameter's value by its name.
60 7242 : pub fn get(&self, name: &str) -> Option<&str> {
61 7242 : self.params.get(name).map(|s| s.as_str())
62 7242 : }
63 :
64 : /// Split command-line options according to PostgreSQL's logic,
65 : /// taking into account all escape sequences but leaving them as-is.
66 : /// [`None`] means that there's no `options` in [`Self`].
67 3510 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
68 3510 : self.get("options").map(Self::parse_options_raw)
69 3510 : }
70 :
71 : /// Split command-line options according to PostgreSQL's logic,
72 : /// applying all escape sequences (using owned strings as needed).
73 : /// [`None`] means that there's no `options` in [`Self`].
74 10 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
75 10 : self.get("options").map(Self::parse_options_escaped)
76 10 : }
77 :
78 : /// Split command-line options according to PostgreSQL's logic,
79 : /// taking into account all escape sequences but leaving them as-is.
80 3514 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
81 3514 : // See `postgres: pg_split_opts`.
82 3514 : let mut last_was_escape = false;
83 3514 : input
84 304287 : .split(move |c: char| {
85 : // We split by non-escaped whitespace symbols.
86 304287 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
87 304287 : last_was_escape = c == '\\' && !last_was_escape;
88 304287 : should_split
89 304287 : })
90 10393 : .filter(|s| !s.is_empty())
91 3514 : }
92 :
93 : /// Split command-line options according to PostgreSQL's logic,
94 : /// applying all escape sequences (using owned strings as needed).
95 8 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
96 8 : // See `postgres: pg_split_opts`.
97 14 : Self::parse_options_raw(input).map(|s| {
98 14 : let mut preserve_next_escape = false;
99 34 : let escape = |c| {
100 : // We should remove '\\' unless it's preceded by '\\'.
101 34 : let should_remove = c == '\\' && !preserve_next_escape;
102 34 : preserve_next_escape = should_remove;
103 34 : should_remove
104 34 : };
105 :
106 14 : match s.contains('\\') {
107 6 : true => Cow::Owned(s.replace(escape, "")),
108 8 : false => Cow::Borrowed(s),
109 : }
110 14 : })
111 8 : }
112 :
113 : /// Iterate through key-value pairs in an arbitrary order.
114 14 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
115 42 : self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
116 14 : }
117 :
118 : // This function is mostly useful in tests.
119 : #[doc(hidden)]
120 46 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
121 46 : Self {
122 62 : params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
123 46 : }
124 46 : }
125 : }
126 :
127 301 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
128 : pub struct CancelKeyData {
129 : pub backend_pid: i32,
130 : pub cancel_key: i32,
131 : }
132 :
133 : impl fmt::Display for CancelKeyData {
134 164 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 164 : let hi = (self.backend_pid as u64) << 32;
136 164 : let lo = self.cancel_key as u64;
137 164 : let id = hi | lo;
138 164 :
139 164 : // This format is more compact and might work better for logs.
140 164 : f.debug_tuple("CancelKeyData")
141 164 : .field(&format_args!("{:x}", id))
142 164 : .finish()
143 164 : }
144 : }
145 :
146 : use rand::distributions::{Distribution, Standard};
147 : impl Distribution<CancelKeyData> for Standard {
148 43 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
149 43 : CancelKeyData {
150 43 : backend_pid: rng.gen(),
151 43 : cancel_key: rng.gen(),
152 43 : }
153 43 : }
154 : }
155 :
156 : // We only support the simple case of Parse on unnamed prepared statement and
157 : // no params
158 0 : #[derive(Debug)]
159 : pub struct FeParseMessage {
160 : pub query_string: Bytes,
161 : }
162 :
163 0 : #[derive(Debug)]
164 : pub struct FeDescribeMessage {
165 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
166 : // we only support unnamed prepared stmt or portal
167 : }
168 :
169 : // we only support unnamed prepared stmt and portal
170 0 : #[derive(Debug)]
171 : pub struct FeBindMessage;
172 :
173 : // we only support unnamed prepared stmt or portal
174 0 : #[derive(Debug)]
175 : pub struct FeExecuteMessage {
176 : /// max # of rows
177 : pub maxrows: i32,
178 : }
179 :
180 : // we only support unnamed prepared stmt and portal
181 0 : #[derive(Debug)]
182 : pub struct FeCloseMessage;
183 :
184 : /// An error occurred while parsing or serializing raw stream into Postgres
185 : /// messages.
186 6 : #[derive(thiserror::Error, Debug)]
187 : pub enum ProtocolError {
188 : /// Invalid packet was received from the client (e.g. unexpected message
189 : /// type or broken len).
190 : #[error("Protocol error: {0}")]
191 : Protocol(String),
192 : /// Failed to parse or, (unlikely), serialize a protocol message.
193 : #[error("Message parse error: {0}")]
194 : BadMessage(String),
195 : }
196 :
197 : impl ProtocolError {
198 : /// Proxy stream.rs uses only io::Error; provide it.
199 0 : pub fn into_io_error(self) -> io::Error {
200 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
201 0 : }
202 : }
203 :
204 : impl FeMessage {
205 : /// Read and parse one message from the `buf` input buffer. If there is at
206 : /// least one valid message, returns it, advancing `buf`; redundant copies
207 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
208 : /// directly into the `buf` (processed data is garbage collected after
209 : /// parsed message is dropped).
210 : ///
211 : /// Returns None if `buf` doesn't contain enough data for a single message.
212 : /// For efficiency, tries to reserve large enough space in `buf` for the
213 : /// next message in this case to save the repeated calls.
214 : ///
215 : /// Returns Error if message is malformed, the only possible ErrorKind is
216 : /// InvalidInput.
217 : //
218 : // Inspired by rust-postgres Message::parse.
219 14509451 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
220 14509451 : // Every message contains message type byte and 4 bytes len; can't do
221 14509451 : // much without them.
222 14509451 : if buf.len() < 5 {
223 6600309 : let to_read = 5 - buf.len();
224 6600309 : buf.reserve(to_read);
225 6600309 : return Ok(None);
226 7909142 : }
227 7909142 :
228 7909142 : // We shouldn't advance `buf` as probably full message is not there yet,
229 7909142 : // so can't directly use Bytes::get_u32 etc.
230 7909142 : let tag = buf[0];
231 7909142 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
232 7909142 : if len < 4 {
233 0 : return Err(ProtocolError::Protocol(format!(
234 0 : "invalid message length {}",
235 0 : len
236 0 : )));
237 7909142 : }
238 7909142 :
239 7909142 : // length field includes itself, but not message type.
240 7909142 : let total_len = len as usize + 1;
241 7909142 : if buf.len() < total_len {
242 : // Don't have full message yet.
243 170419 : let to_read = total_len - buf.len();
244 170419 : buf.reserve(to_read);
245 170419 : return Ok(None);
246 7738723 : }
247 7738723 :
248 7738723 : // got the message, advance buffer
249 7738723 : let mut msg = buf.split_to(total_len).freeze();
250 7738723 : msg.advance(5); // consume message type and len
251 7738723 :
252 7738723 : match tag {
253 14089 : b'Q' => Ok(Some(FeMessage::Query(msg))),
254 630 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
255 630 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
256 630 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
257 630 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
258 629 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
259 1909 : b'S' => Ok(Some(FeMessage::Sync)),
260 2170 : b'X' => Ok(Some(FeMessage::Terminate)),
261 7716986 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
262 13 : b'c' => Ok(Some(FeMessage::CopyDone)),
263 73 : b'f' => Ok(Some(FeMessage::CopyFail)),
264 334 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
265 0 : tag => Err(ProtocolError::Protocol(format!(
266 0 : "unknown message tag: {tag},'{msg:?}'"
267 0 : ))),
268 : }
269 14509451 : }
270 : }
271 :
272 : impl FeStartupPacket {
273 : /// Read and parse startup message from the `buf` input buffer. It is
274 : /// different from [`FeMessage::parse`] because startup messages don't have
275 : /// message type byte; otherwise, its comments apply.
276 52036 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
277 52036 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
278 52036 : const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
279 52036 : const CANCEL_REQUEST_CODE: u32 = 5678;
280 52036 : const NEGOTIATE_SSL_CODE: u32 = 5679;
281 52036 : const NEGOTIATE_GSS_CODE: u32 = 5680;
282 52036 :
283 52036 : // need at least 4 bytes with packet len
284 52036 : if buf.len() < 4 {
285 26018 : let to_read = 4 - buf.len();
286 26018 : buf.reserve(to_read);
287 26018 : return Ok(None);
288 26018 : }
289 26018 :
290 26018 : // We shouldn't advance `buf` as probably full message is not there yet,
291 26018 : // so can't directly use Bytes::get_u32 etc.
292 26018 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
293 26018 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
294 26018 : // which is less readable
295 26018 : #[allow(clippy::manual_range_contains)]
296 26018 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
297 2 : return Err(ProtocolError::Protocol(format!(
298 2 : "invalid startup packet message length {}",
299 2 : len
300 2 : )));
301 26016 : }
302 26016 :
303 26016 : if buf.len() < len {
304 : // Don't have full message yet.
305 0 : let to_read = len - buf.len();
306 0 : buf.reserve(to_read);
307 0 : return Ok(None);
308 26016 : }
309 26016 :
310 26016 : // got the message, advance buffer
311 26016 : let mut msg = buf.split_to(len).freeze();
312 26016 : msg.advance(4); // consume len
313 26016 :
314 26016 : let request_code = msg.get_u32();
315 26016 : let req_hi = request_code >> 16;
316 26016 : let req_lo = request_code & ((1 << 16) - 1);
317 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
318 26016 : let message = match (req_hi, req_lo) {
319 : (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
320 0 : if msg.remaining() != 8 {
321 0 : return Err(ProtocolError::BadMessage(
322 0 : "CancelRequest message is malformed, backend PID / secret key missing"
323 0 : .to_owned(),
324 0 : ));
325 0 : }
326 0 : FeStartupPacket::CancelRequest(CancelKeyData {
327 0 : backend_pid: msg.get_i32(),
328 0 : cancel_key: msg.get_i32(),
329 0 : })
330 : }
331 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
332 : // Requested upgrade to SSL (aka TLS)
333 11974 : FeStartupPacket::SslRequest
334 : }
335 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
336 : // Requested upgrade to GSSAPI
337 0 : FeStartupPacket::GssEncRequest
338 : }
339 0 : (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
340 0 : return Err(ProtocolError::Protocol(format!(
341 0 : "Unrecognized request code {unrecognized_code}"
342 0 : )));
343 : }
344 : // TODO bail if protocol major_version is not 3?
345 14042 : (major_version, minor_version) => {
346 : // StartupMessage
347 :
348 : // Parse pairs of null-terminated strings (key, value).
349 : // See `postgres: ProcessStartupPacket, build_startup_packet`.
350 14042 : let mut tokens = str::from_utf8(&msg)
351 14042 : .map_err(|_e| {
352 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
353 14042 : })?
354 14042 : .strip_suffix('\0') // drop packet's own null
355 14042 : .ok_or_else(|| {
356 0 : ProtocolError::Protocol(
357 0 : "StartupMessage params: missing null terminator".to_string(),
358 0 : )
359 14042 : })?
360 14042 : .split_terminator('\0');
361 14042 :
362 14042 : let mut params = HashMap::new();
363 46390 : while let Some(name) = tokens.next() {
364 32348 : let value = tokens.next().ok_or_else(|| {
365 0 : ProtocolError::Protocol(
366 0 : "StartupMessage params: key without value".to_string(),
367 0 : )
368 32348 : })?;
369 :
370 32348 : params.insert(name.to_owned(), value.to_owned());
371 : }
372 :
373 14042 : FeStartupPacket::StartupMessage {
374 14042 : major_version,
375 14042 : minor_version,
376 14042 : params: StartupMessageParams { params },
377 14042 : }
378 : }
379 : };
380 26016 : Ok(Some(message))
381 52036 : }
382 : }
383 :
384 : impl FeParseMessage {
385 630 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
386 : // FIXME: the rust-postgres driver uses a named prepared statement
387 : // for copy_out(). We're not prepared to handle that correctly. For
388 : // now, just ignore the statement name, assuming that the client never
389 : // uses more than one prepared statement at a time.
390 :
391 630 : let _pstmt_name = read_cstr(&mut buf)?;
392 630 : let query_string = read_cstr(&mut buf)?;
393 630 : if buf.remaining() < 2 {
394 0 : return Err(ProtocolError::BadMessage(
395 0 : "Parse message is malformed, nparams missing".to_string(),
396 0 : ));
397 630 : }
398 630 : let nparams = buf.get_i16();
399 630 :
400 630 : if nparams != 0 {
401 0 : return Err(ProtocolError::BadMessage(
402 0 : "query params not implemented".to_string(),
403 0 : ));
404 630 : }
405 630 :
406 630 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
407 630 : }
408 : }
409 :
410 : impl FeDescribeMessage {
411 630 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
412 630 : let kind = buf.get_u8();
413 630 : let _pstmt_name = read_cstr(&mut buf)?;
414 :
415 : // FIXME: see FeParseMessage::parse
416 630 : if kind != b'S' {
417 0 : return Err(ProtocolError::BadMessage(
418 0 : "only prepared statemement Describe is implemented".to_string(),
419 0 : ));
420 630 : }
421 630 :
422 630 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
423 630 : }
424 : }
425 :
426 : impl FeExecuteMessage {
427 630 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
428 630 : let portal_name = read_cstr(&mut buf)?;
429 630 : if buf.remaining() < 4 {
430 0 : return Err(ProtocolError::BadMessage(
431 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
432 0 : ));
433 630 : }
434 630 : let maxrows = buf.get_i32();
435 630 :
436 630 : if !portal_name.is_empty() {
437 0 : return Err(ProtocolError::BadMessage(
438 0 : "named portals not implemented".to_string(),
439 0 : ));
440 630 : }
441 630 : if maxrows != 0 {
442 0 : return Err(ProtocolError::BadMessage(
443 0 : "row limit in Execute message not implemented".to_string(),
444 0 : ));
445 630 : }
446 630 :
447 630 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
448 630 : }
449 : }
450 :
451 : impl FeBindMessage {
452 630 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
453 630 : let portal_name = read_cstr(&mut buf)?;
454 630 : let _pstmt_name = read_cstr(&mut buf)?;
455 :
456 : // FIXME: see FeParseMessage::parse
457 630 : if !portal_name.is_empty() {
458 0 : return Err(ProtocolError::BadMessage(
459 0 : "named portals not implemented".to_string(),
460 0 : ));
461 630 : }
462 630 :
463 630 : Ok(FeMessage::Bind(FeBindMessage))
464 630 : }
465 : }
466 :
467 : impl FeCloseMessage {
468 629 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
469 629 : let _kind = buf.get_u8();
470 629 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
471 :
472 : // FIXME: we do nothing with Close
473 629 : Ok(FeMessage::Close(FeCloseMessage))
474 629 : }
475 : }
476 :
477 : // Backend
478 :
479 0 : #[derive(Debug)]
480 : pub enum BeMessage<'a> {
481 : AuthenticationOk,
482 : AuthenticationMD5Password([u8; 4]),
483 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
484 : AuthenticationCleartextPassword,
485 : BackendKeyData(CancelKeyData),
486 : BindComplete,
487 : CommandComplete(&'a [u8]),
488 : CopyData(&'a [u8]),
489 : CopyDone,
490 : CopyFail,
491 : CopyInResponse,
492 : CopyOutResponse,
493 : CopyBothResponse,
494 : CloseComplete,
495 : // None means column is NULL
496 : DataRow(&'a [Option<&'a [u8]>]),
497 : // None errcode means internal_error will be sent.
498 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
499 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
500 : EncryptionResponse(bool),
501 : NoData,
502 : ParameterDescription,
503 : ParameterStatus {
504 : name: &'a [u8],
505 : value: &'a [u8],
506 : },
507 : ParseComplete,
508 : ReadyForQuery,
509 : RowDescription(&'a [RowDescriptor<'a>]),
510 : XLogData(XLogDataBody<'a>),
511 : NoticeResponse(&'a str),
512 : KeepAlive(WalSndKeepAlive),
513 : }
514 :
515 : /// Common shorthands.
516 : impl<'a> BeMessage<'a> {
517 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
518 : /// This is a sensible default, given that:
519 : /// * rust strings only support this encoding out of the box.
520 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
521 : ///
522 : /// TODO: do we need to report `server_encoding` as well?
523 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
524 : name: b"client_encoding",
525 : value: b"UTF8",
526 : };
527 :
528 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
529 : name: b"integer_datetimes",
530 : value: b"on",
531 : };
532 :
533 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
534 13719 : pub fn server_version(version: &'a str) -> Self {
535 13719 : Self::ParameterStatus {
536 13719 : name: b"server_version",
537 13719 : value: version.as_bytes(),
538 13719 : }
539 13719 : }
540 : }
541 :
542 0 : #[derive(Debug)]
543 : pub enum BeAuthenticationSaslMessage<'a> {
544 : Methods(&'a [&'a str]),
545 : Continue(&'a [u8]),
546 : Final(&'a [u8]),
547 : }
548 :
549 0 : #[derive(Debug)]
550 : pub enum BeParameterStatusMessage<'a> {
551 : Encoding(&'a str),
552 : ServerVersion(&'a str),
553 : }
554 :
555 : // One row description in RowDescription packet.
556 0 : #[derive(Debug)]
557 : pub struct RowDescriptor<'a> {
558 : pub name: &'a [u8],
559 : pub tableoid: Oid,
560 : pub attnum: i16,
561 : pub typoid: Oid,
562 : pub typlen: i16,
563 : pub typmod: i32,
564 : pub formatcode: i16,
565 : }
566 :
567 : impl Default for RowDescriptor<'_> {
568 3091 : fn default() -> RowDescriptor<'static> {
569 3091 : RowDescriptor {
570 3091 : name: b"",
571 3091 : tableoid: 0,
572 3091 : attnum: 0,
573 3091 : typoid: 0,
574 3091 : typlen: 0,
575 3091 : typmod: 0,
576 3091 : formatcode: 0,
577 3091 : }
578 3091 : }
579 : }
580 :
581 : impl RowDescriptor<'_> {
582 : /// Convenience function to create a RowDescriptor message for an int8 column
583 45 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
584 45 : RowDescriptor {
585 45 : name,
586 45 : tableoid: 0,
587 45 : attnum: 0,
588 45 : typoid: INT8_OID,
589 45 : typlen: 8,
590 45 : typmod: 0,
591 45 : formatcode: 0,
592 45 : }
593 45 : }
594 :
595 1342 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
596 1342 : RowDescriptor {
597 1342 : name,
598 1342 : tableoid: 0,
599 1342 : attnum: 0,
600 1342 : typoid: TEXT_OID,
601 1342 : typlen: -1,
602 1342 : typmod: 0,
603 1342 : formatcode: 0,
604 1342 : }
605 1342 : }
606 : }
607 :
608 0 : #[derive(Debug)]
609 : pub struct XLogDataBody<'a> {
610 : pub wal_start: u64,
611 : pub wal_end: u64, // current end of WAL on the server
612 : pub timestamp: i64,
613 : pub data: &'a [u8],
614 : }
615 :
616 0 : #[derive(Debug)]
617 : pub struct WalSndKeepAlive {
618 : pub wal_end: u64, // current end of WAL on the server
619 : pub timestamp: i64,
620 : pub request_reply: bool,
621 : }
622 :
623 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
624 :
625 : // single text column
626 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
627 : name: b"data",
628 : tableoid: 0,
629 : attnum: 0,
630 : typoid: TEXT_OID,
631 : typlen: -1,
632 : typmod: 0,
633 : formatcode: 0,
634 : }]);
635 :
636 : /// Call f() to write body of the message and prepend it with 4-byte len as
637 : /// prescribed by the protocol.
638 7828537 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
639 7828537 : let base = buf.len();
640 7828537 : buf.extend_from_slice(&[0; 4]);
641 7828537 :
642 7828537 : let res = f(buf);
643 7828537 :
644 7828537 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
645 7828537 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
646 7828537 :
647 7828537 : res
648 7828537 : }
649 :
650 : /// Safe write of s into buf as cstring (String in the protocol).
651 90737 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
652 90737 : let bytes = s.as_ref();
653 90737 : if bytes.contains(&0) {
654 0 : return Err(ProtocolError::BadMessage(
655 0 : "string contains embedded null".to_owned(),
656 0 : ));
657 90737 : }
658 90737 : buf.put_slice(bytes);
659 90737 : buf.put_u8(0);
660 90737 : Ok(())
661 90737 : }
662 :
663 : /// Read cstring from buf, advancing it.
664 3764746 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
665 3764746 : let pos = buf
666 3764746 : .iter()
667 53457811 : .position(|x| *x == 0)
668 3764746 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
669 3764746 : let result = buf.split_to(pos);
670 3764746 : buf.advance(1); // drop the null terminator
671 3764746 : Ok(result)
672 3764746 : }
673 :
674 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
675 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
676 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
677 :
678 : impl<'a> BeMessage<'a> {
679 : /// Serialize `message` to the given `buf`.
680 : /// Apart from smart memory managemet, BytesMut is good here as msg len
681 : /// precedes its body and it is handy to write it down first and then fill
682 : /// the length. With Write we would have to either calc it manually or have
683 : /// one more buffer.
684 7840511 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
685 7840511 : match message {
686 13984 : BeMessage::AuthenticationOk => {
687 13984 : buf.put_u8(b'R');
688 13984 : write_body(buf, |buf| {
689 13984 : buf.put_i32(0); // Specifies that the authentication was successful.
690 13984 : });
691 13984 : }
692 :
693 219 : BeMessage::AuthenticationCleartextPassword => {
694 219 : buf.put_u8(b'R');
695 219 : write_body(buf, |buf| {
696 219 : buf.put_i32(3); // Specifies that clear text password is required.
697 219 : });
698 219 : }
699 :
700 0 : BeMessage::AuthenticationMD5Password(salt) => {
701 0 : buf.put_u8(b'R');
702 0 : write_body(buf, |buf| {
703 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
704 0 : buf.put_slice(&salt[..]);
705 0 : });
706 0 : }
707 :
708 168 : BeMessage::AuthenticationSasl(msg) => {
709 168 : buf.put_u8(b'R');
710 168 : write_body(buf, |buf| {
711 168 : use BeAuthenticationSaslMessage::*;
712 168 : match msg {
713 63 : Methods(methods) => {
714 63 : buf.put_i32(10); // Specifies that SASL auth method is used.
715 126 : for method in methods.iter() {
716 126 : write_cstr(method, buf)?;
717 : }
718 63 : buf.put_u8(0); // zero terminator for the list
719 : }
720 59 : Continue(extra) => {
721 59 : buf.put_i32(11); // Continue SASL auth.
722 59 : buf.put_slice(extra);
723 59 : }
724 46 : Final(extra) => {
725 46 : buf.put_i32(12); // Send final SASL message.
726 46 : buf.put_slice(extra);
727 46 : }
728 : }
729 168 : Ok(())
730 168 : })?;
731 : }
732 :
733 41 : BeMessage::BackendKeyData(key_data) => {
734 41 : buf.put_u8(b'K');
735 41 : write_body(buf, |buf| {
736 41 : buf.put_i32(key_data.backend_pid);
737 41 : buf.put_i32(key_data.cancel_key);
738 41 : });
739 41 : }
740 :
741 630 : BeMessage::BindComplete => {
742 630 : buf.put_u8(b'2');
743 630 : write_body(buf, |_| {});
744 630 : }
745 :
746 629 : BeMessage::CloseComplete => {
747 629 : buf.put_u8(b'3');
748 629 : write_body(buf, |_| {});
749 629 : }
750 :
751 2093 : BeMessage::CommandComplete(cmd) => {
752 2093 : buf.put_u8(b'C');
753 2093 : write_body(buf, |buf| write_cstr(cmd, buf))?;
754 : }
755 :
756 6963135 : BeMessage::CopyData(data) => {
757 6963135 : buf.put_u8(b'd');
758 6963135 : write_body(buf, |buf| {
759 6963135 : buf.put_slice(data);
760 6963135 : });
761 6963135 : }
762 :
763 595 : BeMessage::CopyDone => {
764 595 : buf.put_u8(b'c');
765 595 : write_body(buf, |_| {});
766 595 : }
767 :
768 0 : BeMessage::CopyFail => {
769 0 : buf.put_u8(b'f');
770 0 : write_body(buf, |_| {});
771 0 : }
772 :
773 14 : BeMessage::CopyInResponse => {
774 14 : buf.put_u8(b'G');
775 14 : write_body(buf, |buf| {
776 14 : buf.put_u8(1); // copy_is_binary
777 14 : buf.put_i16(0); // numAttributes
778 14 : });
779 14 : }
780 :
781 607 : BeMessage::CopyOutResponse => {
782 607 : buf.put_u8(b'H');
783 607 : write_body(buf, |buf| {
784 607 : buf.put_u8(0); // copy_is_binary
785 607 : buf.put_i16(0); // numAttributes
786 607 : });
787 607 : }
788 :
789 12540 : BeMessage::CopyBothResponse => {
790 12540 : buf.put_u8(b'W');
791 12540 : write_body(buf, |buf| {
792 12540 : // doesn't matter, used only for replication
793 12540 : buf.put_u8(0); // copy_is_binary
794 12540 : buf.put_i16(0); // numAttributes
795 12540 : });
796 12540 : }
797 :
798 982 : BeMessage::DataRow(vals) => {
799 982 : buf.put_u8(b'D');
800 982 : write_body(buf, |buf| {
801 982 : buf.put_u16(vals.len() as u16); // num of cols
802 3533 : for val_opt in vals.iter() {
803 3533 : if let Some(val) = val_opt {
804 2761 : buf.put_u32(val.len() as u32);
805 2761 : buf.put_slice(val);
806 2761 : } else {
807 772 : buf.put_i32(-1);
808 772 : }
809 : }
810 982 : });
811 982 : }
812 :
813 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
814 : // First byte of each field represents type of this field. Set just enough fields
815 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
816 : // message text.
817 197 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
818 197 : // 'E' signalizes ErrorResponse messages
819 197 : buf.put_u8(b'E');
820 197 : write_body(buf, |buf| {
821 197 : buf.put_u8(b'S'); // severity
822 197 : buf.put_slice(b"ERROR\0");
823 197 :
824 197 : buf.put_u8(b'C'); // SQLSTATE error code
825 197 : buf.put_slice(&terminate_code(
826 197 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
827 197 : ));
828 197 :
829 197 : buf.put_u8(b'M'); // the message
830 197 : write_cstr(error_msg, buf)?;
831 :
832 197 : buf.put_u8(0); // terminator
833 197 : Ok(())
834 197 : })?;
835 : }
836 :
837 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
838 : // message but continue listening for ReadyForQuery or ErrorResponse"
839 6 : BeMessage::NoticeResponse(error_msg) => {
840 6 : // For all the errors set Severity to Error and error code to
841 6 : // 'internal error'.
842 6 :
843 6 : // 'N' signalizes NoticeResponse messages
844 6 : buf.put_u8(b'N');
845 6 : write_body(buf, |buf| {
846 6 : buf.put_u8(b'S'); // severity
847 6 : buf.put_slice(b"NOTICE\0");
848 6 :
849 6 : buf.put_u8(b'C'); // SQLSTATE error code
850 6 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
851 6 :
852 6 : buf.put_u8(b'M'); // the message
853 6 : write_cstr(error_msg.as_bytes(), buf)?;
854 :
855 6 : buf.put_u8(0); // terminator
856 6 : Ok(())
857 6 : })?;
858 : }
859 :
860 630 : BeMessage::NoData => {
861 630 : buf.put_u8(b'n');
862 630 : write_body(buf, |_| {});
863 630 : }
864 :
865 11974 : BeMessage::EncryptionResponse(should_negotiate) => {
866 11974 : let response = if *should_negotiate { b'S' } else { b'N' };
867 11974 : buf.put_u8(response);
868 : }
869 :
870 41917 : BeMessage::ParameterStatus { name, value } => {
871 41917 : buf.put_u8(b'S');
872 41917 : write_body(buf, |buf| {
873 41917 : write_cstr(name, buf)?;
874 41917 : write_cstr(value, buf)
875 41917 : })?;
876 : }
877 :
878 630 : BeMessage::ParameterDescription => {
879 630 : buf.put_u8(b't');
880 630 : write_body(buf, |buf| {
881 630 : // we don't support params, so always 0
882 630 : buf.put_i16(0);
883 630 : });
884 630 : }
885 :
886 630 : BeMessage::ParseComplete => {
887 630 : buf.put_u8(b'1');
888 630 : write_body(buf, |_| {});
889 630 : }
890 :
891 29314 : BeMessage::ReadyForQuery => {
892 29314 : buf.put_u8(b'Z');
893 29314 : write_body(buf, |buf| {
894 29314 : buf.put_u8(b'I');
895 29314 : });
896 29314 : }
897 :
898 1456 : BeMessage::RowDescription(rows) => {
899 1456 : buf.put_u8(b'T');
900 1456 : write_body(buf, |buf| {
901 1456 : buf.put_i16(rows.len() as i16); // # of fields
902 4481 : for row in rows.iter() {
903 4481 : write_cstr(row.name, buf)?;
904 4481 : buf.put_i32(0); /* table oid */
905 4481 : buf.put_i16(0); /* attnum */
906 4481 : buf.put_u32(row.typoid);
907 4481 : buf.put_i16(row.typlen);
908 4481 : buf.put_i32(-1); /* typmod */
909 4481 : buf.put_i16(0); /* format code */
910 : }
911 1456 : Ok(())
912 1456 : })?;
913 : }
914 :
915 754616 : BeMessage::XLogData(body) => {
916 754616 : buf.put_u8(b'd');
917 754616 : write_body(buf, |buf| {
918 754616 : buf.put_u8(b'w');
919 754616 : buf.put_u64(body.wal_start);
920 754616 : buf.put_u64(body.wal_end);
921 754616 : buf.put_i64(body.timestamp);
922 754616 : buf.put_slice(body.data);
923 754616 : });
924 754616 : }
925 :
926 3504 : BeMessage::KeepAlive(req) => {
927 3504 : buf.put_u8(b'd');
928 3504 : write_body(buf, |buf| {
929 3504 : buf.put_u8(b'k');
930 3504 : buf.put_u64(req.wal_end);
931 3504 : buf.put_i64(req.timestamp);
932 3504 : buf.put_u8(u8::from(req.request_reply));
933 3504 : });
934 3504 : }
935 : }
936 7840511 : Ok(())
937 7840511 : }
938 : }
939 :
940 203 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
941 203 : let mut terminated = [0; 6];
942 1015 : for (i, &elem) in code.iter().enumerate() {
943 1015 : terminated[i] = elem;
944 1015 : }
945 :
946 203 : terminated
947 203 : }
948 :
949 : #[cfg(test)]
950 : mod tests {
951 : use super::*;
952 :
953 2 : #[test]
954 2 : fn test_startup_message_params_options_escaped() {
955 8 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
956 8 : params
957 8 : .options_escaped()
958 8 : .expect("options are None")
959 8 : .collect()
960 8 : }
961 2 :
962 8 : let make_params = |options| StartupMessageParams::new([("options", options)]);
963 :
964 2 : let params = StartupMessageParams::new([]);
965 2 : assert!(params.options_escaped().is_none());
966 :
967 2 : let params = make_params("");
968 2 : assert!(split_options(¶ms).is_empty());
969 :
970 2 : let params = make_params("foo");
971 2 : assert_eq!(split_options(¶ms), ["foo"]);
972 :
973 2 : let params = make_params(" foo bar ");
974 2 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
975 :
976 2 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
977 2 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
978 2 : }
979 :
980 2 : #[test]
981 2 : fn parse_fe_startup_packet_regression() {
982 2 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
983 2 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
984 2 : }
985 : }
|