Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use byteorder::{BigEndian, ReadBytesExt};
9 : use bytes::{Buf, BufMut, Bytes, BytesMut};
10 : use itertools::Itertools;
11 : use serde::{Deserialize, Serialize};
12 : use std::{borrow::Cow, fmt, io, str};
13 :
14 : // re-export for use in utils pageserver_feedback.rs
15 : pub use postgres_protocol::PG_EPOCH;
16 :
17 : pub type Oid = u32;
18 : pub type SystemId = u64;
19 :
20 : pub const INT8_OID: Oid = 20;
21 : pub const INT4_OID: Oid = 23;
22 : pub const TEXT_OID: Oid = 25;
23 :
24 : #[derive(Debug)]
25 : pub enum FeMessage {
26 : // Simple query.
27 : Query(Bytes),
28 : // Extended query protocol.
29 : Parse(FeParseMessage),
30 : Describe(FeDescribeMessage),
31 : Bind(FeBindMessage),
32 : Execute(FeExecuteMessage),
33 : Close(FeCloseMessage),
34 : Sync,
35 : Terminate,
36 : CopyData(Bytes),
37 : CopyDone,
38 : CopyFail,
39 : PasswordMessage(Bytes),
40 : }
41 :
42 : #[derive(Clone, Copy, PartialEq, PartialOrd)]
43 : pub struct ProtocolVersion(u32);
44 :
45 : impl ProtocolVersion {
46 0 : pub const fn new(major: u16, minor: u16) -> Self {
47 0 : Self((major as u32) << 16 | minor as u32)
48 0 : }
49 0 : pub const fn minor(self) -> u16 {
50 0 : self.0 as u16
51 0 : }
52 24 : pub const fn major(self) -> u16 {
53 24 : (self.0 >> 16) as u16
54 24 : }
55 : }
56 :
57 : impl fmt::Debug for ProtocolVersion {
58 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 0 : f.debug_list()
60 0 : .entry(&self.major())
61 0 : .entry(&self.minor())
62 0 : .finish()
63 0 : }
64 : }
65 :
66 : #[derive(Debug)]
67 : pub enum FeStartupPacket {
68 : CancelRequest(CancelKeyData),
69 : SslRequest {
70 : direct: bool,
71 : },
72 : GssEncRequest,
73 : StartupMessage {
74 : version: ProtocolVersion,
75 : params: StartupMessageParams,
76 : },
77 : }
78 :
79 : #[derive(Debug, Clone, Default)]
80 : pub struct StartupMessageParamsBuilder {
81 : params: BytesMut,
82 : }
83 :
84 : impl StartupMessageParamsBuilder {
85 : /// Set parameter's value by its name.
86 : /// name and value must not contain a \0 byte
87 25 : pub fn insert(&mut self, name: &str, value: &str) {
88 25 : self.params.put(name.as_bytes());
89 25 : self.params.put(&b"\0"[..]);
90 25 : self.params.put(value.as_bytes());
91 25 : self.params.put(&b"\0"[..]);
92 25 : }
93 :
94 17 : pub fn freeze(self) -> StartupMessageParams {
95 17 : StartupMessageParams {
96 17 : params: self.params.freeze(),
97 17 : }
98 17 : }
99 : }
100 :
101 : #[derive(Debug, Clone, Default)]
102 : pub struct StartupMessageParams {
103 : pub params: Bytes,
104 : }
105 :
106 : impl StartupMessageParams {
107 : /// Get parameter's value by its name.
108 42 : pub fn get(&self, name: &str) -> Option<&str> {
109 58 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
110 42 : }
111 :
112 : /// Split command-line options according to PostgreSQL's logic,
113 : /// taking into account all escape sequences but leaving them as-is.
114 : /// [`None`] means that there's no `options` in [`Self`].
115 24 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
116 24 : self.get("options").map(Self::parse_options_raw)
117 24 : }
118 :
119 : /// Split command-line options according to PostgreSQL's logic,
120 : /// applying all escape sequences (using owned strings as needed).
121 : /// [`None`] means that there's no `options` in [`Self`].
122 5 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
123 5 : self.get("options").map(Self::parse_options_escaped)
124 5 : }
125 :
126 : /// Split command-line options according to PostgreSQL's logic,
127 : /// taking into account all escape sequences but leaving them as-is.
128 30 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
129 30 : // See `postgres: pg_split_opts`.
130 30 : let mut last_was_escape = false;
131 30 : input
132 594 : .split(move |c: char| {
133 : // We split by non-escaped whitespace symbols.
134 594 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
135 594 : last_was_escape = c == '\\' && !last_was_escape;
136 594 : should_split
137 594 : })
138 77 : .filter(|s| !s.is_empty())
139 30 : }
140 :
141 : /// Split command-line options according to PostgreSQL's logic,
142 : /// applying all escape sequences (using owned strings as needed).
143 4 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
144 4 : // See `postgres: pg_split_opts`.
145 7 : Self::parse_options_raw(input).map(|s| {
146 7 : let mut preserve_next_escape = false;
147 17 : let escape = |c| {
148 : // We should remove '\\' unless it's preceded by '\\'.
149 17 : let should_remove = c == '\\' && !preserve_next_escape;
150 17 : preserve_next_escape = should_remove;
151 17 : should_remove
152 17 : };
153 :
154 7 : match s.contains('\\') {
155 3 : true => Cow::Owned(s.replace(escape, "")),
156 4 : false => Cow::Borrowed(s),
157 : }
158 7 : })
159 4 : }
160 :
161 : /// Iterate through key-value pairs in an arbitrary order.
162 42 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
163 42 : let params =
164 42 : std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
165 42 : params.split_terminator('\0').tuples()
166 42 : }
167 :
168 : // This function is mostly useful in tests.
169 : #[doc(hidden)]
170 17 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
171 17 : let mut b = StartupMessageParamsBuilder::default();
172 42 : for (k, v) in pairs {
173 25 : b.insert(k, v)
174 : }
175 17 : b.freeze()
176 17 : }
177 : }
178 :
179 6 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
180 : pub struct CancelKeyData {
181 : pub backend_pid: i32,
182 : pub cancel_key: i32,
183 : }
184 :
185 : impl fmt::Display for CancelKeyData {
186 1 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 1 : let hi = (self.backend_pid as u64) << 32;
188 1 : let lo = (self.cancel_key as u64) & 0xffffffff;
189 1 : let id = hi | lo;
190 1 :
191 1 : // This format is more compact and might work better for logs.
192 1 : f.debug_tuple("CancelKeyData")
193 1 : .field(&format_args!("{:x}", id))
194 1 : .finish()
195 1 : }
196 : }
197 :
198 : use rand::distributions::{Distribution, Standard};
199 : impl Distribution<CancelKeyData> for Standard {
200 1 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
201 1 : CancelKeyData {
202 1 : backend_pid: rng.gen(),
203 1 : cancel_key: rng.gen(),
204 1 : }
205 1 : }
206 : }
207 :
208 : // We only support the simple case of Parse on unnamed prepared statement and
209 : // no params
210 : #[derive(Debug)]
211 : pub struct FeParseMessage {
212 : pub query_string: Bytes,
213 : }
214 :
215 : #[derive(Debug)]
216 : pub struct FeDescribeMessage {
217 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
218 : // we only support unnamed prepared stmt or portal
219 : }
220 :
221 : // we only support unnamed prepared stmt and portal
222 : #[derive(Debug)]
223 : pub struct FeBindMessage;
224 :
225 : // we only support unnamed prepared stmt or portal
226 : #[derive(Debug)]
227 : pub struct FeExecuteMessage {
228 : /// max # of rows
229 : pub maxrows: i32,
230 : }
231 :
232 : // we only support unnamed prepared stmt and portal
233 : #[derive(Debug)]
234 : pub struct FeCloseMessage;
235 :
236 : /// An error occurred while parsing or serializing raw stream into Postgres
237 : /// messages.
238 : #[derive(thiserror::Error, Debug)]
239 : pub enum ProtocolError {
240 : /// Invalid packet was received from the client (e.g. unexpected message
241 : /// type or broken len).
242 : #[error("Protocol error: {0}")]
243 : Protocol(String),
244 : /// Failed to parse or, (unlikely), serialize a protocol message.
245 : #[error("Message parse error: {0}")]
246 : BadMessage(String),
247 : }
248 :
249 : impl ProtocolError {
250 : /// Proxy stream.rs uses only io::Error; provide it.
251 0 : pub fn into_io_error(self) -> io::Error {
252 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
253 0 : }
254 : }
255 :
256 : impl FeMessage {
257 : /// Read and parse one message from the `buf` input buffer. If there is at
258 : /// least one valid message, returns it, advancing `buf`; redundant copies
259 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
260 : /// directly into the `buf` (processed data is garbage collected after
261 : /// parsed message is dropped).
262 : ///
263 : /// Returns None if `buf` doesn't contain enough data for a single message.
264 : /// For efficiency, tries to reserve large enough space in `buf` for the
265 : /// next message in this case to save the repeated calls.
266 : ///
267 : /// Returns Error if message is malformed, the only possible ErrorKind is
268 : /// InvalidInput.
269 : //
270 : // Inspired by rust-postgres Message::parse.
271 57 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
272 57 : // Every message contains message type byte and 4 bytes len; can't do
273 57 : // much without them.
274 57 : if buf.len() < 5 {
275 30 : let to_read = 5 - buf.len();
276 30 : buf.reserve(to_read);
277 30 : return Ok(None);
278 27 : }
279 27 :
280 27 : // We shouldn't advance `buf` as probably full message is not there yet,
281 27 : // so can't directly use Bytes::get_u32 etc.
282 27 : let tag = buf[0];
283 27 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
284 27 : if len < 4 {
285 0 : return Err(ProtocolError::Protocol(format!(
286 0 : "invalid message length {}",
287 0 : len
288 0 : )));
289 27 : }
290 27 :
291 27 : // length field includes itself, but not message type.
292 27 : let total_len = len as usize + 1;
293 27 : if buf.len() < total_len {
294 : // Don't have full message yet.
295 0 : let to_read = total_len - buf.len();
296 0 : buf.reserve(to_read);
297 0 : return Ok(None);
298 27 : }
299 27 :
300 27 : // got the message, advance buffer
301 27 : let mut msg = buf.split_to(total_len).freeze();
302 27 : msg.advance(5); // consume message type and len
303 27 :
304 27 : match tag {
305 2 : b'Q' => Ok(Some(FeMessage::Query(msg))),
306 0 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
307 0 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
308 0 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
309 0 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
310 0 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
311 0 : b'S' => Ok(Some(FeMessage::Sync)),
312 0 : b'X' => Ok(Some(FeMessage::Terminate)),
313 0 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
314 0 : b'c' => Ok(Some(FeMessage::CopyDone)),
315 0 : b'f' => Ok(Some(FeMessage::CopyFail)),
316 25 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
317 0 : tag => Err(ProtocolError::Protocol(format!(
318 0 : "unknown message tag: {tag},'{msg:?}'"
319 0 : ))),
320 : }
321 57 : }
322 : }
323 :
324 : impl FeStartupPacket {
325 : /// Read and parse startup message from the `buf` input buffer. It is
326 : /// different from [`FeMessage::parse`] because startup messages don't have
327 : /// message type byte; otherwise, its comments apply.
328 91 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
329 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
330 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
331 : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
332 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
333 : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
334 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
335 : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
336 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
337 : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
338 :
339 : // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
340 : // First byte indicates standard SSL handshake message
341 : // (It can't be a Postgres startup length because in network byte order
342 : // that would be a startup packet hundreds of megabytes long)
343 91 : if buf.first() == Some(&0x16) {
344 0 : return Ok(Some(FeStartupPacket::SslRequest { direct: true }));
345 91 : }
346 91 :
347 91 : // need at least 4 bytes with packet len
348 91 : if buf.len() < 4 {
349 45 : let to_read = 4 - buf.len();
350 45 : buf.reserve(to_read);
351 45 : return Ok(None);
352 46 : }
353 46 :
354 46 : // We shouldn't advance `buf` as probably full message is not there yet,
355 46 : // so can't directly use Bytes::get_u32 etc.
356 46 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
357 46 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
358 46 : // which is less readable
359 46 : #[allow(clippy::manual_range_contains)]
360 46 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
361 1 : return Err(ProtocolError::Protocol(format!(
362 1 : "invalid startup packet message length {}",
363 1 : len
364 1 : )));
365 45 : }
366 45 :
367 45 : if buf.len() < len {
368 : // Don't have full message yet.
369 0 : let to_read = len - buf.len();
370 0 : buf.reserve(to_read);
371 0 : return Ok(None);
372 45 : }
373 45 :
374 45 : // got the message, advance buffer
375 45 : let mut msg = buf.split_to(len).freeze();
376 45 : msg.advance(4); // consume len
377 45 :
378 45 : let request_code = ProtocolVersion(msg.get_u32());
379 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
380 45 : let message = match request_code {
381 : CANCEL_REQUEST_CODE => {
382 0 : if msg.remaining() != 8 {
383 0 : return Err(ProtocolError::BadMessage(
384 0 : "CancelRequest message is malformed, backend PID / secret key missing"
385 0 : .to_owned(),
386 0 : ));
387 0 : }
388 0 : FeStartupPacket::CancelRequest(CancelKeyData {
389 0 : backend_pid: msg.get_i32(),
390 0 : cancel_key: msg.get_i32(),
391 0 : })
392 : }
393 : NEGOTIATE_SSL_CODE => {
394 : // Requested upgrade to SSL (aka TLS)
395 21 : FeStartupPacket::SslRequest { direct: false }
396 : }
397 : NEGOTIATE_GSS_CODE => {
398 : // Requested upgrade to GSSAPI
399 0 : FeStartupPacket::GssEncRequest
400 : }
401 24 : version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
402 0 : return Err(ProtocolError::Protocol(format!(
403 0 : "Unrecognized request code {}",
404 0 : version.minor()
405 0 : )));
406 : }
407 : // TODO bail if protocol major_version is not 3?
408 24 : version => {
409 : // StartupMessage
410 :
411 24 : let s = str::from_utf8(&msg).map_err(|_e| {
412 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
413 24 : })?;
414 24 : let s = s.strip_suffix('\0').ok_or_else(|| {
415 0 : ProtocolError::Protocol(
416 0 : "StartupMessage params: missing null terminator".to_string(),
417 0 : )
418 24 : })?;
419 :
420 24 : FeStartupPacket::StartupMessage {
421 24 : version,
422 24 : params: StartupMessageParams {
423 24 : params: msg.slice_ref(s.as_bytes()),
424 24 : },
425 24 : }
426 : }
427 : };
428 45 : Ok(Some(message))
429 91 : }
430 : }
431 :
432 : impl FeParseMessage {
433 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
434 : // FIXME: the rust-postgres driver uses a named prepared statement
435 : // for copy_out(). We're not prepared to handle that correctly. For
436 : // now, just ignore the statement name, assuming that the client never
437 : // uses more than one prepared statement at a time.
438 :
439 0 : let _pstmt_name = read_cstr(&mut buf)?;
440 0 : let query_string = read_cstr(&mut buf)?;
441 0 : if buf.remaining() < 2 {
442 0 : return Err(ProtocolError::BadMessage(
443 0 : "Parse message is malformed, nparams missing".to_string(),
444 0 : ));
445 0 : }
446 0 : let nparams = buf.get_i16();
447 0 :
448 0 : if nparams != 0 {
449 0 : return Err(ProtocolError::BadMessage(
450 0 : "query params not implemented".to_string(),
451 0 : ));
452 0 : }
453 0 :
454 0 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
455 0 : }
456 : }
457 :
458 : impl FeDescribeMessage {
459 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
460 0 : let kind = buf.get_u8();
461 0 : let _pstmt_name = read_cstr(&mut buf)?;
462 :
463 : // FIXME: see FeParseMessage::parse
464 0 : if kind != b'S' {
465 0 : return Err(ProtocolError::BadMessage(
466 0 : "only prepared statemement Describe is implemented".to_string(),
467 0 : ));
468 0 : }
469 0 :
470 0 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
471 0 : }
472 : }
473 :
474 : impl FeExecuteMessage {
475 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
476 0 : let portal_name = read_cstr(&mut buf)?;
477 0 : if buf.remaining() < 4 {
478 0 : return Err(ProtocolError::BadMessage(
479 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
480 0 : ));
481 0 : }
482 0 : let maxrows = buf.get_i32();
483 0 :
484 0 : if !portal_name.is_empty() {
485 0 : return Err(ProtocolError::BadMessage(
486 0 : "named portals not implemented".to_string(),
487 0 : ));
488 0 : }
489 0 : if maxrows != 0 {
490 0 : return Err(ProtocolError::BadMessage(
491 0 : "row limit in Execute message not implemented".to_string(),
492 0 : ));
493 0 : }
494 0 :
495 0 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
496 0 : }
497 : }
498 :
499 : impl FeBindMessage {
500 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
501 0 : let portal_name = read_cstr(&mut buf)?;
502 0 : let _pstmt_name = read_cstr(&mut buf)?;
503 :
504 : // FIXME: see FeParseMessage::parse
505 0 : if !portal_name.is_empty() {
506 0 : return Err(ProtocolError::BadMessage(
507 0 : "named portals not implemented".to_string(),
508 0 : ));
509 0 : }
510 0 :
511 0 : Ok(FeMessage::Bind(FeBindMessage))
512 0 : }
513 : }
514 :
515 : impl FeCloseMessage {
516 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
517 0 : let _kind = buf.get_u8();
518 0 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
519 :
520 : // FIXME: we do nothing with Close
521 0 : Ok(FeMessage::Close(FeCloseMessage))
522 0 : }
523 : }
524 :
525 : // Backend
526 :
527 : #[derive(Debug)]
528 : pub enum BeMessage<'a> {
529 : AuthenticationOk,
530 : AuthenticationMD5Password([u8; 4]),
531 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
532 : AuthenticationCleartextPassword,
533 : BackendKeyData(CancelKeyData),
534 : BindComplete,
535 : CommandComplete(&'a [u8]),
536 : CopyData(&'a [u8]),
537 : CopyDone,
538 : CopyFail,
539 : CopyInResponse,
540 : CopyOutResponse,
541 : CopyBothResponse,
542 : CloseComplete,
543 : // None means column is NULL
544 : DataRow(&'a [Option<&'a [u8]>]),
545 : // None errcode means internal_error will be sent.
546 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
547 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
548 : EncryptionResponse(bool),
549 : NoData,
550 : ParameterDescription,
551 : ParameterStatus {
552 : name: &'a [u8],
553 : value: &'a [u8],
554 : },
555 : ParseComplete,
556 : ReadyForQuery,
557 : RowDescription(&'a [RowDescriptor<'a>]),
558 : XLogData(XLogDataBody<'a>),
559 : NoticeResponse(&'a str),
560 : NegotiateProtocolVersion {
561 : version: ProtocolVersion,
562 : options: &'a [&'a str],
563 : },
564 : KeepAlive(WalSndKeepAlive),
565 : /// Batch of interpreted, shard filtered WAL records,
566 : /// ready for the pageserver to ingest
567 : InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
568 :
569 : Raw(u8, &'a [u8]),
570 : }
571 :
572 : /// Common shorthands.
573 : impl<'a> BeMessage<'a> {
574 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
575 : /// This is a sensible default, given that:
576 : /// * rust strings only support this encoding out of the box.
577 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
578 : ///
579 : /// TODO: do we need to report `server_encoding` as well?
580 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
581 : name: b"client_encoding",
582 : value: b"UTF8",
583 : };
584 :
585 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
586 : name: b"integer_datetimes",
587 : value: b"on",
588 : };
589 :
590 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
591 2 : pub fn server_version(version: &'a str) -> Self {
592 2 : Self::ParameterStatus {
593 2 : name: b"server_version",
594 2 : value: version.as_bytes(),
595 2 : }
596 2 : }
597 : }
598 :
599 : #[derive(Debug)]
600 : pub enum BeAuthenticationSaslMessage<'a> {
601 : Methods(&'a [&'a str]),
602 : Continue(&'a [u8]),
603 : Final(&'a [u8]),
604 : }
605 :
606 : #[derive(Debug)]
607 : pub enum BeParameterStatusMessage<'a> {
608 : Encoding(&'a str),
609 : ServerVersion(&'a str),
610 : }
611 :
612 : // One row description in RowDescription packet.
613 : #[derive(Debug)]
614 : pub struct RowDescriptor<'a> {
615 : pub name: &'a [u8],
616 : pub tableoid: Oid,
617 : pub attnum: i16,
618 : pub typoid: Oid,
619 : pub typlen: i16,
620 : pub typmod: i32,
621 : pub formatcode: i16,
622 : }
623 :
624 : impl Default for RowDescriptor<'_> {
625 0 : fn default() -> RowDescriptor<'static> {
626 0 : RowDescriptor {
627 0 : name: b"",
628 0 : tableoid: 0,
629 0 : attnum: 0,
630 0 : typoid: 0,
631 0 : typlen: 0,
632 0 : typmod: 0,
633 0 : formatcode: 0,
634 0 : }
635 0 : }
636 : }
637 :
638 : impl RowDescriptor<'_> {
639 : /// Convenience function to create a RowDescriptor message for an int8 column
640 0 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
641 0 : RowDescriptor {
642 0 : name,
643 0 : tableoid: 0,
644 0 : attnum: 0,
645 0 : typoid: INT8_OID,
646 0 : typlen: 8,
647 0 : typmod: 0,
648 0 : formatcode: 0,
649 0 : }
650 0 : }
651 :
652 2 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
653 2 : RowDescriptor {
654 2 : name,
655 2 : tableoid: 0,
656 2 : attnum: 0,
657 2 : typoid: TEXT_OID,
658 2 : typlen: -1,
659 2 : typmod: 0,
660 2 : formatcode: 0,
661 2 : }
662 2 : }
663 : }
664 :
665 : #[derive(Debug)]
666 : pub struct XLogDataBody<'a> {
667 : pub wal_start: u64,
668 : pub wal_end: u64, // current end of WAL on the server
669 : pub timestamp: i64,
670 : pub data: &'a [u8],
671 : }
672 :
673 : #[derive(Debug)]
674 : pub struct WalSndKeepAlive {
675 : pub wal_end: u64, // current end of WAL on the server
676 : pub timestamp: i64,
677 : pub request_reply: bool,
678 : }
679 :
680 : /// Batch of interpreted WAL records used in the interpreted
681 : /// safekeeper to pageserver protocol.
682 : ///
683 : /// Note that the pageserver uses the RawInterpretedWalRecordsBody
684 : /// counterpart of this from the neondatabase/rust-postgres repo.
685 : /// If you're changing this struct, you likely need to change its
686 : /// twin as well.
687 : #[derive(Debug)]
688 : pub struct InterpretedWalRecordsBody<'a> {
689 : /// End of raw WAL in [`Self::data`]
690 : pub streaming_lsn: u64,
691 : /// Current end of WAL on the server
692 : pub commit_lsn: u64,
693 : pub data: &'a [u8],
694 : }
695 :
696 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
697 :
698 : // single text column
699 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
700 : name: b"data",
701 : tableoid: 0,
702 : attnum: 0,
703 : typoid: TEXT_OID,
704 : typlen: -1,
705 : typmod: 0,
706 : formatcode: 0,
707 : }]);
708 :
709 : /// Call f() to write body of the message and prepend it with 4-byte len as
710 : /// prescribed by the protocol.
711 75 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
712 75 : let base = buf.len();
713 75 : buf.extend_from_slice(&[0; 4]);
714 75 :
715 75 : let res = f(buf);
716 75 :
717 75 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
718 75 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
719 75 :
720 75 : res
721 75 : }
722 :
723 : /// Safe write of s into buf as cstring (String in the protocol).
724 56 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
725 56 : let bytes = s.as_ref();
726 56 : if bytes.contains(&0) {
727 0 : return Err(ProtocolError::BadMessage(
728 0 : "string contains embedded null".to_owned(),
729 0 : ));
730 56 : }
731 56 : buf.put_slice(bytes);
732 56 : buf.put_u8(0);
733 56 : Ok(())
734 56 : }
735 :
736 : /// Read cstring from buf, advancing it.
737 11 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
738 11 : let pos = buf
739 11 : .iter()
740 156 : .position(|x| *x == 0)
741 11 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
742 11 : let result = buf.split_to(pos);
743 11 : buf.advance(1); // drop the null terminator
744 11 : Ok(result)
745 11 : }
746 :
747 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
748 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
749 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
750 :
751 : impl BeMessage<'_> {
752 : /// Serialize `message` to the given `buf`.
753 : /// Apart from smart memory managemet, BytesMut is good here as msg len
754 : /// precedes its body and it is handy to write it down first and then fill
755 : /// the length. With Write we would have to either calc it manually or have
756 : /// one more buffer.
757 96 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
758 96 : match message {
759 0 : BeMessage::Raw(code, data) => {
760 0 : buf.put_u8(*code);
761 0 : write_body(buf, |b| b.put_slice(data))
762 : }
763 12 : BeMessage::AuthenticationOk => {
764 12 : buf.put_u8(b'R');
765 12 : write_body(buf, |buf| {
766 12 : buf.put_i32(0); // Specifies that the authentication was successful.
767 12 : });
768 12 : }
769 :
770 2 : BeMessage::AuthenticationCleartextPassword => {
771 2 : buf.put_u8(b'R');
772 2 : write_body(buf, |buf| {
773 2 : buf.put_i32(3); // Specifies that clear text password is required.
774 2 : });
775 2 : }
776 :
777 0 : BeMessage::AuthenticationMD5Password(salt) => {
778 0 : buf.put_u8(b'R');
779 0 : write_body(buf, |buf| {
780 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
781 0 : buf.put_slice(&salt[..]);
782 0 : });
783 0 : }
784 :
785 30 : BeMessage::AuthenticationSasl(msg) => {
786 30 : buf.put_u8(b'R');
787 30 : write_body(buf, |buf| {
788 : use BeAuthenticationSaslMessage::*;
789 30 : match msg {
790 13 : Methods(methods) => {
791 13 : buf.put_i32(10); // Specifies that SASL auth method is used.
792 25 : for method in methods.iter() {
793 25 : write_cstr(method, buf)?;
794 : }
795 13 : buf.put_u8(0); // zero terminator for the list
796 : }
797 11 : Continue(extra) => {
798 11 : buf.put_i32(11); // Continue SASL auth.
799 11 : buf.put_slice(extra);
800 11 : }
801 6 : Final(extra) => {
802 6 : buf.put_i32(12); // Send final SASL message.
803 6 : buf.put_slice(extra);
804 6 : }
805 : }
806 30 : Ok(())
807 30 : })?;
808 : }
809 :
810 0 : BeMessage::BackendKeyData(key_data) => {
811 0 : buf.put_u8(b'K');
812 0 : write_body(buf, |buf| {
813 0 : buf.put_i32(key_data.backend_pid);
814 0 : buf.put_i32(key_data.cancel_key);
815 0 : });
816 0 : }
817 :
818 0 : BeMessage::BindComplete => {
819 0 : buf.put_u8(b'2');
820 0 : write_body(buf, |_| {});
821 0 : }
822 :
823 0 : BeMessage::CloseComplete => {
824 0 : buf.put_u8(b'3');
825 0 : write_body(buf, |_| {});
826 0 : }
827 :
828 2 : BeMessage::CommandComplete(cmd) => {
829 2 : buf.put_u8(b'C');
830 2 : write_body(buf, |buf| write_cstr(cmd, buf))?;
831 : }
832 :
833 0 : BeMessage::CopyData(data) => {
834 0 : buf.put_u8(b'd');
835 0 : write_body(buf, |buf| {
836 0 : buf.put_slice(data);
837 0 : });
838 0 : }
839 :
840 0 : BeMessage::CopyDone => {
841 0 : buf.put_u8(b'c');
842 0 : write_body(buf, |_| {});
843 0 : }
844 :
845 0 : BeMessage::CopyFail => {
846 0 : buf.put_u8(b'f');
847 0 : write_body(buf, |_| {});
848 0 : }
849 :
850 0 : BeMessage::CopyInResponse => {
851 0 : buf.put_u8(b'G');
852 0 : write_body(buf, |buf| {
853 0 : buf.put_u8(1); // copy_is_binary
854 0 : buf.put_i16(0); // numAttributes
855 0 : });
856 0 : }
857 :
858 0 : BeMessage::CopyOutResponse => {
859 0 : buf.put_u8(b'H');
860 0 : write_body(buf, |buf| {
861 0 : buf.put_u8(0); // copy_is_binary
862 0 : buf.put_i16(0); // numAttributes
863 0 : });
864 0 : }
865 :
866 0 : BeMessage::CopyBothResponse => {
867 0 : buf.put_u8(b'W');
868 0 : write_body(buf, |buf| {
869 0 : // doesn't matter, used only for replication
870 0 : buf.put_u8(0); // copy_is_binary
871 0 : buf.put_i16(0); // numAttributes
872 0 : });
873 0 : }
874 :
875 2 : BeMessage::DataRow(vals) => {
876 2 : buf.put_u8(b'D');
877 2 : write_body(buf, |buf| {
878 2 : buf.put_u16(vals.len() as u16); // num of cols
879 2 : for val_opt in vals.iter() {
880 2 : if let Some(val) = val_opt {
881 2 : buf.put_u32(val.len() as u32);
882 2 : buf.put_slice(val);
883 2 : } else {
884 0 : buf.put_i32(-1);
885 0 : }
886 : }
887 2 : });
888 2 : }
889 :
890 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
891 : // First byte of each field represents type of this field. Set just enough fields
892 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
893 : // message text.
894 1 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
895 1 : // 'E' signalizes ErrorResponse messages
896 1 : buf.put_u8(b'E');
897 1 : write_body(buf, |buf| {
898 1 : buf.put_u8(b'S'); // severity
899 1 : buf.put_slice(b"ERROR\0");
900 1 :
901 1 : buf.put_u8(b'C'); // SQLSTATE error code
902 1 : buf.put_slice(&terminate_code(
903 1 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
904 1 : ));
905 1 :
906 1 : buf.put_u8(b'M'); // the message
907 1 : write_cstr(error_msg, buf)?;
908 :
909 1 : buf.put_u8(0); // terminator
910 1 : Ok(())
911 1 : })?;
912 : }
913 :
914 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
915 : // message but continue listening for ReadyForQuery or ErrorResponse"
916 0 : BeMessage::NoticeResponse(error_msg) => {
917 0 : // For all the errors set Severity to Error and error code to
918 0 : // 'internal error'.
919 0 :
920 0 : // 'N' signalizes NoticeResponse messages
921 0 : buf.put_u8(b'N');
922 0 : write_body(buf, |buf| {
923 0 : buf.put_u8(b'S'); // severity
924 0 : buf.put_slice(b"NOTICE\0");
925 0 :
926 0 : buf.put_u8(b'C'); // SQLSTATE error code
927 0 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
928 0 :
929 0 : buf.put_u8(b'M'); // the message
930 0 : write_cstr(error_msg.as_bytes(), buf)?;
931 :
932 0 : buf.put_u8(0); // terminator
933 0 : Ok(())
934 0 : })?;
935 : }
936 :
937 0 : BeMessage::NoData => {
938 0 : buf.put_u8(b'n');
939 0 : write_body(buf, |_| {});
940 0 : }
941 :
942 21 : BeMessage::EncryptionResponse(should_negotiate) => {
943 21 : let response = if *should_negotiate { b'S' } else { b'N' };
944 21 : buf.put_u8(response);
945 : }
946 :
947 13 : BeMessage::ParameterStatus { name, value } => {
948 13 : buf.put_u8(b'S');
949 13 : write_body(buf, |buf| {
950 13 : write_cstr(name, buf)?;
951 13 : write_cstr(value, buf)
952 13 : })?;
953 : }
954 :
955 0 : BeMessage::ParameterDescription => {
956 0 : buf.put_u8(b't');
957 0 : write_body(buf, |buf| {
958 0 : // we don't support params, so always 0
959 0 : buf.put_i16(0);
960 0 : });
961 0 : }
962 :
963 0 : BeMessage::ParseComplete => {
964 0 : buf.put_u8(b'1');
965 0 : write_body(buf, |_| {});
966 0 : }
967 :
968 11 : BeMessage::ReadyForQuery => {
969 11 : buf.put_u8(b'Z');
970 11 : write_body(buf, |buf| {
971 11 : buf.put_u8(b'I');
972 11 : });
973 11 : }
974 :
975 2 : BeMessage::RowDescription(rows) => {
976 2 : buf.put_u8(b'T');
977 2 : write_body(buf, |buf| {
978 2 : buf.put_i16(rows.len() as i16); // # of fields
979 2 : for row in rows.iter() {
980 2 : write_cstr(row.name, buf)?;
981 2 : buf.put_i32(0); /* table oid */
982 2 : buf.put_i16(0); /* attnum */
983 2 : buf.put_u32(row.typoid);
984 2 : buf.put_i16(row.typlen);
985 2 : buf.put_i32(-1); /* typmod */
986 2 : buf.put_i16(0); /* format code */
987 : }
988 2 : Ok(())
989 2 : })?;
990 : }
991 :
992 0 : BeMessage::XLogData(body) => {
993 0 : buf.put_u8(b'd');
994 0 : write_body(buf, |buf| {
995 0 : buf.put_u8(b'w');
996 0 : buf.put_u64(body.wal_start);
997 0 : buf.put_u64(body.wal_end);
998 0 : buf.put_i64(body.timestamp);
999 0 : buf.put_slice(body.data);
1000 0 : });
1001 0 : }
1002 :
1003 0 : BeMessage::KeepAlive(req) => {
1004 0 : buf.put_u8(b'd');
1005 0 : write_body(buf, |buf| {
1006 0 : buf.put_u8(b'k');
1007 0 : buf.put_u64(req.wal_end);
1008 0 : buf.put_i64(req.timestamp);
1009 0 : buf.put_u8(u8::from(req.request_reply));
1010 0 : });
1011 0 : }
1012 :
1013 0 : BeMessage::NegotiateProtocolVersion { version, options } => {
1014 0 : buf.put_u8(b'v');
1015 0 : write_body(buf, |buf| {
1016 0 : buf.put_u32(version.0);
1017 0 : buf.put_u32(options.len() as u32);
1018 0 : for option in options.iter() {
1019 0 : write_cstr(option, buf)?;
1020 : }
1021 0 : Ok(())
1022 0 : })?
1023 : }
1024 :
1025 0 : BeMessage::InterpretedWalRecords(rec) => {
1026 0 : // We use the COPY_DATA_TAG for our custom message
1027 0 : // since this tag is interpreted as raw bytes.
1028 0 : buf.put_u8(b'd');
1029 0 : write_body(buf, |buf| {
1030 0 : buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
1031 0 : // dependency
1032 0 : buf.put_u64(rec.streaming_lsn);
1033 0 : buf.put_u64(rec.commit_lsn);
1034 0 : buf.put_slice(rec.data);
1035 0 : });
1036 0 : }
1037 : }
1038 96 : Ok(())
1039 96 : }
1040 : }
1041 :
1042 1 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
1043 1 : let mut terminated = [0; 6];
1044 5 : for (i, &elem) in code.iter().enumerate() {
1045 5 : terminated[i] = elem;
1046 5 : }
1047 :
1048 1 : terminated
1049 1 : }
1050 :
1051 : #[cfg(test)]
1052 : mod tests {
1053 : use super::*;
1054 :
1055 : #[test]
1056 1 : fn test_startup_message_params_options_escaped() {
1057 4 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
1058 4 : params
1059 4 : .options_escaped()
1060 4 : .expect("options are None")
1061 4 : .collect()
1062 4 : }
1063 :
1064 4 : let make_params = |options| StartupMessageParams::new([("options", options)]);
1065 :
1066 1 : let params = StartupMessageParams::new([]);
1067 1 : assert!(params.options_escaped().is_none());
1068 :
1069 1 : let params = make_params("");
1070 1 : assert!(split_options(¶ms).is_empty());
1071 :
1072 1 : let params = make_params("foo");
1073 1 : assert_eq!(split_options(¶ms), ["foo"]);
1074 :
1075 1 : let params = make_params(" foo bar ");
1076 1 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
1077 :
1078 1 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
1079 1 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
1080 1 : }
1081 :
1082 : #[test]
1083 1 : fn parse_fe_startup_packet_regression() {
1084 1 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
1085 1 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
1086 1 : }
1087 :
1088 : #[test]
1089 1 : fn cancel_key_data() {
1090 1 : let key = CancelKeyData {
1091 1 : backend_pid: -1817212860,
1092 1 : cancel_key: -1183897012,
1093 1 : };
1094 1 : assert_eq!(format!("{key}"), "CancelKeyData(93af8844b96f2a4c)");
1095 1 : }
1096 : }
|