Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::task::{Context, Poll};
5 : use std::time::Duration;
6 :
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{TryStreamExt, future, ready};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use serde::{Deserialize, Serialize};
13 : use tokio::sync::mpsc;
14 :
15 : use crate::codec::{BackendMessages, FrontendMessage};
16 : use crate::config::{Host, SslMode};
17 : use crate::query::RowStream;
18 : use crate::simple_query::SimpleQueryStream;
19 : use crate::types::{Oid, Type};
20 : use crate::{
21 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
22 : query, simple_query,
23 : };
24 :
25 : pub struct Responses {
26 : /// new messages from conn
27 : receiver: mpsc::Receiver<BackendMessages>,
28 : /// current batch of messages
29 : cur: BackendMessages,
30 : /// number of total queries sent.
31 : waiting: usize,
32 : /// number of ReadyForQuery messages received.
33 : received: usize,
34 : }
35 :
36 : impl Responses {
37 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
38 : loop {
39 : // get the next saved message
40 0 : if let Some(message) = self.cur.next().map_err(Error::parse)? {
41 0 : let received = self.received;
42 0 :
43 0 : // increase the query head if this is the last message.
44 0 : if let Message::ReadyForQuery(_) = message {
45 0 : self.received += 1;
46 0 : }
47 :
48 : // check if the client has skipped this query.
49 0 : if received + 1 < self.waiting {
50 : // grab the next message.
51 0 : continue;
52 0 : }
53 :
54 : // convenience: turn the error messaage into a proper error.
55 0 : let res = match message {
56 0 : Message::ErrorResponse(body) => Err(Error::db(body)),
57 0 : message => Ok(message),
58 : };
59 0 : return Poll::Ready(res);
60 0 : }
61 :
62 : // get the next batch of messages.
63 0 : match ready!(self.receiver.poll_recv(cx)) {
64 0 : Some(messages) => self.cur = messages,
65 0 : None => return Poll::Ready(Err(Error::closed())),
66 : }
67 : }
68 0 : }
69 :
70 0 : pub async fn next(&mut self) -> Result<Message, Error> {
71 0 : future::poll_fn(|cx| self.poll_next(cx)).await
72 0 : }
73 : }
74 :
75 : /// A cache of type info and prepared statements for fetching type info
76 : /// (corresponding to the queries in the [crate::prepare] module).
77 : #[derive(Default)]
78 : pub(crate) struct CachedTypeInfo {
79 : /// Cache of types already looked up.
80 : pub(crate) types: HashMap<Oid, Type>,
81 : }
82 :
83 : pub struct InnerClient {
84 : sender: mpsc::UnboundedSender<FrontendMessage>,
85 : responses: Responses,
86 :
87 : /// A buffer to use when writing out postgres commands.
88 : buffer: BytesMut,
89 : }
90 :
91 : impl InnerClient {
92 0 : pub fn start(&mut self) -> Result<PartialQuery, Error> {
93 0 : self.responses.waiting += 1;
94 0 : Ok(PartialQuery(Some(self)))
95 0 : }
96 :
97 : // pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
98 : // where
99 : // F: FnOnce(&mut BytesMut) -> Result<(), Error>,
100 : // {
101 : // self.start()?.send_with_sync(f)
102 : // }
103 :
104 0 : pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
105 0 : self.responses.waiting += 1;
106 0 :
107 0 : self.buffer.clear();
108 0 : // simple queries do not need sync.
109 0 : frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
110 0 : let buf = self.buffer.split().freeze();
111 0 : self.send_message(FrontendMessage::Raw(buf))
112 0 : }
113 :
114 0 : fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
115 0 : self.sender.send(messages).map_err(|_| Error::closed())?;
116 0 : Ok(&mut self.responses)
117 0 : }
118 : }
119 :
120 : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
121 :
122 : impl Drop for PartialQuery<'_> {
123 0 : fn drop(&mut self) {
124 0 : if let Some(client) = self.0.take() {
125 0 : client.buffer.clear();
126 0 : frontend::sync(&mut client.buffer);
127 0 : let buf = client.buffer.split().freeze();
128 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
129 0 : }
130 0 : }
131 : }
132 :
133 : impl<'a> PartialQuery<'a> {
134 0 : pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
135 0 : where
136 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
137 0 : {
138 0 : let client = self.0.as_deref_mut().unwrap();
139 0 :
140 0 : client.buffer.clear();
141 0 : f(&mut client.buffer)?;
142 0 : frontend::flush(&mut client.buffer);
143 0 : let buf = client.buffer.split().freeze();
144 0 : client.send_message(FrontendMessage::Raw(buf))
145 0 : }
146 :
147 0 : pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
148 0 : where
149 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
150 0 : {
151 0 : let client = self.0.as_deref_mut().unwrap();
152 0 :
153 0 : client.buffer.clear();
154 0 : f(&mut client.buffer)?;
155 0 : frontend::sync(&mut client.buffer);
156 0 : let buf = client.buffer.split().freeze();
157 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
158 0 :
159 0 : Ok(&mut self.0.take().unwrap().responses)
160 0 : }
161 : }
162 :
163 0 : #[derive(Clone, Serialize, Deserialize)]
164 : pub struct SocketConfig {
165 : pub host_addr: Option<IpAddr>,
166 : pub host: Host,
167 : pub port: u16,
168 : pub connect_timeout: Option<Duration>,
169 : }
170 :
171 : /// An asynchronous PostgreSQL client.
172 : ///
173 : /// The client is one half of what is returned when a connection is established. Users interact with the database
174 : /// through this client object.
175 : pub struct Client {
176 : inner: InnerClient,
177 : cached_typeinfo: CachedTypeInfo,
178 :
179 : socket_config: SocketConfig,
180 : ssl_mode: SslMode,
181 : process_id: i32,
182 : secret_key: i32,
183 : }
184 :
185 : impl Client {
186 0 : pub(crate) fn new(
187 0 : sender: mpsc::UnboundedSender<FrontendMessage>,
188 0 : receiver: mpsc::Receiver<BackendMessages>,
189 0 : socket_config: SocketConfig,
190 0 : ssl_mode: SslMode,
191 0 : process_id: i32,
192 0 : secret_key: i32,
193 0 : ) -> Client {
194 0 : Client {
195 0 : inner: InnerClient {
196 0 : sender,
197 0 : responses: Responses {
198 0 : receiver,
199 0 : cur: BackendMessages::empty(),
200 0 : waiting: 0,
201 0 : received: 0,
202 0 : },
203 0 : buffer: Default::default(),
204 0 : },
205 0 : cached_typeinfo: Default::default(),
206 0 :
207 0 : socket_config,
208 0 : ssl_mode,
209 0 : process_id,
210 0 : secret_key,
211 0 : }
212 0 : }
213 :
214 : /// Returns process_id.
215 0 : pub fn get_process_id(&self) -> i32 {
216 0 : self.process_id
217 0 : }
218 :
219 0 : pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
220 0 : &mut self.inner
221 0 : }
222 :
223 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
224 : /// to save a roundtrip
225 0 : pub async fn query_raw_txt<S, I>(
226 0 : &mut self,
227 0 : statement: &str,
228 0 : params: I,
229 0 : ) -> Result<RowStream, Error>
230 0 : where
231 0 : S: AsRef<str>,
232 0 : I: IntoIterator<Item = Option<S>>,
233 0 : I::IntoIter: ExactSizeIterator,
234 0 : {
235 0 : query::query_txt(
236 0 : &mut self.inner,
237 0 : &mut self.cached_typeinfo,
238 0 : statement,
239 0 : params,
240 0 : )
241 0 : .await
242 0 : }
243 :
244 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
245 : ///
246 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
247 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
248 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
249 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
250 : /// or a row of data. This preserves the framing between the separate statements in the request.
251 : ///
252 : /// # Warning
253 : ///
254 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
255 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
256 : /// them to this method!
257 0 : pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
258 0 : self.simple_query_raw(query).await?.try_collect().await
259 0 : }
260 :
261 0 : pub(crate) async fn simple_query_raw(
262 0 : &mut self,
263 0 : query: &str,
264 0 : ) -> Result<SimpleQueryStream, Error> {
265 0 : simple_query::simple_query(self.inner_mut(), query).await
266 0 : }
267 :
268 : /// Executes a sequence of SQL statements using the simple query protocol.
269 : ///
270 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
271 : /// point. This is intended for use when, for example, initializing a database schema.
272 : ///
273 : /// # Warning
274 : ///
275 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
276 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
277 : /// them to this method!
278 0 : pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
279 0 : simple_query::batch_execute(self.inner_mut(), query).await
280 0 : }
281 :
282 0 : pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
283 0 : self.batch_execute("discard all").await
284 0 : }
285 :
286 : /// Begins a new database transaction.
287 : ///
288 : /// The transaction will roll back by default - use the `commit` method to commit it.
289 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
290 : struct RollbackIfNotDone<'me> {
291 : client: &'me mut Client,
292 : done: bool,
293 : }
294 :
295 : impl Drop for RollbackIfNotDone<'_> {
296 0 : fn drop(&mut self) {
297 0 : if self.done {
298 0 : return;
299 0 : }
300 0 :
301 0 : let _ = self.client.inner.send_simple_query("ROLLBACK");
302 0 : }
303 : }
304 :
305 : // This is done, as `Future` created by this method can be dropped after
306 : // `RequestMessages` is synchronously send to the `Connection` by
307 : // `batch_execute()`, but before `Responses` is asynchronously polled to
308 : // completion. In that case `Transaction` won't be created and thus
309 : // won't be rolled back.
310 : {
311 0 : let mut cleaner = RollbackIfNotDone {
312 0 : client: self,
313 0 : done: false,
314 0 : };
315 0 : cleaner.client.batch_execute("BEGIN").await?;
316 0 : cleaner.done = true;
317 0 : }
318 0 :
319 0 : Ok(Transaction::new(self))
320 0 : }
321 :
322 : /// Returns a builder for a transaction with custom settings.
323 : ///
324 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
325 : /// attributes.
326 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
327 0 : TransactionBuilder::new(self)
328 0 : }
329 :
330 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
331 : /// connection associated with this client.
332 0 : pub fn cancel_token(&self) -> CancelToken {
333 0 : CancelToken {
334 0 : socket_config: Some(self.socket_config.clone()),
335 0 : ssl_mode: self.ssl_mode,
336 0 : process_id: self.process_id,
337 0 : secret_key: self.secret_key,
338 0 : }
339 0 : }
340 :
341 : /// Determines if the connection to the server has already closed.
342 : ///
343 : /// In that case, all future queries will fail.
344 0 : pub fn is_closed(&self) -> bool {
345 0 : self.inner.sender.is_closed()
346 0 : }
347 : }
348 :
349 : impl fmt::Debug for Client {
350 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 0 : f.debug_struct("Client").finish()
352 0 : }
353 : }
|