Line data Source code
1 : use std::collections::{HashMap, VecDeque};
2 : use std::future::Future;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll};
5 :
6 : use bytes::BytesMut;
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{Sink, Stream, ready};
9 : use log::{info, trace};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use tokio::io::{AsyncRead, AsyncWrite};
13 : use tokio::sync::mpsc;
14 : use tokio_util::codec::Framed;
15 : use tokio_util::sync::PollSender;
16 :
17 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
18 : use crate::error::DbError;
19 : use crate::maybe_tls_stream::MaybeTlsStream;
20 : use crate::{AsyncMessage, Error, Notification};
21 :
22 : pub enum RequestMessages {
23 : Single(FrontendMessage),
24 : }
25 :
26 : pub struct Request {
27 : pub messages: RequestMessages,
28 : pub sender: mpsc::Sender<BackendMessages>,
29 : }
30 :
31 : pub struct Response {
32 : sender: PollSender<BackendMessages>,
33 : }
34 :
35 : #[derive(PartialEq, Debug)]
36 : enum State {
37 : Active,
38 : Closing,
39 : }
40 :
41 : enum WriteReady {
42 : Terminating,
43 : WaitingOnRead,
44 : }
45 :
46 : /// A connection to a PostgreSQL database.
47 : ///
48 : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
49 : /// server, and should generally be spawned off onto an executor to run in the background.
50 : ///
51 : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
52 : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
53 : #[must_use = "futures do nothing unless polled"]
54 : pub struct Connection<S, T> {
55 : /// HACK: we need this in the Neon Proxy.
56 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
57 : /// HACK: we need this in the Neon Proxy to forward params.
58 : pub parameters: HashMap<String, String>,
59 : receiver: mpsc::UnboundedReceiver<Request>,
60 : pending_responses: VecDeque<BackendMessage>,
61 : responses: VecDeque<Response>,
62 : state: State,
63 : }
64 :
65 : impl<S, T> Connection<S, T>
66 : where
67 : S: AsyncRead + AsyncWrite + Unpin,
68 : T: AsyncRead + AsyncWrite + Unpin,
69 : {
70 0 : pub(crate) fn new(
71 0 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
72 0 : pending_responses: VecDeque<BackendMessage>,
73 0 : parameters: HashMap<String, String>,
74 0 : receiver: mpsc::UnboundedReceiver<Request>,
75 0 : ) -> Connection<S, T> {
76 0 : Connection {
77 0 : stream,
78 0 : parameters,
79 0 : receiver,
80 0 : pending_responses,
81 0 : responses: VecDeque::new(),
82 0 : state: State::Active,
83 0 : }
84 0 : }
85 :
86 0 : fn poll_response(
87 0 : &mut self,
88 0 : cx: &mut Context<'_>,
89 0 : ) -> Poll<Option<Result<BackendMessage, Error>>> {
90 0 : if let Some(message) = self.pending_responses.pop_front() {
91 0 : trace!("retrying pending response");
92 0 : return Poll::Ready(Some(Ok(message)));
93 0 : }
94 0 :
95 0 : Pin::new(&mut self.stream)
96 0 : .poll_next(cx)
97 0 : .map(|o| o.map(|r| r.map_err(Error::io)))
98 0 : }
99 :
100 : /// Read and process messages from the connection to postgres.
101 : /// client <- postgres
102 0 : fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
103 : loop {
104 0 : let message = match self.poll_response(cx)? {
105 0 : Poll::Ready(Some(message)) => message,
106 0 : Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
107 : Poll::Pending => {
108 0 : trace!("poll_read: waiting on response");
109 0 : return Poll::Pending;
110 : }
111 : };
112 :
113 0 : let (mut messages, request_complete) = match message {
114 0 : BackendMessage::Async(Message::NoticeResponse(body)) => {
115 0 : let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
116 0 : return Poll::Ready(Ok(AsyncMessage::Notice(error)));
117 : }
118 0 : BackendMessage::Async(Message::NotificationResponse(body)) => {
119 0 : let notification = Notification {
120 0 : process_id: body.process_id(),
121 0 : channel: body.channel().map_err(Error::parse)?.to_string(),
122 0 : payload: body.message().map_err(Error::parse)?.to_string(),
123 0 : };
124 0 : return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
125 : }
126 0 : BackendMessage::Async(Message::ParameterStatus(body)) => {
127 0 : self.parameters.insert(
128 0 : body.name().map_err(Error::parse)?.to_string(),
129 0 : body.value().map_err(Error::parse)?.to_string(),
130 0 : );
131 0 : continue;
132 : }
133 0 : BackendMessage::Async(_) => unreachable!(),
134 : BackendMessage::Normal {
135 0 : messages,
136 0 : request_complete,
137 0 : } => (messages, request_complete),
138 : };
139 :
140 0 : let mut response = match self.responses.pop_front() {
141 0 : Some(response) => response,
142 0 : None => match messages.next().map_err(Error::parse)? {
143 0 : Some(Message::ErrorResponse(error)) => {
144 0 : return Poll::Ready(Err(Error::db(error)));
145 : }
146 0 : _ => return Poll::Ready(Err(Error::unexpected_message())),
147 : },
148 : };
149 :
150 0 : match response.sender.poll_reserve(cx) {
151 : Poll::Ready(Ok(())) => {
152 0 : let _ = response.sender.send_item(messages);
153 0 : if !request_complete {
154 0 : self.responses.push_front(response);
155 0 : }
156 : }
157 : Poll::Ready(Err(_)) => {
158 : // we need to keep paging through the rest of the messages even if the receiver's hung up
159 0 : if !request_complete {
160 0 : self.responses.push_front(response);
161 0 : }
162 : }
163 : Poll::Pending => {
164 0 : self.responses.push_front(response);
165 0 : self.pending_responses.push_back(BackendMessage::Normal {
166 0 : messages,
167 0 : request_complete,
168 0 : });
169 0 : trace!("poll_read: waiting on sender");
170 0 : return Poll::Pending;
171 : }
172 : }
173 : }
174 0 : }
175 :
176 : /// Fetch the next client request and enqueue the response sender.
177 0 : fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
178 0 : if self.receiver.is_closed() {
179 0 : return Poll::Ready(None);
180 0 : }
181 0 :
182 0 : match self.receiver.poll_recv(cx) {
183 0 : Poll::Ready(Some(request)) => {
184 0 : trace!("polled new request");
185 0 : self.responses.push_back(Response {
186 0 : sender: PollSender::new(request.sender),
187 0 : });
188 0 : Poll::Ready(Some(request.messages))
189 : }
190 0 : Poll::Ready(None) => Poll::Ready(None),
191 0 : Poll::Pending => Poll::Pending,
192 : }
193 0 : }
194 :
195 : /// Process client requests and write them to the postgres connection, flushing if necessary.
196 : /// client -> postgres
197 0 : fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
198 : loop {
199 0 : if Pin::new(&mut self.stream)
200 0 : .poll_ready(cx)
201 0 : .map_err(Error::io)?
202 0 : .is_pending()
203 : {
204 0 : trace!("poll_write: waiting on socket");
205 :
206 : // poll_ready is self-flushing.
207 0 : return Poll::Pending;
208 0 : }
209 0 :
210 0 : match self.poll_request(cx) {
211 : // send the message to postgres
212 0 : Poll::Ready(Some(RequestMessages::Single(request))) => {
213 0 : Pin::new(&mut self.stream)
214 0 : .start_send(request)
215 0 : .map_err(Error::io)?;
216 : }
217 : // No more messages from the client, and no more responses to wait for.
218 : // Send a terminate message to postgres
219 0 : Poll::Ready(None) if self.responses.is_empty() => {
220 0 : trace!("poll_write: at eof, terminating");
221 0 : let mut request = BytesMut::new();
222 0 : frontend::terminate(&mut request);
223 0 : let request = FrontendMessage::Raw(request.freeze());
224 0 :
225 0 : Pin::new(&mut self.stream)
226 0 : .start_send(request)
227 0 : .map_err(Error::io)?;
228 :
229 0 : trace!("poll_write: sent eof, closing");
230 0 : trace!("poll_write: done");
231 0 : return Poll::Ready(Ok(WriteReady::Terminating));
232 : }
233 : // No more messages from the client, but there are still some responses to wait for.
234 : Poll::Ready(None) => {
235 0 : trace!(
236 0 : "poll_write: at eof, pending responses {}",
237 0 : self.responses.len()
238 : );
239 0 : ready!(self.poll_flush(cx))?;
240 0 : return Poll::Ready(Ok(WriteReady::WaitingOnRead));
241 : }
242 : // Still waiting for a message from the client.
243 : Poll::Pending => {
244 0 : trace!("poll_write: waiting on request");
245 0 : ready!(self.poll_flush(cx))?;
246 0 : return Poll::Pending;
247 : }
248 : }
249 : }
250 0 : }
251 :
252 0 : fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
253 0 : match Pin::new(&mut self.stream)
254 0 : .poll_flush(cx)
255 0 : .map_err(Error::io)?
256 : {
257 : Poll::Ready(()) => {
258 0 : trace!("poll_flush: flushed");
259 0 : Poll::Ready(Ok(()))
260 : }
261 : Poll::Pending => {
262 0 : trace!("poll_flush: waiting on socket");
263 0 : Poll::Pending
264 : }
265 : }
266 0 : }
267 :
268 0 : fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
269 0 : match Pin::new(&mut self.stream)
270 0 : .poll_close(cx)
271 0 : .map_err(Error::io)?
272 : {
273 : Poll::Ready(()) => {
274 0 : trace!("poll_shutdown: complete");
275 0 : Poll::Ready(Ok(()))
276 : }
277 : Poll::Pending => {
278 0 : trace!("poll_shutdown: waiting on socket");
279 0 : Poll::Pending
280 : }
281 : }
282 0 : }
283 :
284 : /// Returns the value of a runtime parameter for this connection.
285 0 : pub fn parameter(&self, name: &str) -> Option<&str> {
286 0 : self.parameters.get(name).map(|s| &**s)
287 0 : }
288 :
289 : /// Polls for asynchronous messages from the server.
290 : ///
291 : /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
292 : /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
293 0 : pub fn poll_message(
294 0 : &mut self,
295 0 : cx: &mut Context<'_>,
296 0 : ) -> Poll<Option<Result<AsyncMessage, Error>>> {
297 0 : if self.state != State::Closing {
298 : // if the state is still active, try read from and write to postgres.
299 0 : let message = self.poll_read(cx)?;
300 0 : let closing = self.poll_write(cx)?;
301 0 : if let Poll::Ready(WriteReady::Terminating) = closing {
302 0 : self.state = State::Closing;
303 0 : }
304 :
305 0 : if let Poll::Ready(message) = message {
306 0 : return Poll::Ready(Some(Ok(message)));
307 0 : }
308 0 :
309 0 : // poll_read returned Pending.
310 0 : // poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
311 0 : // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
312 0 : if self.state != State::Closing {
313 0 : return Poll::Pending;
314 0 : }
315 0 : }
316 :
317 0 : match self.poll_shutdown(cx) {
318 0 : Poll::Ready(Ok(())) => Poll::Ready(None),
319 0 : Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
320 0 : Poll::Pending => Poll::Pending,
321 : }
322 0 : }
323 : }
324 :
325 : impl<S, T> Future for Connection<S, T>
326 : where
327 : S: AsyncRead + AsyncWrite + Unpin,
328 : T: AsyncRead + AsyncWrite + Unpin,
329 : {
330 : type Output = Result<(), Error>;
331 :
332 0 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
333 0 : while let Some(message) = ready!(self.poll_message(cx)?) {
334 0 : if let AsyncMessage::Notice(notice) = message {
335 0 : info!("{}: {}", notice.severity(), notice.message());
336 0 : }
337 : }
338 0 : Poll::Ready(Ok(()))
339 0 : }
340 : }
|