Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::task::{Context, Poll};
5 : use std::time::Duration;
6 :
7 : use bytes::BytesMut;
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{TryStreamExt, future, ready};
10 : use postgres_protocol2::message::backend::Message;
11 : use postgres_protocol2::message::frontend;
12 : use serde::{Deserialize, Serialize};
13 : use tokio::sync::mpsc;
14 :
15 : use crate::cancel_token::RawCancelToken;
16 : use crate::codec::{BackendMessages, FrontendMessage};
17 : use crate::config::{Host, SslMode};
18 : use crate::query::RowStream;
19 : use crate::simple_query::SimpleQueryStream;
20 : use crate::types::{Oid, Type};
21 : use crate::{
22 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
23 : query, simple_query,
24 : };
25 :
26 : pub struct Responses {
27 : /// new messages from conn
28 : receiver: mpsc::Receiver<BackendMessages>,
29 : /// current batch of messages
30 : cur: BackendMessages,
31 : /// number of total queries sent.
32 : waiting: usize,
33 : /// number of ReadyForQuery messages received.
34 : received: usize,
35 : }
36 :
37 : impl Responses {
38 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
39 : loop {
40 : // get the next saved message
41 0 : if let Some(message) = self.cur.next().map_err(Error::parse)? {
42 0 : let received = self.received;
43 :
44 : // increase the query head if this is the last message.
45 0 : if let Message::ReadyForQuery(_) = message {
46 0 : self.received += 1;
47 0 : }
48 :
49 : // check if the client has skipped this query.
50 0 : if received + 1 < self.waiting {
51 : // grab the next message.
52 0 : continue;
53 0 : }
54 :
55 : // convenience: turn the error messaage into a proper error.
56 0 : let res = match message {
57 0 : Message::ErrorResponse(body) => Err(Error::db(body)),
58 0 : message => Ok(message),
59 : };
60 0 : return Poll::Ready(res);
61 0 : }
62 :
63 : // get the next batch of messages.
64 0 : match ready!(self.receiver.poll_recv(cx)) {
65 0 : Some(messages) => self.cur = messages,
66 0 : None => return Poll::Ready(Err(Error::closed())),
67 : }
68 : }
69 0 : }
70 :
71 0 : pub async fn next(&mut self) -> Result<Message, Error> {
72 0 : future::poll_fn(|cx| self.poll_next(cx)).await
73 0 : }
74 : }
75 :
76 : /// A cache of type info and prepared statements for fetching type info
77 : /// (corresponding to the queries in the [crate::prepare] module).
78 : #[derive(Default)]
79 : pub(crate) struct CachedTypeInfo {
80 : /// Cache of types already looked up.
81 : pub(crate) types: HashMap<Oid, Type>,
82 : }
83 :
84 : pub struct InnerClient {
85 : sender: mpsc::UnboundedSender<FrontendMessage>,
86 : responses: Responses,
87 :
88 : /// A buffer to use when writing out postgres commands.
89 : buffer: BytesMut,
90 : }
91 :
92 : impl InnerClient {
93 0 : pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
94 0 : self.responses.waiting += 1;
95 0 : Ok(PartialQuery(Some(self)))
96 0 : }
97 :
98 : // pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
99 : // where
100 : // F: FnOnce(&mut BytesMut) -> Result<(), Error>,
101 : // {
102 : // self.start()?.send_with_sync(f)
103 : // }
104 :
105 0 : pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
106 0 : self.responses.waiting += 1;
107 :
108 0 : self.buffer.clear();
109 : // simple queries do not need sync.
110 0 : frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
111 0 : let buf = self.buffer.split().freeze();
112 0 : self.send_message(FrontendMessage::Raw(buf))
113 0 : }
114 :
115 0 : fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
116 0 : self.sender.send(messages).map_err(|_| Error::closed())?;
117 0 : Ok(&mut self.responses)
118 0 : }
119 : }
120 :
121 : pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
122 :
123 : impl Drop for PartialQuery<'_> {
124 0 : fn drop(&mut self) {
125 0 : if let Some(client) = self.0.take() {
126 0 : client.buffer.clear();
127 0 : frontend::sync(&mut client.buffer);
128 0 : let buf = client.buffer.split().freeze();
129 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
130 0 : }
131 0 : }
132 : }
133 :
134 : impl<'a> PartialQuery<'a> {
135 0 : pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
136 0 : where
137 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
138 : {
139 0 : let client = self.0.as_deref_mut().unwrap();
140 :
141 0 : client.buffer.clear();
142 0 : f(&mut client.buffer)?;
143 0 : frontend::flush(&mut client.buffer);
144 0 : let buf = client.buffer.split().freeze();
145 0 : client.send_message(FrontendMessage::Raw(buf))
146 0 : }
147 :
148 0 : pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
149 0 : where
150 0 : F: FnOnce(&mut BytesMut) -> Result<(), Error>,
151 : {
152 0 : let client = self.0.as_deref_mut().unwrap();
153 :
154 0 : client.buffer.clear();
155 0 : f(&mut client.buffer)?;
156 0 : frontend::sync(&mut client.buffer);
157 0 : let buf = client.buffer.split().freeze();
158 0 : let _ = client.send_message(FrontendMessage::Raw(buf));
159 :
160 0 : Ok(&mut self.0.take().unwrap().responses)
161 0 : }
162 : }
163 :
164 0 : #[derive(Clone, Serialize, Deserialize)]
165 : pub struct SocketConfig {
166 : pub host_addr: Option<IpAddr>,
167 : pub host: Host,
168 : pub port: u16,
169 : pub connect_timeout: Option<Duration>,
170 : }
171 :
172 : /// An asynchronous PostgreSQL client.
173 : ///
174 : /// The client is one half of what is returned when a connection is established. Users interact with the database
175 : /// through this client object.
176 : pub struct Client {
177 : inner: InnerClient,
178 : cached_typeinfo: CachedTypeInfo,
179 :
180 : socket_config: SocketConfig,
181 : ssl_mode: SslMode,
182 : process_id: i32,
183 : secret_key: i32,
184 : }
185 :
186 : impl Client {
187 0 : pub(crate) fn new(
188 0 : sender: mpsc::UnboundedSender<FrontendMessage>,
189 0 : receiver: mpsc::Receiver<BackendMessages>,
190 0 : socket_config: SocketConfig,
191 0 : ssl_mode: SslMode,
192 0 : process_id: i32,
193 0 : secret_key: i32,
194 0 : ) -> Client {
195 0 : Client {
196 0 : inner: InnerClient {
197 0 : sender,
198 0 : responses: Responses {
199 0 : receiver,
200 0 : cur: BackendMessages::empty(),
201 0 : waiting: 0,
202 0 : received: 0,
203 0 : },
204 0 : buffer: Default::default(),
205 0 : },
206 0 : cached_typeinfo: Default::default(),
207 0 :
208 0 : socket_config,
209 0 : ssl_mode,
210 0 : process_id,
211 0 : secret_key,
212 0 : }
213 0 : }
214 :
215 : /// Returns process_id.
216 0 : pub fn get_process_id(&self) -> i32 {
217 0 : self.process_id
218 0 : }
219 :
220 0 : pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
221 0 : &mut self.inner
222 0 : }
223 :
224 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
225 : /// to save a roundtrip
226 0 : pub async fn query_raw_txt<S, I>(
227 0 : &mut self,
228 0 : statement: &str,
229 0 : params: I,
230 0 : ) -> Result<RowStream<'_>, Error>
231 0 : where
232 0 : S: AsRef<str>,
233 0 : I: IntoIterator<Item = Option<S>>,
234 0 : I::IntoIter: ExactSizeIterator,
235 0 : {
236 0 : query::query_txt(
237 0 : &mut self.inner,
238 0 : &mut self.cached_typeinfo,
239 0 : statement,
240 0 : params,
241 0 : )
242 0 : .await
243 0 : }
244 :
245 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
246 : ///
247 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
248 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
249 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
250 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
251 : /// or a row of data. This preserves the framing between the separate statements in the request.
252 : ///
253 : /// # Warning
254 : ///
255 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
256 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
257 : /// them to this method!
258 0 : pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
259 0 : self.simple_query_raw(query).await?.try_collect().await
260 0 : }
261 :
262 0 : pub(crate) async fn simple_query_raw(
263 0 : &mut self,
264 0 : query: &str,
265 0 : ) -> Result<SimpleQueryStream<'_>, Error> {
266 0 : simple_query::simple_query(self.inner_mut(), query).await
267 0 : }
268 :
269 : /// Executes a sequence of SQL statements using the simple query protocol.
270 : ///
271 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
272 : /// point. This is intended for use when, for example, initializing a database schema.
273 : ///
274 : /// # Warning
275 : ///
276 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
277 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
278 : /// them to this method!
279 0 : pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
280 0 : simple_query::batch_execute(self.inner_mut(), query).await
281 0 : }
282 :
283 0 : pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
284 0 : self.batch_execute("discard all").await
285 0 : }
286 :
287 : /// Begins a new database transaction.
288 : ///
289 : /// The transaction will roll back by default - use the `commit` method to commit it.
290 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
291 : struct RollbackIfNotDone<'me> {
292 : client: &'me mut Client,
293 : done: bool,
294 : }
295 :
296 : impl Drop for RollbackIfNotDone<'_> {
297 0 : fn drop(&mut self) {
298 0 : if self.done {
299 0 : return;
300 0 : }
301 :
302 0 : let _ = self.client.inner.send_simple_query("ROLLBACK");
303 0 : }
304 : }
305 :
306 : // This is done, as `Future` created by this method can be dropped after
307 : // `RequestMessages` is synchronously send to the `Connection` by
308 : // `batch_execute()`, but before `Responses` is asynchronously polled to
309 : // completion. In that case `Transaction` won't be created and thus
310 : // won't be rolled back.
311 : {
312 0 : let mut cleaner = RollbackIfNotDone {
313 0 : client: self,
314 0 : done: false,
315 0 : };
316 0 : cleaner.client.batch_execute("BEGIN").await?;
317 0 : cleaner.done = true;
318 : }
319 :
320 0 : Ok(Transaction::new(self))
321 0 : }
322 :
323 : /// Returns a builder for a transaction with custom settings.
324 : ///
325 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
326 : /// attributes.
327 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
328 0 : TransactionBuilder::new(self)
329 0 : }
330 :
331 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
332 : /// connection associated with this client.
333 0 : pub fn cancel_token(&self) -> CancelToken {
334 0 : CancelToken {
335 0 : socket_config: self.socket_config.clone(),
336 0 : raw: RawCancelToken {
337 0 : ssl_mode: self.ssl_mode,
338 0 : process_id: self.process_id,
339 0 : secret_key: self.secret_key,
340 0 : },
341 0 : }
342 0 : }
343 :
344 : /// Determines if the connection to the server has already closed.
345 : ///
346 : /// In that case, all future queries will fail.
347 0 : pub fn is_closed(&self) -> bool {
348 0 : self.inner.sender.is_closed()
349 0 : }
350 : }
351 :
352 : impl fmt::Debug for Client {
353 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
354 0 : f.debug_struct("Client").finish()
355 0 : }
356 : }
|