Line data Source code
1 : use bytes::{Buf, Bytes, BytesMut};
2 : use fallible_iterator::FallibleIterator;
3 : use postgres_protocol2::message::backend;
4 : use postgres_protocol2::message::frontend::CopyData;
5 : use std::io;
6 : use tokio_util::codec::{Decoder, Encoder};
7 :
8 : pub enum FrontendMessage {
9 : Raw(Bytes),
10 : CopyData(CopyData<Box<dyn Buf + Send>>),
11 : }
12 :
13 : pub enum BackendMessage {
14 : Normal {
15 : messages: BackendMessages,
16 : request_complete: bool,
17 : },
18 : Async(backend::Message),
19 : }
20 :
21 : pub struct BackendMessages(BytesMut);
22 :
23 : impl BackendMessages {
24 15 : pub fn empty() -> BackendMessages {
25 15 : BackendMessages(BytesMut::new())
26 15 : }
27 : }
28 :
29 : impl FallibleIterator for BackendMessages {
30 : type Item = backend::Message;
31 : type Error = io::Error;
32 :
33 128 : fn next(&mut self) -> io::Result<Option<backend::Message>> {
34 128 : backend::Message::parse(&mut self.0)
35 128 : }
36 : }
37 :
38 : pub struct PostgresCodec;
39 :
40 : impl Encoder<FrontendMessage> for PostgresCodec {
41 : type Error = io::Error;
42 :
43 36 : fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
44 36 : match item {
45 36 : FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
46 0 : FrontendMessage::CopyData(data) => data.write(dst),
47 : }
48 :
49 36 : Ok(())
50 36 : }
51 : }
52 :
53 : impl Decoder for PostgresCodec {
54 : type Item = BackendMessage;
55 : type Error = io::Error;
56 :
57 65 : fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
58 65 : let mut idx = 0;
59 65 : let mut request_complete = false;
60 :
61 100 : while let Some(header) = backend::Header::parse(&src[idx..])? {
62 56 : let len = header.len() as usize + 1;
63 56 : if src[idx..].len() < len {
64 0 : break;
65 56 : }
66 56 :
67 56 : match header.tag() {
68 : backend::NOTICE_RESPONSE_TAG
69 : | backend::NOTIFICATION_RESPONSE_TAG
70 : | backend::PARAMETER_STATUS_TAG => {
71 14 : if idx == 0 {
72 7 : let message = backend::Message::parse(src)?.unwrap();
73 7 : return Ok(Some(BackendMessage::Async(message)));
74 : } else {
75 7 : break;
76 : }
77 : }
78 42 : _ => {}
79 42 : }
80 42 :
81 42 : idx += len;
82 42 :
83 42 : if header.tag() == backend::READY_FOR_QUERY_TAG {
84 7 : request_complete = true;
85 7 : break;
86 35 : }
87 : }
88 :
89 58 : if idx == 0 {
90 21 : Ok(None)
91 : } else {
92 37 : Ok(Some(BackendMessage::Normal {
93 37 : messages: BackendMessages(src.split_to(idx)),
94 37 : request_complete,
95 37 : }))
96 : }
97 65 : }
98 : }
|