Line data Source code
1 : use std::future::Future;
2 : use std::pin::Pin;
3 : use std::task::{Context, Poll};
4 :
5 : use bytes::BytesMut;
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{Sink, StreamExt, ready};
8 : use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
9 : use postgres_protocol2::message::frontend;
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::sync::mpsc;
12 : use tokio_util::codec::Framed;
13 : use tokio_util::sync::PollSender;
14 : use tracing::trace;
15 :
16 : use crate::Error;
17 : use crate::codec::{
18 : BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
19 : };
20 : use crate::maybe_tls_stream::MaybeTlsStream;
21 :
22 : #[derive(PartialEq, Debug)]
23 : enum State {
24 : Active,
25 : Closing,
26 : }
27 :
28 : /// A connection to a PostgreSQL database.
29 : ///
30 : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
31 : /// server, and should generally be spawned off onto an executor to run in the background.
32 : ///
33 : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
34 : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
35 : #[must_use = "futures do nothing unless polled"]
36 : pub struct Connection<S, T> {
37 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
38 :
39 : sender: PollSender<BackendMessages>,
40 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
41 : notices: Option<RecordNotices>,
42 :
43 : pending_response: Option<BackendMessages>,
44 : state: State,
45 : }
46 :
47 : pub enum Never {}
48 :
49 : impl<S, T> Connection<S, T>
50 : where
51 : S: AsyncRead + AsyncWrite + Unpin,
52 : T: AsyncRead + AsyncWrite + Unpin,
53 : {
54 0 : pub(crate) fn new(
55 0 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
56 0 : sender: mpsc::Sender<BackendMessages>,
57 0 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
58 0 : ) -> Connection<S, T> {
59 0 : Connection {
60 0 : stream,
61 0 : sender: PollSender::new(sender),
62 0 : receiver,
63 0 : notices: None,
64 0 : pending_response: None,
65 0 : state: State::Active,
66 0 : }
67 0 : }
68 :
69 : /// Read and process messages from the connection to postgres.
70 : /// client <- postgres
71 0 : fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
72 : loop {
73 0 : let messages = match self.pending_response.take() {
74 0 : Some(messages) => messages,
75 : None => {
76 0 : let message = match self.stream.poll_next_unpin(cx) {
77 0 : Poll::Pending => return Poll::Pending,
78 0 : Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
79 0 : Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
80 0 : Poll::Ready(Some(Ok(message))) => message,
81 : };
82 :
83 0 : match message {
84 0 : BackendMessage::Async(Message::NoticeResponse(body)) => {
85 0 : self.handle_notice(body)?;
86 0 : continue;
87 : }
88 0 : BackendMessage::Async(_) => continue,
89 0 : BackendMessage::Normal { messages } => messages,
90 : }
91 : }
92 : };
93 :
94 0 : match self.sender.poll_reserve(cx) {
95 0 : Poll::Ready(Ok(())) => {
96 0 : let _ = self.sender.send_item(messages);
97 0 : }
98 : Poll::Ready(Err(_)) => {
99 0 : return Poll::Ready(Err(Error::closed()));
100 : }
101 : Poll::Pending => {
102 0 : self.pending_response = Some(messages);
103 0 : trace!("poll_read: waiting on sender");
104 0 : return Poll::Pending;
105 : }
106 : }
107 : }
108 0 : }
109 :
110 0 : fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
111 0 : let Some(notices) = &mut self.notices else {
112 0 : return Ok(());
113 : };
114 :
115 0 : let mut fields = body.fields();
116 0 : while let Some(field) = fields.next().map_err(Error::parse)? {
117 : // loop until we find the message field
118 0 : if field.type_() == b'M' {
119 : // if the message field is within the limit, send it.
120 0 : if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
121 0 : match notices.sender.send(field.value().into()) {
122 : // set the new limit.
123 0 : Ok(()) => notices.limit = new_limit,
124 : // closed.
125 0 : Err(_) => self.notices = None,
126 : }
127 0 : }
128 0 : break;
129 0 : }
130 : }
131 :
132 0 : Ok(())
133 0 : }
134 :
135 : /// Fetch the next client request and enqueue the response sender.
136 0 : fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
137 0 : if self.receiver.is_closed() {
138 0 : return Poll::Ready(None);
139 0 : }
140 :
141 0 : match self.receiver.poll_recv(cx) {
142 0 : Poll::Ready(Some(request)) => {
143 0 : trace!("polled new request");
144 0 : Poll::Ready(Some(request))
145 : }
146 0 : Poll::Ready(None) => Poll::Ready(None),
147 0 : Poll::Pending => Poll::Pending,
148 : }
149 0 : }
150 :
151 : /// Process client requests and write them to the postgres connection, flushing if necessary.
152 : /// client -> postgres
153 0 : fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
154 : loop {
155 0 : if Pin::new(&mut self.stream)
156 0 : .poll_ready(cx)
157 0 : .map_err(Error::io)?
158 0 : .is_pending()
159 : {
160 0 : trace!("poll_write: waiting on socket");
161 :
162 : // poll_ready is self-flushing.
163 0 : return Poll::Pending;
164 0 : }
165 :
166 0 : match self.poll_request(cx) {
167 : // send the message to postgres
168 0 : Poll::Ready(Some(FrontendMessage::Raw(request))) => {
169 0 : Pin::new(&mut self.stream)
170 0 : .start_send(request)
171 0 : .map_err(Error::io)?;
172 : }
173 0 : Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
174 0 : self.notices = Some(notices)
175 : }
176 : // No more messages from the client, and no more responses to wait for.
177 : // Send a terminate message to postgres
178 : Poll::Ready(None) => {
179 0 : trace!("poll_write: at eof, terminating");
180 0 : let mut request = BytesMut::new();
181 0 : frontend::terminate(&mut request);
182 :
183 0 : Pin::new(&mut self.stream)
184 0 : .start_send(request.freeze())
185 0 : .map_err(Error::io)?;
186 :
187 0 : trace!("poll_write: sent eof, closing");
188 0 : trace!("poll_write: done");
189 0 : return Poll::Ready(Ok(()));
190 : }
191 : // Still waiting for a message from the client.
192 : Poll::Pending => {
193 0 : trace!("poll_write: waiting on request");
194 0 : ready!(self.poll_flush(cx))?;
195 0 : return Poll::Pending;
196 : }
197 : }
198 : }
199 0 : }
200 :
201 0 : fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
202 0 : match Pin::new(&mut self.stream)
203 0 : .poll_flush(cx)
204 0 : .map_err(Error::io)?
205 : {
206 : Poll::Ready(()) => {
207 0 : trace!("poll_flush: flushed");
208 0 : Poll::Ready(Ok(()))
209 : }
210 : Poll::Pending => {
211 0 : trace!("poll_flush: waiting on socket");
212 0 : Poll::Pending
213 : }
214 : }
215 0 : }
216 :
217 0 : fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
218 0 : match Pin::new(&mut self.stream)
219 0 : .poll_close(cx)
220 0 : .map_err(Error::io)?
221 : {
222 : Poll::Ready(()) => {
223 0 : trace!("poll_shutdown: complete");
224 0 : Poll::Ready(Ok(()))
225 : }
226 : Poll::Pending => {
227 0 : trace!("poll_shutdown: waiting on socket");
228 0 : Poll::Pending
229 : }
230 : }
231 0 : }
232 :
233 0 : fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
234 0 : if self.state != State::Closing {
235 : // if the state is still active, try read from and write to postgres.
236 0 : let Poll::Pending = self.poll_read(cx)?;
237 0 : if self.poll_write(cx)?.is_ready() {
238 0 : self.state = State::Closing;
239 0 : }
240 :
241 : // poll_read returned Pending.
242 : // poll_write returned Pending or Ready(()).
243 : // if poll_write returned Ready(()), then we are waiting to read more data from postgres.
244 0 : if self.state != State::Closing {
245 0 : return Poll::Pending;
246 0 : }
247 0 : }
248 :
249 0 : match self.poll_shutdown(cx) {
250 0 : Poll::Ready(Ok(())) => Poll::Ready(None),
251 0 : Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
252 0 : Poll::Pending => Poll::Pending,
253 : }
254 0 : }
255 : }
256 :
257 : impl<S, T> Future for Connection<S, T>
258 : where
259 : S: AsyncRead + AsyncWrite + Unpin,
260 : T: AsyncRead + AsyncWrite + Unpin,
261 : {
262 : type Output = Result<(), Error>;
263 :
264 0 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
265 0 : match self.poll_message(cx)? {
266 0 : Poll::Ready(None) => Poll::Ready(Ok(())),
267 0 : Poll::Pending => Poll::Pending,
268 : }
269 0 : }
270 : }
|