Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::task::{Context, Poll};
5 : use std::time::Duration;
6 :
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{TryStreamExt, future, ready};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use serde::{Deserialize, Serialize};
13 : use tokio::sync::mpsc;
14 :
15 : use crate::cancel_token::RawCancelToken;
16 : use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
17 : use crate::config::{Host, SslMode};
18 : use crate::connection::gc_bytesmut;
19 : use crate::query::RowStream;
20 : use crate::simple_query::SimpleQueryStream;
21 : use crate::types::{Oid, Type};
22 : use crate::{
23 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
24 : query, simple_query,
25 : };
26 :
27 : pub struct Responses {
28 : /// new messages from conn
29 : receiver: mpsc::Receiver<BackendMessages>,
30 : /// current batch of messages
31 : cur: BackendMessages,
32 : /// number of total queries sent.
33 : waiting: usize,
34 : /// number of ReadyForQuery messages received.
35 : received: usize,
36 : }
37 :
38 : impl Responses {
39 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
40 : loop {
41 : // get the next saved message
42 0 : if let Some(message) = self.cur.next().map_err(Error::parse)? {
43 0 : let received = self.received;
44 :
45 : // increase the query head if this is the last message.
46 0 : if let Message::ReadyForQuery(_) = message {
47 0 : self.received += 1;
48 0 : }
49 :
50 : // check if the client has skipped this query.
51 0 : if received + 1 < self.waiting {
52 : // grab the next message.
53 0 : continue;
54 0 : }
55 :
56 : // convenience: turn the error messaage into a proper error.
57 0 : let res = match message {
58 0 : Message::ErrorResponse(body) => Err(Error::db(body)),
59 0 : message => Ok(message),
60 : };
61 0 : return Poll::Ready(res);
62 0 : }
63 :
64 : // get the next batch of messages.
65 0 : match ready!(self.receiver.poll_recv(cx)) {
66 0 : Some(messages) => self.cur = messages,
67 0 : None => return Poll::Ready(Err(Error::closed())),
68 : }
69 : }
70 0 : }
71 :
72 0 : pub async fn next(&mut self) -> Result<Message, Error> {
73 0 : future::poll_fn(|cx| self.poll_next(cx)).await
74 0 : }
75 : }
76 :
77 : /// A cache of type info and prepared statements for fetching type info
78 : /// (corresponding to the queries in the [crate::prepare] module).
79 : #[derive(Default)]
80 : pub(crate) struct CachedTypeInfo {
81 : /// Cache of types already looked up.
82 : pub(crate) types: HashMap<Oid, Type>,
83 : }
84 :
85 : pub struct InnerClient {
86 : sender: mpsc::UnboundedSender<FrontendMessage>,
87 : responses: Responses,
88 :
89 : /// A buffer to use when writing out postgres commands.
90 : buffer: BytesMut,
91 : }
92 :
93 : impl InnerClient {
94 0 : pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
95 0 : self.responses.waiting += 1;
96 0 : Ok(PartialQuery(Some(self)))
97 0 : }
98 :
99 0 : pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
100 0 : self.responses.waiting += 1;
101 :
102 0 : self.buffer.clear();
103 : // simple queries do not need sync.
104 0 : frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
105 0 : let buf = self.buffer.split();
106 0 : self.send_message(FrontendMessage::Raw(buf))
107 0 : }
108 :
109 0 : fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
110 0 : self.sender.send(messages).map_err(|_| Error::closed())?;
111 0 : Ok(&mut self.responses)
112 0 : }
113 : }
114 :
115 : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
116 :
117 : impl Drop for PartialQuery<'_> {
118 0 : fn drop(&mut self) {
119 0 : if let Some(client) = self.0.take() {
120 0 : client.buffer.clear();
121 0 : frontend::sync(&mut client.buffer);
122 0 : let buf = client.buffer.split();
123 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
124 0 : }
125 0 : }
126 : }
127 :
128 : impl<'a> PartialQuery<'a> {
129 0 : pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
130 0 : where
131 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
132 : {
133 0 : let client = self.0.as_deref_mut().unwrap();
134 :
135 0 : client.buffer.clear();
136 0 : f(&mut client.buffer)?;
137 0 : frontend::flush(&mut client.buffer);
138 0 : let buf = client.buffer.split();
139 0 : client.send_message(FrontendMessage::Raw(buf))
140 0 : }
141 :
142 0 : pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
143 0 : where
144 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
145 : {
146 0 : let client = self.0.as_deref_mut().unwrap();
147 :
148 0 : client.buffer.clear();
149 0 : f(&mut client.buffer)?;
150 0 : frontend::sync(&mut client.buffer);
151 0 : let buf = client.buffer.split();
152 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
153 :
154 0 : Ok(&mut self.0.take().unwrap().responses)
155 0 : }
156 : }
157 :
158 0 : #[derive(Clone, Serialize, Deserialize)]
159 : pub struct SocketConfig {
160 : pub host_addr: Option<IpAddr>,
161 : pub host: Host,
162 : pub port: u16,
163 : pub connect_timeout: Option<Duration>,
164 : }
165 :
166 : /// An asynchronous PostgreSQL client.
167 : ///
168 : /// The client is one half of what is returned when a connection is established. Users interact with the database
169 : /// through this client object.
170 : pub struct Client {
171 : inner: InnerClient,
172 : cached_typeinfo: CachedTypeInfo,
173 :
174 : socket_config: SocketConfig,
175 : ssl_mode: SslMode,
176 : process_id: i32,
177 : secret_key: i32,
178 : }
179 :
180 : impl Client {
181 0 : pub(crate) fn new(
182 0 : sender: mpsc::UnboundedSender<FrontendMessage>,
183 0 : receiver: mpsc::Receiver<BackendMessages>,
184 0 : socket_config: SocketConfig,
185 0 : ssl_mode: SslMode,
186 0 : process_id: i32,
187 0 : secret_key: i32,
188 0 : ) -> Client {
189 0 : Client {
190 0 : inner: InnerClient {
191 0 : sender,
192 0 : responses: Responses {
193 0 : receiver,
194 0 : cur: BackendMessages::empty(),
195 0 : waiting: 0,
196 0 : received: 0,
197 0 : },
198 0 : buffer: Default::default(),
199 0 : },
200 0 : cached_typeinfo: Default::default(),
201 0 :
202 0 : socket_config,
203 0 : ssl_mode,
204 0 : process_id,
205 0 : secret_key,
206 0 : }
207 0 : }
208 :
209 : /// Returns process_id.
210 0 : pub fn get_process_id(&self) -> i32 {
211 0 : self.process_id
212 0 : }
213 :
214 0 : pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
215 0 : &mut self.inner
216 0 : }
217 :
218 0 : pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
219 0 : let (tx, rx) = mpsc::unbounded_channel();
220 :
221 0 : let notices = RecordNotices { sender: tx, limit };
222 0 : self.inner
223 0 : .sender
224 0 : .send(FrontendMessage::RecordNotices(notices))
225 0 : .ok();
226 :
227 0 : rx
228 0 : }
229 :
230 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
231 : /// to save a roundtrip
232 0 : pub async fn query_raw_txt<S, I>(
233 0 : &mut self,
234 0 : statement: &str,
235 0 : params: I,
236 0 : ) -> Result<RowStream<'_>, Error>
237 0 : where
238 0 : S: AsRef<str>,
239 0 : I: IntoIterator<Item = Option<S>>,
240 0 : I::IntoIter: ExactSizeIterator,
241 0 : {
242 0 : query::query_txt(
243 0 : &mut self.inner,
244 0 : &mut self.cached_typeinfo,
245 0 : statement,
246 0 : params,
247 0 : )
248 0 : .await
249 0 : }
250 :
251 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
252 : ///
253 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
254 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
255 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
256 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
257 : /// or a row of data. This preserves the framing between the separate statements in the request.
258 : ///
259 : /// # Warning
260 : ///
261 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
262 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
263 : /// them to this method!
264 0 : pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
265 0 : self.simple_query_raw(query).await?.try_collect().await
266 0 : }
267 :
268 0 : pub(crate) async fn simple_query_raw(
269 0 : &mut self,
270 0 : query: &str,
271 0 : ) -> Result<SimpleQueryStream<'_>, Error> {
272 0 : simple_query::simple_query(self.inner_mut(), query).await
273 0 : }
274 :
275 : /// Executes a sequence of SQL statements using the simple query protocol.
276 : ///
277 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
278 : /// point. This is intended for use when, for example, initializing a database schema.
279 : ///
280 : /// # Warning
281 : ///
282 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
283 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
284 : /// them to this method!
285 0 : pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
286 0 : simple_query::batch_execute(self.inner_mut(), query).await
287 0 : }
288 :
289 : /// Similar to `discard_all`, but it does not clear any query plans
290 : ///
291 : /// This runs in the background, so it can be executed without `await`ing.
292 0 : pub fn reset_session_background(&mut self) -> Result<(), Error> {
293 : // "CLOSE ALL": closes any cursors
294 : // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
295 : // "RESET ALL": resets any GUCs back to their session defaults.
296 : // "DEALLOCATE ALL": deallocates any prepared statements
297 : // "UNLISTEN *": stops listening on all channels
298 : // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
299 : // "DISCARD TEMP;": drops all temporary tables
300 : // "DISCARD SEQUENCES;": deallocates all cached sequence state
301 :
302 0 : let _responses = self.inner_mut().send_simple_query(
303 0 : "ROLLBACK;
304 0 : CLOSE ALL;
305 0 : SET SESSION AUTHORIZATION DEFAULT;
306 0 : RESET ALL;
307 0 : DEALLOCATE ALL;
308 0 : UNLISTEN *;
309 0 : SELECT pg_advisory_unlock_all();
310 0 : DISCARD TEMP;
311 0 : DISCARD SEQUENCES;",
312 0 : )?;
313 :
314 : // Clean up memory usage.
315 0 : gc_bytesmut(&mut self.inner_mut().buffer);
316 :
317 0 : Ok(())
318 0 : }
319 :
320 : /// Begins a new database transaction.
321 : ///
322 : /// The transaction will roll back by default - use the `commit` method to commit it.
323 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
324 : struct RollbackIfNotDone<'me> {
325 : client: &'me mut Client,
326 : done: bool,
327 : }
328 :
329 : impl Drop for RollbackIfNotDone<'_> {
330 0 : fn drop(&mut self) {
331 0 : if self.done {
332 0 : return;
333 0 : }
334 :
335 0 : let _ = self.client.inner.send_simple_query("ROLLBACK");
336 0 : }
337 : }
338 :
339 : // This is done, as `Future` created by this method can be dropped after
340 : // `RequestMessages` is synchronously send to the `Connection` by
341 : // `batch_execute()`, but before `Responses` is asynchronously polled to
342 : // completion. In that case `Transaction` won't be created and thus
343 : // won't be rolled back.
344 : {
345 0 : let mut cleaner = RollbackIfNotDone {
346 0 : client: self,
347 0 : done: false,
348 0 : };
349 0 : cleaner.client.batch_execute("BEGIN").await?;
350 0 : cleaner.done = true;
351 : }
352 :
353 0 : Ok(Transaction::new(self))
354 0 : }
355 :
356 : /// Returns a builder for a transaction with custom settings.
357 : ///
358 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
359 : /// attributes.
360 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
361 0 : TransactionBuilder::new(self)
362 0 : }
363 :
364 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
365 : /// connection associated with this client.
366 0 : pub fn cancel_token(&self) -> CancelToken {
367 0 : CancelToken {
368 0 : socket_config: self.socket_config.clone(),
369 0 : raw: RawCancelToken {
370 0 : ssl_mode: self.ssl_mode,
371 0 : process_id: self.process_id,
372 0 : secret_key: self.secret_key,
373 0 : },
374 0 : }
375 0 : }
376 :
377 : /// Determines if the connection to the server has already closed.
378 : ///
379 : /// In that case, all future queries will fail.
380 0 : pub fn is_closed(&self) -> bool {
381 0 : self.inner.sender.is_closed()
382 0 : }
383 : }
384 :
385 : impl fmt::Debug for Client {
386 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 0 : f.debug_struct("Client").finish()
388 0 : }
389 : }
|