Line data Source code
1 : use std::io;
2 :
3 : use bytes::{Bytes, BytesMut};
4 : use fallible_iterator::FallibleIterator;
5 : use postgres_protocol2::message::backend;
6 : use tokio::sync::mpsc::UnboundedSender;
7 : use tokio_util::codec::{Decoder, Encoder};
8 :
9 : pub enum FrontendMessage {
10 : Raw(Bytes),
11 : RecordNotices(RecordNotices),
12 : }
13 :
14 : pub struct RecordNotices {
15 : pub sender: UnboundedSender<Box<str>>,
16 : pub limit: usize,
17 : }
18 :
19 : pub enum BackendMessage {
20 : Normal { messages: BackendMessages },
21 : Async(backend::Message),
22 : }
23 :
24 : pub struct BackendMessages(BytesMut);
25 :
26 : impl BackendMessages {
27 15 : pub fn empty() -> BackendMessages {
28 15 : BackendMessages(BytesMut::new())
29 15 : }
30 : }
31 :
32 : impl FallibleIterator for BackendMessages {
33 : type Item = backend::Message;
34 : type Error = io::Error;
35 :
36 129 : fn next(&mut self) -> io::Result<Option<backend::Message>> {
37 129 : backend::Message::parse(&mut self.0)
38 129 : }
39 : }
40 :
41 : pub struct PostgresCodec;
42 :
43 : impl Encoder<Bytes> for PostgresCodec {
44 : type Error = io::Error;
45 :
46 36 : fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
47 36 : dst.extend_from_slice(&item);
48 36 : Ok(())
49 36 : }
50 : }
51 :
52 : impl Decoder for PostgresCodec {
53 : type Item = BackendMessage;
54 : type Error = io::Error;
55 :
56 69 : fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
57 69 : let mut idx = 0;
58 :
59 104 : while let Some(header) = backend::Header::parse(&src[idx..])? {
60 55 : let len = header.len() as usize + 1;
61 55 : if src[idx..].len() < len {
62 0 : break;
63 55 : }
64 :
65 55 : match header.tag() {
66 : backend::NOTICE_RESPONSE_TAG
67 : | backend::NOTIFICATION_RESPONSE_TAG
68 : | backend::PARAMETER_STATUS_TAG => {
69 13 : if idx == 0 {
70 7 : let message = backend::Message::parse(src)?.unwrap();
71 7 : return Ok(Some(BackendMessage::Async(message)));
72 : } else {
73 6 : break;
74 : }
75 : }
76 42 : _ => {}
77 : }
78 :
79 42 : idx += len;
80 :
81 42 : if header.tag() == backend::READY_FOR_QUERY_TAG {
82 7 : break;
83 35 : }
84 : }
85 :
86 62 : if idx == 0 {
87 24 : Ok(None)
88 : } else {
89 38 : Ok(Some(BackendMessage::Normal {
90 38 : messages: BackendMessages(src.split_to(idx)),
91 38 : }))
92 : }
93 69 : }
94 : }
|