TLA Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 : #![deny(clippy::undocumented_unsafe_blocks)]
5 :
6 : pub mod framed;
7 :
8 : use byteorder::{BigEndian, ReadBytesExt};
9 : use bytes::{Buf, BufMut, Bytes, BytesMut};
10 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
11 :
12 : // re-export for use in utils pageserver_feedback.rs
13 : pub use postgres_protocol::PG_EPOCH;
14 :
15 : pub type Oid = u32;
16 : pub type SystemId = u64;
17 :
18 : pub const INT8_OID: Oid = 20;
19 : pub const INT4_OID: Oid = 23;
20 : pub const TEXT_OID: Oid = 25;
21 :
22 CBC 50 : #[derive(Debug)]
23 : pub enum FeMessage {
24 : // Simple query.
25 : Query(Bytes),
26 : // Extended query protocol.
27 : Parse(FeParseMessage),
28 : Describe(FeDescribeMessage),
29 : Bind(FeBindMessage),
30 : Execute(FeExecuteMessage),
31 : Close(FeCloseMessage),
32 : Sync,
33 : Terminate,
34 : CopyData(Bytes),
35 : CopyDone,
36 : CopyFail,
37 : PasswordMessage(Bytes),
38 : }
39 :
40 218 : #[derive(Debug)]
41 : pub enum FeStartupPacket {
42 : CancelRequest(CancelKeyData),
43 : SslRequest,
44 : GssEncRequest,
45 : StartupMessage {
46 : major_version: u32,
47 : minor_version: u32,
48 : params: StartupMessageParams,
49 : },
50 : }
51 :
52 120 : #[derive(Debug)]
53 : pub struct StartupMessageParams {
54 : params: HashMap<String, String>,
55 : }
56 :
57 : impl StartupMessageParams {
58 : /// Get parameter's value by its name.
59 6957 : pub fn get(&self, name: &str) -> Option<&str> {
60 6957 : self.params.get(name).map(|s| s.as_str())
61 6957 : }
62 :
63 : /// Split command-line options according to PostgreSQL's logic,
64 : /// taking into account all escape sequences but leaving them as-is.
65 : /// [`None`] means that there's no `options` in [`Self`].
66 3455 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
67 3455 : self.get("options").map(Self::parse_options_raw)
68 3455 : }
69 :
70 : /// Split command-line options according to PostgreSQL's logic,
71 : /// applying all escape sequences (using owned strings as needed).
72 : /// [`None`] means that there's no `options` in [`Self`].
73 5 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
74 5 : self.get("options").map(Self::parse_options_escaped)
75 5 : }
76 :
77 : /// Split command-line options according to PostgreSQL's logic,
78 : /// taking into account all escape sequences but leaving them as-is.
79 3448 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
80 3448 : // See `postgres: pg_split_opts`.
81 3448 : let mut last_was_escape = false;
82 3448 : input
83 288295 : .split(move |c: char| {
84 : // We split by non-escaped whitespace symbols.
85 288295 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
86 288295 : last_was_escape = c == '\\' && !last_was_escape;
87 288295 : should_split
88 288295 : })
89 9948 : .filter(|s| !s.is_empty())
90 3448 : }
91 :
92 : /// Split command-line options according to PostgreSQL's logic,
93 : /// applying all escape sequences (using owned strings as needed).
94 4 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
95 4 : // See `postgres: pg_split_opts`.
96 7 : Self::parse_options_raw(input).map(|s| {
97 7 : let mut preserve_next_escape = false;
98 17 : let escape = |c| {
99 : // We should remove '\\' unless it's preceded by '\\'.
100 17 : let should_remove = c == '\\' && !preserve_next_escape;
101 17 : preserve_next_escape = should_remove;
102 17 : should_remove
103 17 : };
104 :
105 7 : match s.contains('\\') {
106 3 : true => Cow::Owned(s.replace(escape, "")),
107 4 : false => Cow::Borrowed(s),
108 : }
109 7 : })
110 4 : }
111 :
112 : /// Iterate through key-value pairs in an arbitrary order.
113 7 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
114 21 : self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
115 7 : }
116 :
117 : // This function is mostly useful in tests.
118 : #[doc(hidden)]
119 64 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
120 64 : Self {
121 195 : params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
122 64 : }
123 64 : }
124 : }
125 :
126 328 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
127 : pub struct CancelKeyData {
128 : pub backend_pid: i32,
129 : pub cancel_key: i32,
130 : }
131 :
132 : impl fmt::Display for CancelKeyData {
133 196 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 196 : let hi = (self.backend_pid as u64) << 32;
135 196 : let lo = self.cancel_key as u64;
136 196 : let id = hi | lo;
137 196 :
138 196 : // This format is more compact and might work better for logs.
139 196 : f.debug_tuple("CancelKeyData")
140 196 : .field(&format_args!("{:x}", id))
141 196 : .finish()
142 196 : }
143 : }
144 :
145 : use rand::distributions::{Distribution, Standard};
146 : impl Distribution<CancelKeyData> for Standard {
147 50 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
148 50 : CancelKeyData {
149 50 : backend_pid: rng.gen(),
150 50 : cancel_key: rng.gen(),
151 50 : }
152 50 : }
153 : }
154 :
155 : // We only support the simple case of Parse on unnamed prepared statement and
156 : // no params
157 UBC 0 : #[derive(Debug)]
158 : pub struct FeParseMessage {
159 : pub query_string: Bytes,
160 : }
161 :
162 0 : #[derive(Debug)]
163 : pub struct FeDescribeMessage {
164 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
165 : // we only support unnamed prepared stmt or portal
166 : }
167 :
168 : // we only support unnamed prepared stmt and portal
169 0 : #[derive(Debug)]
170 : pub struct FeBindMessage;
171 :
172 : // we only support unnamed prepared stmt or portal
173 0 : #[derive(Debug)]
174 : pub struct FeExecuteMessage {
175 : /// max # of rows
176 : pub maxrows: i32,
177 : }
178 :
179 : // we only support unnamed prepared stmt and portal
180 0 : #[derive(Debug)]
181 : pub struct FeCloseMessage;
182 :
183 : /// An error occurred while parsing or serializing raw stream into Postgres
184 : /// messages.
185 CBC 5 : #[derive(thiserror::Error, Debug)]
186 : pub enum ProtocolError {
187 : /// Invalid packet was received from the client (e.g. unexpected message
188 : /// type or broken len).
189 : #[error("Protocol error: {0}")]
190 : Protocol(String),
191 : /// Failed to parse or, (unlikely), serialize a protocol message.
192 : #[error("Message parse error: {0}")]
193 : BadMessage(String),
194 : }
195 :
196 : impl ProtocolError {
197 : /// Proxy stream.rs uses only io::Error; provide it.
198 UBC 0 : pub fn into_io_error(self) -> io::Error {
199 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
200 0 : }
201 : }
202 :
203 : impl FeMessage {
204 : /// Read and parse one message from the `buf` input buffer. If there is at
205 : /// least one valid message, returns it, advancing `buf`; redundant copies
206 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
207 : /// directly into the `buf` (processed data is garbage collected after
208 : /// parsed message is dropped).
209 : ///
210 : /// Returns None if `buf` doesn't contain enough data for a single message.
211 : /// For efficiency, tries to reserve large enough space in `buf` for the
212 : /// next message in this case to save the repeated calls.
213 : ///
214 : /// Returns Error if message is malformed, the only possible ErrorKind is
215 : /// InvalidInput.
216 : //
217 : // Inspired by rust-postgres Message::parse.
218 CBC 11812587 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
219 11812587 : // Every message contains message type byte and 4 bytes len; can't do
220 11812587 : // much without them.
221 11812587 : if buf.len() < 5 {
222 5371100 : let to_read = 5 - buf.len();
223 5371100 : buf.reserve(to_read);
224 5371100 : return Ok(None);
225 6441487 : }
226 6441487 :
227 6441487 : // We shouldn't advance `buf` as probably full message is not there yet,
228 6441487 : // so can't directly use Bytes::get_u32 etc.
229 6441487 : let tag = buf[0];
230 6441487 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
231 6441487 : if len < 4 {
232 UBC 0 : return Err(ProtocolError::Protocol(format!(
233 0 : "invalid message length {}",
234 0 : len
235 0 : )));
236 CBC 6441487 : }
237 6441487 :
238 6441487 : // length field includes itself, but not message type.
239 6441487 : let total_len = len as usize + 1;
240 6441487 : if buf.len() < total_len {
241 : // Don't have full message yet.
242 149702 : let to_read = total_len - buf.len();
243 149702 : buf.reserve(to_read);
244 149702 : return Ok(None);
245 6291785 : }
246 6291785 :
247 6291785 : // got the message, advance buffer
248 6291785 : let mut msg = buf.split_to(total_len).freeze();
249 6291785 : msg.advance(5); // consume message type and len
250 6291785 :
251 6291785 : match tag {
252 11393 : b'Q' => Ok(Some(FeMessage::Query(msg))),
253 552 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
254 552 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
255 552 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
256 552 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
257 551 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
258 1669 : b'S' => Ok(Some(FeMessage::Sync)),
259 3391 : b'X' => Ok(Some(FeMessage::Terminate)),
260 6272188 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
261 12 : b'c' => Ok(Some(FeMessage::CopyDone)),
262 67 : b'f' => Ok(Some(FeMessage::CopyFail)),
263 306 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
264 UBC 0 : tag => Err(ProtocolError::Protocol(format!(
265 0 : "unknown message tag: {tag},'{msg:?}'"
266 0 : ))),
267 : }
268 CBC 11812587 : }
269 : }
270 :
271 : impl FeStartupPacket {
272 : /// Read and parse startup message from the `buf` input buffer. It is
273 : /// different from [`FeMessage::parse`] because startup messages don't have
274 : /// message type byte; otherwise, its comments apply.
275 41362 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
276 41362 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
277 41362 : const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
278 41362 : const CANCEL_REQUEST_CODE: u32 = 5678;
279 41362 : const NEGOTIATE_SSL_CODE: u32 = 5679;
280 41362 : const NEGOTIATE_GSS_CODE: u32 = 5680;
281 41362 :
282 41362 : // need at least 4 bytes with packet len
283 41362 : if buf.len() < 4 {
284 20681 : let to_read = 4 - buf.len();
285 20681 : buf.reserve(to_read);
286 20681 : return Ok(None);
287 20681 : }
288 20681 :
289 20681 : // We shouldn't advance `buf` as probably full message is not there yet,
290 20681 : // so can't directly use Bytes::get_u32 etc.
291 20681 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
292 20681 : // The proposed replacement is `!(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
293 20681 : // which is less readable
294 20681 : #[allow(clippy::manual_range_contains)]
295 20681 : if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
296 1 : return Err(ProtocolError::Protocol(format!(
297 1 : "invalid startup packet message length {}",
298 1 : len
299 1 : )));
300 20680 : }
301 20680 :
302 20680 : if buf.len() < len {
303 : // Don't have full message yet.
304 UBC 0 : let to_read = len - buf.len();
305 0 : buf.reserve(to_read);
306 0 : return Ok(None);
307 CBC 20680 : }
308 20680 :
309 20680 : // got the message, advance buffer
310 20680 : let mut msg = buf.split_to(len).freeze();
311 20680 : msg.advance(4); // consume len
312 20680 :
313 20680 : let request_code = msg.get_u32();
314 20680 : let req_hi = request_code >> 16;
315 20680 : let req_lo = request_code & ((1 << 16) - 1);
316 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
317 20680 : let message = match (req_hi, req_lo) {
318 : (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
319 UBC 0 : if msg.remaining() != 8 {
320 0 : return Err(ProtocolError::BadMessage(
321 0 : "CancelRequest message is malformed, backend PID / secret key missing"
322 0 : .to_owned(),
323 0 : ));
324 0 : }
325 0 : FeStartupPacket::CancelRequest(CancelKeyData {
326 0 : backend_pid: msg.get_i32(),
327 0 : cancel_key: msg.get_i32(),
328 0 : })
329 : }
330 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
331 : // Requested upgrade to SSL (aka TLS)
332 CBC 9389 : FeStartupPacket::SslRequest
333 : }
334 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
335 : // Requested upgrade to GSSAPI
336 UBC 0 : FeStartupPacket::GssEncRequest
337 : }
338 0 : (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
339 0 : return Err(ProtocolError::Protocol(format!(
340 0 : "Unrecognized request code {unrecognized_code}"
341 0 : )));
342 : }
343 : // TODO bail if protocol major_version is not 3?
344 CBC 11291 : (major_version, minor_version) => {
345 : // StartupMessage
346 :
347 : // Parse pairs of null-terminated strings (key, value).
348 : // See `postgres: ProcessStartupPacket, build_startup_packet`.
349 11291 : let mut tokens = str::from_utf8(&msg)
350 11291 : .map_err(|_e| {
351 UBC 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
352 CBC 11291 : })?
353 11291 : .strip_suffix('\0') // drop packet's own null
354 11291 : .ok_or_else(|| {
355 UBC 0 : ProtocolError::Protocol(
356 0 : "StartupMessage params: missing null terminator".to_string(),
357 0 : )
358 CBC 11291 : })?
359 11291 : .split_terminator('\0');
360 11291 :
361 11291 : let mut params = HashMap::new();
362 37868 : while let Some(name) = tokens.next() {
363 26577 : let value = tokens.next().ok_or_else(|| {
364 UBC 0 : ProtocolError::Protocol(
365 0 : "StartupMessage params: key without value".to_string(),
366 0 : )
367 CBC 26577 : })?;
368 :
369 26577 : params.insert(name.to_owned(), value.to_owned());
370 : }
371 :
372 11291 : FeStartupPacket::StartupMessage {
373 11291 : major_version,
374 11291 : minor_version,
375 11291 : params: StartupMessageParams { params },
376 11291 : }
377 : }
378 : };
379 20680 : Ok(Some(message))
380 41362 : }
381 : }
382 :
383 : impl FeParseMessage {
384 552 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
385 : // FIXME: the rust-postgres driver uses a named prepared statement
386 : // for copy_out(). We're not prepared to handle that correctly. For
387 : // now, just ignore the statement name, assuming that the client never
388 : // uses more than one prepared statement at a time.
389 :
390 552 : let _pstmt_name = read_cstr(&mut buf)?;
391 552 : let query_string = read_cstr(&mut buf)?;
392 552 : if buf.remaining() < 2 {
393 UBC 0 : return Err(ProtocolError::BadMessage(
394 0 : "Parse message is malformed, nparams missing".to_string(),
395 0 : ));
396 CBC 552 : }
397 552 : let nparams = buf.get_i16();
398 552 :
399 552 : if nparams != 0 {
400 UBC 0 : return Err(ProtocolError::BadMessage(
401 0 : "query params not implemented".to_string(),
402 0 : ));
403 CBC 552 : }
404 552 :
405 552 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
406 552 : }
407 : }
408 :
409 : impl FeDescribeMessage {
410 552 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
411 552 : let kind = buf.get_u8();
412 552 : let _pstmt_name = read_cstr(&mut buf)?;
413 :
414 : // FIXME: see FeParseMessage::parse
415 552 : if kind != b'S' {
416 UBC 0 : return Err(ProtocolError::BadMessage(
417 0 : "only prepared statemement Describe is implemented".to_string(),
418 0 : ));
419 CBC 552 : }
420 552 :
421 552 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
422 552 : }
423 : }
424 :
425 : impl FeExecuteMessage {
426 552 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
427 552 : let portal_name = read_cstr(&mut buf)?;
428 552 : if buf.remaining() < 4 {
429 UBC 0 : return Err(ProtocolError::BadMessage(
430 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
431 0 : ));
432 CBC 552 : }
433 552 : let maxrows = buf.get_i32();
434 552 :
435 552 : if !portal_name.is_empty() {
436 UBC 0 : return Err(ProtocolError::BadMessage(
437 0 : "named portals not implemented".to_string(),
438 0 : ));
439 CBC 552 : }
440 552 : if maxrows != 0 {
441 UBC 0 : return Err(ProtocolError::BadMessage(
442 0 : "row limit in Execute message not implemented".to_string(),
443 0 : ));
444 CBC 552 : }
445 552 :
446 552 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
447 552 : }
448 : }
449 :
450 : impl FeBindMessage {
451 552 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
452 552 : let portal_name = read_cstr(&mut buf)?;
453 552 : let _pstmt_name = read_cstr(&mut buf)?;
454 :
455 : // FIXME: see FeParseMessage::parse
456 552 : if !portal_name.is_empty() {
457 UBC 0 : return Err(ProtocolError::BadMessage(
458 0 : "named portals not implemented".to_string(),
459 0 : ));
460 CBC 552 : }
461 552 :
462 552 : Ok(FeMessage::Bind(FeBindMessage))
463 552 : }
464 : }
465 :
466 : impl FeCloseMessage {
467 551 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
468 551 : let _kind = buf.get_u8();
469 551 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
470 :
471 : // FIXME: we do nothing with Close
472 551 : Ok(FeMessage::Close(FeCloseMessage))
473 551 : }
474 : }
475 :
476 : // Backend
477 :
478 UBC 0 : #[derive(Debug)]
479 : pub enum BeMessage<'a> {
480 : AuthenticationOk,
481 : AuthenticationMD5Password([u8; 4]),
482 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
483 : AuthenticationCleartextPassword,
484 : BackendKeyData(CancelKeyData),
485 : BindComplete,
486 : CommandComplete(&'a [u8]),
487 : CopyData(&'a [u8]),
488 : CopyDone,
489 : CopyFail,
490 : CopyInResponse,
491 : CopyOutResponse,
492 : CopyBothResponse,
493 : CloseComplete,
494 : // None means column is NULL
495 : DataRow(&'a [Option<&'a [u8]>]),
496 : // None errcode means internal_error will be sent.
497 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
498 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
499 : EncryptionResponse(bool),
500 : NoData,
501 : ParameterDescription,
502 : ParameterStatus {
503 : name: &'a [u8],
504 : value: &'a [u8],
505 : },
506 : ParseComplete,
507 : ReadyForQuery,
508 : RowDescription(&'a [RowDescriptor<'a>]),
509 : XLogData(XLogDataBody<'a>),
510 : NoticeResponse(&'a str),
511 : KeepAlive(WalSndKeepAlive),
512 : }
513 :
514 : /// Common shorthands.
515 : impl<'a> BeMessage<'a> {
516 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
517 : /// This is a sensible default, given that:
518 : /// * rust strings only support this encoding out of the box.
519 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
520 : ///
521 : /// TODO: do we need to report `server_encoding` as well?
522 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
523 : name: b"client_encoding",
524 : value: b"UTF8",
525 : };
526 :
527 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
528 : name: b"integer_datetimes",
529 : value: b"on",
530 : };
531 :
532 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
533 CBC 10994 : pub fn server_version(version: &'a str) -> Self {
534 10994 : Self::ParameterStatus {
535 10994 : name: b"server_version",
536 10994 : value: version.as_bytes(),
537 10994 : }
538 10994 : }
539 : }
540 :
541 UBC 0 : #[derive(Debug)]
542 : pub enum BeAuthenticationSaslMessage<'a> {
543 : Methods(&'a [&'a str]),
544 : Continue(&'a [u8]),
545 : Final(&'a [u8]),
546 : }
547 :
548 0 : #[derive(Debug)]
549 : pub enum BeParameterStatusMessage<'a> {
550 : Encoding(&'a str),
551 : ServerVersion(&'a str),
552 : }
553 :
554 : // One row description in RowDescription packet.
555 0 : #[derive(Debug)]
556 : pub struct RowDescriptor<'a> {
557 : pub name: &'a [u8],
558 : pub tableoid: Oid,
559 : pub attnum: i16,
560 : pub typoid: Oid,
561 : pub typlen: i16,
562 : pub typmod: i32,
563 : pub formatcode: i16,
564 : }
565 :
566 : impl Default for RowDescriptor<'_> {
567 CBC 2899 : fn default() -> RowDescriptor<'static> {
568 2899 : RowDescriptor {
569 2899 : name: b"",
570 2899 : tableoid: 0,
571 2899 : attnum: 0,
572 2899 : typoid: 0,
573 2899 : typlen: 0,
574 2899 : typmod: 0,
575 2899 : formatcode: 0,
576 2899 : }
577 2899 : }
578 : }
579 :
580 : impl RowDescriptor<'_> {
581 : /// Convenience function to create a RowDescriptor message for an int8 column
582 45 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
583 45 : RowDescriptor {
584 45 : name,
585 45 : tableoid: 0,
586 45 : attnum: 0,
587 45 : typoid: INT8_OID,
588 45 : typlen: 8,
589 45 : typmod: 0,
590 45 : formatcode: 0,
591 45 : }
592 45 : }
593 :
594 1268 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
595 1268 : RowDescriptor {
596 1268 : name,
597 1268 : tableoid: 0,
598 1268 : attnum: 0,
599 1268 : typoid: TEXT_OID,
600 1268 : typlen: -1,
601 1268 : typmod: 0,
602 1268 : formatcode: 0,
603 1268 : }
604 1268 : }
605 : }
606 :
607 UBC 0 : #[derive(Debug)]
608 : pub struct XLogDataBody<'a> {
609 : pub wal_start: u64,
610 : pub wal_end: u64, // current end of WAL on the server
611 : pub timestamp: i64,
612 : pub data: &'a [u8],
613 : }
614 :
615 0 : #[derive(Debug)]
616 : pub struct WalSndKeepAlive {
617 : pub wal_end: u64, // current end of WAL on the server
618 : pub timestamp: i64,
619 : pub request_reply: bool,
620 : }
621 :
622 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
623 :
624 : // single text column
625 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
626 : name: b"data",
627 : tableoid: 0,
628 : attnum: 0,
629 : typoid: TEXT_OID,
630 : typlen: -1,
631 : typmod: 0,
632 : formatcode: 0,
633 : }]);
634 :
635 : /// Call f() to write body of the message and prepend it with 4-byte len as
636 : /// prescribed by the protocol.
637 CBC 6499104 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
638 6499104 : let base = buf.len();
639 6499104 : buf.extend_from_slice(&[0; 4]);
640 6499104 :
641 6499104 : let res = f(buf);
642 6499104 :
643 6499104 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
644 6499104 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
645 6499104 :
646 6499104 : res
647 6499104 : }
648 :
649 : /// Safe write of s into buf as cstring (String in the protocol).
650 74334 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
651 74334 : let bytes = s.as_ref();
652 74334 : if bytes.contains(&0) {
653 UBC 0 : return Err(ProtocolError::BadMessage(
654 0 : "string contains embedded null".to_owned(),
655 0 : ));
656 CBC 74334 : }
657 74334 : buf.put_slice(bytes);
658 74334 : buf.put_u8(0);
659 74334 : Ok(())
660 74334 : }
661 :
662 : /// Read cstring from buf, advancing it.
663 2832929 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
664 2832929 : let pos = buf
665 2832929 : .iter()
666 40226045 : .position(|x| *x == 0)
667 2832929 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
668 2832929 : let result = buf.split_to(pos);
669 2832929 : buf.advance(1); // drop the null terminator
670 2832929 : Ok(result)
671 2832929 : }
672 :
673 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
674 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
675 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
676 :
677 : impl<'a> BeMessage<'a> {
678 : /// Serialize `message` to the given `buf`.
679 : /// Apart from smart memory managemet, BytesMut is good here as msg len
680 : /// precedes its body and it is handy to write it down first and then fill
681 : /// the length. With Write we would have to either calc it manually or have
682 : /// one more buffer.
683 6508493 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
684 6508493 : match message {
685 11248 : BeMessage::AuthenticationOk => {
686 11248 : buf.put_u8(b'R');
687 11248 : write_body(buf, |buf| {
688 11248 : buf.put_i32(0); // Specifies that the authentication was successful.
689 11248 : });
690 11248 : }
691 :
692 218 : BeMessage::AuthenticationCleartextPassword => {
693 218 : buf.put_u8(b'R');
694 218 : write_body(buf, |buf| {
695 218 : buf.put_i32(3); // Specifies that clear text password is required.
696 218 : });
697 218 : }
698 :
699 UBC 0 : BeMessage::AuthenticationMD5Password(salt) => {
700 0 : buf.put_u8(b'R');
701 0 : write_body(buf, |buf| {
702 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
703 0 : buf.put_slice(&salt[..]);
704 0 : });
705 0 : }
706 :
707 CBC 132 : BeMessage::AuthenticationSasl(msg) => {
708 132 : buf.put_u8(b'R');
709 132 : write_body(buf, |buf| {
710 132 : use BeAuthenticationSaslMessage::*;
711 132 : match msg {
712 48 : Methods(methods) => {
713 48 : buf.put_i32(10); // Specifies that SASL auth method is used.
714 96 : for method in methods.iter() {
715 96 : write_cstr(method, buf)?;
716 : }
717 48 : buf.put_u8(0); // zero terminator for the list
718 : }
719 46 : Continue(extra) => {
720 46 : buf.put_i32(11); // Continue SASL auth.
721 46 : buf.put_slice(extra);
722 46 : }
723 38 : Final(extra) => {
724 38 : buf.put_i32(12); // Send final SASL message.
725 38 : buf.put_slice(extra);
726 38 : }
727 : }
728 132 : Ok(())
729 132 : })?;
730 : }
731 :
732 38 : BeMessage::BackendKeyData(key_data) => {
733 38 : buf.put_u8(b'K');
734 38 : write_body(buf, |buf| {
735 38 : buf.put_i32(key_data.backend_pid);
736 38 : buf.put_i32(key_data.cancel_key);
737 38 : });
738 38 : }
739 :
740 552 : BeMessage::BindComplete => {
741 552 : buf.put_u8(b'2');
742 552 : write_body(buf, |_| {});
743 552 : }
744 :
745 551 : BeMessage::CloseComplete => {
746 551 : buf.put_u8(b'3');
747 551 : write_body(buf, |_| {});
748 551 : }
749 :
750 1969 : BeMessage::CommandComplete(cmd) => {
751 1969 : buf.put_u8(b'C');
752 1969 : write_body(buf, |buf| write_cstr(cmd, buf))?;
753 : }
754 :
755 5840052 : BeMessage::CopyData(data) => {
756 5840052 : buf.put_u8(b'd');
757 5840052 : write_body(buf, |buf| {
758 5840052 : buf.put_slice(data);
759 5840052 : });
760 5840052 : }
761 :
762 557 : BeMessage::CopyDone => {
763 557 : buf.put_u8(b'c');
764 557 : write_body(buf, |_| {});
765 557 : }
766 :
767 UBC 0 : BeMessage::CopyFail => {
768 0 : buf.put_u8(b'f');
769 0 : write_body(buf, |_| {});
770 0 : }
771 :
772 CBC 13 : BeMessage::CopyInResponse => {
773 13 : buf.put_u8(b'G');
774 13 : write_body(buf, |buf| {
775 13 : buf.put_u8(1); // copy_is_binary
776 13 : buf.put_i16(0); // numAttributes
777 13 : });
778 13 : }
779 :
780 559 : BeMessage::CopyOutResponse => {
781 559 : buf.put_u8(b'H');
782 559 : write_body(buf, |buf| {
783 559 : buf.put_u8(0); // copy_is_binary
784 559 : buf.put_i16(0); // numAttributes
785 559 : });
786 559 : }
787 :
788 9440 : BeMessage::CopyBothResponse => {
789 9440 : buf.put_u8(b'W');
790 9440 : write_body(buf, |buf| {
791 9440 : // doesn't matter, used only for replication
792 9440 : buf.put_u8(0); // copy_is_binary
793 9440 : buf.put_i16(0); // numAttributes
794 9440 : });
795 9440 : }
796 :
797 927 : BeMessage::DataRow(vals) => {
798 927 : buf.put_u8(b'D');
799 927 : write_body(buf, |buf| {
800 927 : buf.put_u16(vals.len() as u16); // num of cols
801 3329 : for val_opt in vals.iter() {
802 3329 : if let Some(val) = val_opt {
803 2605 : buf.put_u32(val.len() as u32);
804 2605 : buf.put_slice(val);
805 2605 : } else {
806 724 : buf.put_i32(-1);
807 724 : }
808 : }
809 927 : });
810 927 : }
811 :
812 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
813 : // First byte of each field represents type of this field. Set just enough fields
814 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
815 : // message text.
816 658 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
817 658 : // 'E' signalizes ErrorResponse messages
818 658 : buf.put_u8(b'E');
819 658 : write_body(buf, |buf| {
820 658 : buf.put_u8(b'S'); // severity
821 658 : buf.put_slice(b"ERROR\0");
822 658 :
823 658 : buf.put_u8(b'C'); // SQLSTATE error code
824 658 : buf.put_slice(&terminate_code(
825 658 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
826 658 : ));
827 658 :
828 658 : buf.put_u8(b'M'); // the message
829 658 : write_cstr(error_msg, buf)?;
830 :
831 658 : buf.put_u8(0); // terminator
832 658 : Ok(())
833 658 : })?;
834 : }
835 :
836 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
837 : // message but continue listening for ReadyForQuery or ErrorResponse"
838 6 : BeMessage::NoticeResponse(error_msg) => {
839 6 : // For all the errors set Severity to Error and error code to
840 6 : // 'internal error'.
841 6 :
842 6 : // 'N' signalizes NoticeResponse messages
843 6 : buf.put_u8(b'N');
844 6 : write_body(buf, |buf| {
845 6 : buf.put_u8(b'S'); // severity
846 6 : buf.put_slice(b"NOTICE\0");
847 6 :
848 6 : buf.put_u8(b'C'); // SQLSTATE error code
849 6 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
850 6 :
851 6 : buf.put_u8(b'M'); // the message
852 6 : write_cstr(error_msg.as_bytes(), buf)?;
853 :
854 6 : buf.put_u8(0); // terminator
855 6 : Ok(())
856 6 : })?;
857 : }
858 :
859 552 : BeMessage::NoData => {
860 552 : buf.put_u8(b'n');
861 552 : write_body(buf, |_| {});
862 552 : }
863 :
864 9389 : BeMessage::EncryptionResponse(should_negotiate) => {
865 9389 : let response = if *should_negotiate { b'S' } else { b'N' };
866 9389 : buf.put_u8(response);
867 : }
868 :
869 33695 : BeMessage::ParameterStatus { name, value } => {
870 33695 : buf.put_u8(b'S');
871 33695 : write_body(buf, |buf| {
872 33695 : write_cstr(name, buf)?;
873 33695 : write_cstr(value, buf)
874 33695 : })?;
875 : }
876 :
877 552 : BeMessage::ParameterDescription => {
878 552 : buf.put_u8(b't');
879 552 : write_body(buf, |buf| {
880 552 : // we don't support params, so always 0
881 552 : buf.put_i16(0);
882 552 : });
883 552 : }
884 :
885 552 : BeMessage::ParseComplete => {
886 552 : buf.put_u8(b'1');
887 552 : write_body(buf, |_| {});
888 552 : }
889 :
890 23672 : BeMessage::ReadyForQuery => {
891 23672 : buf.put_u8(b'Z');
892 23672 : write_body(buf, |buf| {
893 23672 : buf.put_u8(b'I');
894 23672 : });
895 23672 : }
896 :
897 1370 : BeMessage::RowDescription(rows) => {
898 1370 : buf.put_u8(b'T');
899 1370 : write_body(buf, |buf| {
900 1370 : buf.put_i16(rows.len() as i16); // # of fields
901 4215 : for row in rows.iter() {
902 4215 : write_cstr(row.name, buf)?;
903 4215 : buf.put_i32(0); /* table oid */
904 4215 : buf.put_i16(0); /* attnum */
905 4215 : buf.put_u32(row.typoid);
906 4215 : buf.put_i16(row.typlen);
907 4215 : buf.put_i32(-1); /* typmod */
908 4215 : buf.put_i16(0); /* format code */
909 : }
910 1370 : Ok(())
911 1370 : })?;
912 : }
913 :
914 568377 : BeMessage::XLogData(body) => {
915 568377 : buf.put_u8(b'd');
916 568377 : write_body(buf, |buf| {
917 568377 : buf.put_u8(b'w');
918 568377 : buf.put_u64(body.wal_start);
919 568377 : buf.put_u64(body.wal_end);
920 568377 : buf.put_i64(body.timestamp);
921 568377 : buf.put_slice(body.data);
922 568377 : });
923 568377 : }
924 :
925 3414 : BeMessage::KeepAlive(req) => {
926 3414 : buf.put_u8(b'd');
927 3414 : write_body(buf, |buf| {
928 3414 : buf.put_u8(b'k');
929 3414 : buf.put_u64(req.wal_end);
930 3414 : buf.put_i64(req.timestamp);
931 3414 : buf.put_u8(u8::from(req.request_reply));
932 3414 : });
933 3414 : }
934 : }
935 6508493 : Ok(())
936 6508493 : }
937 : }
938 :
939 664 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
940 664 : let mut terminated = [0; 6];
941 3320 : for (i, &elem) in code.iter().enumerate() {
942 3320 : terminated[i] = elem;
943 3320 : }
944 :
945 664 : terminated
946 664 : }
947 :
948 : #[cfg(test)]
949 : mod tests {
950 : use super::*;
951 :
952 1 : #[test]
953 1 : fn test_startup_message_params_options_escaped() {
954 4 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
955 4 : params
956 4 : .options_escaped()
957 4 : .expect("options are None")
958 4 : .collect()
959 4 : }
960 1 :
961 4 : let make_params = |options| StartupMessageParams::new([("options", options)]);
962 :
963 1 : let params = StartupMessageParams::new([]);
964 1 : assert!(params.options_escaped().is_none());
965 :
966 1 : let params = make_params("");
967 1 : assert!(split_options(¶ms).is_empty());
968 :
969 1 : let params = make_params("foo");
970 1 : assert_eq!(split_options(¶ms), ["foo"]);
971 :
972 1 : let params = make_params(" foo bar ");
973 1 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
974 :
975 1 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
976 1 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
977 1 : }
978 :
979 1 : #[test]
980 1 : fn parse_fe_startup_packet_regression() {
981 1 : let data = [0, 0, 0, 7, 0, 0, 0, 0];
982 1 : FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err();
983 1 : }
984 : }
|