Line data Source code
1 : //! Postgres protocol codec
2 : //!
3 : //! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
4 :
5 : use std::fmt;
6 : use std::io::{self, Cursor};
7 :
8 : use bytes::{Buf, BufMut};
9 : use itertools::Itertools;
10 : use rand::distr::{Distribution, StandardUniform};
11 : use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
12 : use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
13 :
14 : pub type ErrorCode = [u8; 5];
15 :
16 : pub const FE_PASSWORD_MESSAGE: u8 = b'p';
17 :
18 : pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
19 :
20 : /// The protocol version number.
21 : ///
22 : /// The most significant 16 bits are the major version number (3 for the protocol described here).
23 : /// The least significant 16 bits are the minor version number (0 for the protocol described here).
24 : /// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
25 : #[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
26 : #[repr(C)]
27 : pub struct ProtocolVersion {
28 : major: big_endian::U16,
29 : minor: big_endian::U16,
30 : }
31 :
32 : impl ProtocolVersion {
33 3 : pub const fn new(major: u16, minor: u16) -> Self {
34 3 : Self {
35 3 : major: big_endian::U16::new(major),
36 3 : minor: big_endian::U16::new(minor),
37 3 : }
38 3 : }
39 1 : pub const fn minor(self) -> u16 {
40 1 : self.minor.get()
41 1 : }
42 24 : pub const fn major(self) -> u16 {
43 24 : self.major.get()
44 24 : }
45 : }
46 :
47 : impl fmt::Debug for ProtocolVersion {
48 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 0 : f.debug_list()
50 0 : .entry(&self.major())
51 0 : .entry(&self.minor())
52 0 : .finish()
53 0 : }
54 : }
55 :
56 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
57 : const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
58 : const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
59 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
60 : const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
61 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
62 : const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
63 : /// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
64 : const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
65 :
66 : /// This first reads the startup message header, is 8 bytes.
67 : /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
68 : ///
69 : /// The length value is inclusive of the header. For example,
70 : /// an empty message will always have length 8.
71 : #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
72 : #[repr(C)]
73 : struct StartupHeader {
74 : len: big_endian::U32,
75 : version: ProtocolVersion,
76 : }
77 :
78 : /// read the type from the stream using zerocopy.
79 : ///
80 : /// not cancel safe.
81 : macro_rules! read {
82 : ($s:expr => $t:ty) => {{
83 : // cannot be implemented as a function due to lack of const-generic-expr
84 : let mut buf = [0; size_of::<$t>()];
85 : $s.read_exact(&mut buf).await?;
86 : let res: $t = zerocopy::transmute!(buf);
87 : res
88 : }};
89 : }
90 :
91 : /// Returns true if TLS is supported.
92 : ///
93 : /// This is not cancel safe.
94 0 : pub async fn request_tls<S>(stream: &mut S) -> io::Result<bool>
95 0 : where
96 0 : S: AsyncRead + AsyncWrite + Unpin,
97 0 : {
98 0 : let payload = StartupHeader {
99 0 : len: 8.into(),
100 0 : version: NEGOTIATE_SSL_CODE,
101 0 : };
102 0 : stream.write_all(payload.as_bytes()).await?;
103 0 : stream.flush().await?;
104 :
105 : // we expect back either `S` or `N` as a single byte.
106 0 : let mut res = *b"0";
107 0 : stream.read_exact(&mut res).await?;
108 :
109 0 : debug_assert!(
110 0 : res == *b"S" || res == *b"N",
111 0 : "unexpected SSL negotiation response: {}",
112 0 : char::from(res[0]),
113 : );
114 :
115 : // S for SSL.
116 0 : Ok(res == *b"S")
117 0 : }
118 :
119 46 : pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
120 46 : where
121 46 : S: AsyncRead + Unpin,
122 46 : {
123 46 : let header = read!(stream => StartupHeader);
124 :
125 : // <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
126 : // First byte indicates standard SSL handshake message
127 : // (It can't be a Postgres startup length because in network byte order
128 : // that would be a startup packet hundreds of megabytes long)
129 46 : if header.as_bytes()[0] == 0x16 {
130 : return Ok(FeStartupPacket::SslRequest {
131 : // The bytes we read for the header are actually part of a TLS ClientHello.
132 : // In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
133 : // In practice though, I see no world where a ClientHello is less than 8 bytes
134 : // since it includes ephemeral keys etc.
135 1 : direct: Some(zerocopy::transmute!(header)),
136 : });
137 45 : }
138 :
139 45 : let Some(len) = (header.len.get() as usize).checked_sub(8) else {
140 0 : return Err(io::Error::other(format!(
141 0 : "invalid startup message length {}, must be at least 8.",
142 0 : header.len,
143 0 : )));
144 : };
145 :
146 : // TODO: add a histogram for startup packet lengths
147 45 : if len > MAX_STARTUP_PACKET_LENGTH {
148 1 : tracing::warn!("large startup message detected: {len} bytes");
149 1 : return Err(io::Error::other(format!(
150 1 : "invalid startup message length {len}"
151 1 : )));
152 44 : }
153 :
154 23 : match header.version {
155 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
156 : CANCEL_REQUEST_CODE => {
157 0 : if len != 8 {
158 0 : return Err(io::Error::other(
159 0 : "CancelRequest message is malformed, backend PID / secret key missing",
160 0 : ));
161 0 : }
162 :
163 : Ok(FeStartupPacket::CancelRequest(
164 0 : read!(stream => CancelKeyData),
165 : ))
166 : }
167 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
168 : NEGOTIATE_SSL_CODE => {
169 : // Requested upgrade to SSL (aka TLS)
170 21 : Ok(FeStartupPacket::SslRequest { direct: None })
171 : }
172 : NEGOTIATE_GSS_CODE => {
173 : // Requested upgrade to GSSAPI
174 0 : Ok(FeStartupPacket::GssEncRequest)
175 : }
176 23 : version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
177 0 : format!("Unrecognized request code {version:?}"),
178 0 : )),
179 : // StartupMessage
180 23 : version => {
181 : // The protocol version number is followed by one or more pairs of parameter name and value strings.
182 : // A zero byte is required as a terminator after the last name/value pair.
183 : // Parameters can appear in any order. user is required, others are optional.
184 :
185 23 : let mut buf = vec![0; len];
186 23 : stream.read_exact(&mut buf).await?;
187 :
188 23 : if buf.pop() != Some(b'\0') {
189 0 : return Err(io::Error::other(
190 0 : "StartupMessage params: missing null terminator",
191 0 : ));
192 23 : }
193 :
194 : // TODO: Don't do this.
195 : // There's no guarantee that these messages are utf8,
196 : // but they usually happen to be simple ascii.
197 23 : let params = String::from_utf8(buf)
198 23 : .map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
199 :
200 23 : Ok(FeStartupPacket::StartupMessage {
201 23 : version,
202 23 : params: StartupMessageParams { params },
203 23 : })
204 : }
205 : }
206 46 : }
207 :
208 : /// Read a raw postgres packet, which will respect the max length requested.
209 : ///
210 : /// This returns the message tag, as well as the message body. The message
211 : /// body is written into `buf`, and it is otherwise completely overwritten.
212 : ///
213 : /// This is not cancel safe.
214 29 : pub async fn read_message<'a, S>(
215 29 : stream: &mut S,
216 29 : buf: &'a mut Vec<u8>,
217 29 : max: u32,
218 29 : ) -> io::Result<(u8, &'a mut [u8])>
219 29 : where
220 29 : S: AsyncRead + Unpin,
221 29 : {
222 : /// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
223 : /// The first byte is a message tag, and the next 4 bytes is a big-endian length.
224 : ///
225 : /// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
226 : /// an empty message will always have length 4.
227 : #[derive(Clone, Copy, FromBytes)]
228 : #[repr(C)]
229 : struct Header {
230 : tag: u8,
231 : len: big_endian::U32,
232 : }
233 :
234 29 : let header = read!(stream => Header);
235 :
236 : // as described above, the length must be at least 4.
237 28 : let Some(len) = header.len.get().checked_sub(4) else {
238 0 : return Err(io::Error::other(format!(
239 0 : "invalid startup message length {}, must be at least 4.",
240 0 : header.len,
241 0 : )));
242 : };
243 :
244 : // TODO: add a histogram for message lengths
245 :
246 : // check if the message exceeds our desired max.
247 28 : if len > max {
248 1 : tracing::warn!("large postgres message detected: {len} bytes");
249 1 : return Err(io::Error::other(format!("invalid message length {len}")));
250 27 : }
251 :
252 : // read in our entire message.
253 27 : buf.resize(len as usize, 0);
254 27 : stream.read_exact(buf).await?;
255 :
256 27 : Ok((header.tag, buf))
257 29 : }
258 :
259 : pub struct WriteBuf(Cursor<Vec<u8>>);
260 :
261 : impl Buf for WriteBuf {
262 : #[inline]
263 117 : fn remaining(&self) -> usize {
264 117 : self.0.remaining()
265 117 : }
266 :
267 : #[inline]
268 55 : fn chunk(&self) -> &[u8] {
269 55 : self.0.chunk()
270 55 : }
271 :
272 : #[inline]
273 55 : fn advance(&mut self, cnt: usize) {
274 55 : self.0.advance(cnt);
275 55 : }
276 : }
277 :
278 : impl WriteBuf {
279 45 : pub const fn new() -> Self {
280 45 : Self(Cursor::new(Vec::new()))
281 45 : }
282 :
283 : /// Use a heuristic to determine if we should shrink the write buffer.
284 : #[inline]
285 57 : fn should_shrink(&self) -> bool {
286 57 : let n = self.0.position() as usize;
287 57 : let len = self.0.get_ref().len();
288 :
289 : // the unused space at the front of our buffer is 2x the size of our filled portion.
290 57 : n + n > len
291 57 : }
292 :
293 : /// Shrink the write buffer so that subsequent writes have more spare capacity.
294 : #[cold]
295 1 : fn shrink(&mut self) {
296 1 : let n = self.0.position() as usize;
297 1 : let buf = self.0.get_mut();
298 :
299 : // buf repr:
300 : // [----unused------|-----filled-----|-----uninit-----]
301 : // ^ n ^ buf.len() ^ buf.capacity()
302 1 : let filled = n..buf.len();
303 1 : let filled_len = filled.len();
304 1 : buf.copy_within(filled, 0);
305 1 : buf.truncate(filled_len);
306 1 : self.0.set_position(0);
307 1 : }
308 :
309 : /// clear the write buffer.
310 62 : pub fn reset(&mut self) {
311 62 : let buf = self.0.get_mut();
312 62 : buf.clear();
313 62 : self.0.set_position(0);
314 62 : }
315 :
316 : /// Shrinks the buffer if efficient to do so, and returns the remaining size.
317 0 : pub fn occupied_len(&mut self) -> usize {
318 0 : if self.should_shrink() {
319 0 : self.shrink();
320 0 : }
321 0 : self.0.get_mut().len()
322 0 : }
323 :
324 : /// Write a raw message to the internal buffer.
325 : ///
326 : /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
327 : /// we calculate the length after the fact.
328 57 : pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
329 57 : if self.should_shrink() {
330 0 : self.shrink();
331 57 : }
332 :
333 57 : let buf = self.0.get_mut();
334 57 : buf.reserve(5 + size_hint);
335 :
336 57 : buf.push(tag);
337 57 : let start = buf.len();
338 57 : buf.extend_from_slice(&[0, 0, 0, 0]);
339 :
340 57 : f(buf);
341 :
342 57 : let end = buf.len();
343 57 : let len = (end - start) as u32;
344 57 : buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
345 57 : }
346 :
347 : /// Write an encryption response message.
348 20 : pub fn encryption(&mut self, m: u8) {
349 20 : self.0.get_mut().push(m);
350 20 : }
351 :
352 1 : pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
353 1 : self.shrink();
354 :
355 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
356 : // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
357 : // "SERROR\0CXXXXX\0M\0\0".len() == 17
358 1 : self.write_raw(17 + msg.len(), b'E', |buf| {
359 : // Severity: ERROR
360 1 : buf.put_slice(b"SERROR\0");
361 :
362 : // Code: error_code
363 1 : buf.put_u8(b'C');
364 1 : buf.put_slice(&error_code);
365 1 : buf.put_u8(0);
366 :
367 : // Message: msg
368 1 : buf.put_u8(b'M');
369 1 : buf.put_slice(msg.as_bytes());
370 1 : buf.put_u8(0);
371 :
372 : // End.
373 1 : buf.put_u8(0);
374 1 : });
375 1 : }
376 : }
377 :
378 : #[derive(Debug)]
379 : pub enum FeStartupPacket {
380 : CancelRequest(CancelKeyData),
381 : SslRequest {
382 : direct: Option<[u8; 8]>,
383 : },
384 : GssEncRequest,
385 : StartupMessage {
386 : version: ProtocolVersion,
387 : params: StartupMessageParams,
388 : },
389 : }
390 :
391 : #[derive(Debug, Clone, Default)]
392 : pub struct StartupMessageParams {
393 : pub params: String,
394 : }
395 :
396 : impl StartupMessageParams {
397 : /// Get parameter's value by its name.
398 41 : pub fn get(&self, name: &str) -> Option<&str> {
399 60 : self.iter().find_map(|(k, v)| (k == name).then_some(v))
400 41 : }
401 :
402 : /// Split command-line options according to PostgreSQL's logic,
403 : /// taking into account all escape sequences but leaving them as-is.
404 : /// [`None`] means that there's no `options` in [`Self`].
405 27 : pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
406 27 : self.get("options").map(Self::parse_options_raw)
407 27 : }
408 :
409 : /// Split command-line options according to PostgreSQL's logic,
410 : /// taking into account all escape sequences but leaving them as-is.
411 34 : pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
412 : // See `postgres: pg_split_opts`.
413 34 : let mut last_was_escape = false;
414 34 : input
415 608 : .split(move |c: char| {
416 : // We split by non-escaped whitespace symbols.
417 608 : let should_split = c.is_ascii_whitespace() && !last_was_escape;
418 608 : last_was_escape = c == '\\' && !last_was_escape;
419 608 : should_split
420 608 : })
421 74 : .filter(|s| !s.is_empty())
422 34 : }
423 :
424 : /// Iterate through key-value pairs in an arbitrary order.
425 41 : pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
426 41 : self.params.split_terminator('\0').tuples()
427 41 : }
428 :
429 : // This function is mostly useful in tests.
430 : #[cfg(test)]
431 13 : pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
432 13 : let mut b = Self {
433 13 : params: String::new(),
434 13 : };
435 36 : for (k, v) in pairs {
436 23 : b.insert(k, v);
437 23 : }
438 13 : b
439 13 : }
440 :
441 : /// Set parameter's value by its name.
442 : /// name and value must not contain a \0 byte
443 23 : pub fn insert(&mut self, name: &str, value: &str) {
444 23 : self.params.reserve(name.len() + value.len() + 2);
445 23 : self.params.push_str(name);
446 23 : self.params.push('\0');
447 23 : self.params.push_str(value);
448 23 : self.params.push('\0');
449 23 : }
450 : }
451 :
452 : /// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
453 : /// opaque bytes.
454 : #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
455 : pub struct CancelKeyData(pub big_endian::U64);
456 :
457 1 : pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
458 1 : CancelKeyData(big_endian::U64::new(id))
459 1 : }
460 :
461 : impl fmt::Display for CancelKeyData {
462 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
463 0 : let id = self.0;
464 0 : f.debug_tuple("CancelKeyData")
465 0 : .field(&format_args!("{id:x}"))
466 0 : .finish()
467 0 : }
468 : }
469 : impl Distribution<CancelKeyData> for StandardUniform {
470 0 : fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
471 0 : id_to_cancel_key(rng.random())
472 0 : }
473 : }
474 :
475 : pub enum BeMessage<'a> {
476 : AuthenticationOk,
477 : AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
478 : AuthenticationCleartextPassword,
479 : BackendKeyData(CancelKeyData),
480 : ParameterStatus {
481 : name: &'a [u8],
482 : value: &'a [u8],
483 : },
484 : ReadyForQuery,
485 : NoticeResponse(&'a str),
486 : NegotiateProtocolVersion {
487 : version: ProtocolVersion,
488 : options: &'a [&'a str],
489 : },
490 : }
491 :
492 : #[derive(Debug)]
493 : pub enum BeAuthenticationSaslMessage<'a> {
494 : Methods(&'a [&'a str]),
495 : Continue(&'a [u8]),
496 : Final(&'a [u8]),
497 : }
498 :
499 : impl BeMessage<'_> {
500 : /// Write the message into an internal buffer
501 56 : pub fn write_message(self, buf: &mut WriteBuf) {
502 30 : match self {
503 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
504 : BeMessage::AuthenticationOk => {
505 10 : buf.write_raw(1, b'R', |buf| buf.put_i32(0));
506 : }
507 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
508 : BeMessage::AuthenticationCleartextPassword => {
509 2 : buf.write_raw(1, b'R', |buf| buf.put_i32(3));
510 : }
511 :
512 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
513 13 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
514 25 : let len: usize = methods.iter().map(|m| m.len() + 1).sum();
515 13 : buf.write_raw(len + 2, b'R', |buf| {
516 13 : buf.put_i32(10); // Specifies that SASL auth method is used.
517 38 : for method in methods {
518 25 : buf.put_slice(method.as_bytes());
519 25 : buf.put_u8(0);
520 25 : }
521 13 : buf.put_u8(0); // zero terminator for the list
522 13 : });
523 : }
524 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
525 11 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
526 11 : buf.write_raw(extra.len() + 1, b'R', |buf| {
527 11 : buf.put_i32(11); // Continue SASL auth.
528 11 : buf.put_slice(extra);
529 11 : });
530 : }
531 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
532 6 : BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
533 6 : buf.write_raw(extra.len() + 1, b'R', |buf| {
534 6 : buf.put_i32(12); // Send final SASL message.
535 6 : buf.put_slice(extra);
536 6 : });
537 : }
538 :
539 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
540 0 : BeMessage::BackendKeyData(key_data) => {
541 0 : buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
542 : }
543 :
544 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
545 : // <https://www.postgresql.org/docs/current/protocol-error-fields.html>
546 0 : BeMessage::NoticeResponse(msg) => {
547 : // 'N' signalizes NoticeResponse messages
548 0 : buf.write_raw(18 + msg.len(), b'N', |buf| {
549 : // Severity: NOTICE
550 0 : buf.put_slice(b"SNOTICE\0");
551 :
552 : // Code: XX000 (ignored for notice, but still required)
553 0 : buf.put_slice(b"CXX000\0");
554 :
555 : // Message: msg
556 0 : buf.put_u8(b'M');
557 0 : buf.put_slice(msg.as_bytes());
558 0 : buf.put_u8(0);
559 :
560 : // End notice.
561 0 : buf.put_u8(0);
562 0 : });
563 : }
564 :
565 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
566 7 : BeMessage::ParameterStatus { name, value } => {
567 7 : buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
568 7 : buf.put_slice(name.as_bytes());
569 7 : buf.put_u8(0);
570 7 : buf.put_slice(value.as_bytes());
571 7 : buf.put_u8(0);
572 7 : });
573 : }
574 :
575 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
576 : BeMessage::ReadyForQuery => {
577 7 : buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
578 : }
579 :
580 : // <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
581 0 : BeMessage::NegotiateProtocolVersion { version, options } => {
582 0 : let len: usize = options.iter().map(|o| o.len() + 1).sum();
583 0 : buf.write_raw(8 + len, b'v', |buf| {
584 0 : buf.put_slice(version.as_bytes());
585 0 : buf.put_u32(options.len() as u32);
586 0 : for option in options {
587 0 : buf.put_slice(option.as_bytes());
588 0 : buf.put_u8(0);
589 0 : }
590 0 : });
591 : }
592 : }
593 56 : }
594 : }
595 :
596 : #[cfg(test)]
597 : mod tests {
598 : use std::io::Cursor;
599 :
600 : use tokio::io::{AsyncWriteExt, duplex};
601 : use zerocopy::IntoBytes;
602 :
603 : use super::ProtocolVersion;
604 : use crate::pqproto::{FeStartupPacket, read_message, read_startup};
605 :
606 : #[tokio::test]
607 1 : async fn reject_large_startup() {
608 : // we're going to define a v3.0 startup message with far too many parameters.
609 1 : let mut payload = vec![];
610 : // 10001 + 8 bytes.
611 1 : payload.extend_from_slice(&10009_u32.to_be_bytes());
612 1 : payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
613 1 : payload.resize(10009, b'a');
614 :
615 1 : let (mut server, mut client) = duplex(128);
616 : #[rustfmt::skip]
617 1 : let (server, client) = tokio::join!(
618 1 : async move { read_startup(&mut server).await.unwrap_err() },
619 1 : async move { client.write_all(&payload).await.unwrap_err() },
620 : );
621 :
622 1 : assert_eq!(server.to_string(), "invalid startup message length 10001");
623 1 : assert_eq!(client.to_string(), "broken pipe");
624 1 : }
625 :
626 : #[tokio::test]
627 1 : async fn reject_large_password() {
628 : // we're going to define a password message that is far too long.
629 1 : let mut payload = vec![];
630 1 : payload.push(b'p');
631 1 : payload.extend_from_slice(&517_u32.to_be_bytes());
632 1 : payload.resize(518, b'a');
633 :
634 1 : let (mut server, mut client) = duplex(128);
635 : #[rustfmt::skip]
636 1 : let (server, client) = tokio::join!(
637 1 : async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
638 1 : async move { client.write_all(&payload).await.unwrap_err() },
639 : );
640 :
641 1 : assert_eq!(server.to_string(), "invalid message length 513");
642 1 : assert_eq!(client.to_string(), "broken pipe");
643 1 : }
644 :
645 : #[tokio::test]
646 1 : async fn read_startup_message() {
647 1 : let mut payload = vec![];
648 1 : payload.extend_from_slice(&17_u32.to_be_bytes());
649 1 : payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
650 1 : payload.extend_from_slice(b"abc\0def\0\0");
651 :
652 1 : let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
653 1 : let FeStartupPacket::StartupMessage { version, params } = startup else {
654 0 : panic!("unexpected startup message: {startup:?}");
655 : };
656 :
657 1 : assert_eq!(version.major(), 3);
658 1 : assert_eq!(version.minor(), 0);
659 1 : assert_eq!(params.params, "abc\0def\0");
660 1 : }
661 :
662 : #[tokio::test]
663 1 : async fn read_ssl_message() {
664 1 : let mut payload = vec![];
665 1 : payload.extend_from_slice(&8_u32.to_be_bytes());
666 1 : payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
667 :
668 1 : let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
669 1 : let FeStartupPacket::SslRequest { direct: None } = startup else {
670 1 : panic!("unexpected startup message: {startup:?}");
671 1 : };
672 1 : }
673 :
674 : #[tokio::test]
675 1 : async fn read_tls_message() {
676 : // sample client hello taken from <https://tls13.xargs.org/#client-hello>
677 1 : let client_hello = [
678 1 : 0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
679 1 : 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
680 1 : 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
681 1 : 0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
682 1 : 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
683 1 : 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
684 1 : 0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
685 1 : 0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
686 1 : 0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
687 1 : 0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
688 1 : 0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
689 1 : 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
690 1 : 0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
691 1 : 0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
692 1 : 0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
693 1 : 0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
694 1 : 0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
695 1 : 0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
696 1 : 0x54,
697 1 : ];
698 :
699 1 : let mut cursor = Cursor::new(&client_hello);
700 :
701 1 : let startup = read_startup(&mut cursor).await.unwrap();
702 : let FeStartupPacket::SslRequest {
703 1 : direct: Some(prefix),
704 1 : } = startup
705 : else {
706 0 : panic!("unexpected startup message: {startup:?}");
707 : };
708 :
709 : // check that no data is lost.
710 1 : assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
711 1 : assert_eq!(cursor.position(), 8);
712 1 : }
713 :
714 : #[tokio::test]
715 1 : async fn read_message_success() {
716 1 : let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
717 1 : let mut cursor = Cursor::new(&query);
718 :
719 1 : let mut buf = vec![];
720 1 : let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
721 1 : assert_eq!(tag, b'Q');
722 1 : assert_eq!(message, b"SELECT 1");
723 :
724 1 : let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
725 1 : assert_eq!(tag, b'Q');
726 1 : assert_eq!(message, b"SELECT 2");
727 1 : }
728 : }
|