Line data Source code
1 : //! Postgres protocol codec
2 : //!
3 : //! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
4 :
5 : use std::fmt;
6 : use std::io::{self, Cursor};
7 :
8 : use bytes::{Buf, BufMut};
9 : use itertools::Itertools;
10 : use rand::distributions::{Distribution, Standard};
11 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12 : use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
13 :
14 : pub type ErrorCode = [u8; 5];
15 :
16 : pub const FE_PASSWORD_MESSAGE: u8 = b'p';
17 :
18 : pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
19 :
20 : /// The protocol version number.
21 : ///
22 : /// The most significant 16 bits are the major version number (3 for the protocol described here).
23 : /// The least significant 16 bits are the minor version number (0 for the protocol described here).
24 : /// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
25 : #[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
26 : #[repr(C)]
27 : pub struct ProtocolVersion {
28 : major: big_endian::U16,
29 : minor: big_endian::U16,
30 : }
31 :
32 : impl ProtocolVersion {
33 3 : pub const fn new(major: u16, minor: u16) -> Self {
34 3 : Self {
35 3 : major: big_endian::U16::new(major),
36 3 : minor: big_endian::U16::new(minor),
37 3 : }
38 3 : }
39 1 : pub const fn minor(self) -> u16 {
40 1 : self.minor.get()
41 1 : }
42 24 : pub const fn major(self) -> u16 {
43 24 : self.major.get()
44 24 : }
45 : }
46 :
47 : impl fmt::Debug for ProtocolVersion {
48 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 0 : f.debug_list()
50 0 : .entry(&self.major())
51 0 : .entry(&self.minor())
52 0 : .finish()
53 0 : }
54 : }
55 :
56 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
57 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
58 : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
59 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
60 : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
61 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
62 : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
63 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
64 : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
65 :
66 : /// This first reads the startup message header, is 8 bytes.
67 : /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
68 : ///
69 : /// The length value is inclusive of the header. For example,
70 : /// an empty message will always have length 8.
71 : #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
72 : #[repr(C)]
73 : struct StartupHeader {
74 : len: big_endian::U32,
75 : version: ProtocolVersion,
76 : }
77 :
78 : /// read the type from the stream using zerocopy.
79 : ///
80 : /// not cancel safe.
81 : macro_rules! read {
82 : ($s:expr => $t:ty) => {{
83 : // cannot be implemented as a function due to lack of const-generic-expr
84 : let mut buf = [0; size_of::<$t>()];
85 : $s.read_exact(&mut buf).await?;
86 : let res: $t = zerocopy::transmute!(buf);
87 : res
88 : }};
89 : }
90 :
91 : /// Returns true if TLS is supported.
92 : ///
93 : /// This is not cancel safe.
94 0 : pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
95 0 : where
96 0 : S: AsyncRead + AsyncWrite + Unpin,
97 0 : {
98 0 : let payload = StartupHeader {
99 0 : len: 8.into(),
100 0 : version: NEGOTIATE_SSL_CODE,
101 0 : };
102 0 : stream.write_all(payload.as_bytes()).await?;
103 0 : stream.flush().await?;
104 :
105 : // we expect back either `S` or `N` as a single byte.
106 0 : let mut res = *b"0";
107 0 : stream.read_exact(&mut res).await?;
108 :
109 0 : debug_assert!(
110 0 : res == *b"S" || res == *b"N",
111 0 : "unexpected SSL negotiation response: {}",
112 0 : char::from(res[0]),
113 : );
114 :
115 : // S for SSL.
116 0 : Ok(res == *b"S")
117 0 : }
118 :
119 46 : pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
120 46 : where
121 46 : S: AsyncRead + Unpin,
122 46 : {
123 46 : let header = read!(stream => StartupHeader);
124 :
125 : // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
126 : // First byte indicates standard SSL handshake message
127 : // (It can't be a Postgres startup length because in network byte order
128 : // that would be a startup packet hundreds of megabytes long)
129 46 : if header.as_bytes()[0] == 0x16 {
130 : return Ok(FeStartupPacket::SslRequest {
131 : // The bytes we read for the header are actually part of a TLS ClientHello.
132 : // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
133 : // In practice though, I see no world where a ClientHello is less than 8 bytes
134 : // since it includes ephemeral keys etc.
135 1 : direct: Some(zerocopy::transmute!(header)),
136 : });
137 45 : }
138 :
139 45 : let Some(len) = (header.len.get() as usize).checked_sub(8) else {
140 0 : return Err(io::Error::other(format!(
141 0 : "invalid startup message length {}, must be at least 8.",
142 0 : header.len,
143 0 : )));
144 : };
145 :
146 : // TODO: add a histogram for startup packet lengths
147 45 : if len > MAX_STARTUP_PACKET_LENGTH {
148 1 : tracing::warn!("large startup message detected: {len} bytes");
149 1 : return Err(io::Error::other(format!(
150 1 : "invalid startup message length {len}"
151 1 : )));
152 44 : }
153 :
154 23 : match header.version {
155 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
156 : CANCEL_REQUEST_CODE => {
157 0 : if len != 8 {
158 0 : return Err(io::Error::other(
159 0 : "CancelRequest message is malformed, backend PID / secret key missing",
160 0 : ));
161 0 : }
162 :
163 : Ok(FeStartupPacket::CancelRequest(
164 0 : read!(stream => CancelKeyData),
165 : ))
166 : }
167 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
168 : NEGOTIATE_SSL_CODE => {
169 : // Requested upgrade to SSL (aka TLS)
170 21 : Ok(FeStartupPacket::SslRequest { direct: None })
171 : }
172 : NEGOTIATE_GSS_CODE => {
173 : // Requested upgrade to GSSAPI
174 0 : Ok(FeStartupPacket::GssEncRequest)
175 : }
176 23 : version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
177 0 : format!("Unrecognized request code {version:?}"),
178 0 : )),
179 : // StartupMessage
180 23 : version => {
181 : // The protocol version number is followed by one or more pairs of parameter name and value strings.
182 : // A zero byte is required as a terminator after the last name/value pair.
183 : // Parameters can appear in any order. user is required, others are optional.
184 :
185 23 : let mut buf = vec![0; len];
186 23 : stream.read_exact(&mut buf).await?;
187 :
188 23 : if buf.pop() != Some(b'\0') {
189 0 : return Err(io::Error::other(
190 0 : "StartupMessage params: missing null terminator",
191 0 : ));
192 23 : }
193 :
194 : // TODO: Don't do this.
195 : // There's no guarantee that these messages are utf8,
196 : // but they usually happen to be simple ascii.
197 23 : let params = String::from_utf8(buf)
198 23 : .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
199 :
200 23 : Ok(FeStartupPacket::StartupMessage {
201 23 : version,
202 23 : params: StartupMessageParams { params },
203 23 : })
204 : }
205 : }
206 46 : }
207 :
208 : /// Read a raw postgres packet, which will respect the max length requested.
209 : ///
210 : /// This returns the message tag, as well as the message body. The message
211 : /// body is written into `buf`, and it is otherwise completely overwritten.
212 : ///
213 : /// This is not cancel safe.
214 29 : pub async fn read_message<'a, S>(
215 29 : stream: &mut S,
216 29 : buf: &'a mut Vec<u8>,
217 29 : max: u32,
218 29 : ) -> io::Result<(u8, &'a mut [u8])>
219 29 : where
220 29 : S: AsyncRead + Unpin,
221 29 : {
222 : /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
223 : /// The first byte is a message tag, and the next 4 bytes is a big-endian length.
224 : ///
225 : /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
226 : /// an empty message will always have length 4.
227 : #[derive(Clone, Copy, FromBytes)]
228 : #[repr(C)]
229 : struct Header {
230 : tag: u8,
231 : len: big_endian::U32,
232 : }
233 :
234 29 : let header = read!(stream => Header);
235 :
236 : // as described above, the length must be at least 4.
237 28 : let Some(len) = header.len.get().checked_sub(4) else {
238 0 : return Err(io::Error::other(format!(
239 0 : "invalid startup message length {}, must be at least 4.",
240 0 : header.len,
241 0 : )));
242 : };
243 :
244 : // TODO: add a histogram for message lengths
245 :
246 : // check if the message exceeds our desired max.
247 28 : if len > max {
248 1 : tracing::warn!("large postgres message detected: {len} bytes");
249 1 : return Err(io::Error::other(format!("invalid message length {len}")));
250 27 : }
251 :
252 : // read in our entire message.
253 27 : buf.resize(len as usize, 0);
254 27 : stream.read_exact(buf).await?;
255 :
256 27 : Ok((header.tag, buf))
257 29 : }
258 :
259 : pub struct WriteBuf(Cursor<Vec<u8>>);
260 :
261 : impl Buf for WriteBuf {
262 : #[inline]
263 117 : fn remaining(&self) -> usize {
264 117 : self.0.remaining()
265 117 : }
266 :
267 : #[inline]
268 55 : fn chunk(&self) -> &[u8] {
269 55 : self.0.chunk()
270 55 : }
271 :
272 : #[inline]
273 55 : fn advance(&mut self, cnt: usize) {
274 55 : self.0.advance(cnt);
275 55 : }
276 : }
277 :
278 : impl WriteBuf {
279 45 : pub const fn new() -> Self {
280 45 : Self(Cursor::new(Vec::new()))
281 45 : }
282 :
283 : /// Use a heuristic to determine if we should shrink the write buffer.
284 : #[inline]
285 57 : fn should_shrink(&self) -> bool {
286 57 : let n = self.0.position() as usize;
287 57 : let len = self.0.get_ref().len();
288 :
289 : // the unused space at the front of our buffer is 2x the size of our filled portion.
290 57 : n + n > len
291 57 : }
292 :
293 : /// Shrink the write buffer so that subsequent writes have more spare capacity.
294 : #[cold]
295 1 : fn shrink(&mut self) {
296 1 : let n = self.0.position() as usize;
297 1 : let buf = self.0.get_mut();
298 :
299 : // buf repr:
300 : // [----unused------|-----filled-----|-----uninit-----]
301 : // ^ n ^ buf.len() ^ buf.capacity()
302 1 : let filled = n..buf.len();
303 1 : let filled_len = filled.len();
304 1 : buf.copy_within(filled, 0);
305 1 : buf.truncate(filled_len);
306 1 : self.0.set_position(0);
307 1 : }
308 :
309 : /// clear the write buffer.
310 62 : pub fn reset(&mut self) {
311 62 : let buf = self.0.get_mut();
312 62 : buf.clear();
313 62 : self.0.set_position(0);
314 62 : }
315 :
316 : /// Write a raw message to the internal buffer.
317 : ///
318 : /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
319 : /// we calculate the length after the fact.
320 57 : pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
321 57 : if self.should_shrink() {
322 0 : self.shrink();
323 57 : }
324 :
325 57 : let buf = self.0.get_mut();
326 57 : buf.reserve(5 + size_hint);
327 :
328 57 : buf.push(tag);
329 57 : let start = buf.len();
330 57 : buf.extend_from_slice(&[0, 0, 0, 0]);
331 :
332 57 : f(buf);
333 :
334 57 : let end = buf.len();
335 57 : let len = (end - start) as u32;
336 57 : buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
337 57 : }
338 :
339 : /// Write an encryption response message.
340 20 : pub fn encryption(&mut self, m: u8) {
341 20 : self.0.get_mut().push(m);
342 20 : }
343 :
344 1 : pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
345 1 : self.shrink();
346 :
347 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
348 : // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
349 : // "SERROR\0CXXXXX\0M\0\0".len() == 17
350 1 : self.write_raw(17 + msg.len(), b'E', |buf| {
351 : // Severity: ERROR
352 1 : buf.put_slice(b"SERROR\0");
353 :
354 : // Code: error_code
355 1 : buf.put_u8(b'C');
356 1 : buf.put_slice(&error_code);
357 1 : buf.put_u8(0);
358 :
359 : // Message: msg
360 1 : buf.put_u8(b'M');
361 1 : buf.put_slice(msg.as_bytes());
362 1 : buf.put_u8(0);
363 :
364 : // End.
365 1 : buf.put_u8(0);
366 1 : });
367 1 : }
368 : }
369 :
370 : #[derive(Debug)]
371 : pub enum FeStartupPacket {
372 : CancelRequest(CancelKeyData),
373 : SslRequest {
374 : direct: Option<[u8; 8]>,
375 : },
376 : GssEncRequest,
377 : StartupMessage {
378 : version: ProtocolVersion,
379 : params: StartupMessageParams,
380 : },
381 : }
382 :
383 : #[derive(Debug, Clone, Default)]
384 : pub struct StartupMessageParams {
385 : pub params: String,
386 : }
387 :
388 : impl StartupMessageParams {
389 : /// Get parameter's value by its name.
390 41 : pub fn get(&self, name: &str) -> Option<&str> {
391 60 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
392 41 : }
393 :
394 : /// Split command-line options according to PostgreSQL's logic,
395 : /// taking into account all escape sequences but leaving them as-is.
396 : /// [`None`] means that there's no `options` in [`Self`].
397 27 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
398 27 : self.get("options").map(Self::parse_options_raw)
399 27 : }
400 :
401 : /// Split command-line options according to PostgreSQL's logic,
402 : /// taking into account all escape sequences but leaving them as-is.
403 34 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
404 : // See `postgres: pg_split_opts`.
405 34 : let mut last_was_escape = false;
406 34 : input
407 608 : .split(move |c: char| {
408 : // We split by non-escaped whitespace symbols.
409 608 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
410 608 : last_was_escape = c == '\\' && !last_was_escape;
411 608 : should_split
412 608 : })
413 74 : .filter(|s| !s.is_empty())
414 34 : }
415 :
416 : /// Iterate through key-value pairs in an arbitrary order.
417 41 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
418 41 : self.params.split_terminator('\0').tuples()
419 41 : }
420 :
421 : // This function is mostly useful in tests.
422 : #[cfg(test)]
423 13 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
424 13 : let mut b = Self {
425 13 : params: String::new(),
426 13 : };
427 36 : for (k, v) in pairs {
428 23 : b.insert(k, v);
429 23 : }
430 13 : b
431 13 : }
432 :
433 : /// Set parameter's value by its name.
434 : /// name and value must not contain a \0 byte
435 23 : pub fn insert(&mut self, name: &str, value: &str) {
436 23 : self.params.reserve(name.len() + value.len() + 2);
437 23 : self.params.push_str(name);
438 23 : self.params.push('\0');
439 23 : self.params.push_str(value);
440 23 : self.params.push('\0');
441 23 : }
442 : }
443 :
444 : /// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
445 : /// opaque bytes.
446 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
447 : pub struct CancelKeyData(pub big_endian::U64);
448 :
449 1 : pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
450 1 : CancelKeyData(big_endian::U64::new(id))
451 1 : }
452 :
453 : impl fmt::Display for CancelKeyData {
454 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 0 : let id = self.0;
456 0 : f.debug_tuple("CancelKeyData")
457 0 : .field(&format_args!("{id:x}"))
458 0 : .finish()
459 0 : }
460 : }
461 : impl Distribution<CancelKeyData> for Standard {
462 0 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
463 0 : id_to_cancel_key(rng.r#gen())
464 0 : }
465 : }
466 :
467 : pub enum BeMessage<'a> {
468 : AuthenticationOk,
469 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
470 : AuthenticationCleartextPassword,
471 : BackendKeyData(CancelKeyData),
472 : ParameterStatus {
473 : name: &'a [u8],
474 : value: &'a [u8],
475 : },
476 : ReadyForQuery,
477 : NoticeResponse(&'a str),
478 : NegotiateProtocolVersion {
479 : version: ProtocolVersion,
480 : options: &'a [&'a str],
481 : },
482 : }
483 :
484 : #[derive(Debug)]
485 : pub enum BeAuthenticationSaslMessage<'a> {
486 : Methods(&'a [&'a str]),
487 : Continue(&'a [u8]),
488 : Final(&'a [u8]),
489 : }
490 :
491 : impl BeMessage<'_> {
492 : /// Write the message into an internal buffer
493 56 : pub fn write_message(self, buf: &mut WriteBuf) {
494 30 : match self {
495 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
496 : BeMessage::AuthenticationOk => {
497 10 : buf.write_raw(1, b'R', |buf| buf.put_i32(0));
498 : }
499 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
500 : BeMessage::AuthenticationCleartextPassword => {
501 2 : buf.write_raw(1, b'R', |buf| buf.put_i32(3));
502 : }
503 :
504 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
505 13 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
506 25 : let len: usize = methods.iter().map(|m| m.len() + 1).sum();
507 13 : buf.write_raw(len + 2, b'R', |buf| {
508 13 : buf.put_i32(10); // Specifies that SASL auth method is used.
509 38 : for method in methods {
510 25 : buf.put_slice(method.as_bytes());
511 25 : buf.put_u8(0);
512 25 : }
513 13 : buf.put_u8(0); // zero terminator for the list
514 13 : });
515 : }
516 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
517 11 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
518 11 : buf.write_raw(extra.len() + 1, b'R', |buf| {
519 11 : buf.put_i32(11); // Continue SASL auth.
520 11 : buf.put_slice(extra);
521 11 : });
522 : }
523 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
524 6 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
525 6 : buf.write_raw(extra.len() + 1, b'R', |buf| {
526 6 : buf.put_i32(12); // Send final SASL message.
527 6 : buf.put_slice(extra);
528 6 : });
529 : }
530 :
531 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
532 0 : BeMessage::BackendKeyData(key_data) => {
533 0 : buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
534 : }
535 :
536 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
537 : // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
538 0 : BeMessage::NoticeResponse(msg) => {
539 : // 'N' signalizes NoticeResponse messages
540 0 : buf.write_raw(18 + msg.len(), b'N', |buf| {
541 : // Severity: NOTICE
542 0 : buf.put_slice(b"SNOTICE\0");
543 :
544 : // Code: XX000 (ignored for notice, but still required)
545 0 : buf.put_slice(b"CXX000\0");
546 :
547 : // Message: msg
548 0 : buf.put_u8(b'M');
549 0 : buf.put_slice(msg.as_bytes());
550 0 : buf.put_u8(0);
551 :
552 : // End notice.
553 0 : buf.put_u8(0);
554 0 : });
555 : }
556 :
557 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
558 7 : BeMessage::ParameterStatus { name, value } => {
559 7 : buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
560 7 : buf.put_slice(name.as_bytes());
561 7 : buf.put_u8(0);
562 7 : buf.put_slice(value.as_bytes());
563 7 : buf.put_u8(0);
564 7 : });
565 : }
566 :
567 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
568 : BeMessage::ReadyForQuery => {
569 7 : buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
570 : }
571 :
572 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
573 0 : BeMessage::NegotiateProtocolVersion { version, options } => {
574 0 : let len: usize = options.iter().map(|o| o.len() + 1).sum();
575 0 : buf.write_raw(8 + len, b'v', |buf| {
576 0 : buf.put_slice(version.as_bytes());
577 0 : buf.put_u32(options.len() as u32);
578 0 : for option in options {
579 0 : buf.put_slice(option.as_bytes());
580 0 : buf.put_u8(0);
581 0 : }
582 0 : });
583 : }
584 : }
585 56 : }
586 : }
587 :
588 : #[cfg(test)]
589 : mod tests {
590 : use std::io::Cursor;
591 :
592 : use tokio::io::{AsyncWriteExt, duplex};
593 : use zerocopy::IntoBytes;
594 :
595 : use super::ProtocolVersion;
596 : use crate::pqproto::{FeStartupPacket, read_message, read_startup};
597 :
598 : #[tokio::test]
599 1 : async fn reject_large_startup() {
600 : // we're going to define a v3.0 startup message with far too many parameters.
601 1 : let mut payload = vec![];
602 : // 10001 + 8 bytes.
603 1 : payload.extend_from_slice(&10009_u32.to_be_bytes());
604 1 : payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
605 1 : payload.resize(10009, b'a');
606 :
607 1 : let (mut server, mut client) = duplex(128);
608 : #[rustfmt::skip]
609 1 : let (server, client) = tokio::join!(
610 1 : async move { read_startup(&mut server).await.unwrap_err() },
611 1 : async move { client.write_all(&payload).await.unwrap_err() },
612 : );
613 :
614 1 : assert_eq!(server.to_string(), "invalid startup message length 10001");
615 1 : assert_eq!(client.to_string(), "broken pipe");
616 1 : }
617 :
618 : #[tokio::test]
619 1 : async fn reject_large_password() {
620 : // we're going to define a password message that is far too long.
621 1 : let mut payload = vec![];
622 1 : payload.push(b'p');
623 1 : payload.extend_from_slice(&517_u32.to_be_bytes());
624 1 : payload.resize(518, b'a');
625 :
626 1 : let (mut server, mut client) = duplex(128);
627 : #[rustfmt::skip]
628 1 : let (server, client) = tokio::join!(
629 1 : async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
630 1 : async move { client.write_all(&payload).await.unwrap_err() },
631 : );
632 :
633 1 : assert_eq!(server.to_string(), "invalid message length 513");
634 1 : assert_eq!(client.to_string(), "broken pipe");
635 1 : }
636 :
637 : #[tokio::test]
638 1 : async fn read_startup_message() {
639 1 : let mut payload = vec![];
640 1 : payload.extend_from_slice(&17_u32.to_be_bytes());
641 1 : payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
642 1 : payload.extend_from_slice(b"abc\0def\0\0");
643 :
644 1 : let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
645 1 : let FeStartupPacket::StartupMessage { version, params } = startup else {
646 0 : panic!("unexpected startup message: {startup:?}");
647 : };
648 :
649 1 : assert_eq!(version.major(), 3);
650 1 : assert_eq!(version.minor(), 0);
651 1 : assert_eq!(params.params, "abc\0def\0");
652 1 : }
653 :
654 : #[tokio::test]
655 1 : async fn read_ssl_message() {
656 1 : let mut payload = vec![];
657 1 : payload.extend_from_slice(&8_u32.to_be_bytes());
658 1 : payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
659 :
660 1 : let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
661 1 : let FeStartupPacket::SslRequest { direct: None } = startup else {
662 1 : panic!("unexpected startup message: {startup:?}");
663 1 : };
664 1 : }
665 :
666 : #[tokio::test]
667 1 : async fn read_tls_message() {
668 : // sample client hello taken from <https://tls13.xargs.org/#client-hello>
669 1 : let client_hello = [
670 1 : 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
671 1 : 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
672 1 : 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
673 1 : 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
674 1 : 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
675 1 : 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
676 1 : 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
677 1 : 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
678 1 : 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
679 1 : 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
680 1 : 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
681 1 : 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
682 1 : 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
683 1 : 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
684 1 : 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
685 1 : 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
686 1 : 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
687 1 : 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
688 1 : 0x54,
689 1 : ];
690 :
691 1 : let mut cursor = Cursor::new(&client_hello);
692 :
693 1 : let startup = read_startup(&mut cursor).await.unwrap();
694 : let FeStartupPacket::SslRequest {
695 1 : direct: Some(prefix),
696 1 : } = startup
697 : else {
698 0 : panic!("unexpected startup message: {startup:?}");
699 : };
700 :
701 : // check that no data is lost.
702 1 : assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
703 1 : assert_eq!(cursor.position(), 8);
704 1 : }
705 :
706 : #[tokio::test]
707 1 : async fn read_message_success() {
708 1 : let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
709 1 : let mut cursor = Cursor::new(&query);
710 :
711 1 : let mut buf = vec![];
712 1 : let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
713 1 : assert_eq!(tag, b'Q');
714 1 : assert_eq!(message, b"SELECT 1");
715 :
716 1 : let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
717 1 : assert_eq!(tag, b'Q');
718 1 : assert_eq!(message, b"SELECT 2");
719 1 : }
720 : }
|