Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use std::borrow::Cow;
9 : use std::{fmt, io, str};
10 :
11 : use byteorder::{BigEndian, ReadBytesExt};
12 : use bytes::{Buf, BufMut, Bytes, BytesMut};
13 : use itertools::Itertools;
14 : // re-export for use in utils pageserver_feedback.rs
15 : pub use postgres_protocol::PG_EPOCH;
16 : use serde::{Deserialize, Serialize};
17 :
18 : pub type Oid = u32;
19 : pub type SystemId = u64;
20 :
21 : pub const INT8_OID: Oid = 20;
22 : pub const INT4_OID: Oid = 23;
23 : pub const TEXT_OID: Oid = 25;
24 :
25 : #[derive(Debug)]
26 : pub enum FeMessage {
27 : // Simple query.
28 : Query(Bytes),
29 : // Extended query protocol.
30 : Parse(FeParseMessage),
31 : Describe(FeDescribeMessage),
32 : Bind(FeBindMessage),
33 : Execute(FeExecuteMessage),
34 : Close(FeCloseMessage),
35 : Sync,
36 : Terminate,
37 : CopyData(Bytes),
38 : CopyDone,
39 : CopyFail,
40 : PasswordMessage(Bytes),
41 : }
42 :
43 : #[derive(Clone, Copy, PartialEq, PartialOrd)]
44 : pub struct ProtocolVersion(u32);
45 :
46 : impl ProtocolVersion {
47 0 : pub const fn new(major: u16, minor: u16) -> Self {
48 0 : Self(((major as u32) << 16) | minor as u32)
49 0 : }
50 0 : pub const fn minor(self) -> u16 {
51 0 : self.0 as u16
52 0 : }
53 2 : pub const fn major(self) -> u16 {
54 2 : (self.0 >> 16) as u16
55 2 : }
56 : }
57 :
58 : impl fmt::Debug for ProtocolVersion {
59 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 0 : f.debug_list()
61 0 : .entry(&self.major())
62 0 : .entry(&self.minor())
63 0 : .finish()
64 0 : }
65 : }
66 :
67 : #[derive(Debug)]
68 : pub enum FeStartupPacket {
69 : CancelRequest(CancelKeyData),
70 : SslRequest {
71 : direct: bool,
72 : },
73 : GssEncRequest,
74 : StartupMessage {
75 : version: ProtocolVersion,
76 : params: StartupMessageParams,
77 : },
78 : }
79 :
80 : #[derive(Debug, Clone, Default)]
81 : pub struct StartupMessageParamsBuilder {
82 : params: BytesMut,
83 : }
84 :
85 : impl StartupMessageParamsBuilder {
86 : /// Set parameter's value by its name.
87 : /// name and value must not contain a \0 byte
88 4 : pub fn insert(&mut self, name: &str, value: &str) {
89 4 : self.params.put(name.as_bytes());
90 4 : self.params.put(&b"\0"[..]);
91 4 : self.params.put(value.as_bytes());
92 4 : self.params.put(&b"\0"[..]);
93 4 : }
94 :
95 5 : pub fn freeze(self) -> StartupMessageParams {
96 5 : StartupMessageParams {
97 5 : params: self.params.freeze(),
98 5 : }
99 5 : }
100 : }
101 :
102 : #[derive(Debug, Clone, Default)]
103 : pub struct StartupMessageParams {
104 : pub params: Bytes,
105 : }
106 :
107 : impl StartupMessageParams {
108 : /// Get parameter's value by its name.
109 5 : pub fn get(&self, name: &str) -> Option<&str> {
110 5 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
111 5 : }
112 :
113 : /// Split command-line options according to PostgreSQL's logic,
114 : /// taking into account all escape sequences but leaving them as-is.
115 : /// [`None`] means that there's no `options` in [`Self`].
116 0 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
117 0 : self.get("options").map(Self::parse_options_raw)
118 0 : }
119 :
120 : /// Split command-line options according to PostgreSQL's logic,
121 : /// applying all escape sequences (using owned strings as needed).
122 : /// [`None`] means that there's no `options` in [`Self`].
123 5 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
124 5 : self.get("options").map(Self::parse_options_escaped)
125 5 : }
126 :
127 : /// Split command-line options according to PostgreSQL's logic,
128 : /// taking into account all escape sequences but leaving them as-is.
129 4 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
130 : // See `postgres: pg_split_opts`.
131 4 : let mut last_was_escape = false;
132 4 : input
133 36 : .split(move |c: char| {
134 : // We split by non-escaped whitespace symbols.
135 36 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
136 36 : last_was_escape = c == '\\' && !last_was_escape;
137 36 : should_split
138 36 : })
139 11 : .filter(|s| !s.is_empty())
140 4 : }
141 :
142 : /// Split command-line options according to PostgreSQL's logic,
143 : /// applying all escape sequences (using owned strings as needed).
144 4 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
145 : // See `postgres: pg_split_opts`.
146 7 : Self::parse_options_raw(input).map(|s| {
147 7 : let mut preserve_next_escape = false;
148 17 : let escape = |c| {
149 : // We should remove '\\' unless it's preceded by '\\'.
150 17 : let should_remove = c == '\\' && !preserve_next_escape;
151 17 : preserve_next_escape = should_remove;
152 17 : should_remove
153 17 : };
154 :
155 7 : match s.contains('\\') {
156 3 : true => Cow::Owned(s.replace(escape, "")),
157 4 : false => Cow::Borrowed(s),
158 : }
159 7 : })
160 4 : }
161 :
162 : /// Iterate through key-value pairs in an arbitrary order.
163 5 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
164 5 : let params =
165 5 : std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
166 5 : params.split_terminator('\0').tuples()
167 5 : }
168 :
169 : // This function is mostly useful in tests.
170 : #[doc(hidden)]
171 5 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
172 5 : let mut b = StartupMessageParamsBuilder::default();
173 9 : for (k, v) in pairs {
174 4 : b.insert(k, v)
175 : }
176 5 : b.freeze()
177 5 : }
178 : }
179 :
180 0 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
181 : pub struct CancelKeyData {
182 : pub backend_pid: i32,
183 : pub cancel_key: i32,
184 : }
185 :
186 0 : pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
187 0 : CancelKeyData {
188 0 : backend_pid: (id >> 32) as i32,
189 0 : cancel_key: (id & 0xffffffff) as i32,
190 0 : }
191 0 : }
192 :
193 : impl fmt::Display for CancelKeyData {
194 1 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 1 : let hi = (self.backend_pid as u64) << 32;
196 1 : let lo = (self.cancel_key as u64) & 0xffffffff;
197 1 : let id = hi | lo;
198 :
199 : // This format is more compact and might work better for logs.
200 1 : f.debug_tuple("CancelKeyData")
201 1 : .field(&format_args!("{id:x}"))
202 1 : .finish()
203 1 : }
204 : }
205 :
206 : use rand::distributions::{Distribution, Standard};
207 : impl Distribution<CancelKeyData> for Standard {
208 0 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
209 0 : CancelKeyData {
210 0 : backend_pid: rng.r#gen(),
211 0 : cancel_key: rng.r#gen(),
212 0 : }
213 0 : }
214 : }
215 :
216 : // We only support the simple case of Parse on unnamed prepared statement and
217 : // no params
218 : #[derive(Debug)]
219 : pub struct FeParseMessage {
220 : pub query_string: Bytes,
221 : }
222 :
223 : #[derive(Debug)]
224 : pub struct FeDescribeMessage {
225 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
226 : // we only support unnamed prepared stmt or portal
227 : }
228 :
229 : // we only support unnamed prepared stmt and portal
230 : #[derive(Debug)]
231 : pub struct FeBindMessage;
232 :
233 : // we only support unnamed prepared stmt or portal
234 : #[derive(Debug)]
235 : pub struct FeExecuteMessage {
236 : /// max # of rows
237 : pub maxrows: i32,
238 : }
239 :
240 : // we only support unnamed prepared stmt and portal
241 : #[derive(Debug)]
242 : pub struct FeCloseMessage;
243 :
244 : /// An error occurred while parsing or serializing raw stream into Postgres
245 : /// messages.
246 : #[derive(thiserror::Error, Debug)]
247 : pub enum ProtocolError {
248 : /// Invalid packet was received from the client (e.g. unexpected message
249 : /// type or broken len).
250 : #[error("Protocol error: {0}")]
251 : Protocol(String),
252 : /// Failed to parse or, (unlikely), serialize a protocol message.
253 : #[error("Message parse error: {0}")]
254 : BadMessage(String),
255 : }
256 :
257 : impl ProtocolError {
258 : /// Proxy stream.rs uses only io::Error; provide it.
259 0 : pub fn into_io_error(self) -> io::Error {
260 0 : io::Error::other(self.to_string())
261 0 : }
262 : }
263 :
264 : impl FeMessage {
265 : /// Read and parse one message from the `buf` input buffer. If there is at
266 : /// least one valid message, returns it, advancing `buf`; redundant copies
267 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
268 : /// directly into the `buf` (processed data is garbage collected after
269 : /// parsed message is dropped).
270 : ///
271 : /// Returns None if `buf` doesn't contain enough data for a single message.
272 : /// For efficiency, tries to reserve large enough space in `buf` for the
273 : /// next message in this case to save the repeated calls.
274 : ///
275 : /// Returns Error if message is malformed, the only possible ErrorKind is
276 : /// InvalidInput.
277 : //
278 : // Inspired by rust-postgres Message::parse.
279 6 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
280 : // Every message contains message type byte and 4 bytes len; can't do
281 : // much without them.
282 6 : if buf.len() < 5 {
283 4 : let to_read = 5 - buf.len();
284 4 : buf.reserve(to_read);
285 4 : return Ok(None);
286 2 : }
287 :
288 : // We shouldn't advance `buf` as probably full message is not there yet,
289 : // so can't directly use Bytes::get_u32 etc.
290 2 : let tag = buf[0];
291 2 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
292 2 : if len < 4 {
293 0 : return Err(ProtocolError::Protocol(format!(
294 0 : "invalid message length {len}"
295 0 : )));
296 2 : }
297 :
298 : // length field includes itself, but not message type.
299 2 : let total_len = len as usize + 1;
300 2 : if buf.len() < total_len {
301 : // Don't have full message yet.
302 0 : let to_read = total_len - buf.len();
303 0 : buf.reserve(to_read);
304 0 : return Ok(None);
305 2 : }
306 :
307 : // got the message, advance buffer
308 2 : let mut msg = buf.split_to(total_len).freeze();
309 2 : msg.advance(5); // consume message type and len
310 :
311 2 : match tag {
312 2 : b'Q' => Ok(Some(FeMessage::Query(msg))),
313 0 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
314 0 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
315 0 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
316 0 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
317 0 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
318 0 : b'S' => Ok(Some(FeMessage::Sync)),
319 0 : b'X' => Ok(Some(FeMessage::Terminate)),
320 0 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
321 0 : b'c' => Ok(Some(FeMessage::CopyDone)),
322 0 : b'f' => Ok(Some(FeMessage::CopyFail)),
323 0 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
324 0 : tag => Err(ProtocolError::Protocol(format!(
325 0 : "unknown message tag: {tag},'{msg:?}'"
326 0 : ))),
327 : }
328 6 : }
329 : }
330 :
331 : impl FeStartupPacket {
332 : /// Read and parse startup message from the `buf` input buffer. It is
333 : /// different from [`FeMessage::parse`] because startup messages don't have
334 : /// message type byte; otherwise, its comments apply.
335 7 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
336 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
337 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
338 : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
339 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
340 : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
341 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
342 : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
343 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
344 : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
345 :
346 : // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
347 : // First byte indicates standard SSL handshake message
348 : // (It can't be a Postgres startup length because in network byte order
349 : // that would be a startup packet hundreds of megabytes long)
350 7 : if buf.first() == Some(&0x16) {
351 0 : return Ok(Some(FeStartupPacket::SslRequest { direct: true }));
352 7 : }
353 :
354 : // need at least 4 bytes with packet len
355 7 : if buf.len() < 4 {
356 3 : let to_read = 4 - buf.len();
357 3 : buf.reserve(to_read);
358 3 : return Ok(None);
359 4 : }
360 :
361 : // We shouldn't advance `buf` as probably full message is not there yet,
362 : // so can't directly use Bytes::get_u32 etc.
363 4 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
364 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
365 : // which is less readable
366 : #[allow(clippy::manual_range_contains)]
367 4 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
368 1 : return Err(ProtocolError::Protocol(format!(
369 1 : "invalid startup packet message length {len}"
370 1 : )));
371 3 : }
372 :
373 3 : if buf.len() < len {
374 : // Don't have full message yet.
375 0 : let to_read = len - buf.len();
376 0 : buf.reserve(to_read);
377 0 : return Ok(None);
378 3 : }
379 :
380 : // got the message, advance buffer
381 3 : let mut msg = buf.split_to(len).freeze();
382 3 : msg.advance(4); // consume len
383 :
384 3 : let request_code = ProtocolVersion(msg.get_u32());
385 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
386 3 : let message = match request_code {
387 : CANCEL_REQUEST_CODE => {
388 0 : if msg.remaining() != 8 {
389 0 : return Err(ProtocolError::BadMessage(
390 0 : "CancelRequest message is malformed, backend PID / secret key missing"
391 0 : .to_owned(),
392 0 : ));
393 0 : }
394 0 : FeStartupPacket::CancelRequest(CancelKeyData {
395 0 : backend_pid: msg.get_i32(),
396 0 : cancel_key: msg.get_i32(),
397 0 : })
398 : }
399 : NEGOTIATE_SSL_CODE => {
400 : // Requested upgrade to SSL (aka TLS)
401 1 : FeStartupPacket::SslRequest { direct: false }
402 : }
403 : NEGOTIATE_GSS_CODE => {
404 : // Requested upgrade to GSSAPI
405 0 : FeStartupPacket::GssEncRequest
406 : }
407 2 : version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
408 0 : return Err(ProtocolError::Protocol(format!(
409 0 : "Unrecognized request code {}",
410 0 : version.minor()
411 0 : )));
412 : }
413 : // TODO bail if protocol major_version is not 3?
414 2 : version => {
415 : // StartupMessage
416 :
417 2 : let s = str::from_utf8(&msg).map_err(|_e| {
418 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
419 0 : })?;
420 2 : let s = s.strip_suffix('\0').ok_or_else(|| {
421 0 : ProtocolError::Protocol(
422 0 : "StartupMessage params: missing null terminator".to_string(),
423 0 : )
424 0 : })?;
425 :
426 2 : FeStartupPacket::StartupMessage {
427 2 : version,
428 2 : params: StartupMessageParams {
429 2 : params: msg.slice_ref(s.as_bytes()),
430 2 : },
431 2 : }
432 : }
433 : };
434 3 : Ok(Some(message))
435 7 : }
436 : }
437 :
438 : impl FeParseMessage {
439 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
440 : // FIXME: the rust-postgres driver uses a named prepared statement
441 : // for copy_out(). We're not prepared to handle that correctly. For
442 : // now, just ignore the statement name, assuming that the client never
443 : // uses more than one prepared statement at a time.
444 :
445 0 : let _pstmt_name = read_cstr(&mut buf)?;
446 0 : let query_string = read_cstr(&mut buf)?;
447 0 : if buf.remaining() < 2 {
448 0 : return Err(ProtocolError::BadMessage(
449 0 : "Parse message is malformed, nparams missing".to_string(),
450 0 : ));
451 0 : }
452 0 : let nparams = buf.get_i16();
453 :
454 0 : if nparams != 0 {
455 0 : return Err(ProtocolError::BadMessage(
456 0 : "query params not implemented".to_string(),
457 0 : ));
458 0 : }
459 :
460 0 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
461 0 : }
462 : }
463 :
464 : impl FeDescribeMessage {
465 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
466 0 : let kind = buf.get_u8();
467 0 : let _pstmt_name = read_cstr(&mut buf)?;
468 :
469 : // FIXME: see FeParseMessage::parse
470 0 : if kind != b'S' {
471 0 : return Err(ProtocolError::BadMessage(
472 0 : "only prepared statemement Describe is implemented".to_string(),
473 0 : ));
474 0 : }
475 :
476 0 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
477 0 : }
478 : }
479 :
480 : impl FeExecuteMessage {
481 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
482 0 : let portal_name = read_cstr(&mut buf)?;
483 0 : if buf.remaining() < 4 {
484 0 : return Err(ProtocolError::BadMessage(
485 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
486 0 : ));
487 0 : }
488 0 : let maxrows = buf.get_i32();
489 :
490 0 : if !portal_name.is_empty() {
491 0 : return Err(ProtocolError::BadMessage(
492 0 : "named portals not implemented".to_string(),
493 0 : ));
494 0 : }
495 0 : if maxrows != 0 {
496 0 : return Err(ProtocolError::BadMessage(
497 0 : "row limit in Execute message not implemented".to_string(),
498 0 : ));
499 0 : }
500 :
501 0 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
502 0 : }
503 : }
504 :
505 : impl FeBindMessage {
506 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
507 0 : let portal_name = read_cstr(&mut buf)?;
508 0 : let _pstmt_name = read_cstr(&mut buf)?;
509 :
510 : // FIXME: see FeParseMessage::parse
511 0 : if !portal_name.is_empty() {
512 0 : return Err(ProtocolError::BadMessage(
513 0 : "named portals not implemented".to_string(),
514 0 : ));
515 0 : }
516 :
517 0 : Ok(FeMessage::Bind(FeBindMessage))
518 0 : }
519 : }
520 :
521 : impl FeCloseMessage {
522 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
523 0 : let _kind = buf.get_u8();
524 0 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
525 :
526 : // FIXME: we do nothing with Close
527 0 : Ok(FeMessage::Close(FeCloseMessage))
528 0 : }
529 : }
530 :
531 : // Backend
532 :
533 : #[derive(Debug)]
534 : pub enum BeMessage<'a> {
535 : AuthenticationOk,
536 : AuthenticationMD5Password([u8; 4]),
537 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
538 : AuthenticationCleartextPassword,
539 : BackendKeyData(CancelKeyData),
540 : BindComplete,
541 : CommandComplete(&'a [u8]),
542 : CopyData(&'a [u8]),
543 : CopyDone,
544 : CopyFail,
545 : CopyInResponse,
546 : CopyOutResponse,
547 : CopyBothResponse,
548 : CloseComplete,
549 : // None means column is NULL
550 : DataRow(&'a [Option<&'a [u8]>]),
551 : // None errcode means internal_error will be sent.
552 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
553 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
554 : EncryptionResponse(bool),
555 : NoData,
556 : ParameterDescription,
557 : ParameterStatus {
558 : name: &'a [u8],
559 : value: &'a [u8],
560 : },
561 : ParseComplete,
562 : ReadyForQuery,
563 : RowDescription(&'a [RowDescriptor<'a>]),
564 : XLogData(XLogDataBody<'a>),
565 : NoticeResponse(&'a str),
566 : NegotiateProtocolVersion {
567 : version: ProtocolVersion,
568 : options: &'a [&'a str],
569 : },
570 : KeepAlive(WalSndKeepAlive),
571 : /// Batch of interpreted, shard filtered WAL records,
572 : /// ready for the pageserver to ingest
573 : InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
574 :
575 : Raw(u8, &'a [u8]),
576 : }
577 :
578 : /// Common shorthands.
579 : impl<'a> BeMessage<'a> {
580 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
581 : /// This is a sensible default, given that:
582 : /// * rust strings only support this encoding out of the box.
583 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
584 : ///
585 : /// TODO: do we need to report `server_encoding` as well?
586 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
587 : name: b"client_encoding",
588 : value: b"UTF8",
589 : };
590 :
591 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
592 : name: b"integer_datetimes",
593 : value: b"on",
594 : };
595 :
596 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
597 2 : pub fn server_version(version: &'a str) -> Self {
598 2 : Self::ParameterStatus {
599 2 : name: b"server_version",
600 2 : value: version.as_bytes(),
601 2 : }
602 2 : }
603 : }
604 :
605 : #[derive(Debug)]
606 : pub enum BeAuthenticationSaslMessage<'a> {
607 : Methods(&'a [&'a str]),
608 : Continue(&'a [u8]),
609 : Final(&'a [u8]),
610 : }
611 :
612 : #[derive(Debug)]
613 : pub enum BeParameterStatusMessage<'a> {
614 : Encoding(&'a str),
615 : ServerVersion(&'a str),
616 : }
617 :
618 : // One row description in RowDescription packet.
619 : #[derive(Debug)]
620 : pub struct RowDescriptor<'a> {
621 : pub name: &'a [u8],
622 : pub tableoid: Oid,
623 : pub attnum: i16,
624 : pub typoid: Oid,
625 : pub typlen: i16,
626 : pub typmod: i32,
627 : pub formatcode: i16,
628 : }
629 :
630 : impl Default for RowDescriptor<'_> {
631 0 : fn default() -> RowDescriptor<'static> {
632 0 : RowDescriptor {
633 0 : name: b"",
634 0 : tableoid: 0,
635 0 : attnum: 0,
636 0 : typoid: 0,
637 0 : typlen: 0,
638 0 : typmod: 0,
639 0 : formatcode: 0,
640 0 : }
641 0 : }
642 : }
643 :
644 : impl RowDescriptor<'_> {
645 : /// Convenience function to create a RowDescriptor message for an int8 column
646 0 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
647 0 : RowDescriptor {
648 0 : name,
649 0 : tableoid: 0,
650 0 : attnum: 0,
651 0 : typoid: INT8_OID,
652 0 : typlen: 8,
653 0 : typmod: 0,
654 0 : formatcode: 0,
655 0 : }
656 0 : }
657 :
658 2 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
659 2 : RowDescriptor {
660 2 : name,
661 2 : tableoid: 0,
662 2 : attnum: 0,
663 2 : typoid: TEXT_OID,
664 2 : typlen: -1,
665 2 : typmod: 0,
666 2 : formatcode: 0,
667 2 : }
668 2 : }
669 : }
670 :
671 : #[derive(Debug)]
672 : pub struct XLogDataBody<'a> {
673 : pub wal_start: u64,
674 : pub wal_end: u64, // current end of WAL on the server
675 : pub timestamp: i64,
676 : pub data: &'a [u8],
677 : }
678 :
679 : #[derive(Debug)]
680 : pub struct WalSndKeepAlive {
681 : pub wal_end: u64, // current end of WAL on the server
682 : pub timestamp: i64,
683 : pub request_reply: bool,
684 : }
685 :
686 : /// Batch of interpreted WAL records used in the interpreted
687 : /// safekeeper to pageserver protocol.
688 : ///
689 : /// Note that the pageserver uses the RawInterpretedWalRecordsBody
690 : /// counterpart of this from the neondatabase/rust-postgres repo.
691 : /// If you're changing this struct, you likely need to change its
692 : /// twin as well.
693 : #[derive(Debug)]
694 : pub struct InterpretedWalRecordsBody<'a> {
695 : /// End of raw WAL in [`Self::data`]
696 : pub streaming_lsn: u64,
697 : /// Current end of WAL on the server
698 : pub commit_lsn: u64,
699 : pub data: &'a [u8],
700 : }
701 :
702 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
703 :
704 : // single text column
705 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
706 : name: b"data",
707 : tableoid: 0,
708 : attnum: 0,
709 : typoid: TEXT_OID,
710 : typlen: -1,
711 : typmod: 0,
712 : formatcode: 0,
713 : }]);
714 :
715 : /// Call f() to write body of the message and prepend it with 4-byte len as
716 : /// prescribed by the protocol.
717 18 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
718 18 : let base = buf.len();
719 18 : buf.extend_from_slice(&[0; 4]);
720 :
721 18 : let res = f(buf);
722 :
723 18 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
724 18 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
725 :
726 18 : res
727 18 : }
728 :
729 : /// Safe write of s into buf as cstring (String in the protocol).
730 16 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
731 16 : let bytes = s.as_ref();
732 16 : if bytes.contains(&0) {
733 0 : return Err(ProtocolError::BadMessage(
734 0 : "string contains embedded null".to_owned(),
735 0 : ));
736 16 : }
737 16 : buf.put_slice(bytes);
738 16 : buf.put_u8(0);
739 16 : Ok(())
740 16 : }
741 :
742 : /// Read cstring from buf, advancing it.
743 11 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
744 11 : let pos = buf
745 11 : .iter()
746 156 : .position(|x| *x == 0)
747 11 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
748 11 : let result = buf.split_to(pos);
749 11 : buf.advance(1); // drop the null terminator
750 11 : Ok(result)
751 11 : }
752 :
753 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
754 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
755 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
756 :
757 : impl BeMessage<'_> {
758 : /// Serialize `message` to the given `buf`.
759 : /// Apart from smart memory managemet, BytesMut is good here as msg len
760 : /// precedes its body and it is handy to write it down first and then fill
761 : /// the length. With Write we would have to either calc it manually or have
762 : /// one more buffer.
763 19 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
764 19 : match message {
765 0 : BeMessage::Raw(code, data) => {
766 0 : buf.put_u8(*code);
767 0 : write_body(buf, |b| b.put_slice(data))
768 : }
769 : BeMessage::AuthenticationOk => {
770 2 : buf.put_u8(b'R');
771 2 : write_body(buf, |buf| {
772 2 : buf.put_i32(0); // Specifies that the authentication was successful.
773 2 : });
774 : }
775 :
776 : BeMessage::AuthenticationCleartextPassword => {
777 0 : buf.put_u8(b'R');
778 0 : write_body(buf, |buf| {
779 0 : buf.put_i32(3); // Specifies that clear text password is required.
780 0 : });
781 : }
782 :
783 0 : BeMessage::AuthenticationMD5Password(salt) => {
784 0 : buf.put_u8(b'R');
785 0 : write_body(buf, |buf| {
786 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
787 0 : buf.put_slice(&salt[..]);
788 0 : });
789 : }
790 :
791 0 : BeMessage::AuthenticationSasl(msg) => {
792 0 : buf.put_u8(b'R');
793 0 : write_body(buf, |buf| {
794 : use BeAuthenticationSaslMessage::*;
795 0 : match msg {
796 0 : Methods(methods) => {
797 0 : buf.put_i32(10); // Specifies that SASL auth method is used.
798 0 : for method in methods.iter() {
799 0 : write_cstr(method, buf)?;
800 : }
801 0 : buf.put_u8(0); // zero terminator for the list
802 : }
803 0 : Continue(extra) => {
804 0 : buf.put_i32(11); // Continue SASL auth.
805 0 : buf.put_slice(extra);
806 0 : }
807 0 : Final(extra) => {
808 0 : buf.put_i32(12); // Send final SASL message.
809 0 : buf.put_slice(extra);
810 0 : }
811 : }
812 0 : Ok(())
813 0 : })?;
814 : }
815 :
816 0 : BeMessage::BackendKeyData(key_data) => {
817 0 : buf.put_u8(b'K');
818 0 : write_body(buf, |buf| {
819 0 : buf.put_i32(key_data.backend_pid);
820 0 : buf.put_i32(key_data.cancel_key);
821 0 : });
822 : }
823 :
824 : BeMessage::BindComplete => {
825 0 : buf.put_u8(b'2');
826 0 : write_body(buf, |_| {});
827 : }
828 :
829 : BeMessage::CloseComplete => {
830 0 : buf.put_u8(b'3');
831 0 : write_body(buf, |_| {});
832 : }
833 :
834 2 : BeMessage::CommandComplete(cmd) => {
835 2 : buf.put_u8(b'C');
836 2 : write_body(buf, |buf| write_cstr(cmd, buf))?;
837 : }
838 :
839 0 : BeMessage::CopyData(data) => {
840 0 : buf.put_u8(b'd');
841 0 : write_body(buf, |buf| {
842 0 : buf.put_slice(data);
843 0 : });
844 : }
845 :
846 : BeMessage::CopyDone => {
847 0 : buf.put_u8(b'c');
848 0 : write_body(buf, |_| {});
849 : }
850 :
851 : BeMessage::CopyFail => {
852 0 : buf.put_u8(b'f');
853 0 : write_body(buf, |_| {});
854 : }
855 :
856 : BeMessage::CopyInResponse => {
857 0 : buf.put_u8(b'G');
858 0 : write_body(buf, |buf| {
859 0 : buf.put_u8(1); // copy_is_binary
860 0 : buf.put_i16(0); // numAttributes
861 0 : });
862 : }
863 :
864 : BeMessage::CopyOutResponse => {
865 0 : buf.put_u8(b'H');
866 0 : write_body(buf, |buf| {
867 0 : buf.put_u8(0); // copy_is_binary
868 0 : buf.put_i16(0); // numAttributes
869 0 : });
870 : }
871 :
872 : BeMessage::CopyBothResponse => {
873 0 : buf.put_u8(b'W');
874 0 : write_body(buf, |buf| {
875 : // doesn't matter, used only for replication
876 0 : buf.put_u8(0); // copy_is_binary
877 0 : buf.put_i16(0); // numAttributes
878 0 : });
879 : }
880 :
881 2 : BeMessage::DataRow(vals) => {
882 2 : buf.put_u8(b'D');
883 2 : write_body(buf, |buf| {
884 2 : buf.put_u16(vals.len() as u16); // num of cols
885 2 : for val_opt in vals.iter() {
886 2 : if let Some(val) = val_opt {
887 2 : buf.put_u32(val.len() as u32);
888 2 : buf.put_slice(val);
889 2 : } else {
890 0 : buf.put_i32(-1);
891 0 : }
892 : }
893 2 : });
894 : }
895 :
896 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
897 : // First byte of each field represents type of this field. Set just enough fields
898 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
899 : // message text.
900 0 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
901 : // 'E' signalizes ErrorResponse messages
902 0 : buf.put_u8(b'E');
903 0 : write_body(buf, |buf| {
904 0 : buf.put_u8(b'S'); // severity
905 0 : buf.put_slice(b"ERROR\0");
906 :
907 0 : buf.put_u8(b'C'); // SQLSTATE error code
908 0 : buf.put_slice(&terminate_code(
909 0 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
910 0 : ));
911 :
912 0 : buf.put_u8(b'M'); // the message
913 0 : write_cstr(error_msg, buf)?;
914 :
915 0 : buf.put_u8(0); // terminator
916 0 : Ok(())
917 0 : })?;
918 : }
919 :
920 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
921 : // message but continue listening for ReadyForQuery or ErrorResponse"
922 0 : BeMessage::NoticeResponse(error_msg) => {
923 : // For all the errors set Severity to Error and error code to
924 : // 'internal error'.
925 :
926 : // 'N' signalizes NoticeResponse messages
927 0 : buf.put_u8(b'N');
928 0 : write_body(buf, |buf| {
929 0 : buf.put_u8(b'S'); // severity
930 0 : buf.put_slice(b"NOTICE\0");
931 :
932 0 : buf.put_u8(b'C'); // SQLSTATE error code
933 0 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
934 :
935 0 : buf.put_u8(b'M'); // the message
936 0 : write_cstr(error_msg.as_bytes(), buf)?;
937 :
938 0 : buf.put_u8(0); // terminator
939 0 : Ok(())
940 0 : })?;
941 : }
942 :
943 : BeMessage::NoData => {
944 0 : buf.put_u8(b'n');
945 0 : write_body(buf, |_| {});
946 : }
947 :
948 1 : BeMessage::EncryptionResponse(should_negotiate) => {
949 1 : let response = if *should_negotiate { b'S' } else { b'N' };
950 1 : buf.put_u8(response);
951 : }
952 :
953 6 : BeMessage::ParameterStatus { name, value } => {
954 6 : buf.put_u8(b'S');
955 6 : write_body(buf, |buf| {
956 6 : write_cstr(name, buf)?;
957 6 : write_cstr(value, buf)
958 6 : })?;
959 : }
960 :
961 : BeMessage::ParameterDescription => {
962 0 : buf.put_u8(b't');
963 0 : write_body(buf, |buf| {
964 : // we don't support params, so always 0
965 0 : buf.put_i16(0);
966 0 : });
967 : }
968 :
969 : BeMessage::ParseComplete => {
970 0 : buf.put_u8(b'1');
971 0 : write_body(buf, |_| {});
972 : }
973 :
974 : BeMessage::ReadyForQuery => {
975 4 : buf.put_u8(b'Z');
976 4 : write_body(buf, |buf| {
977 4 : buf.put_u8(b'I');
978 4 : });
979 : }
980 :
981 2 : BeMessage::RowDescription(rows) => {
982 2 : buf.put_u8(b'T');
983 2 : write_body(buf, |buf| {
984 2 : buf.put_i16(rows.len() as i16); // # of fields
985 2 : for row in rows.iter() {
986 2 : write_cstr(row.name, buf)?;
987 2 : buf.put_i32(0); /* table oid */
988 2 : buf.put_i16(0); /* attnum */
989 2 : buf.put_u32(row.typoid);
990 2 : buf.put_i16(row.typlen);
991 2 : buf.put_i32(-1); /* typmod */
992 2 : buf.put_i16(0); /* format code */
993 : }
994 2 : Ok(())
995 2 : })?;
996 : }
997 :
998 0 : BeMessage::XLogData(body) => {
999 0 : buf.put_u8(b'd');
1000 0 : write_body(buf, |buf| {
1001 0 : buf.put_u8(b'w');
1002 0 : buf.put_u64(body.wal_start);
1003 0 : buf.put_u64(body.wal_end);
1004 0 : buf.put_i64(body.timestamp);
1005 0 : buf.put_slice(body.data);
1006 0 : });
1007 : }
1008 :
1009 0 : BeMessage::KeepAlive(req) => {
1010 0 : buf.put_u8(b'd');
1011 0 : write_body(buf, |buf| {
1012 0 : buf.put_u8(b'k');
1013 0 : buf.put_u64(req.wal_end);
1014 0 : buf.put_i64(req.timestamp);
1015 0 : buf.put_u8(u8::from(req.request_reply));
1016 0 : });
1017 : }
1018 :
1019 0 : BeMessage::NegotiateProtocolVersion { version, options } => {
1020 0 : buf.put_u8(b'v');
1021 0 : write_body(buf, |buf| {
1022 0 : buf.put_u32(version.0);
1023 0 : buf.put_u32(options.len() as u32);
1024 0 : for option in options.iter() {
1025 0 : write_cstr(option, buf)?;
1026 : }
1027 0 : Ok(())
1028 0 : })?
1029 : }
1030 :
1031 0 : BeMessage::InterpretedWalRecords(rec) => {
1032 : // We use the COPY_DATA_TAG for our custom message
1033 : // since this tag is interpreted as raw bytes.
1034 0 : buf.put_u8(b'd');
1035 0 : write_body(buf, |buf| {
1036 0 : buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
1037 : // dependency
1038 0 : buf.put_u64(rec.streaming_lsn);
1039 0 : buf.put_u64(rec.commit_lsn);
1040 0 : buf.put_slice(rec.data);
1041 0 : });
1042 : }
1043 : }
1044 19 : Ok(())
1045 19 : }
1046 : }
1047 :
1048 0 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
1049 0 : let mut terminated = [0; 6];
1050 0 : for (i, &elem) in code.iter().enumerate() {
1051 0 : terminated[i] = elem;
1052 0 : }
1053 :
1054 0 : terminated
1055 0 : }
1056 :
1057 : #[cfg(test)]
1058 : mod tests {
1059 : use super::*;
1060 :
1061 : #[test]
1062 1 : fn test_startup_message_params_options_escaped() {
1063 4 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
1064 4 : params
1065 4 : .options_escaped()
1066 4 : .expect("options are None")
1067 4 : .collect()
1068 4 : }
1069 :
1070 4 : let make_params = |options| StartupMessageParams::new([("options", options)]);
1071 :
1072 1 : let params = StartupMessageParams::new([]);
1073 1 : assert!(params.options_escaped().is_none());
1074 :
1075 1 : let params = make_params("");
1076 1 : assert!(split_options(¶ms).is_empty());
1077 :
1078 1 : let params = make_params("foo");
1079 1 : assert_eq!(split_options(¶ms), ["foo"]);
1080 :
1081 1 : let params = make_params(" foo bar ");
1082 1 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
1083 :
1084 1 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
1085 1 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
1086 1 : }
1087 :
1088 : #[test]
1089 1 : fn parse_fe_startup_packet_regression() {
1090 1 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
1091 1 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
1092 1 : }
1093 :
1094 : #[test]
1095 1 : fn cancel_key_data() {
1096 1 : let key = CancelKeyData {
1097 1 : backend_pid: -1817212860,
1098 1 : cancel_key: -1183897012,
1099 1 : };
1100 1 : assert_eq!(format!("{key}"), "CancelKeyData(93af8844b96f2a4c)");
1101 1 : }
1102 : }
|