Line data Source code
1 : use std::future::Future;
2 : use std::pin::Pin;
3 : use std::task::{Context, Poll};
4 :
5 : use bytes::BytesMut;
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{Sink, StreamExt, ready};
8 : use postgres_protocol2::message::backend::{Message, NoticeResponseBody};
9 : use postgres_protocol2::message::frontend;
10 : use tokio::io::{AsyncRead, AsyncWrite};
11 : use tokio::sync::mpsc;
12 : use tokio_util::codec::Framed;
13 : use tokio_util::sync::PollSender;
14 : use tracing::trace;
15 :
16 : use crate::Error;
17 : use crate::codec::{
18 : BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices,
19 : };
20 : use crate::maybe_tls_stream::MaybeTlsStream;
21 :
22 : #[derive(PartialEq, Debug)]
23 : enum State {
24 : Active,
25 : Closing,
26 : }
27 :
28 : /// A connection to a PostgreSQL database.
29 : ///
30 : /// This is one half of what is returned when a new connection is established. It performs the actual IO with the
31 : /// server, and should generally be spawned off onto an executor to run in the background.
32 : ///
33 : /// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
34 : /// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
35 : #[must_use = "futures do nothing unless polled"]
36 : pub struct Connection<S, T> {
37 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
38 :
39 : sender: PollSender<BackendMessages>,
40 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
41 : notices: Option<RecordNotices>,
42 :
43 : pending_response: Option<BackendMessages>,
44 : state: State,
45 : }
46 :
47 : pub const INITIAL_CAPACITY: usize = 2 * 1024;
48 : pub const GC_THRESHOLD: usize = 16 * 1024;
49 :
50 : /// Gargabe collect the [`BytesMut`] if it has too much spare capacity.
51 0 : pub fn gc_bytesmut(buf: &mut BytesMut) {
52 : // We use a different mode to shrink the buf when above the threshold.
53 : // When above the threshold, we only re-allocate when the buf has 2x spare capacity.
54 0 : let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len());
55 :
56 : // `try_reclaim` tries to get the capacity from any shared `BytesMut`s,
57 : // before then comparing the length against the capacity.
58 0 : if buf.try_reclaim(reclaim) {
59 0 : let capacity = usize::max(buf.len(), INITIAL_CAPACITY);
60 0 :
61 0 : // Allocate a new `BytesMut` so that we deallocate the old version.
62 0 : let mut new = BytesMut::with_capacity(capacity);
63 0 : new.extend_from_slice(buf);
64 0 : *buf = new;
65 0 : }
66 0 : }
67 :
68 : pub enum Never {}
69 :
70 : impl<S, T> Connection<S, T>
71 : where
72 : S: AsyncRead + AsyncWrite + Unpin,
73 : T: AsyncRead + AsyncWrite + Unpin,
74 : {
75 0 : pub(crate) fn new(
76 0 : stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
77 0 : sender: mpsc::Sender<BackendMessages>,
78 0 : receiver: mpsc::UnboundedReceiver<FrontendMessage>,
79 0 : ) -> Connection<S, T> {
80 0 : Connection {
81 0 : stream,
82 0 : sender: PollSender::new(sender),
83 0 : receiver,
84 0 : notices: None,
85 0 : pending_response: None,
86 0 : state: State::Active,
87 0 : }
88 0 : }
89 :
90 : /// Read and process messages from the connection to postgres.
91 : /// client <- postgres
92 0 : fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Never, Error>> {
93 : loop {
94 0 : let messages = match self.pending_response.take() {
95 0 : Some(messages) => messages,
96 : None => {
97 0 : let message = match self.stream.poll_next_unpin(cx) {
98 0 : Poll::Pending => return Poll::Pending,
99 0 : Poll::Ready(None) => return Poll::Ready(Err(Error::closed())),
100 0 : Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))),
101 0 : Poll::Ready(Some(Ok(message))) => message,
102 : };
103 :
104 0 : match message {
105 0 : BackendMessage::Async(Message::NoticeResponse(body)) => {
106 0 : self.handle_notice(body)?;
107 0 : continue;
108 : }
109 0 : BackendMessage::Async(_) => continue,
110 0 : BackendMessage::Normal { messages, ready } => {
111 : // if we read a ReadyForQuery from postgres, let's try GC the read buffer.
112 0 : if ready {
113 0 : gc_bytesmut(self.stream.read_buffer_mut());
114 0 : }
115 :
116 0 : messages
117 : }
118 : }
119 : }
120 : };
121 :
122 0 : match self.sender.poll_reserve(cx) {
123 0 : Poll::Ready(Ok(())) => {
124 0 : let _ = self.sender.send_item(messages);
125 0 : }
126 : Poll::Ready(Err(_)) => {
127 0 : return Poll::Ready(Err(Error::closed()));
128 : }
129 : Poll::Pending => {
130 0 : self.pending_response = Some(messages);
131 0 : trace!("poll_read: waiting on sender");
132 0 : return Poll::Pending;
133 : }
134 : }
135 : }
136 0 : }
137 :
138 0 : fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> {
139 0 : let Some(notices) = &mut self.notices else {
140 0 : return Ok(());
141 : };
142 :
143 0 : let mut fields = body.fields();
144 0 : while let Some(field) = fields.next().map_err(Error::parse)? {
145 : // loop until we find the message field
146 0 : if field.type_() == b'M' {
147 : // if the message field is within the limit, send it.
148 0 : if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) {
149 0 : match notices.sender.send(field.value().into()) {
150 : // set the new limit.
151 0 : Ok(()) => notices.limit = new_limit,
152 : // closed.
153 0 : Err(_) => self.notices = None,
154 : }
155 0 : }
156 0 : break;
157 0 : }
158 : }
159 :
160 0 : Ok(())
161 0 : }
162 :
163 : /// Fetch the next client request and enqueue the response sender.
164 0 : fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
165 0 : if self.receiver.is_closed() {
166 0 : return Poll::Ready(None);
167 0 : }
168 :
169 0 : match self.receiver.poll_recv(cx) {
170 0 : Poll::Ready(Some(request)) => {
171 0 : trace!("polled new request");
172 0 : Poll::Ready(Some(request))
173 : }
174 0 : Poll::Ready(None) => Poll::Ready(None),
175 0 : Poll::Pending => Poll::Pending,
176 : }
177 0 : }
178 :
179 : /// Process client requests and write them to the postgres connection, flushing if necessary.
180 : /// client -> postgres
181 0 : fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
182 : loop {
183 0 : if Pin::new(&mut self.stream)
184 0 : .poll_ready(cx)
185 0 : .map_err(Error::io)?
186 0 : .is_pending()
187 : {
188 0 : trace!("poll_write: waiting on socket");
189 :
190 : // poll_ready is self-flushing.
191 0 : return Poll::Pending;
192 0 : }
193 :
194 0 : match self.poll_request(cx) {
195 : // send the message to postgres
196 0 : Poll::Ready(Some(FrontendMessage::Raw(request))) => {
197 0 : Pin::new(&mut self.stream)
198 0 : .start_send(request)
199 0 : .map_err(Error::io)?;
200 : }
201 0 : Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => {
202 0 : self.notices = Some(notices)
203 : }
204 : // No more messages from the client, and no more responses to wait for.
205 : // Send a terminate message to postgres
206 : Poll::Ready(None) => {
207 0 : trace!("poll_write: at eof, terminating");
208 0 : frontend::terminate(self.stream.write_buffer_mut());
209 :
210 0 : trace!("poll_write: sent eof, closing");
211 0 : trace!("poll_write: done");
212 0 : return Poll::Ready(Ok(()));
213 : }
214 : // Still waiting for a message from the client.
215 : Poll::Pending => {
216 0 : trace!("poll_write: waiting on request");
217 0 : ready!(self.poll_flush(cx))?;
218 0 : return Poll::Pending;
219 : }
220 : }
221 : }
222 0 : }
223 :
224 0 : fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
225 0 : match Pin::new(&mut self.stream)
226 0 : .poll_flush(cx)
227 0 : .map_err(Error::io)?
228 : {
229 : Poll::Ready(()) => {
230 0 : trace!("poll_flush: flushed");
231 :
232 : // Since our codec prefers to share the buffer with the `Client`,
233 : // if we don't release our share, then the `Client` would have to re-alloc
234 : // the buffer when they next use it.
235 0 : debug_assert!(self.stream.write_buffer().is_empty());
236 0 : *self.stream.write_buffer_mut() = BytesMut::new();
237 :
238 0 : Poll::Ready(Ok(()))
239 : }
240 : Poll::Pending => {
241 0 : trace!("poll_flush: waiting on socket");
242 0 : Poll::Pending
243 : }
244 : }
245 0 : }
246 :
247 0 : fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
248 0 : match Pin::new(&mut self.stream)
249 0 : .poll_close(cx)
250 0 : .map_err(Error::io)?
251 : {
252 : Poll::Ready(()) => {
253 0 : trace!("poll_shutdown: complete");
254 0 : Poll::Ready(Ok(()))
255 : }
256 : Poll::Pending => {
257 0 : trace!("poll_shutdown: waiting on socket");
258 0 : Poll::Pending
259 : }
260 : }
261 0 : }
262 :
263 0 : fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Never, Error>>> {
264 0 : if self.state != State::Closing {
265 : // if the state is still active, try read from and write to postgres.
266 0 : let Poll::Pending = self.poll_read(cx)?;
267 0 : if self.poll_write(cx)?.is_ready() {
268 0 : self.state = State::Closing;
269 0 : }
270 :
271 : // poll_read returned Pending.
272 : // poll_write returned Pending or Ready(()).
273 : // if poll_write returned Ready(()), then we are waiting to read more data from postgres.
274 0 : if self.state != State::Closing {
275 0 : return Poll::Pending;
276 0 : }
277 0 : }
278 :
279 0 : match self.poll_shutdown(cx) {
280 0 : Poll::Ready(Ok(())) => Poll::Ready(None),
281 0 : Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
282 0 : Poll::Pending => Poll::Pending,
283 : }
284 0 : }
285 : }
286 :
287 : impl<S, T> Future for Connection<S, T>
288 : where
289 : S: AsyncRead + AsyncWrite + Unpin,
290 : T: AsyncRead + AsyncWrite + Unpin,
291 : {
292 : type Output = Result<(), Error>;
293 :
294 0 : fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
295 0 : match self.poll_message(cx)? {
296 0 : Poll::Ready(None) => Poll::Ready(Ok(())),
297 0 : Poll::Pending => Poll::Pending,
298 : }
299 0 : }
300 : }
|