Line data Source code
1 : use std::io;
2 :
3 : use bytes::{Buf, Bytes, BytesMut};
4 : use fallible_iterator::FallibleIterator;
5 : use postgres_protocol2::message::backend;
6 : use postgres_protocol2::message::frontend::CopyData;
7 : use tokio_util::codec::{Decoder, Encoder};
8 :
9 : pub enum FrontendMessage {
10 : Raw(Bytes),
11 : CopyData(CopyData<Box<dyn Buf + Send>>),
12 : }
13 :
14 : pub enum BackendMessage {
15 : Normal {
16 : messages: BackendMessages,
17 : request_complete: bool,
18 : },
19 : Async(backend::Message),
20 : }
21 :
22 : pub struct BackendMessages(BytesMut);
23 :
24 : impl BackendMessages {
25 15 : pub fn empty() -> BackendMessages {
26 15 : BackendMessages(BytesMut::new())
27 15 : }
28 : }
29 :
30 : impl FallibleIterator for BackendMessages {
31 : type Item = backend::Message;
32 : type Error = io::Error;
33 :
34 128 : fn next(&mut self) -> io::Result<Option<backend::Message>> {
35 128 : backend::Message::parse(&mut self.0)
36 128 : }
37 : }
38 :
39 : pub struct PostgresCodec;
40 :
41 : impl Encoder<FrontendMessage> for PostgresCodec {
42 : type Error = io::Error;
43 :
44 36 : fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
45 36 : match item {
46 36 : FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
47 0 : FrontendMessage::CopyData(data) => data.write(dst),
48 : }
49 :
50 36 : Ok(())
51 36 : }
52 : }
53 :
54 : impl Decoder for PostgresCodec {
55 : type Item = BackendMessage;
56 : type Error = io::Error;
57 :
58 65 : fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
59 65 : let mut idx = 0;
60 65 : let mut request_complete = false;
61 :
62 100 : while let Some(header) = backend::Header::parse(&src[idx..])? {
63 56 : let len = header.len() as usize + 1;
64 56 : if src[idx..].len() < len {
65 0 : break;
66 56 : }
67 56 :
68 56 : match header.tag() {
69 : backend::NOTICE_RESPONSE_TAG
70 : | backend::NOTIFICATION_RESPONSE_TAG
71 : | backend::PARAMETER_STATUS_TAG => {
72 14 : if idx == 0 {
73 7 : let message = backend::Message::parse(src)?.unwrap();
74 7 : return Ok(Some(BackendMessage::Async(message)));
75 : } else {
76 7 : break;
77 : }
78 : }
79 42 : _ => {}
80 42 : }
81 42 :
82 42 : idx += len;
83 42 :
84 42 : if header.tag() == backend::READY_FOR_QUERY_TAG {
85 7 : request_complete = true;
86 7 : break;
87 35 : }
88 : }
89 :
90 58 : if idx == 0 {
91 21 : Ok(None)
92 : } else {
93 37 : Ok(Some(BackendMessage::Normal {
94 37 : messages: BackendMessages(src.split_to(idx)),
95 37 : request_complete,
96 37 : }))
97 : }
98 65 : }
99 : }
|