Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::task::{Context, Poll};
5 : use std::time::Duration;
6 :
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{TryStreamExt, future, ready};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use serde::{Deserialize, Serialize};
13 : use tokio::sync::mpsc;
14 :
15 : use crate::cancel_token::RawCancelToken;
16 : use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
17 : use crate::config::{Host, SslMode};
18 : use crate::connection::gc_bytesmut;
19 : use crate::query::RowStream;
20 : use crate::simple_query::SimpleQueryStream;
21 : use crate::types::{Oid, Type};
22 : use crate::{
23 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
24 : query, simple_query,
25 : };
26 :
27 : pub struct Responses {
28 : /// new messages from conn
29 : receiver: mpsc::Receiver<BackendMessages>,
30 : /// current batch of messages
31 : cur: BackendMessages,
32 : /// number of total queries sent.
33 : waiting: usize,
34 : /// number of ReadyForQuery messages received.
35 : received: usize,
36 : }
37 :
38 : impl Responses {
39 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
40 : loop {
41 : // get the next saved message
42 0 : if let Some(message) = self.cur.next().map_err(Error::parse)? {
43 0 : let received = self.received;
44 :
45 : // increase the query head if this is the last message.
46 0 : if let Message::ReadyForQuery(_) = message {
47 0 : self.received += 1;
48 0 : }
49 :
50 : // check if the client has skipped this query.
51 0 : if received + 1 < self.waiting {
52 : // grab the next message.
53 0 : continue;
54 0 : }
55 :
56 : // convenience: turn the error messaage into a proper error.
57 0 : let res = match message {
58 0 : Message::ErrorResponse(body) => Err(Error::db(body)),
59 0 : message => Ok(message),
60 : };
61 0 : return Poll::Ready(res);
62 0 : }
63 :
64 : // get the next batch of messages.
65 0 : match ready!(self.receiver.poll_recv(cx)) {
66 0 : Some(messages) => self.cur = messages,
67 0 : None => return Poll::Ready(Err(Error::closed())),
68 : }
69 : }
70 0 : }
71 :
72 0 : pub async fn next(&mut self) -> Result<Message, Error> {
73 0 : future::poll_fn(|cx| self.poll_next(cx)).await
74 0 : }
75 : }
76 :
77 : /// A cache of type info and prepared statements for fetching type info
78 : /// (corresponding to the queries in the [crate::prepare] module).
79 : #[derive(Default)]
80 : pub(crate) struct CachedTypeInfo {
81 : /// Cache of types already looked up.
82 : pub(crate) types: HashMap<Oid, Type>,
83 : }
84 :
85 : pub struct InnerClient {
86 : sender: mpsc::UnboundedSender<FrontendMessage>,
87 : responses: Responses,
88 :
89 : /// A buffer to use when writing out postgres commands.
90 : buffer: BytesMut,
91 : }
92 :
93 : impl InnerClient {
94 0 : pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
95 0 : self.responses.waiting += 1;
96 0 : Ok(PartialQuery(Some(self)))
97 0 : }
98 :
99 0 : pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
100 0 : self.responses.waiting += 1;
101 :
102 0 : self.buffer.clear();
103 : // simple queries do not need sync.
104 0 : frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
105 0 : let buf = self.buffer.split();
106 0 : self.send_message(FrontendMessage::Raw(buf))
107 0 : }
108 :
109 0 : fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
110 0 : self.sender.send(messages).map_err(|_| Error::closed())?;
111 0 : Ok(&mut self.responses)
112 0 : }
113 : }
114 :
115 : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
116 :
117 : impl Drop for PartialQuery<'_> {
118 0 : fn drop(&mut self) {
119 0 : if let Some(client) = self.0.take() {
120 0 : client.buffer.clear();
121 0 : frontend::sync(&mut client.buffer);
122 0 : let buf = client.buffer.split();
123 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
124 0 : }
125 0 : }
126 : }
127 :
128 : impl<'a> PartialQuery<'a> {
129 0 : pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
130 0 : where
131 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
132 : {
133 0 : let client = self.0.as_deref_mut().unwrap();
134 :
135 0 : client.buffer.clear();
136 0 : f(&mut client.buffer)?;
137 0 : frontend::flush(&mut client.buffer);
138 0 : let buf = client.buffer.split();
139 0 : client.send_message(FrontendMessage::Raw(buf))
140 0 : }
141 :
142 0 : pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
143 0 : where
144 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
145 : {
146 0 : let client = self.0.as_deref_mut().unwrap();
147 :
148 0 : client.buffer.clear();
149 0 : f(&mut client.buffer)?;
150 0 : frontend::sync(&mut client.buffer);
151 0 : let buf = client.buffer.split();
152 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
153 :
154 0 : Ok(&mut self.0.take().unwrap().responses)
155 0 : }
156 : }
157 :
158 0 : #[derive(Clone, Serialize, Deserialize)]
159 : pub struct SocketConfig {
160 : pub host_addr: Option<IpAddr>,
161 : pub host: Host,
162 : pub port: u16,
163 : pub connect_timeout: Option<Duration>,
164 : }
165 :
166 : /// An asynchronous PostgreSQL client.
167 : ///
168 : /// The client is one half of what is returned when a connection is established. Users interact with the database
169 : /// through this client object.
170 : pub struct Client {
171 : inner: InnerClient,
172 : cached_typeinfo: CachedTypeInfo,
173 :
174 : socket_config: SocketConfig,
175 : ssl_mode: SslMode,
176 : process_id: i32,
177 : secret_key: i32,
178 : }
179 :
180 : impl Client {
181 0 : pub(crate) fn new(
182 0 : sender: mpsc::UnboundedSender<FrontendMessage>,
183 0 : receiver: mpsc::Receiver<BackendMessages>,
184 0 : socket_config: SocketConfig,
185 0 : ssl_mode: SslMode,
186 0 : process_id: i32,
187 0 : secret_key: i32,
188 0 : write_buf: BytesMut,
189 0 : ) -> Client {
190 0 : Client {
191 0 : inner: InnerClient {
192 0 : sender,
193 0 : responses: Responses {
194 0 : receiver,
195 0 : cur: BackendMessages::empty(),
196 0 : waiting: 0,
197 0 : received: 0,
198 0 : },
199 0 : buffer: write_buf,
200 0 : },
201 0 : cached_typeinfo: Default::default(),
202 0 :
203 0 : socket_config,
204 0 : ssl_mode,
205 0 : process_id,
206 0 : secret_key,
207 0 : }
208 0 : }
209 :
210 : /// Returns process_id.
211 0 : pub fn get_process_id(&self) -> i32 {
212 0 : self.process_id
213 0 : }
214 :
215 0 : pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
216 0 : &mut self.inner
217 0 : }
218 :
219 0 : pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
220 0 : let (tx, rx) = mpsc::unbounded_channel();
221 :
222 0 : let notices = RecordNotices { sender: tx, limit };
223 0 : self.inner
224 0 : .sender
225 0 : .send(FrontendMessage::RecordNotices(notices))
226 0 : .ok();
227 :
228 0 : rx
229 0 : }
230 :
231 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
232 : /// to save a roundtrip
233 0 : pub async fn query_raw_txt<S, I>(
234 0 : &mut self,
235 0 : statement: &str,
236 0 : params: I,
237 0 : ) -> Result<RowStream<'_>, Error>
238 0 : where
239 0 : S: AsRef<str>,
240 0 : I: IntoIterator<Item = Option<S>>,
241 0 : I::IntoIter: ExactSizeIterator,
242 0 : {
243 0 : query::query_txt(
244 0 : &mut self.inner,
245 0 : &mut self.cached_typeinfo,
246 0 : statement,
247 0 : params,
248 0 : )
249 0 : .await
250 0 : }
251 :
252 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
253 : ///
254 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
255 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
256 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
257 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
258 : /// or a row of data. This preserves the framing between the separate statements in the request.
259 : ///
260 : /// # Warning
261 : ///
262 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
263 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
264 : /// them to this method!
265 0 : pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
266 0 : self.simple_query_raw(query).await?.try_collect().await
267 0 : }
268 :
269 0 : pub(crate) async fn simple_query_raw(
270 0 : &mut self,
271 0 : query: &str,
272 0 : ) -> Result<SimpleQueryStream<'_>, Error> {
273 0 : simple_query::simple_query(self.inner_mut(), query).await
274 0 : }
275 :
276 : /// Executes a sequence of SQL statements using the simple query protocol.
277 : ///
278 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
279 : /// point. This is intended for use when, for example, initializing a database schema.
280 : ///
281 : /// # Warning
282 : ///
283 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
284 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
285 : /// them to this method!
286 0 : pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
287 0 : simple_query::batch_execute(self.inner_mut(), query).await
288 0 : }
289 :
290 : /// Similar to `discard_all`, but it does not clear any query plans
291 : ///
292 : /// This runs in the background, so it can be executed without `await`ing.
293 0 : pub fn reset_session_background(&mut self) -> Result<(), Error> {
294 : // "CLOSE ALL": closes any cursors
295 : // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
296 : // "RESET ALL": resets any GUCs back to their session defaults.
297 : // "DEALLOCATE ALL": deallocates any prepared statements
298 : // "UNLISTEN *": stops listening on all channels
299 : // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
300 : // "DISCARD TEMP;": drops all temporary tables
301 : // "DISCARD SEQUENCES;": deallocates all cached sequence state
302 :
303 0 : let _responses = self.inner_mut().send_simple_query(
304 0 : "ROLLBACK;
305 0 : CLOSE ALL;
306 0 : SET SESSION AUTHORIZATION DEFAULT;
307 0 : RESET ALL;
308 0 : DEALLOCATE ALL;
309 0 : UNLISTEN *;
310 0 : SELECT pg_advisory_unlock_all();
311 0 : DISCARD TEMP;
312 0 : DISCARD SEQUENCES;",
313 0 : )?;
314 :
315 : // Clean up memory usage.
316 0 : gc_bytesmut(&mut self.inner_mut().buffer);
317 :
318 0 : Ok(())
319 0 : }
320 :
321 : /// Begins a new database transaction.
322 : ///
323 : /// The transaction will roll back by default - use the `commit` method to commit it.
324 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
325 : struct RollbackIfNotDone<'me> {
326 : client: &'me mut Client,
327 : done: bool,
328 : }
329 :
330 : impl Drop for RollbackIfNotDone<'_> {
331 0 : fn drop(&mut self) {
332 0 : if self.done {
333 0 : return;
334 0 : }
335 :
336 0 : let _ = self.client.inner.send_simple_query("ROLLBACK");
337 0 : }
338 : }
339 :
340 : // This is done, as `Future` created by this method can be dropped after
341 : // `RequestMessages` is synchronously send to the `Connection` by
342 : // `batch_execute()`, but before `Responses` is asynchronously polled to
343 : // completion. In that case `Transaction` won't be created and thus
344 : // won't be rolled back.
345 : {
346 0 : let mut cleaner = RollbackIfNotDone {
347 0 : client: self,
348 0 : done: false,
349 0 : };
350 0 : cleaner.client.batch_execute("BEGIN").await?;
351 0 : cleaner.done = true;
352 : }
353 :
354 0 : Ok(Transaction::new(self))
355 0 : }
356 :
357 : /// Returns a builder for a transaction with custom settings.
358 : ///
359 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
360 : /// attributes.
361 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
362 0 : TransactionBuilder::new(self)
363 0 : }
364 :
365 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
366 : /// connection associated with this client.
367 0 : pub fn cancel_token(&self) -> CancelToken {
368 0 : CancelToken {
369 0 : socket_config: self.socket_config.clone(),
370 0 : raw: RawCancelToken {
371 0 : ssl_mode: self.ssl_mode,
372 0 : process_id: self.process_id,
373 0 : secret_key: self.secret_key,
374 0 : },
375 0 : }
376 0 : }
377 :
378 : /// Determines if the connection to the server has already closed.
379 : ///
380 : /// In that case, all future queries will fail.
381 0 : pub fn is_closed(&self) -> bool {
382 0 : self.inner.sender.is_closed()
383 0 : }
384 : }
385 :
386 : impl fmt::Debug for Client {
387 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
388 0 : f.debug_struct("Client").finish()
389 0 : }
390 : }
|