Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::task::{Context, Poll};
5 : use std::time::Duration;
6 :
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{TryStreamExt, future, ready};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use serde::{Deserialize, Serialize};
13 : use tokio::sync::mpsc;
14 :
15 : use crate::cancel_token::RawCancelToken;
16 : use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
17 : use crate::config::{Host, SslMode};
18 : use crate::query::RowStream;
19 : use crate::simple_query::SimpleQueryStream;
20 : use crate::types::{Oid, Type};
21 : use crate::{
22 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
23 : query, simple_query,
24 : };
25 :
26 : pub struct Responses {
27 : /// new messages from conn
28 : receiver: mpsc::Receiver<BackendMessages>,
29 : /// current batch of messages
30 : cur: BackendMessages,
31 : /// number of total queries sent.
32 : waiting: usize,
33 : /// number of ReadyForQuery messages received.
34 : received: usize,
35 : }
36 :
37 : impl Responses {
38 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
39 : loop {
40 : // get the next saved message
41 0 : if let Some(message) = self.cur.next().map_err(Error::parse)? {
42 0 : let received = self.received;
43 :
44 : // increase the query head if this is the last message.
45 0 : if let Message::ReadyForQuery(_) = message {
46 0 : self.received += 1;
47 0 : }
48 :
49 : // check if the client has skipped this query.
50 0 : if received + 1 < self.waiting {
51 : // grab the next message.
52 0 : continue;
53 0 : }
54 :
55 : // convenience: turn the error messaage into a proper error.
56 0 : let res = match message {
57 0 : Message::ErrorResponse(body) => Err(Error::db(body)),
58 0 : message => Ok(message),
59 : };
60 0 : return Poll::Ready(res);
61 0 : }
62 :
63 : // get the next batch of messages.
64 0 : match ready!(self.receiver.poll_recv(cx)) {
65 0 : Some(messages) => self.cur = messages,
66 0 : None => return Poll::Ready(Err(Error::closed())),
67 : }
68 : }
69 0 : }
70 :
71 0 : pub async fn next(&mut self) -> Result<Message, Error> {
72 0 : future::poll_fn(|cx| self.poll_next(cx)).await
73 0 : }
74 : }
75 :
76 : /// A cache of type info and prepared statements for fetching type info
77 : /// (corresponding to the queries in the [crate::prepare] module).
78 : #[derive(Default)]
79 : pub(crate) struct CachedTypeInfo {
80 : /// Cache of types already looked up.
81 : pub(crate) types: HashMap<Oid, Type>,
82 : }
83 :
84 : pub struct InnerClient {
85 : sender: mpsc::UnboundedSender<FrontendMessage>,
86 : responses: Responses,
87 :
88 : /// A buffer to use when writing out postgres commands.
89 : buffer: BytesMut,
90 : }
91 :
92 : impl InnerClient {
93 0 : pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
94 0 : self.responses.waiting += 1;
95 0 : Ok(PartialQuery(Some(self)))
96 0 : }
97 :
98 : // pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
99 : // where
100 : // F: FnOnce(&mut BytesMut) -> Result<(), Error>,
101 : // {
102 : // self.start()?.send_with_sync(f)
103 : // }
104 :
105 0 : pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
106 0 : self.responses.waiting += 1;
107 :
108 0 : self.buffer.clear();
109 : // simple queries do not need sync.
110 0 : frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
111 0 : let buf = self.buffer.split().freeze();
112 0 : self.send_message(FrontendMessage::Raw(buf))
113 0 : }
114 :
115 0 : fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
116 0 : self.sender.send(messages).map_err(|_| Error::closed())?;
117 0 : Ok(&mut self.responses)
118 0 : }
119 : }
120 :
121 : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
122 :
123 : impl Drop for PartialQuery<'_> {
124 0 : fn drop(&mut self) {
125 0 : if let Some(client) = self.0.take() {
126 0 : client.buffer.clear();
127 0 : frontend::sync(&mut client.buffer);
128 0 : let buf = client.buffer.split().freeze();
129 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
130 0 : }
131 0 : }
132 : }
133 :
134 : impl<'a> PartialQuery<'a> {
135 0 : pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
136 0 : where
137 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
138 : {
139 0 : let client = self.0.as_deref_mut().unwrap();
140 :
141 0 : client.buffer.clear();
142 0 : f(&mut client.buffer)?;
143 0 : frontend::flush(&mut client.buffer);
144 0 : let buf = client.buffer.split().freeze();
145 0 : client.send_message(FrontendMessage::Raw(buf))
146 0 : }
147 :
148 0 : pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
149 0 : where
150 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
151 : {
152 0 : let client = self.0.as_deref_mut().unwrap();
153 :
154 0 : client.buffer.clear();
155 0 : f(&mut client.buffer)?;
156 0 : frontend::sync(&mut client.buffer);
157 0 : let buf = client.buffer.split().freeze();
158 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
159 :
160 0 : Ok(&mut self.0.take().unwrap().responses)
161 0 : }
162 : }
163 :
164 0 : #[derive(Clone, Serialize, Deserialize)]
165 : pub struct SocketConfig {
166 : pub host_addr: Option<IpAddr>,
167 : pub host: Host,
168 : pub port: u16,
169 : pub connect_timeout: Option<Duration>,
170 : }
171 :
172 : /// An asynchronous PostgreSQL client.
173 : ///
174 : /// The client is one half of what is returned when a connection is established. Users interact with the database
175 : /// through this client object.
176 : pub struct Client {
177 : inner: InnerClient,
178 : cached_typeinfo: CachedTypeInfo,
179 :
180 : socket_config: SocketConfig,
181 : ssl_mode: SslMode,
182 : process_id: i32,
183 : secret_key: i32,
184 : }
185 :
186 : impl Client {
187 0 : pub(crate) fn new(
188 0 : sender: mpsc::UnboundedSender<FrontendMessage>,
189 0 : receiver: mpsc::Receiver<BackendMessages>,
190 0 : socket_config: SocketConfig,
191 0 : ssl_mode: SslMode,
192 0 : process_id: i32,
193 0 : secret_key: i32,
194 0 : ) -> Client {
195 0 : Client {
196 0 : inner: InnerClient {
197 0 : sender,
198 0 : responses: Responses {
199 0 : receiver,
200 0 : cur: BackendMessages::empty(),
201 0 : waiting: 0,
202 0 : received: 0,
203 0 : },
204 0 : buffer: Default::default(),
205 0 : },
206 0 : cached_typeinfo: Default::default(),
207 0 :
208 0 : socket_config,
209 0 : ssl_mode,
210 0 : process_id,
211 0 : secret_key,
212 0 : }
213 0 : }
214 :
215 : /// Returns process_id.
216 0 : pub fn get_process_id(&self) -> i32 {
217 0 : self.process_id
218 0 : }
219 :
220 0 : pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
221 0 : &mut self.inner
222 0 : }
223 :
224 0 : pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver<Box<str>> {
225 0 : let (tx, rx) = mpsc::unbounded_channel();
226 :
227 0 : let notices = RecordNotices { sender: tx, limit };
228 0 : self.inner
229 0 : .sender
230 0 : .send(FrontendMessage::RecordNotices(notices))
231 0 : .ok();
232 :
233 0 : rx
234 0 : }
235 :
236 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
237 : /// to save a roundtrip
238 0 : pub async fn query_raw_txt<S, I>(
239 0 : &mut self,
240 0 : statement: &str,
241 0 : params: I,
242 0 : ) -> Result<RowStream<'_>, Error>
243 0 : where
244 0 : S: AsRef<str>,
245 0 : I: IntoIterator<Item = Option<S>>,
246 0 : I::IntoIter: ExactSizeIterator,
247 0 : {
248 0 : query::query_txt(
249 0 : &mut self.inner,
250 0 : &mut self.cached_typeinfo,
251 0 : statement,
252 0 : params,
253 0 : )
254 0 : .await
255 0 : }
256 :
257 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
258 : ///
259 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
260 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
261 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
262 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
263 : /// or a row of data. This preserves the framing between the separate statements in the request.
264 : ///
265 : /// # Warning
266 : ///
267 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
268 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
269 : /// them to this method!
270 0 : pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
271 0 : self.simple_query_raw(query).await?.try_collect().await
272 0 : }
273 :
274 0 : pub(crate) async fn simple_query_raw(
275 0 : &mut self,
276 0 : query: &str,
277 0 : ) -> Result<SimpleQueryStream<'_>, Error> {
278 0 : simple_query::simple_query(self.inner_mut(), query).await
279 0 : }
280 :
281 : /// Executes a sequence of SQL statements using the simple query protocol.
282 : ///
283 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
284 : /// point. This is intended for use when, for example, initializing a database schema.
285 : ///
286 : /// # Warning
287 : ///
288 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
289 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
290 : /// them to this method!
291 0 : pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
292 0 : simple_query::batch_execute(self.inner_mut(), query).await
293 0 : }
294 :
295 0 : pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
296 0 : self.batch_execute("discard all").await
297 0 : }
298 :
299 : /// Begins a new database transaction.
300 : ///
301 : /// The transaction will roll back by default - use the `commit` method to commit it.
302 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
303 : struct RollbackIfNotDone<'me> {
304 : client: &'me mut Client,
305 : done: bool,
306 : }
307 :
308 : impl Drop for RollbackIfNotDone<'_> {
309 0 : fn drop(&mut self) {
310 0 : if self.done {
311 0 : return;
312 0 : }
313 :
314 0 : let _ = self.client.inner.send_simple_query("ROLLBACK");
315 0 : }
316 : }
317 :
318 : // This is done, as `Future` created by this method can be dropped after
319 : // `RequestMessages` is synchronously send to the `Connection` by
320 : // `batch_execute()`, but before `Responses` is asynchronously polled to
321 : // completion. In that case `Transaction` won't be created and thus
322 : // won't be rolled back.
323 : {
324 0 : let mut cleaner = RollbackIfNotDone {
325 0 : client: self,
326 0 : done: false,
327 0 : };
328 0 : cleaner.client.batch_execute("BEGIN").await?;
329 0 : cleaner.done = true;
330 : }
331 :
332 0 : Ok(Transaction::new(self))
333 0 : }
334 :
335 : /// Returns a builder for a transaction with custom settings.
336 : ///
337 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
338 : /// attributes.
339 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
340 0 : TransactionBuilder::new(self)
341 0 : }
342 :
343 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
344 : /// connection associated with this client.
345 0 : pub fn cancel_token(&self) -> CancelToken {
346 0 : CancelToken {
347 0 : socket_config: self.socket_config.clone(),
348 0 : raw: RawCancelToken {
349 0 : ssl_mode: self.ssl_mode,
350 0 : process_id: self.process_id,
351 0 : secret_key: self.secret_key,
352 0 : },
353 0 : }
354 0 : }
355 :
356 : /// Determines if the connection to the server has already closed.
357 : ///
358 : /// In that case, all future queries will fail.
359 0 : pub fn is_closed(&self) -> bool {
360 0 : self.inner.sender.is_closed()
361 0 : }
362 : }
363 :
364 : impl fmt::Debug for Client {
365 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 0 : f.debug_struct("Client").finish()
367 0 : }
368 : }
|