Line data Source code
1 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
2 : use crate::error::DbError;
3 : use crate::maybe_tls_stream::MaybeTlsStream;
4 : use crate::{AsyncMessage, Error, Notification};
5 : use bytes::BytesMut;
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{ready, Sink, Stream};
8 : use log::{info, trace};
9 : use postgres_protocol2::message::backend::Message;
10 : use postgres_protocol2::message::frontend;
11 : use std::collections::{HashMap, VecDeque};
12 : use std::future::Future;
13 : use std::pin::Pin;
14 : use std::task::{Context, Poll};
15 : use tokio::io::{AsyncRead, AsyncWrite};
16 : use tokio::sync::mpsc;
17 : use tokio_util::codec::Framed;
18 : use tokio_util::sync::PollSender;
19 :
20 : pub enum RequestMessages {
21 : Single(FrontendMessage),
22 : }
23 :
24 : pub struct Request {
25 : pub messages: RequestMessages,
26 : pub sender: mpsc::Sender<BackendMessages>,
27 : }
28 :
29 : pub struct Response {
30 : sender: PollSender<BackendMessages>,
31 : }
32 :
33 : #[derive(PartialEq, Debug)]
34 : enum State {
35 : Active,
36 : Closing,
37 : }
38 :
39 : enum WriteReady {
40 : Terminating,
41 : WaitingOnRead,
42 : }
43 :
44 : /// A connection to a PostgreSQL database.
45 : ///
46 : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
47 : /// server, and should generally be spawned off onto an executor to run in the background.
48 : ///
49 : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
50 : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
51 : #[must_use = "futures do nothing unless polled"]
52 : pub struct Connection<S, T> {
53 : /// HACK: we need this in the Neon Proxy.
54 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
55 : /// HACK: we need this in the Neon Proxy to forward params.
56 : pub parameters: HashMap<String, String>,
57 : receiver: mpsc::UnboundedReceiver<Request>,
58 : pending_responses: VecDeque<BackendMessage>,
59 : responses: VecDeque<Response>,
60 : state: State,
61 : }
62 :
63 : impl<S, T> Connection<S, T>
64 : where
65 : S: AsyncRead + AsyncWrite + Unpin,
66 : T: AsyncRead + AsyncWrite + Unpin,
67 : {
68 0 : pub(crate) fn new(
69 0 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
70 0 : pending_responses: VecDeque<BackendMessage>,
71 0 : parameters: HashMap<String, String>,
72 0 : receiver: mpsc::UnboundedReceiver<Request>,
73 0 : ) -> Connection<S, T> {
74 0 : Connection {
75 0 : stream,
76 0 : parameters,
77 0 : receiver,
78 0 : pending_responses,
79 0 : responses: VecDeque::new(),
80 0 : state: State::Active,
81 0 : }
82 0 : }
83 :
84 0 : fn poll_response(
85 0 : &mut self,
86 0 : cx: &mut Context<'_>,
87 0 : ) -> Poll<Option<Result<BackendMessage, Error>>> {
88 0 : if let Some(message) = self.pending_responses.pop_front() {
89 0 : trace!("retrying pending response");
90 0 : return Poll::Ready(Some(Ok(message)));
91 0 : }
92 0 :
93 0 : Pin::new(&mut self.stream)
94 0 : .poll_next(cx)
95 0 : .map(|o| o.map(|r| r.map_err(Error::io)))
96 0 : }
97 :
98 : /// Read and process messages from the connection to postgres.
99 : /// client <- postgres
100 0 : fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
101 : loop {
102 0 : let message = match self.poll_response(cx)? {
103 0 : Poll::Ready(Some(message)) => message,
104 0 : Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
105 : Poll::Pending => {
106 0 : trace!("poll_read: waiting on response");
107 0 : return Poll::Pending;
108 : }
109 : };
110 :
111 0 : let (mut messages, request_complete) = match message {
112 0 : BackendMessage::Async(Message::NoticeResponse(body)) => {
113 0 : let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
114 0 : return Poll::Ready(Ok(AsyncMessage::Notice(error)));
115 : }
116 0 : BackendMessage::Async(Message::NotificationResponse(body)) => {
117 0 : let notification = Notification {
118 0 : process_id: body.process_id(),
119 0 : channel: body.channel().map_err(Error::parse)?.to_string(),
120 0 : payload: body.message().map_err(Error::parse)?.to_string(),
121 0 : };
122 0 : return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
123 : }
124 0 : BackendMessage::Async(Message::ParameterStatus(body)) => {
125 0 : self.parameters.insert(
126 0 : body.name().map_err(Error::parse)?.to_string(),
127 0 : body.value().map_err(Error::parse)?.to_string(),
128 0 : );
129 0 : continue;
130 : }
131 0 : BackendMessage::Async(_) => unreachable!(),
132 : BackendMessage::Normal {
133 0 : messages,
134 0 : request_complete,
135 0 : } => (messages, request_complete),
136 : };
137 :
138 0 : let mut response = match self.responses.pop_front() {
139 0 : Some(response) => response,
140 0 : None => match messages.next().map_err(Error::parse)? {
141 0 : Some(Message::ErrorResponse(error)) => {
142 0 : return Poll::Ready(Err(Error::db(error)))
143 : }
144 0 : _ => return Poll::Ready(Err(Error::unexpected_message())),
145 : },
146 : };
147 :
148 0 : match response.sender.poll_reserve(cx) {
149 : Poll::Ready(Ok(())) => {
150 0 : let _ = response.sender.send_item(messages);
151 0 : if !request_complete {
152 0 : self.responses.push_front(response);
153 0 : }
154 : }
155 : Poll::Ready(Err(_)) => {
156 : // we need to keep paging through the rest of the messages even if the receiver's hung up
157 0 : if !request_complete {
158 0 : self.responses.push_front(response);
159 0 : }
160 : }
161 : Poll::Pending => {
162 0 : self.responses.push_front(response);
163 0 : self.pending_responses.push_back(BackendMessage::Normal {
164 0 : messages,
165 0 : request_complete,
166 0 : });
167 0 : trace!("poll_read: waiting on sender");
168 0 : return Poll::Pending;
169 : }
170 : }
171 : }
172 0 : }
173 :
174 : /// Fetch the next client request and enqueue the response sender.
175 0 : fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
176 0 : if self.receiver.is_closed() {
177 0 : return Poll::Ready(None);
178 0 : }
179 0 :
180 0 : match self.receiver.poll_recv(cx) {
181 0 : Poll::Ready(Some(request)) => {
182 0 : trace!("polled new request");
183 0 : self.responses.push_back(Response {
184 0 : sender: PollSender::new(request.sender),
185 0 : });
186 0 : Poll::Ready(Some(request.messages))
187 : }
188 0 : Poll::Ready(None) => Poll::Ready(None),
189 0 : Poll::Pending => Poll::Pending,
190 : }
191 0 : }
192 :
193 : /// Process client requests and write them to the postgres connection, flushing if necessary.
194 : /// client -> postgres
195 0 : fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
196 : loop {
197 0 : if Pin::new(&mut self.stream)
198 0 : .poll_ready(cx)
199 0 : .map_err(Error::io)?
200 0 : .is_pending()
201 : {
202 0 : trace!("poll_write: waiting on socket");
203 :
204 : // poll_ready is self-flushing.
205 0 : return Poll::Pending;
206 0 : }
207 0 :
208 0 : match self.poll_request(cx) {
209 : // send the message to postgres
210 0 : Poll::Ready(Some(RequestMessages::Single(request))) => {
211 0 : Pin::new(&mut self.stream)
212 0 : .start_send(request)
213 0 : .map_err(Error::io)?;
214 : }
215 : // No more messages from the client, and no more responses to wait for.
216 : // Send a terminate message to postgres
217 0 : Poll::Ready(None) if self.responses.is_empty() => {
218 0 : trace!("poll_write: at eof, terminating");
219 0 : let mut request = BytesMut::new();
220 0 : frontend::terminate(&mut request);
221 0 : let request = FrontendMessage::Raw(request.freeze());
222 0 :
223 0 : Pin::new(&mut self.stream)
224 0 : .start_send(request)
225 0 : .map_err(Error::io)?;
226 :
227 0 : trace!("poll_write: sent eof, closing");
228 0 : trace!("poll_write: done");
229 0 : return Poll::Ready(Ok(WriteReady::Terminating));
230 : }
231 : // No more messages from the client, but there are still some responses to wait for.
232 : Poll::Ready(None) => {
233 0 : trace!(
234 0 : "poll_write: at eof, pending responses {}",
235 0 : self.responses.len()
236 : );
237 0 : ready!(self.poll_flush(cx))?;
238 0 : return Poll::Ready(Ok(WriteReady::WaitingOnRead));
239 : }
240 : // Still waiting for a message from the client.
241 : Poll::Pending => {
242 0 : trace!("poll_write: waiting on request");
243 0 : ready!(self.poll_flush(cx))?;
244 0 : return Poll::Pending;
245 : }
246 : }
247 : }
248 0 : }
249 :
250 0 : fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
251 0 : match Pin::new(&mut self.stream)
252 0 : .poll_flush(cx)
253 0 : .map_err(Error::io)?
254 : {
255 : Poll::Ready(()) => {
256 0 : trace!("poll_flush: flushed");
257 0 : Poll::Ready(Ok(()))
258 : }
259 : Poll::Pending => {
260 0 : trace!("poll_flush: waiting on socket");
261 0 : Poll::Pending
262 : }
263 : }
264 0 : }
265 :
266 0 : fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
267 0 : match Pin::new(&mut self.stream)
268 0 : .poll_close(cx)
269 0 : .map_err(Error::io)?
270 : {
271 : Poll::Ready(()) => {
272 0 : trace!("poll_shutdown: complete");
273 0 : Poll::Ready(Ok(()))
274 : }
275 : Poll::Pending => {
276 0 : trace!("poll_shutdown: waiting on socket");
277 0 : Poll::Pending
278 : }
279 : }
280 0 : }
281 :
282 : /// Returns the value of a runtime parameter for this connection.
283 0 : pub fn parameter(&self, name: &str) -> Option<&str> {
284 0 : self.parameters.get(name).map(|s| &**s)
285 0 : }
286 :
287 : /// Polls for asynchronous messages from the server.
288 : ///
289 : /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
290 : /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
291 0 : pub fn poll_message(
292 0 : &mut self,
293 0 : cx: &mut Context<'_>,
294 0 : ) -> Poll<Option<Result<AsyncMessage, Error>>> {
295 0 : if self.state != State::Closing {
296 : // if the state is still active, try read from and write to postgres.
297 0 : let message = self.poll_read(cx)?;
298 0 : let closing = self.poll_write(cx)?;
299 0 : if let Poll::Ready(WriteReady::Terminating) = closing {
300 0 : self.state = State::Closing;
301 0 : }
302 :
303 0 : if let Poll::Ready(message) = message {
304 0 : return Poll::Ready(Some(Ok(message)));
305 0 : }
306 0 :
307 0 : // poll_read returned Pending.
308 0 : // poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
309 0 : // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
310 0 : if self.state != State::Closing {
311 0 : return Poll::Pending;
312 0 : }
313 0 : }
314 :
315 0 : match self.poll_shutdown(cx) {
316 0 : Poll::Ready(Ok(())) => Poll::Ready(None),
317 0 : Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
318 0 : Poll::Pending => Poll::Pending,
319 : }
320 0 : }
321 : }
322 :
323 : impl<S, T> Future for Connection<S, T>
324 : where
325 : S: AsyncRead + AsyncWrite + Unpin,
326 : T: AsyncRead + AsyncWrite + Unpin,
327 : {
328 : type Output = Result<(), Error>;
329 :
330 0 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
331 0 : while let Some(message) = ready!(self.poll_message(cx)?) {
332 0 : if let AsyncMessage::Notice(notice) = message {
333 0 : info!("{}: {}", notice.severity(), notice.message());
334 0 : }
335 : }
336 0 : Poll::Ready(Ok(()))
337 0 : }
338 : }
|