Line data Source code
1 : use std::collections::{HashMap, VecDeque};
2 : use std::future::Future;
3 : use std::pin::Pin;
4 : use std::task::{Context, Poll};
5 :
6 : use bytes::BytesMut;
7 : use futures_util::{Sink, Stream, ready};
8 : use postgres_protocol2::message::backend::Message;
9 : use postgres_protocol2::message::frontend;
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::sync::mpsc;
12 : use tokio_util::codec::Framed;
13 : use tokio_util::sync::PollSender;
14 : use tracing::{info, trace};
15 :
16 : use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
17 : use crate::error::DbError;
18 : use crate::maybe_tls_stream::MaybeTlsStream;
19 : use crate::{AsyncMessage, Error, Notification};
20 :
21 : #[derive(PartialEq, Debug)]
22 : enum State {
23 : Active,
24 : Closing,
25 : }
26 :
27 : /// A connection to a PostgreSQL database.
28 : ///
29 : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
30 : /// server, and should generally be spawned off onto an executor to run in the background.
31 : ///
32 : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
33 : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
34 : #[must_use = "futures do nothing unless polled"]
35 : pub struct Connection<S, T> {
36 : /// HACK: we need this in the Neon Proxy.
37 : pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
38 : /// HACK: we need this in the Neon Proxy to forward params.
39 : pub parameters: HashMap<String, String>,
40 :
41 : sender: PollSender<BackendMessages>,
42 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
43 :
44 : pending_responses: VecDeque<BackendMessage>,
45 : state: State,
46 : }
47 :
48 : impl<S, T> Connection<S, T>
49 : where
50 : S: AsyncRead + AsyncWrite + Unpin,
51 : T: AsyncRead + AsyncWrite + Unpin,
52 : {
53 0 : pub(crate) fn new(
54 0 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
55 0 : pending_responses: VecDeque<BackendMessage>,
56 0 : parameters: HashMap<String, String>,
57 0 : sender: mpsc::Sender<BackendMessages>,
58 0 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
59 0 : ) -> Connection<S, T> {
60 0 : Connection {
61 0 : stream,
62 0 : parameters,
63 0 : sender: PollSender::new(sender),
64 0 : receiver,
65 0 : pending_responses,
66 0 : state: State::Active,
67 0 : }
68 0 : }
69 :
70 0 : fn poll_response(
71 0 : &mut self,
72 0 : cx: &mut Context<'_>,
73 0 : ) -> Poll<Option<Result<BackendMessage, Error>>> {
74 0 : if let Some(message) = self.pending_responses.pop_front() {
75 0 : trace!("retrying pending response");
76 0 : return Poll::Ready(Some(Ok(message)));
77 0 : }
78 :
79 0 : Pin::new(&mut self.stream)
80 0 : .poll_next(cx)
81 0 : .map(|o| o.map(|r| r.map_err(Error::io)))
82 0 : }
83 :
84 : /// Read and process messages from the connection to postgres.
85 : /// client <- postgres
86 0 : fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<AsyncMessage, Error>> {
87 : loop {
88 0 : let message = match self.poll_response(cx)? {
89 0 : Poll::Ready(Some(message)) => message,
90 0 : Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
91 : Poll::Pending => {
92 0 : trace!("poll_read: waiting on response");
93 0 : return Poll::Pending;
94 : }
95 : };
96 :
97 0 : let messages = match message {
98 0 : BackendMessage::Async(Message::NoticeResponse(body)) => {
99 0 : let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
100 0 : return Poll::Ready(Ok(AsyncMessage::Notice(error)));
101 : }
102 0 : BackendMessage::Async(Message::NotificationResponse(body)) => {
103 0 : let notification = Notification {
104 0 : process_id: body.process_id(),
105 0 : channel: body.channel().map_err(Error::parse)?.to_string(),
106 0 : payload: body.message().map_err(Error::parse)?.to_string(),
107 : };
108 0 : return Poll::Ready(Ok(AsyncMessage::Notification(notification)));
109 : }
110 0 : BackendMessage::Async(Message::ParameterStatus(body)) => {
111 0 : self.parameters.insert(
112 0 : body.name().map_err(Error::parse)?.to_string(),
113 0 : body.value().map_err(Error::parse)?.to_string(),
114 : );
115 0 : continue;
116 : }
117 0 : BackendMessage::Async(_) => unreachable!(),
118 0 : BackendMessage::Normal { messages } => messages,
119 : };
120 :
121 0 : match self.sender.poll_reserve(cx) {
122 0 : Poll::Ready(Ok(())) => {
123 0 : let _ = self.sender.send_item(messages);
124 0 : }
125 : Poll::Ready(Err(_)) => {
126 0 : return Poll::Ready(Err(Error::closed()));
127 : }
128 : Poll::Pending => {
129 0 : self.pending_responses
130 0 : .push_back(BackendMessage::Normal { messages });
131 0 : trace!("poll_read: waiting on sender");
132 0 : return Poll::Pending;
133 : }
134 : }
135 : }
136 0 : }
137 :
138 : /// Fetch the next client request and enqueue the response sender.
139 0 : fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
140 0 : if self.receiver.is_closed() {
141 0 : return Poll::Ready(None);
142 0 : }
143 :
144 0 : match self.receiver.poll_recv(cx) {
145 0 : Poll::Ready(Some(request)) => {
146 0 : trace!("polled new request");
147 0 : Poll::Ready(Some(request))
148 : }
149 0 : Poll::Ready(None) => Poll::Ready(None),
150 0 : Poll::Pending => Poll::Pending,
151 : }
152 0 : }
153 :
154 : /// Process client requests and write them to the postgres connection, flushing if necessary.
155 : /// client -> postgres
156 0 : fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
157 : loop {
158 0 : if Pin::new(&mut self.stream)
159 0 : .poll_ready(cx)
160 0 : .map_err(Error::io)?
161 0 : .is_pending()
162 : {
163 0 : trace!("poll_write: waiting on socket");
164 :
165 : // poll_ready is self-flushing.
166 0 : return Poll::Pending;
167 0 : }
168 :
169 0 : match self.poll_request(cx) {
170 : // send the message to postgres
171 0 : Poll::Ready(Some(request)) => {
172 0 : Pin::new(&mut self.stream)
173 0 : .start_send(request)
174 0 : .map_err(Error::io)?;
175 : }
176 : // No more messages from the client, and no more responses to wait for.
177 : // Send a terminate message to postgres
178 : Poll::Ready(None) => {
179 0 : trace!("poll_write: at eof, terminating");
180 0 : let mut request = BytesMut::new();
181 0 : frontend::terminate(&mut request);
182 0 : let request = FrontendMessage::Raw(request.freeze());
183 :
184 0 : Pin::new(&mut self.stream)
185 0 : .start_send(request)
186 0 : .map_err(Error::io)?;
187 :
188 0 : trace!("poll_write: sent eof, closing");
189 0 : trace!("poll_write: done");
190 0 : return Poll::Ready(Ok(()));
191 : }
192 : // Still waiting for a message from the client.
193 : Poll::Pending => {
194 0 : trace!("poll_write: waiting on request");
195 0 : ready!(self.poll_flush(cx))?;
196 0 : return Poll::Pending;
197 : }
198 : }
199 : }
200 0 : }
201 :
202 0 : fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
203 0 : match Pin::new(&mut self.stream)
204 0 : .poll_flush(cx)
205 0 : .map_err(Error::io)?
206 : {
207 : Poll::Ready(()) => {
208 0 : trace!("poll_flush: flushed");
209 0 : Poll::Ready(Ok(()))
210 : }
211 : Poll::Pending => {
212 0 : trace!("poll_flush: waiting on socket");
213 0 : Poll::Pending
214 : }
215 : }
216 0 : }
217 :
218 0 : fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
219 0 : match Pin::new(&mut self.stream)
220 0 : .poll_close(cx)
221 0 : .map_err(Error::io)?
222 : {
223 : Poll::Ready(()) => {
224 0 : trace!("poll_shutdown: complete");
225 0 : Poll::Ready(Ok(()))
226 : }
227 : Poll::Pending => {
228 0 : trace!("poll_shutdown: waiting on socket");
229 0 : Poll::Pending
230 : }
231 : }
232 0 : }
233 :
234 : /// Returns the value of a runtime parameter for this connection.
235 0 : pub fn parameter(&self, name: &str) -> Option<&str> {
236 0 : self.parameters.get(name).map(|s| &**s)
237 0 : }
238 :
239 : /// Polls for asynchronous messages from the server.
240 : ///
241 : /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
242 : /// examine those messages should use this method to drive the connection rather than its `Future` implementation.
243 0 : pub fn poll_message(
244 0 : &mut self,
245 0 : cx: &mut Context<'_>,
246 0 : ) -> Poll<Option<Result<AsyncMessage, Error>>> {
247 0 : if self.state != State::Closing {
248 : // if the state is still active, try read from and write to postgres.
249 0 : let message = self.poll_read(cx)?;
250 0 : let closing = self.poll_write(cx)?;
251 0 : if let Poll::Ready(()) = closing {
252 0 : self.state = State::Closing;
253 0 : }
254 :
255 0 : if let Poll::Ready(message) = message {
256 0 : return Poll::Ready(Some(Ok(message)));
257 0 : }
258 :
259 : // poll_read returned Pending.
260 : // poll_write returned Pending or Ready(WriteReady::WaitingOnRead).
261 : // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres.
262 0 : if self.state != State::Closing {
263 0 : return Poll::Pending;
264 0 : }
265 0 : }
266 :
267 0 : match self.poll_shutdown(cx) {
268 0 : Poll::Ready(Ok(())) => Poll::Ready(None),
269 0 : Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
270 0 : Poll::Pending => Poll::Pending,
271 : }
272 0 : }
273 : }
274 :
275 : impl<S, T> Future for Connection<S, T>
276 : where
277 : S: AsyncRead + AsyncWrite + Unpin,
278 : T: AsyncRead + AsyncWrite + Unpin,
279 : {
280 : type Output = Result<(), Error>;
281 :
282 0 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
283 0 : while let Some(message) = ready!(self.poll_message(cx)?) {
284 0 : if let AsyncMessage::Notice(notice) = message {
285 0 : info!("{}: {}", notice.severity(), notice.message());
286 0 : }
287 : }
288 0 : Poll::Ready(Ok(()))
289 0 : }
290 : }
|