Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use std::borrow::Cow;
9 : use std::{fmt, io, str};
10 :
11 : use byteorder::{BigEndian, ReadBytesExt};
12 : use bytes::{Buf, BufMut, Bytes, BytesMut};
13 : use itertools::Itertools;
14 : // re-export for use in utils pageserver_feedback.rs
15 : pub use postgres_protocol::PG_EPOCH;
16 : use serde::{Deserialize, Serialize};
17 :
18 : pub type Oid = u32;
19 : pub type SystemId = u64;
20 :
21 : pub const INT8_OID: Oid = 20;
22 : pub const INT4_OID: Oid = 23;
23 : pub const TEXT_OID: Oid = 25;
24 :
25 : #[derive(Debug)]
26 : pub enum FeMessage {
27 : // Simple query.
28 : Query(Bytes),
29 : // Extended query protocol.
30 : Parse(FeParseMessage),
31 : Describe(FeDescribeMessage),
32 : Bind(FeBindMessage),
33 : Execute(FeExecuteMessage),
34 : Close(FeCloseMessage),
35 : Sync,
36 : Terminate,
37 : CopyData(Bytes),
38 : CopyDone,
39 : CopyFail,
40 : PasswordMessage(Bytes),
41 : }
42 :
43 : #[derive(Clone, Copy, PartialEq, PartialOrd)]
44 : pub struct ProtocolVersion(u32);
45 :
46 : impl ProtocolVersion {
47 0 : pub const fn new(major: u16, minor: u16) -> Self {
48 0 : Self(((major as u32) << 16) | minor as u32)
49 0 : }
50 0 : pub const fn minor(self) -> u16 {
51 0 : self.0 as u16
52 0 : }
53 24 : pub const fn major(self) -> u16 {
54 24 : (self.0 >> 16) as u16
55 24 : }
56 : }
57 :
58 : impl fmt::Debug for ProtocolVersion {
59 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 0 : f.debug_list()
61 0 : .entry(&self.major())
62 0 : .entry(&self.minor())
63 0 : .finish()
64 0 : }
65 : }
66 :
67 : #[derive(Debug)]
68 : pub enum FeStartupPacket {
69 : CancelRequest(CancelKeyData),
70 : SslRequest {
71 : direct: bool,
72 : },
73 : GssEncRequest,
74 : StartupMessage {
75 : version: ProtocolVersion,
76 : params: StartupMessageParams,
77 : },
78 : }
79 :
80 : #[derive(Debug, Clone, Default)]
81 : pub struct StartupMessageParamsBuilder {
82 : params: BytesMut,
83 : }
84 :
85 : impl StartupMessageParamsBuilder {
86 : /// Set parameter's value by its name.
87 : /// name and value must not contain a \0 byte
88 25 : pub fn insert(&mut self, name: &str, value: &str) {
89 25 : self.params.put(name.as_bytes());
90 25 : self.params.put(&b"\0"[..]);
91 25 : self.params.put(value.as_bytes());
92 25 : self.params.put(&b"\0"[..]);
93 25 : }
94 :
95 17 : pub fn freeze(self) -> StartupMessageParams {
96 17 : StartupMessageParams {
97 17 : params: self.params.freeze(),
98 17 : }
99 17 : }
100 : }
101 :
102 : #[derive(Debug, Clone, Default)]
103 : pub struct StartupMessageParams {
104 : pub params: Bytes,
105 : }
106 :
107 : impl StartupMessageParams {
108 : /// Get parameter's value by its name.
109 42 : pub fn get(&self, name: &str) -> Option<&str> {
110 58 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
111 42 : }
112 :
113 : /// Split command-line options according to PostgreSQL's logic,
114 : /// taking into account all escape sequences but leaving them as-is.
115 : /// [`None`] means that there's no `options` in [`Self`].
116 24 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
117 24 : self.get("options").map(Self::parse_options_raw)
118 24 : }
119 :
120 : /// Split command-line options according to PostgreSQL's logic,
121 : /// applying all escape sequences (using owned strings as needed).
122 : /// [`None`] means that there's no `options` in [`Self`].
123 5 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
124 5 : self.get("options").map(Self::parse_options_escaped)
125 5 : }
126 :
127 : /// Split command-line options according to PostgreSQL's logic,
128 : /// taking into account all escape sequences but leaving them as-is.
129 30 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
130 30 : // See `postgres: pg_split_opts`.
131 30 : let mut last_was_escape = false;
132 30 : input
133 36 : .split(move |c: char| {
134 : // We split by non-escaped whitespace symbols.
135 594 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
136 594 : last_was_escape = c == '\\' && !last_was_escape;
137 594 : should_split
138 594 : })
139 77 : .filter(|s| !s.is_empty())
140 30 : }
141 :
142 : /// Split command-line options according to PostgreSQL's logic,
143 : /// applying all escape sequences (using owned strings as needed).
144 4 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
145 4 : // See `postgres: pg_split_opts`.
146 7 : Self::parse_options_raw(input).map(|s| {
147 7 : let mut preserve_next_escape = false;
148 17 : let escape = |c| {
149 : // We should remove '\\' unless it's preceded by '\\'.
150 17 : let should_remove = c == '\\' && !preserve_next_escape;
151 17 : preserve_next_escape = should_remove;
152 17 : should_remove
153 17 : };
154 :
155 7 : match s.contains('\\') {
156 3 : true => Cow::Owned(s.replace(escape, "")),
157 4 : false => Cow::Borrowed(s),
158 : }
159 7 : })
160 4 : }
161 :
162 : /// Iterate through key-value pairs in an arbitrary order.
163 42 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
164 42 : let params =
165 42 : std::str::from_utf8(&self.params).expect("should be validated as utf8 already");
166 42 : params.split_terminator('\0').tuples()
167 42 : }
168 :
169 : // This function is mostly useful in tests.
170 : #[doc(hidden)]
171 17 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
172 17 : let mut b = StartupMessageParamsBuilder::default();
173 42 : for (k, v) in pairs {
174 25 : b.insert(k, v)
175 : }
176 17 : b.freeze()
177 17 : }
178 : }
179 :
180 0 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
181 : pub struct CancelKeyData {
182 : pub backend_pid: i32,
183 : pub cancel_key: i32,
184 : }
185 :
186 1 : pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
187 1 : CancelKeyData {
188 1 : backend_pid: (id >> 32) as i32,
189 1 : cancel_key: (id & 0xffffffff) as i32,
190 1 : }
191 1 : }
192 :
193 : impl fmt::Display for CancelKeyData {
194 1 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 1 : let hi = (self.backend_pid as u64) << 32;
196 1 : let lo = (self.cancel_key as u64) & 0xffffffff;
197 1 : let id = hi | lo;
198 1 :
199 1 : // This format is more compact and might work better for logs.
200 1 : f.debug_tuple("CancelKeyData")
201 1 : .field(&format_args!("{:x}", id))
202 1 : .finish()
203 1 : }
204 : }
205 :
206 : use rand::distributions::{Distribution, Standard};
207 : impl Distribution<CancelKeyData> for Standard {
208 0 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
209 0 : CancelKeyData {
210 0 : backend_pid: rng.r#gen(),
211 0 : cancel_key: rng.r#gen(),
212 0 : }
213 0 : }
214 : }
215 :
216 : // We only support the simple case of Parse on unnamed prepared statement and
217 : // no params
218 : #[derive(Debug)]
219 : pub struct FeParseMessage {
220 : pub query_string: Bytes,
221 : }
222 :
223 : #[derive(Debug)]
224 : pub struct FeDescribeMessage {
225 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
226 : // we only support unnamed prepared stmt or portal
227 : }
228 :
229 : // we only support unnamed prepared stmt and portal
230 : #[derive(Debug)]
231 : pub struct FeBindMessage;
232 :
233 : // we only support unnamed prepared stmt or portal
234 : #[derive(Debug)]
235 : pub struct FeExecuteMessage {
236 : /// max # of rows
237 : pub maxrows: i32,
238 : }
239 :
240 : // we only support unnamed prepared stmt and portal
241 : #[derive(Debug)]
242 : pub struct FeCloseMessage;
243 :
244 : /// An error occurred while parsing or serializing raw stream into Postgres
245 : /// messages.
246 : #[derive(thiserror::Error, Debug)]
247 : pub enum ProtocolError {
248 : /// Invalid packet was received from the client (e.g. unexpected message
249 : /// type or broken len).
250 : #[error("Protocol error: {0}")]
251 : Protocol(String),
252 : /// Failed to parse or, (unlikely), serialize a protocol message.
253 : #[error("Message parse error: {0}")]
254 : BadMessage(String),
255 : }
256 :
257 : impl ProtocolError {
258 : /// Proxy stream.rs uses only io::Error; provide it.
259 0 : pub fn into_io_error(self) -> io::Error {
260 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
261 0 : }
262 : }
263 :
264 : impl FeMessage {
265 : /// Read and parse one message from the `buf` input buffer. If there is at
266 : /// least one valid message, returns it, advancing `buf`; redundant copies
267 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
268 : /// directly into the `buf` (processed data is garbage collected after
269 : /// parsed message is dropped).
270 : ///
271 : /// Returns None if `buf` doesn't contain enough data for a single message.
272 : /// For efficiency, tries to reserve large enough space in `buf` for the
273 : /// next message in this case to save the repeated calls.
274 : ///
275 : /// Returns Error if message is malformed, the only possible ErrorKind is
276 : /// InvalidInput.
277 : //
278 : // Inspired by rust-postgres Message::parse.
279 57 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
280 57 : // Every message contains message type byte and 4 bytes len; can't do
281 57 : // much without them.
282 57 : if buf.len() < 5 {
283 30 : let to_read = 5 - buf.len();
284 30 : buf.reserve(to_read);
285 30 : return Ok(None);
286 27 : }
287 27 :
288 27 : // We shouldn't advance `buf` as probably full message is not there yet,
289 27 : // so can't directly use Bytes::get_u32 etc.
290 27 : let tag = buf[0];
291 27 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
292 27 : if len < 4 {
293 0 : return Err(ProtocolError::Protocol(format!(
294 0 : "invalid message length {}",
295 0 : len
296 0 : )));
297 27 : }
298 27 :
299 27 : // length field includes itself, but not message type.
300 27 : let total_len = len as usize + 1;
301 27 : if buf.len() < total_len {
302 : // Don't have full message yet.
303 0 : let to_read = total_len - buf.len();
304 0 : buf.reserve(to_read);
305 0 : return Ok(None);
306 27 : }
307 27 :
308 27 : // got the message, advance buffer
309 27 : let mut msg = buf.split_to(total_len).freeze();
310 27 : msg.advance(5); // consume message type and len
311 27 :
312 27 : match tag {
313 2 : b'Q' => Ok(Some(FeMessage::Query(msg))),
314 0 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
315 0 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
316 0 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
317 0 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
318 0 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
319 0 : b'S' => Ok(Some(FeMessage::Sync)),
320 0 : b'X' => Ok(Some(FeMessage::Terminate)),
321 0 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
322 0 : b'c' => Ok(Some(FeMessage::CopyDone)),
323 0 : b'f' => Ok(Some(FeMessage::CopyFail)),
324 25 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
325 0 : tag => Err(ProtocolError::Protocol(format!(
326 0 : "unknown message tag: {tag},'{msg:?}'"
327 0 : ))),
328 : }
329 57 : }
330 : }
331 :
332 : impl FeStartupPacket {
333 : /// Read and parse startup message from the `buf` input buffer. It is
334 : /// different from [`FeMessage::parse`] because startup messages don't have
335 : /// message type byte; otherwise, its comments apply.
336 91 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
337 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
338 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
339 : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
340 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
341 : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
342 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
343 : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
344 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
345 : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
346 :
347 : // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
348 : // First byte indicates standard SSL handshake message
349 : // (It can't be a Postgres startup length because in network byte order
350 : // that would be a startup packet hundreds of megabytes long)
351 91 : if buf.first() == Some(&0x16) {
352 0 : return Ok(Some(FeStartupPacket::SslRequest { direct: true }));
353 91 : }
354 91 :
355 91 : // need at least 4 bytes with packet len
356 91 : if buf.len() < 4 {
357 45 : let to_read = 4 - buf.len();
358 45 : buf.reserve(to_read);
359 45 : return Ok(None);
360 46 : }
361 46 :
362 46 : // We shouldn't advance `buf` as probably full message is not there yet,
363 46 : // so can't directly use Bytes::get_u32 etc.
364 46 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
365 46 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
366 46 : // which is less readable
367 46 : #[allow(clippy::manual_range_contains)]
368 46 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
369 1 : return Err(ProtocolError::Protocol(format!(
370 1 : "invalid startup packet message length {}",
371 1 : len
372 1 : )));
373 45 : }
374 45 :
375 45 : if buf.len() < len {
376 : // Don't have full message yet.
377 0 : let to_read = len - buf.len();
378 0 : buf.reserve(to_read);
379 0 : return Ok(None);
380 45 : }
381 45 :
382 45 : // got the message, advance buffer
383 45 : let mut msg = buf.split_to(len).freeze();
384 45 : msg.advance(4); // consume len
385 45 :
386 45 : let request_code = ProtocolVersion(msg.get_u32());
387 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
388 45 : let message = match request_code {
389 : CANCEL_REQUEST_CODE => {
390 0 : if msg.remaining() != 8 {
391 0 : return Err(ProtocolError::BadMessage(
392 0 : "CancelRequest message is malformed, backend PID / secret key missing"
393 0 : .to_owned(),
394 0 : ));
395 0 : }
396 0 : FeStartupPacket::CancelRequest(CancelKeyData {
397 0 : backend_pid: msg.get_i32(),
398 0 : cancel_key: msg.get_i32(),
399 0 : })
400 : }
401 : NEGOTIATE_SSL_CODE => {
402 : // Requested upgrade to SSL (aka TLS)
403 21 : FeStartupPacket::SslRequest { direct: false }
404 : }
405 : NEGOTIATE_GSS_CODE => {
406 : // Requested upgrade to GSSAPI
407 0 : FeStartupPacket::GssEncRequest
408 : }
409 24 : version if version.major() == RESERVED_INVALID_MAJOR_VERSION => {
410 0 : return Err(ProtocolError::Protocol(format!(
411 0 : "Unrecognized request code {}",
412 0 : version.minor()
413 0 : )));
414 : }
415 : // TODO bail if protocol major_version is not 3?
416 24 : version => {
417 : // StartupMessage
418 :
419 24 : let s = str::from_utf8(&msg).map_err(|_e| {
420 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
421 24 : })?;
422 24 : let s = s.strip_suffix('\0').ok_or_else(|| {
423 0 : ProtocolError::Protocol(
424 0 : "StartupMessage params: missing null terminator".to_string(),
425 0 : )
426 24 : })?;
427 :
428 24 : FeStartupPacket::StartupMessage {
429 24 : version,
430 24 : params: StartupMessageParams {
431 24 : params: msg.slice_ref(s.as_bytes()),
432 24 : },
433 24 : }
434 : }
435 : };
436 45 : Ok(Some(message))
437 91 : }
438 : }
439 :
440 : impl FeParseMessage {
441 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
442 : // FIXME: the rust-postgres driver uses a named prepared statement
443 : // for copy_out(). We're not prepared to handle that correctly. For
444 : // now, just ignore the statement name, assuming that the client never
445 : // uses more than one prepared statement at a time.
446 :
447 0 : let _pstmt_name = read_cstr(&mut buf)?;
448 0 : let query_string = read_cstr(&mut buf)?;
449 0 : if buf.remaining() < 2 {
450 0 : return Err(ProtocolError::BadMessage(
451 0 : "Parse message is malformed, nparams missing".to_string(),
452 0 : ));
453 0 : }
454 0 : let nparams = buf.get_i16();
455 0 :
456 0 : if nparams != 0 {
457 0 : return Err(ProtocolError::BadMessage(
458 0 : "query params not implemented".to_string(),
459 0 : ));
460 0 : }
461 0 :
462 0 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
463 0 : }
464 : }
465 :
466 : impl FeDescribeMessage {
467 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
468 0 : let kind = buf.get_u8();
469 0 : let _pstmt_name = read_cstr(&mut buf)?;
470 :
471 : // FIXME: see FeParseMessage::parse
472 0 : if kind != b'S' {
473 0 : return Err(ProtocolError::BadMessage(
474 0 : "only prepared statemement Describe is implemented".to_string(),
475 0 : ));
476 0 : }
477 0 :
478 0 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
479 0 : }
480 : }
481 :
482 : impl FeExecuteMessage {
483 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
484 0 : let portal_name = read_cstr(&mut buf)?;
485 0 : if buf.remaining() < 4 {
486 0 : return Err(ProtocolError::BadMessage(
487 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
488 0 : ));
489 0 : }
490 0 : let maxrows = buf.get_i32();
491 0 :
492 0 : if !portal_name.is_empty() {
493 0 : return Err(ProtocolError::BadMessage(
494 0 : "named portals not implemented".to_string(),
495 0 : ));
496 0 : }
497 0 : if maxrows != 0 {
498 0 : return Err(ProtocolError::BadMessage(
499 0 : "row limit in Execute message not implemented".to_string(),
500 0 : ));
501 0 : }
502 0 :
503 0 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
504 0 : }
505 : }
506 :
507 : impl FeBindMessage {
508 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
509 0 : let portal_name = read_cstr(&mut buf)?;
510 0 : let _pstmt_name = read_cstr(&mut buf)?;
511 :
512 : // FIXME: see FeParseMessage::parse
513 0 : if !portal_name.is_empty() {
514 0 : return Err(ProtocolError::BadMessage(
515 0 : "named portals not implemented".to_string(),
516 0 : ));
517 0 : }
518 0 :
519 0 : Ok(FeMessage::Bind(FeBindMessage))
520 0 : }
521 : }
522 :
523 : impl FeCloseMessage {
524 0 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
525 0 : let _kind = buf.get_u8();
526 0 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
527 :
528 : // FIXME: we do nothing with Close
529 0 : Ok(FeMessage::Close(FeCloseMessage))
530 0 : }
531 : }
532 :
533 : // Backend
534 :
535 : #[derive(Debug)]
536 : pub enum BeMessage<'a> {
537 : AuthenticationOk,
538 : AuthenticationMD5Password([u8; 4]),
539 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
540 : AuthenticationCleartextPassword,
541 : BackendKeyData(CancelKeyData),
542 : BindComplete,
543 : CommandComplete(&'a [u8]),
544 : CopyData(&'a [u8]),
545 : CopyDone,
546 : CopyFail,
547 : CopyInResponse,
548 : CopyOutResponse,
549 : CopyBothResponse,
550 : CloseComplete,
551 : // None means column is NULL
552 : DataRow(&'a [Option<&'a [u8]>]),
553 : // None errcode means internal_error will be sent.
554 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
555 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
556 : EncryptionResponse(bool),
557 : NoData,
558 : ParameterDescription,
559 : ParameterStatus {
560 : name: &'a [u8],
561 : value: &'a [u8],
562 : },
563 : ParseComplete,
564 : ReadyForQuery,
565 : RowDescription(&'a [RowDescriptor<'a>]),
566 : XLogData(XLogDataBody<'a>),
567 : NoticeResponse(&'a str),
568 : NegotiateProtocolVersion {
569 : version: ProtocolVersion,
570 : options: &'a [&'a str],
571 : },
572 : KeepAlive(WalSndKeepAlive),
573 : /// Batch of interpreted, shard filtered WAL records,
574 : /// ready for the pageserver to ingest
575 : InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
576 :
577 : Raw(u8, &'a [u8]),
578 : }
579 :
580 : /// Common shorthands.
581 : impl<'a> BeMessage<'a> {
582 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
583 : /// This is a sensible default, given that:
584 : /// * rust strings only support this encoding out of the box.
585 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
586 : ///
587 : /// TODO: do we need to report `server_encoding` as well?
588 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
589 : name: b"client_encoding",
590 : value: b"UTF8",
591 : };
592 :
593 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
594 : name: b"integer_datetimes",
595 : value: b"on",
596 : };
597 :
598 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
599 2 : pub fn server_version(version: &'a str) -> Self {
600 2 : Self::ParameterStatus {
601 2 : name: b"server_version",
602 2 : value: version.as_bytes(),
603 2 : }
604 2 : }
605 : }
606 :
607 : #[derive(Debug)]
608 : pub enum BeAuthenticationSaslMessage<'a> {
609 : Methods(&'a [&'a str]),
610 : Continue(&'a [u8]),
611 : Final(&'a [u8]),
612 : }
613 :
614 : #[derive(Debug)]
615 : pub enum BeParameterStatusMessage<'a> {
616 : Encoding(&'a str),
617 : ServerVersion(&'a str),
618 : }
619 :
620 : // One row description in RowDescription packet.
621 : #[derive(Debug)]
622 : pub struct RowDescriptor<'a> {
623 : pub name: &'a [u8],
624 : pub tableoid: Oid,
625 : pub attnum: i16,
626 : pub typoid: Oid,
627 : pub typlen: i16,
628 : pub typmod: i32,
629 : pub formatcode: i16,
630 : }
631 :
632 : impl Default for RowDescriptor<'_> {
633 0 : fn default() -> RowDescriptor<'static> {
634 0 : RowDescriptor {
635 0 : name: b"",
636 0 : tableoid: 0,
637 0 : attnum: 0,
638 0 : typoid: 0,
639 0 : typlen: 0,
640 0 : typmod: 0,
641 0 : formatcode: 0,
642 0 : }
643 0 : }
644 : }
645 :
646 : impl RowDescriptor<'_> {
647 : /// Convenience function to create a RowDescriptor message for an int8 column
648 0 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
649 0 : RowDescriptor {
650 0 : name,
651 0 : tableoid: 0,
652 0 : attnum: 0,
653 0 : typoid: INT8_OID,
654 0 : typlen: 8,
655 0 : typmod: 0,
656 0 : formatcode: 0,
657 0 : }
658 0 : }
659 :
660 2 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
661 2 : RowDescriptor {
662 2 : name,
663 2 : tableoid: 0,
664 2 : attnum: 0,
665 2 : typoid: TEXT_OID,
666 2 : typlen: -1,
667 2 : typmod: 0,
668 2 : formatcode: 0,
669 2 : }
670 2 : }
671 : }
672 :
673 : #[derive(Debug)]
674 : pub struct XLogDataBody<'a> {
675 : pub wal_start: u64,
676 : pub wal_end: u64, // current end of WAL on the server
677 : pub timestamp: i64,
678 : pub data: &'a [u8],
679 : }
680 :
681 : #[derive(Debug)]
682 : pub struct WalSndKeepAlive {
683 : pub wal_end: u64, // current end of WAL on the server
684 : pub timestamp: i64,
685 : pub request_reply: bool,
686 : }
687 :
688 : /// Batch of interpreted WAL records used in the interpreted
689 : /// safekeeper to pageserver protocol.
690 : ///
691 : /// Note that the pageserver uses the RawInterpretedWalRecordsBody
692 : /// counterpart of this from the neondatabase/rust-postgres repo.
693 : /// If you're changing this struct, you likely need to change its
694 : /// twin as well.
695 : #[derive(Debug)]
696 : pub struct InterpretedWalRecordsBody<'a> {
697 : /// End of raw WAL in [`Self::data`]
698 : pub streaming_lsn: u64,
699 : /// Current end of WAL on the server
700 : pub commit_lsn: u64,
701 : pub data: &'a [u8],
702 : }
703 :
704 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
705 :
706 : // single text column
707 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
708 : name: b"data",
709 : tableoid: 0,
710 : attnum: 0,
711 : typoid: TEXT_OID,
712 : typlen: -1,
713 : typmod: 0,
714 : formatcode: 0,
715 : }]);
716 :
717 : /// Call f() to write body of the message and prepend it with 4-byte len as
718 : /// prescribed by the protocol.
719 75 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
720 75 : let base = buf.len();
721 75 : buf.extend_from_slice(&[0; 4]);
722 75 :
723 75 : let res = f(buf);
724 75 :
725 75 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
726 75 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
727 75 :
728 75 : res
729 75 : }
730 :
731 : /// Safe write of s into buf as cstring (String in the protocol).
732 56 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
733 56 : let bytes = s.as_ref();
734 56 : if bytes.contains(&0) {
735 0 : return Err(ProtocolError::BadMessage(
736 0 : "string contains embedded null".to_owned(),
737 0 : ));
738 56 : }
739 56 : buf.put_slice(bytes);
740 56 : buf.put_u8(0);
741 56 : Ok(())
742 56 : }
743 :
744 : /// Read cstring from buf, advancing it.
745 11 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
746 11 : let pos = buf
747 11 : .iter()
748 156 : .position(|x| *x == 0)
749 11 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
750 11 : let result = buf.split_to(pos);
751 11 : buf.advance(1); // drop the null terminator
752 11 : Ok(result)
753 11 : }
754 :
755 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
756 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
757 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
758 :
759 : impl BeMessage<'_> {
760 : /// Serialize `message` to the given `buf`.
761 : /// Apart from smart memory managemet, BytesMut is good here as msg len
762 : /// precedes its body and it is handy to write it down first and then fill
763 : /// the length. With Write we would have to either calc it manually or have
764 : /// one more buffer.
765 96 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
766 96 : match message {
767 0 : BeMessage::Raw(code, data) => {
768 0 : buf.put_u8(*code);
769 0 : write_body(buf, |b| b.put_slice(data))
770 : }
771 12 : BeMessage::AuthenticationOk => {
772 12 : buf.put_u8(b'R');
773 12 : write_body(buf, |buf| {
774 12 : buf.put_i32(0); // Specifies that the authentication was successful.
775 12 : });
776 12 : }
777 :
778 2 : BeMessage::AuthenticationCleartextPassword => {
779 2 : buf.put_u8(b'R');
780 2 : write_body(buf, |buf| {
781 2 : buf.put_i32(3); // Specifies that clear text password is required.
782 2 : });
783 2 : }
784 :
785 0 : BeMessage::AuthenticationMD5Password(salt) => {
786 0 : buf.put_u8(b'R');
787 0 : write_body(buf, |buf| {
788 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
789 0 : buf.put_slice(&salt[..]);
790 0 : });
791 0 : }
792 :
793 30 : BeMessage::AuthenticationSasl(msg) => {
794 30 : buf.put_u8(b'R');
795 30 : write_body(buf, |buf| {
796 : use BeAuthenticationSaslMessage::*;
797 30 : match msg {
798 13 : Methods(methods) => {
799 13 : buf.put_i32(10); // Specifies that SASL auth method is used.
800 25 : for method in methods.iter() {
801 25 : write_cstr(method, buf)?;
802 : }
803 13 : buf.put_u8(0); // zero terminator for the list
804 : }
805 11 : Continue(extra) => {
806 11 : buf.put_i32(11); // Continue SASL auth.
807 11 : buf.put_slice(extra);
808 11 : }
809 6 : Final(extra) => {
810 6 : buf.put_i32(12); // Send final SASL message.
811 6 : buf.put_slice(extra);
812 6 : }
813 : }
814 30 : Ok(())
815 30 : })?;
816 : }
817 :
818 0 : BeMessage::BackendKeyData(key_data) => {
819 0 : buf.put_u8(b'K');
820 0 : write_body(buf, |buf| {
821 0 : buf.put_i32(key_data.backend_pid);
822 0 : buf.put_i32(key_data.cancel_key);
823 0 : });
824 0 : }
825 :
826 0 : BeMessage::BindComplete => {
827 0 : buf.put_u8(b'2');
828 0 : write_body(buf, |_| {});
829 0 : }
830 :
831 0 : BeMessage::CloseComplete => {
832 0 : buf.put_u8(b'3');
833 0 : write_body(buf, |_| {});
834 0 : }
835 :
836 2 : BeMessage::CommandComplete(cmd) => {
837 2 : buf.put_u8(b'C');
838 2 : write_body(buf, |buf| write_cstr(cmd, buf))?;
839 : }
840 :
841 0 : BeMessage::CopyData(data) => {
842 0 : buf.put_u8(b'd');
843 0 : write_body(buf, |buf| {
844 0 : buf.put_slice(data);
845 0 : });
846 0 : }
847 :
848 0 : BeMessage::CopyDone => {
849 0 : buf.put_u8(b'c');
850 0 : write_body(buf, |_| {});
851 0 : }
852 :
853 0 : BeMessage::CopyFail => {
854 0 : buf.put_u8(b'f');
855 0 : write_body(buf, |_| {});
856 0 : }
857 :
858 0 : BeMessage::CopyInResponse => {
859 0 : buf.put_u8(b'G');
860 0 : write_body(buf, |buf| {
861 0 : buf.put_u8(1); // copy_is_binary
862 0 : buf.put_i16(0); // numAttributes
863 0 : });
864 0 : }
865 :
866 0 : BeMessage::CopyOutResponse => {
867 0 : buf.put_u8(b'H');
868 0 : write_body(buf, |buf| {
869 0 : buf.put_u8(0); // copy_is_binary
870 0 : buf.put_i16(0); // numAttributes
871 0 : });
872 0 : }
873 :
874 0 : BeMessage::CopyBothResponse => {
875 0 : buf.put_u8(b'W');
876 0 : write_body(buf, |buf| {
877 0 : // doesn't matter, used only for replication
878 0 : buf.put_u8(0); // copy_is_binary
879 0 : buf.put_i16(0); // numAttributes
880 0 : });
881 0 : }
882 :
883 2 : BeMessage::DataRow(vals) => {
884 2 : buf.put_u8(b'D');
885 2 : write_body(buf, |buf| {
886 2 : buf.put_u16(vals.len() as u16); // num of cols
887 2 : for val_opt in vals.iter() {
888 2 : if let Some(val) = val_opt {
889 2 : buf.put_u32(val.len() as u32);
890 2 : buf.put_slice(val);
891 2 : } else {
892 0 : buf.put_i32(-1);
893 0 : }
894 : }
895 2 : });
896 2 : }
897 :
898 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
899 : // First byte of each field represents type of this field. Set just enough fields
900 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
901 : // message text.
902 1 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
903 1 : // 'E' signalizes ErrorResponse messages
904 1 : buf.put_u8(b'E');
905 1 : write_body(buf, |buf| {
906 1 : buf.put_u8(b'S'); // severity
907 1 : buf.put_slice(b"ERROR\0");
908 1 :
909 1 : buf.put_u8(b'C'); // SQLSTATE error code
910 1 : buf.put_slice(&terminate_code(
911 1 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
912 1 : ));
913 1 :
914 1 : buf.put_u8(b'M'); // the message
915 1 : write_cstr(error_msg, buf)?;
916 :
917 1 : buf.put_u8(0); // terminator
918 1 : Ok(())
919 1 : })?;
920 : }
921 :
922 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
923 : // message but continue listening for ReadyForQuery or ErrorResponse"
924 0 : BeMessage::NoticeResponse(error_msg) => {
925 0 : // For all the errors set Severity to Error and error code to
926 0 : // 'internal error'.
927 0 :
928 0 : // 'N' signalizes NoticeResponse messages
929 0 : buf.put_u8(b'N');
930 0 : write_body(buf, |buf| {
931 0 : buf.put_u8(b'S'); // severity
932 0 : buf.put_slice(b"NOTICE\0");
933 0 :
934 0 : buf.put_u8(b'C'); // SQLSTATE error code
935 0 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
936 0 :
937 0 : buf.put_u8(b'M'); // the message
938 0 : write_cstr(error_msg.as_bytes(), buf)?;
939 :
940 0 : buf.put_u8(0); // terminator
941 0 : Ok(())
942 0 : })?;
943 : }
944 :
945 0 : BeMessage::NoData => {
946 0 : buf.put_u8(b'n');
947 0 : write_body(buf, |_| {});
948 0 : }
949 :
950 21 : BeMessage::EncryptionResponse(should_negotiate) => {
951 21 : let response = if *should_negotiate { b'S' } else { b'N' };
952 21 : buf.put_u8(response);
953 : }
954 :
955 13 : BeMessage::ParameterStatus { name, value } => {
956 13 : buf.put_u8(b'S');
957 13 : write_body(buf, |buf| {
958 13 : write_cstr(name, buf)?;
959 13 : write_cstr(value, buf)
960 13 : })?;
961 : }
962 :
963 0 : BeMessage::ParameterDescription => {
964 0 : buf.put_u8(b't');
965 0 : write_body(buf, |buf| {
966 0 : // we don't support params, so always 0
967 0 : buf.put_i16(0);
968 0 : });
969 0 : }
970 :
971 0 : BeMessage::ParseComplete => {
972 0 : buf.put_u8(b'1');
973 0 : write_body(buf, |_| {});
974 0 : }
975 :
976 11 : BeMessage::ReadyForQuery => {
977 11 : buf.put_u8(b'Z');
978 11 : write_body(buf, |buf| {
979 11 : buf.put_u8(b'I');
980 11 : });
981 11 : }
982 :
983 2 : BeMessage::RowDescription(rows) => {
984 2 : buf.put_u8(b'T');
985 2 : write_body(buf, |buf| {
986 2 : buf.put_i16(rows.len() as i16); // # of fields
987 2 : for row in rows.iter() {
988 2 : write_cstr(row.name, buf)?;
989 2 : buf.put_i32(0); /* table oid */
990 2 : buf.put_i16(0); /* attnum */
991 2 : buf.put_u32(row.typoid);
992 2 : buf.put_i16(row.typlen);
993 2 : buf.put_i32(-1); /* typmod */
994 2 : buf.put_i16(0); /* format code */
995 : }
996 2 : Ok(())
997 2 : })?;
998 : }
999 :
1000 0 : BeMessage::XLogData(body) => {
1001 0 : buf.put_u8(b'd');
1002 0 : write_body(buf, |buf| {
1003 0 : buf.put_u8(b'w');
1004 0 : buf.put_u64(body.wal_start);
1005 0 : buf.put_u64(body.wal_end);
1006 0 : buf.put_i64(body.timestamp);
1007 0 : buf.put_slice(body.data);
1008 0 : });
1009 0 : }
1010 :
1011 0 : BeMessage::KeepAlive(req) => {
1012 0 : buf.put_u8(b'd');
1013 0 : write_body(buf, |buf| {
1014 0 : buf.put_u8(b'k');
1015 0 : buf.put_u64(req.wal_end);
1016 0 : buf.put_i64(req.timestamp);
1017 0 : buf.put_u8(u8::from(req.request_reply));
1018 0 : });
1019 0 : }
1020 :
1021 0 : BeMessage::NegotiateProtocolVersion { version, options } => {
1022 0 : buf.put_u8(b'v');
1023 0 : write_body(buf, |buf| {
1024 0 : buf.put_u32(version.0);
1025 0 : buf.put_u32(options.len() as u32);
1026 0 : for option in options.iter() {
1027 0 : write_cstr(option, buf)?;
1028 : }
1029 0 : Ok(())
1030 0 : })?
1031 : }
1032 :
1033 0 : BeMessage::InterpretedWalRecords(rec) => {
1034 0 : // We use the COPY_DATA_TAG for our custom message
1035 0 : // since this tag is interpreted as raw bytes.
1036 0 : buf.put_u8(b'd');
1037 0 : write_body(buf, |buf| {
1038 0 : buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
1039 0 : // dependency
1040 0 : buf.put_u64(rec.streaming_lsn);
1041 0 : buf.put_u64(rec.commit_lsn);
1042 0 : buf.put_slice(rec.data);
1043 0 : });
1044 0 : }
1045 : }
1046 96 : Ok(())
1047 96 : }
1048 : }
1049 :
1050 1 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
1051 1 : let mut terminated = [0; 6];
1052 5 : for (i, &elem) in code.iter().enumerate() {
1053 5 : terminated[i] = elem;
1054 5 : }
1055 :
1056 1 : terminated
1057 1 : }
1058 :
1059 : #[cfg(test)]
1060 : mod tests {
1061 : use super::*;
1062 :
1063 : #[test]
1064 1 : fn test_startup_message_params_options_escaped() {
1065 4 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
1066 4 : params
1067 4 : .options_escaped()
1068 4 : .expect("options are None")
1069 4 : .collect()
1070 4 : }
1071 :
1072 4 : let make_params = |options| StartupMessageParams::new([("options", options)]);
1073 :
1074 1 : let params = StartupMessageParams::new([]);
1075 1 : assert!(params.options_escaped().is_none());
1076 :
1077 1 : let params = make_params("");
1078 1 : assert!(split_options(¶ms).is_empty());
1079 :
1080 1 : let params = make_params("foo");
1081 1 : assert_eq!(split_options(¶ms), ["foo"]);
1082 :
1083 1 : let params = make_params(" foo bar ");
1084 1 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
1085 :
1086 1 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
1087 1 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
1088 1 : }
1089 :
1090 : #[test]
1091 1 : fn parse_fe_startup_packet_regression() {
1092 1 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
1093 1 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
1094 1 : }
1095 :
1096 : #[test]
1097 1 : fn cancel_key_data() {
1098 1 : let key = CancelKeyData {
1099 1 : backend_pid: -1817212860,
1100 1 : cancel_key: -1183897012,
1101 1 : };
1102 1 : assert_eq!(format!("{key}"), "CancelKeyData(93af8844b96f2a4c)");
1103 1 : }
1104 : }
|