TLA Line data Source code
1 : //! Postgres protocol messages serialization-deserialization. See
2 : //! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
3 : //! on message formats.
4 :
5 : pub mod framed;
6 :
7 : use byteorder::{BigEndian, ReadBytesExt};
8 : use bytes::{Buf, BufMut, Bytes, BytesMut};
9 : use std::{borrow::Cow, collections::HashMap, fmt, io, str};
10 :
11 : // re-export for use in utils pageserver_feedback.rs
12 : pub use postgres_protocol::PG_EPOCH;
13 :
14 : pub type Oid = u32;
15 : pub type SystemId = u64;
16 :
17 : pub const INT8_OID: Oid = 20;
18 : pub const INT4_OID: Oid = 23;
19 : pub const TEXT_OID: Oid = 25;
20 :
21 CBC 59 : #[derive(Debug)]
22 : pub enum FeMessage {
23 : // Simple query.
24 : Query(Bytes),
25 : // Extended query protocol.
26 : Parse(FeParseMessage),
27 : Describe(FeDescribeMessage),
28 : Bind(FeBindMessage),
29 : Execute(FeExecuteMessage),
30 : Close(FeCloseMessage),
31 : Sync,
32 : Terminate,
33 : CopyData(Bytes),
34 : CopyDone,
35 : CopyFail,
36 : PasswordMessage(Bytes),
37 : }
38 :
39 140 : #[derive(Debug)]
40 : pub enum FeStartupPacket {
41 : CancelRequest(CancelKeyData),
42 : SslRequest,
43 : GssEncRequest,
44 : StartupMessage {
45 : major_version: u32,
46 : minor_version: u32,
47 : params: StartupMessageParams,
48 : },
49 : }
50 :
51 74 : #[derive(Debug)]
52 : pub struct StartupMessageParams {
53 : params: HashMap<String, String>,
54 : }
55 :
56 : impl StartupMessageParams {
57 : /// Get parameter's value by its name.
58 7454 : pub fn get(&self, name: &str) -> Option<&str> {
59 7454 : self.params.get(name).map(|s| s.as_str())
60 7454 : }
61 :
62 : /// Split command-line options according to PostgreSQL's logic,
63 : /// taking into account all escape sequences but leaving them as-is.
64 : /// [`None`] means that there's no `options` in [`Self`].
65 3663 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
66 3663 : self.get("options").map(Self::parse_options_raw)
67 3663 : }
68 :
69 : /// Split command-line options according to PostgreSQL's logic,
70 : /// applying all escape sequences (using owned strings as needed).
71 : /// [`None`] means that there's no `options` in [`Self`].
72 5 : pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
73 5 : self.get("options").map(Self::parse_options_escaped)
74 5 : }
75 :
76 : /// Split command-line options according to PostgreSQL's logic,
77 : /// taking into account all escape sequences but leaving them as-is.
78 3634 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
79 3634 : // See `postgres: pg_split_opts`.
80 3634 : let mut last_was_escape = false;
81 3634 : input
82 323264 : .split(move |c: char| {
83 : // We split by non-escaped whitespace symbols.
84 323264 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
85 323264 : last_was_escape = c == '\\' && !last_was_escape;
86 323264 : should_split
87 323264 : })
88 10863 : .filter(|s| !s.is_empty())
89 3634 : }
90 :
91 : /// Split command-line options according to PostgreSQL's logic,
92 : /// applying all escape sequences (using owned strings as needed).
93 4 : pub fn parse_options_escaped(input: &str) -> impl Iterator<Item = Cow<'_, str>> {
94 4 : // See `postgres: pg_split_opts`.
95 7 : Self::parse_options_raw(input).map(|s| {
96 7 : let mut preserve_next_escape = false;
97 17 : let escape = |c| {
98 : // We should remove '\\' unless it's preceded by '\\'.
99 17 : let should_remove = c == '\\' && !preserve_next_escape;
100 17 : preserve_next_escape = should_remove;
101 17 : should_remove
102 17 : };
103 :
104 7 : match s.contains('\\') {
105 3 : true => Cow::Owned(s.replace(escape, "")),
106 4 : false => Cow::Borrowed(s),
107 : }
108 7 : })
109 4 : }
110 :
111 : /// Iterate through key-value pairs in an arbitrary order.
112 UBC 0 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
113 0 : self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
114 0 : }
115 :
116 : // This function is mostly useful in tests.
117 : #[doc(hidden)]
118 CBC 48 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
119 48 : Self {
120 109 : params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
121 48 : }
122 48 : }
123 : }
124 :
125 98 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
126 : pub struct CancelKeyData {
127 : pub backend_pid: i32,
128 : pub cancel_key: i32,
129 : }
130 :
131 : impl fmt::Display for CancelKeyData {
132 132 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 132 : let hi = (self.backend_pid as u64) << 32;
134 132 : let lo = self.cancel_key as u64;
135 132 : let id = hi | lo;
136 132 :
137 132 : // This format is more compact and might work better for logs.
138 132 : f.debug_tuple("CancelKeyData")
139 132 : .field(&format_args!("{:x}", id))
140 132 : .finish()
141 132 : }
142 : }
143 :
144 : use rand::distributions::{Distribution, Standard};
145 : impl Distribution<CancelKeyData> for Standard {
146 34 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
147 34 : CancelKeyData {
148 34 : backend_pid: rng.gen(),
149 34 : cancel_key: rng.gen(),
150 34 : }
151 34 : }
152 : }
153 :
154 : // We only support the simple case of Parse on unnamed prepared statement and
155 : // no params
156 UBC 0 : #[derive(Debug)]
157 : pub struct FeParseMessage {
158 : pub query_string: Bytes,
159 : }
160 :
161 0 : #[derive(Debug)]
162 : pub struct FeDescribeMessage {
163 : pub kind: u8, // 'S' to describe a prepared statement; or 'P' to describe a portal.
164 : // we only support unnamed prepared stmt or portal
165 : }
166 :
167 : // we only support unnamed prepared stmt and portal
168 0 : #[derive(Debug)]
169 : pub struct FeBindMessage;
170 :
171 : // we only support unnamed prepared stmt or portal
172 0 : #[derive(Debug)]
173 : pub struct FeExecuteMessage {
174 : /// max # of rows
175 : pub maxrows: i32,
176 : }
177 :
178 : // we only support unnamed prepared stmt and portal
179 0 : #[derive(Debug)]
180 : pub struct FeCloseMessage;
181 :
182 : /// An error occurred while parsing or serializing raw stream into Postgres
183 : /// messages.
184 CBC 7 : #[derive(thiserror::Error, Debug)]
185 : pub enum ProtocolError {
186 : /// Invalid packet was received from the client (e.g. unexpected message
187 : /// type or broken len).
188 : #[error("Protocol error: {0}")]
189 : Protocol(String),
190 : /// Failed to parse or, (unlikely), serialize a protocol message.
191 : #[error("Message parse error: {0}")]
192 : BadMessage(String),
193 : }
194 :
195 : impl ProtocolError {
196 : /// Proxy stream.rs uses only io::Error; provide it.
197 UBC 0 : pub fn into_io_error(self) -> io::Error {
198 0 : io::Error::new(io::ErrorKind::Other, self.to_string())
199 0 : }
200 : }
201 :
202 : impl FeMessage {
203 : /// Read and parse one message from the `buf` input buffer. If there is at
204 : /// least one valid message, returns it, advancing `buf`; redundant copies
205 : /// are avoided, as thanks to `bytes` crate ptrs in parsed message point
206 : /// directly into the `buf` (processed data is garbage collected after
207 : /// parsed message is dropped).
208 : ///
209 : /// Returns None if `buf` doesn't contain enough data for a single message.
210 : /// For efficiency, tries to reserve large enough space in `buf` for the
211 : /// next message in this case to save the repeated calls.
212 : ///
213 : /// Returns Error if message is malformed, the only possible ErrorKind is
214 : /// InvalidInput.
215 : //
216 : // Inspired by rust-postgres Message::parse.
217 CBC 14442968 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
218 14442968 : // Every message contains message type byte and 4 bytes len; can't do
219 14442968 : // much without them.
220 14442968 : if buf.len() < 5 {
221 6811831 : let to_read = 5 - buf.len();
222 6811831 : buf.reserve(to_read);
223 6811831 : return Ok(None);
224 7631137 : }
225 7631137 :
226 7631137 : // We shouldn't advance `buf` as probably full message is not there yet,
227 7631137 : // so can't directly use Bytes::get_u32 etc.
228 7631137 : let tag = buf[0];
229 7631137 : let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
230 7631137 : if len < 4 {
231 UBC 0 : return Err(ProtocolError::Protocol(format!(
232 0 : "invalid message length {}",
233 0 : len
234 0 : )));
235 CBC 7631137 : }
236 7631137 :
237 7631137 : // length field includes itself, but not message type.
238 7631137 : let total_len = len as usize + 1;
239 7631137 : if buf.len() < total_len {
240 : // Don't have full message yet.
241 112724 : let to_read = total_len - buf.len();
242 112724 : buf.reserve(to_read);
243 112724 : return Ok(None);
244 7518413 : }
245 7518413 :
246 7518413 : // got the message, advance buffer
247 7518413 : let mut msg = buf.split_to(total_len).freeze();
248 7518413 : msg.advance(5); // consume message type and len
249 7518413 :
250 7518413 : match tag {
251 8744 : b'Q' => Ok(Some(FeMessage::Query(msg))),
252 640 : b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
253 640 : b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
254 640 : b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
255 640 : b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
256 639 : b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
257 1934 : b'S' => Ok(Some(FeMessage::Sync)),
258 1525 : b'X' => Ok(Some(FeMessage::Terminate)),
259 7502651 : b'd' => Ok(Some(FeMessage::CopyData(msg))),
260 6 : b'c' => Ok(Some(FeMessage::CopyDone)),
261 70 : b'f' => Ok(Some(FeMessage::CopyFail)),
262 284 : b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
263 UBC 0 : tag => Err(ProtocolError::Protocol(format!(
264 0 : "unknown message tag: {tag},'{msg:?}'"
265 0 : ))),
266 : }
267 CBC 14442968 : }
268 : }
269 :
270 : impl FeStartupPacket {
271 : /// Read and parse startup message from the `buf` input buffer. It is
272 : /// different from [`FeMessage::parse`] because startup messages don't have
273 : /// message type byte; otherwise, its comments apply.
274 30764 : pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
275 30764 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
276 30764 : const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
277 30764 : const CANCEL_REQUEST_CODE: u32 = 5678;
278 30764 : const NEGOTIATE_SSL_CODE: u32 = 5679;
279 30764 : const NEGOTIATE_GSS_CODE: u32 = 5680;
280 30764 :
281 30764 : // need at least 4 bytes with packet len
282 30764 : if buf.len() < 4 {
283 15383 : let to_read = 4 - buf.len();
284 15383 : buf.reserve(to_read);
285 15383 : return Ok(None);
286 15381 : }
287 15381 :
288 15381 : // We shouldn't advance `buf` as probably full message is not there yet,
289 15381 : // so can't directly use Bytes::get_u32 etc.
290 15381 : let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
291 15381 : // The proposed replacement is `!(4..=MAX_STARTUP_PACKET_LENGTH).contains(&len)`
292 15381 : // which is less readable
293 15381 : #[allow(clippy::manual_range_contains)]
294 15381 : if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
295 UBC 0 : return Err(ProtocolError::Protocol(format!(
296 0 : "invalid startup packet message length {}",
297 0 : len
298 0 : )));
299 CBC 15381 : }
300 15381 :
301 15381 : if buf.len() < len {
302 : // Don't have full message yet.
303 UBC 0 : let to_read = len - buf.len();
304 0 : buf.reserve(to_read);
305 0 : return Ok(None);
306 CBC 15381 : }
307 15381 :
308 15381 : // got the message, advance buffer
309 15381 : let mut msg = buf.split_to(len).freeze();
310 15381 : msg.advance(4); // consume len
311 15381 :
312 15381 : let request_code = msg.get_u32();
313 15381 : let req_hi = request_code >> 16;
314 15381 : let req_lo = request_code & ((1 << 16) - 1);
315 : // StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
316 15381 : let message = match (req_hi, req_lo) {
317 : (RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
318 UBC 0 : if msg.remaining() != 8 {
319 0 : return Err(ProtocolError::BadMessage(
320 0 : "CancelRequest message is malformed, backend PID / secret key missing"
321 0 : .to_owned(),
322 0 : ));
323 0 : }
324 0 : FeStartupPacket::CancelRequest(CancelKeyData {
325 0 : backend_pid: msg.get_i32(),
326 0 : cancel_key: msg.get_i32(),
327 0 : })
328 : }
329 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
330 : // Requested upgrade to SSL (aka TLS)
331 CBC 6653 : FeStartupPacket::SslRequest
332 : }
333 : (RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
334 : // Requested upgrade to GSSAPI
335 UBC 0 : FeStartupPacket::GssEncRequest
336 : }
337 0 : (RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
338 0 : return Err(ProtocolError::Protocol(format!(
339 0 : "Unrecognized request code {unrecognized_code}"
340 0 : )));
341 : }
342 : // TODO bail if protocol major_version is not 3?
343 CBC 8728 : (major_version, minor_version) => {
344 : // StartupMessage
345 :
346 : // Parse pairs of null-terminated strings (key, value).
347 : // See `postgres: ProcessStartupPacket, build_startup_packet`.
348 8728 : let mut tokens = str::from_utf8(&msg)
349 8728 : .map_err(|_e| {
350 UBC 0 : ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
351 CBC 8728 : })?
352 8728 : .strip_suffix('\0') // drop packet's own null
353 8728 : .ok_or_else(|| {
354 UBC 0 : ProtocolError::Protocol(
355 0 : "StartupMessage params: missing null terminator".to_string(),
356 0 : )
357 CBC 8728 : })?
358 8728 : .split_terminator('\0');
359 8728 :
360 8728 : let mut params = HashMap::new();
361 30726 : while let Some(name) = tokens.next() {
362 21998 : let value = tokens.next().ok_or_else(|| {
363 UBC 0 : ProtocolError::Protocol(
364 0 : "StartupMessage params: key without value".to_string(),
365 0 : )
366 CBC 21998 : })?;
367 :
368 21998 : params.insert(name.to_owned(), value.to_owned());
369 : }
370 :
371 8728 : FeStartupPacket::StartupMessage {
372 8728 : major_version,
373 8728 : minor_version,
374 8728 : params: StartupMessageParams { params },
375 8728 : }
376 : }
377 : };
378 15381 : Ok(Some(message))
379 30764 : }
380 : }
381 :
382 : impl FeParseMessage {
383 640 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
384 : // FIXME: the rust-postgres driver uses a named prepared statement
385 : // for copy_out(). We're not prepared to handle that correctly. For
386 : // now, just ignore the statement name, assuming that the client never
387 : // uses more than one prepared statement at a time.
388 :
389 640 : let _pstmt_name = read_cstr(&mut buf)?;
390 640 : let query_string = read_cstr(&mut buf)?;
391 640 : if buf.remaining() < 2 {
392 UBC 0 : return Err(ProtocolError::BadMessage(
393 0 : "Parse message is malformed, nparams missing".to_string(),
394 0 : ));
395 CBC 640 : }
396 640 : let nparams = buf.get_i16();
397 640 :
398 640 : if nparams != 0 {
399 UBC 0 : return Err(ProtocolError::BadMessage(
400 0 : "query params not implemented".to_string(),
401 0 : ));
402 CBC 640 : }
403 640 :
404 640 : Ok(FeMessage::Parse(FeParseMessage { query_string }))
405 640 : }
406 : }
407 :
408 : impl FeDescribeMessage {
409 640 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
410 640 : let kind = buf.get_u8();
411 640 : let _pstmt_name = read_cstr(&mut buf)?;
412 :
413 : // FIXME: see FeParseMessage::parse
414 640 : if kind != b'S' {
415 UBC 0 : return Err(ProtocolError::BadMessage(
416 0 : "only prepared statemement Describe is implemented".to_string(),
417 0 : ));
418 CBC 640 : }
419 640 :
420 640 : Ok(FeMessage::Describe(FeDescribeMessage { kind }))
421 640 : }
422 : }
423 :
424 : impl FeExecuteMessage {
425 640 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
426 640 : let portal_name = read_cstr(&mut buf)?;
427 640 : if buf.remaining() < 4 {
428 UBC 0 : return Err(ProtocolError::BadMessage(
429 0 : "FeExecuteMessage message is malformed, maxrows missing".to_string(),
430 0 : ));
431 CBC 640 : }
432 640 : let maxrows = buf.get_i32();
433 640 :
434 640 : if !portal_name.is_empty() {
435 UBC 0 : return Err(ProtocolError::BadMessage(
436 0 : "named portals not implemented".to_string(),
437 0 : ));
438 CBC 640 : }
439 640 : if maxrows != 0 {
440 UBC 0 : return Err(ProtocolError::BadMessage(
441 0 : "row limit in Execute message not implemented".to_string(),
442 0 : ));
443 CBC 640 : }
444 640 :
445 640 : Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
446 640 : }
447 : }
448 :
449 : impl FeBindMessage {
450 640 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
451 640 : let portal_name = read_cstr(&mut buf)?;
452 640 : let _pstmt_name = read_cstr(&mut buf)?;
453 :
454 : // FIXME: see FeParseMessage::parse
455 640 : if !portal_name.is_empty() {
456 UBC 0 : return Err(ProtocolError::BadMessage(
457 0 : "named portals not implemented".to_string(),
458 0 : ));
459 CBC 640 : }
460 640 :
461 640 : Ok(FeMessage::Bind(FeBindMessage))
462 640 : }
463 : }
464 :
465 : impl FeCloseMessage {
466 639 : fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
467 639 : let _kind = buf.get_u8();
468 639 : let _pstmt_or_portal_name = read_cstr(&mut buf)?;
469 :
470 : // FIXME: we do nothing with Close
471 639 : Ok(FeMessage::Close(FeCloseMessage))
472 639 : }
473 : }
474 :
475 : // Backend
476 :
477 UBC 0 : #[derive(Debug)]
478 : pub enum BeMessage<'a> {
479 : AuthenticationOk,
480 : AuthenticationMD5Password([u8; 4]),
481 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
482 : AuthenticationCleartextPassword,
483 : BackendKeyData(CancelKeyData),
484 : BindComplete,
485 : CommandComplete(&'a [u8]),
486 : CopyData(&'a [u8]),
487 : CopyDone,
488 : CopyFail,
489 : CopyInResponse,
490 : CopyOutResponse,
491 : CopyBothResponse,
492 : CloseComplete,
493 : // None means column is NULL
494 : DataRow(&'a [Option<&'a [u8]>]),
495 : // None errcode means internal_error will be sent.
496 : ErrorResponse(&'a str, Option<&'a [u8; 5]>),
497 : /// Single byte - used in response to SSLRequest/GSSENCRequest.
498 : EncryptionResponse(bool),
499 : NoData,
500 : ParameterDescription,
501 : ParameterStatus {
502 : name: &'a [u8],
503 : value: &'a [u8],
504 : },
505 : ParseComplete,
506 : ReadyForQuery,
507 : RowDescription(&'a [RowDescriptor<'a>]),
508 : XLogData(XLogDataBody<'a>),
509 : NoticeResponse(&'a str),
510 : KeepAlive(WalSndKeepAlive),
511 : }
512 :
513 : /// Common shorthands.
514 : impl<'a> BeMessage<'a> {
515 : /// A [`BeMessage::ParameterStatus`] holding the client encoding, i.e. UTF-8.
516 : /// This is a sensible default, given that:
517 : /// * rust strings only support this encoding out of the box.
518 : /// * tokio-postgres, postgres-jdbc (and probably more) mandate it.
519 : ///
520 : /// TODO: do we need to report `server_encoding` as well?
521 : pub const CLIENT_ENCODING: Self = Self::ParameterStatus {
522 : name: b"client_encoding",
523 : value: b"UTF8",
524 : };
525 :
526 : pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
527 : name: b"integer_datetimes",
528 : value: b"on",
529 : };
530 :
531 : /// Build a [`BeMessage::ParameterStatus`] holding the server version.
532 CBC 8459 : pub fn server_version(version: &'a str) -> Self {
533 8459 : Self::ParameterStatus {
534 8459 : name: b"server_version",
535 8459 : value: version.as_bytes(),
536 8459 : }
537 8459 : }
538 : }
539 :
540 UBC 0 : #[derive(Debug)]
541 : pub enum BeAuthenticationSaslMessage<'a> {
542 : Methods(&'a [&'a str]),
543 : Continue(&'a [u8]),
544 : Final(&'a [u8]),
545 : }
546 :
547 0 : #[derive(Debug)]
548 : pub enum BeParameterStatusMessage<'a> {
549 : Encoding(&'a str),
550 : ServerVersion(&'a str),
551 : }
552 :
553 : // One row description in RowDescription packet.
554 0 : #[derive(Debug)]
555 : pub struct RowDescriptor<'a> {
556 : pub name: &'a [u8],
557 : pub tableoid: Oid,
558 : pub attnum: i16,
559 : pub typoid: Oid,
560 : pub typlen: i16,
561 : pub typmod: i32,
562 : pub formatcode: i16,
563 : }
564 :
565 : impl Default for RowDescriptor<'_> {
566 CBC 2791 : fn default() -> RowDescriptor<'static> {
567 2791 : RowDescriptor {
568 2791 : name: b"",
569 2791 : tableoid: 0,
570 2791 : attnum: 0,
571 2791 : typoid: 0,
572 2791 : typlen: 0,
573 2791 : typmod: 0,
574 2791 : formatcode: 0,
575 2791 : }
576 2791 : }
577 : }
578 :
579 : impl RowDescriptor<'_> {
580 : /// Convenience function to create a RowDescriptor message for an int8 column
581 45 : pub const fn int8_col(name: &[u8]) -> RowDescriptor {
582 45 : RowDescriptor {
583 45 : name,
584 45 : tableoid: 0,
585 45 : attnum: 0,
586 45 : typoid: INT8_OID,
587 45 : typlen: 8,
588 45 : typmod: 0,
589 45 : formatcode: 0,
590 45 : }
591 45 : }
592 :
593 1508 : pub const fn text_col(name: &[u8]) -> RowDescriptor {
594 1508 : RowDescriptor {
595 1508 : name,
596 1508 : tableoid: 0,
597 1508 : attnum: 0,
598 1508 : typoid: TEXT_OID,
599 1508 : typlen: -1,
600 1508 : typmod: 0,
601 1508 : formatcode: 0,
602 1508 : }
603 1508 : }
604 : }
605 :
606 UBC 0 : #[derive(Debug)]
607 : pub struct XLogDataBody<'a> {
608 : pub wal_start: u64,
609 : pub wal_end: u64, // current end of WAL on the server
610 : pub timestamp: i64,
611 : pub data: &'a [u8],
612 : }
613 :
614 0 : #[derive(Debug)]
615 : pub struct WalSndKeepAlive {
616 : pub wal_end: u64, // current end of WAL on the server
617 : pub timestamp: i64,
618 : pub request_reply: bool,
619 : }
620 :
621 : pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
622 :
623 : // single text column
624 : pub static SINGLE_COL_ROWDESC: BeMessage = BeMessage::RowDescription(&[RowDescriptor {
625 : name: b"data",
626 : tableoid: 0,
627 : attnum: 0,
628 : typoid: TEXT_OID,
629 : typlen: -1,
630 : typmod: 0,
631 : formatcode: 0,
632 : }]);
633 :
634 : /// Call f() to write body of the message and prepend it with 4-byte len as
635 : /// prescribed by the protocol.
636 CBC 8027128 : fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
637 8027128 : let base = buf.len();
638 8027128 : buf.extend_from_slice(&[0; 4]);
639 8027128 :
640 8027128 : let res = f(buf);
641 8027128 :
642 8027128 : let size = i32::try_from(buf.len() - base).expect("message too big to transmit");
643 8027128 : (&mut buf[base..]).put_slice(&size.to_be_bytes());
644 8027128 :
645 8027128 : res
646 8027128 : }
647 :
648 : /// Safe write of s into buf as cstring (String in the protocol).
649 58728 : fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
650 58728 : let bytes = s.as_ref();
651 58728 : if bytes.contains(&0) {
652 UBC 0 : return Err(ProtocolError::BadMessage(
653 0 : "string contains embedded null".to_owned(),
654 0 : ));
655 CBC 58728 : }
656 58728 : buf.put_slice(bytes);
657 58728 : buf.put_u8(0);
658 58728 : Ok(())
659 58728 : }
660 :
661 : /// Read cstring from buf, advancing it.
662 3901695 : pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
663 3901695 : let pos = buf
664 3901695 : .iter()
665 55402786 : .position(|x| *x == 0)
666 3901695 : .ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
667 3901695 : let result = buf.split_to(pos);
668 3901695 : buf.advance(1); // drop the null terminator
669 3901695 : Ok(result)
670 3901695 : }
671 :
672 : pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
673 : pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
674 : pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
675 :
676 : impl<'a> BeMessage<'a> {
677 : /// Serialize `message` to the given `buf`.
678 : /// Apart from smart memory managemet, BytesMut is good here as msg len
679 : /// precedes its body and it is handy to write it down first and then fill
680 : /// the length. With Write we would have to either calc it manually or have
681 : /// one more buffer.
682 8033781 : pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
683 8033781 : match message {
684 8711 : BeMessage::AuthenticationOk => {
685 8711 : buf.put_u8(b'R');
686 8711 : write_body(buf, |buf| {
687 8711 : buf.put_i32(0); // Specifies that the authentication was successful.
688 8711 : });
689 8711 : }
690 :
691 228 : BeMessage::AuthenticationCleartextPassword => {
692 228 : buf.put_u8(b'R');
693 228 : write_body(buf, |buf| {
694 228 : buf.put_i32(3); // Specifies that clear text password is required.
695 228 : });
696 228 : }
697 :
698 UBC 0 : BeMessage::AuthenticationMD5Password(salt) => {
699 0 : buf.put_u8(b'R');
700 0 : write_body(buf, |buf| {
701 0 : buf.put_i32(5); // Specifies that an MD5-encrypted password is required.
702 0 : buf.put_slice(&salt[..]);
703 0 : });
704 0 : }
705 :
706 CBC 89 : BeMessage::AuthenticationSasl(msg) => {
707 89 : buf.put_u8(b'R');
708 89 : write_body(buf, |buf| {
709 89 : use BeAuthenticationSaslMessage::*;
710 89 : match msg {
711 31 : Methods(methods) => {
712 31 : buf.put_i32(10); // Specifies that SASL auth method is used.
713 31 : for method in methods.iter() {
714 31 : write_cstr(method, buf)?;
715 : }
716 31 : buf.put_u8(0); // zero terminator for the list
717 : }
718 31 : Continue(extra) => {
719 31 : buf.put_i32(11); // Continue SASL auth.
720 31 : buf.put_slice(extra);
721 31 : }
722 27 : Final(extra) => {
723 27 : buf.put_i32(12); // Send final SASL message.
724 27 : buf.put_slice(extra);
725 27 : }
726 : }
727 89 : Ok(())
728 89 : })?;
729 : }
730 :
731 29 : BeMessage::BackendKeyData(key_data) => {
732 29 : buf.put_u8(b'K');
733 29 : write_body(buf, |buf| {
734 29 : buf.put_i32(key_data.backend_pid);
735 29 : buf.put_i32(key_data.cancel_key);
736 29 : });
737 29 : }
738 :
739 640 : BeMessage::BindComplete => {
740 640 : buf.put_u8(b'2');
741 640 : write_body(buf, |_| {});
742 640 : }
743 :
744 639 : BeMessage::CloseComplete => {
745 639 : buf.put_u8(b'3');
746 639 : write_body(buf, |_| {});
747 639 : }
748 :
749 2137 : BeMessage::CommandComplete(cmd) => {
750 2137 : buf.put_u8(b'C');
751 2137 : write_body(buf, |buf| write_cstr(cmd, buf))?;
752 : }
753 :
754 7171073 : BeMessage::CopyData(data) => {
755 7171073 : buf.put_u8(b'd');
756 7171073 : write_body(buf, |buf| {
757 7171073 : buf.put_slice(data);
758 7171073 : });
759 7171073 : }
760 :
761 638 : BeMessage::CopyDone => {
762 638 : buf.put_u8(b'c');
763 638 : write_body(buf, |_| {});
764 638 : }
765 :
766 UBC 0 : BeMessage::CopyFail => {
767 0 : buf.put_u8(b'f');
768 0 : write_body(buf, |_| {});
769 0 : }
770 :
771 CBC 7 : BeMessage::CopyInResponse => {
772 7 : buf.put_u8(b'G');
773 7 : write_body(buf, |buf| {
774 7 : buf.put_u8(1); // copy_is_binary
775 7 : buf.put_i16(0); // numAttributes
776 7 : });
777 7 : }
778 :
779 640 : BeMessage::CopyOutResponse => {
780 640 : buf.put_u8(b'H');
781 640 : write_body(buf, |buf| {
782 640 : buf.put_u8(0); // copy_is_binary
783 640 : buf.put_i16(0); // numAttributes
784 640 : });
785 640 : }
786 :
787 7197 : BeMessage::CopyBothResponse => {
788 7197 : buf.put_u8(b'W');
789 7197 : write_body(buf, |buf| {
790 7197 : // doesn't matter, used only for replication
791 7197 : buf.put_u8(0); // copy_is_binary
792 7197 : buf.put_i16(0); // numAttributes
793 7197 : });
794 7197 : }
795 :
796 994 : BeMessage::DataRow(vals) => {
797 994 : buf.put_u8(b'D');
798 994 : write_body(buf, |buf| {
799 994 : buf.put_u16(vals.len() as u16); // num of cols
800 3409 : for val_opt in vals.iter() {
801 3409 : if let Some(val) = val_opt {
802 2712 : buf.put_u32(val.len() as u32);
803 2712 : buf.put_slice(val);
804 2712 : } else {
805 697 : buf.put_i32(-1);
806 697 : }
807 : }
808 994 : });
809 994 : }
810 :
811 : // ErrorResponse is a zero-terminated array of zero-terminated fields.
812 : // First byte of each field represents type of this field. Set just enough fields
813 : // to satisfy rust-postgres client: 'S' -- severity, 'C' -- error, 'M' -- error
814 : // message text.
815 247 : BeMessage::ErrorResponse(error_msg, pg_error_code) => {
816 247 : // 'E' signalizes ErrorResponse messages
817 247 : buf.put_u8(b'E');
818 247 : write_body(buf, |buf| {
819 247 : buf.put_u8(b'S'); // severity
820 247 : buf.put_slice(b"ERROR\0");
821 247 :
822 247 : buf.put_u8(b'C'); // SQLSTATE error code
823 247 : buf.put_slice(&terminate_code(
824 247 : pg_error_code.unwrap_or(SQLSTATE_INTERNAL_ERROR),
825 247 : ));
826 247 :
827 247 : buf.put_u8(b'M'); // the message
828 247 : write_cstr(error_msg, buf)?;
829 :
830 247 : buf.put_u8(0); // terminator
831 247 : Ok(())
832 247 : })?;
833 : }
834 :
835 : // NoticeResponse has the same format as ErrorResponse. From doc: "The frontend should display the
836 : // message but continue listening for ReadyForQuery or ErrorResponse"
837 6 : BeMessage::NoticeResponse(error_msg) => {
838 6 : // For all the errors set Severity to Error and error code to
839 6 : // 'internal error'.
840 6 :
841 6 : // 'N' signalizes NoticeResponse messages
842 6 : buf.put_u8(b'N');
843 6 : write_body(buf, |buf| {
844 6 : buf.put_u8(b'S'); // severity
845 6 : buf.put_slice(b"NOTICE\0");
846 6 :
847 6 : buf.put_u8(b'C'); // SQLSTATE error code
848 6 : buf.put_slice(&terminate_code(SQLSTATE_INTERNAL_ERROR));
849 6 :
850 6 : buf.put_u8(b'M'); // the message
851 6 : write_cstr(error_msg.as_bytes(), buf)?;
852 :
853 6 : buf.put_u8(0); // terminator
854 6 : Ok(())
855 6 : })?;
856 : }
857 :
858 640 : BeMessage::NoData => {
859 640 : buf.put_u8(b'n');
860 640 : write_body(buf, |_| {});
861 640 : }
862 :
863 6653 : BeMessage::EncryptionResponse(should_negotiate) => {
864 6653 : let response = if *should_negotiate { b'S' } else { b'N' };
865 6653 : buf.put_u8(response);
866 : }
867 :
868 25980 : BeMessage::ParameterStatus { name, value } => {
869 25980 : buf.put_u8(b'S');
870 25980 : write_body(buf, |buf| {
871 25980 : write_cstr(name, buf)?;
872 25980 : write_cstr(value, buf)
873 25980 : })?;
874 : }
875 :
876 640 : BeMessage::ParameterDescription => {
877 640 : buf.put_u8(b't');
878 640 : write_body(buf, |buf| {
879 640 : // we don't support params, so always 0
880 640 : buf.put_i16(0);
881 640 : });
882 640 : }
883 :
884 640 : BeMessage::ParseComplete => {
885 640 : buf.put_u8(b'1');
886 640 : write_body(buf, |_| {});
887 640 : }
888 :
889 18909 : BeMessage::ReadyForQuery => {
890 18909 : buf.put_u8(b'Z');
891 18909 : write_body(buf, |buf| {
892 18909 : buf.put_u8(b'I');
893 18909 : });
894 18909 : }
895 :
896 1463 : BeMessage::RowDescription(rows) => {
897 1463 : buf.put_u8(b'T');
898 1463 : write_body(buf, |buf| {
899 1463 : buf.put_i16(rows.len() as i16); // # of fields
900 4347 : for row in rows.iter() {
901 4347 : write_cstr(row.name, buf)?;
902 4347 : buf.put_i32(0); /* table oid */
903 4347 : buf.put_i16(0); /* attnum */
904 4347 : buf.put_u32(row.typoid);
905 4347 : buf.put_i16(row.typlen);
906 4347 : buf.put_i32(-1); /* typmod */
907 4347 : buf.put_i16(0); /* format code */
908 : }
909 1463 : Ok(())
910 1463 : })?;
911 : }
912 :
913 783639 : BeMessage::XLogData(body) => {
914 783639 : buf.put_u8(b'd');
915 783639 : write_body(buf, |buf| {
916 783639 : buf.put_u8(b'w');
917 783639 : buf.put_u64(body.wal_start);
918 783639 : buf.put_u64(body.wal_end);
919 783639 : buf.put_i64(body.timestamp);
920 783639 : buf.put_slice(body.data);
921 783639 : });
922 783639 : }
923 :
924 1942 : BeMessage::KeepAlive(req) => {
925 1942 : buf.put_u8(b'd');
926 1942 : write_body(buf, |buf| {
927 1942 : buf.put_u8(b'k');
928 1942 : buf.put_u64(req.wal_end);
929 1942 : buf.put_i64(req.timestamp);
930 1942 : buf.put_u8(u8::from(req.request_reply));
931 1942 : });
932 1942 : }
933 : }
934 8033781 : Ok(())
935 8033781 : }
936 : }
937 :
938 253 : fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
939 253 : let mut terminated = [0; 6];
940 1265 : for (i, &elem) in code.iter().enumerate() {
941 1265 : terminated[i] = elem;
942 1265 : }
943 :
944 253 : terminated
945 253 : }
946 :
947 : #[cfg(test)]
948 : mod tests {
949 : use super::*;
950 :
951 1 : #[test]
952 1 : fn test_startup_message_params_options_escaped() {
953 4 : fn split_options(params: &StartupMessageParams) -> Vec<Cow<'_, str>> {
954 4 : params
955 4 : .options_escaped()
956 4 : .expect("options are None")
957 4 : .collect()
958 4 : }
959 1 :
960 4 : let make_params = |options| StartupMessageParams::new([("options", options)]);
961 :
962 1 : let params = StartupMessageParams::new([]);
963 1 : assert!(params.options_escaped().is_none());
964 :
965 1 : let params = make_params("");
966 1 : assert!(split_options(¶ms).is_empty());
967 :
968 1 : let params = make_params("foo");
969 1 : assert_eq!(split_options(¶ms), ["foo"]);
970 :
971 1 : let params = make_params(" foo bar ");
972 1 : assert_eq!(split_options(¶ms), ["foo", "bar"]);
973 :
974 1 : let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
975 1 : assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
976 1 : }
977 : }
|