Line data Source code
1 : use crate::codec::{BackendMessages, FrontendMessage};
2 :
3 : use crate::config::Host;
4 : use crate::config::SslMode;
5 : use crate::connection::{Request, RequestMessages};
6 :
7 : use crate::query::RowStream;
8 : use crate::simple_query::SimpleQueryStream;
9 :
10 : use crate::types::{Oid, ToSql, Type};
11 :
12 : use crate::{
13 : prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row,
14 : SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
15 : };
16 : use bytes::BytesMut;
17 : use fallible_iterator::FallibleIterator;
18 : use futures_util::{future, ready, TryStreamExt};
19 : use parking_lot::Mutex;
20 : use postgres_protocol2::message::{backend::Message, frontend};
21 : use serde::{Deserialize, Serialize};
22 : use std::collections::HashMap;
23 : use std::fmt;
24 : use std::sync::Arc;
25 : use std::task::{Context, Poll};
26 : use tokio::sync::mpsc;
27 :
28 : use std::time::Duration;
29 :
30 : pub struct Responses {
31 : receiver: mpsc::Receiver<BackendMessages>,
32 : cur: BackendMessages,
33 : }
34 :
35 : impl Responses {
36 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
37 : loop {
38 0 : match self.cur.next().map_err(Error::parse)? {
39 0 : Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
40 0 : Some(message) => return Poll::Ready(Ok(message)),
41 0 : None => {}
42 : }
43 :
44 0 : match ready!(self.receiver.poll_recv(cx)) {
45 0 : Some(messages) => self.cur = messages,
46 0 : None => return Poll::Ready(Err(Error::closed())),
47 : }
48 : }
49 0 : }
50 :
51 0 : pub async fn next(&mut self) -> Result<Message, Error> {
52 0 : future::poll_fn(|cx| self.poll_next(cx)).await
53 0 : }
54 : }
55 :
56 : /// A cache of type info and prepared statements for fetching type info
57 : /// (corresponding to the queries in the [prepare] module).
58 : #[derive(Default)]
59 : struct CachedTypeInfo {
60 : /// A statement for basic information for a type from its
61 : /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
62 : /// fallback).
63 : typeinfo: Option<Statement>,
64 : /// A statement for getting information for a composite type from its OID.
65 : /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
66 : typeinfo_composite: Option<Statement>,
67 : /// A statement for getting information for a composite type from its OID.
68 : /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
69 : /// its fallback).
70 : typeinfo_enum: Option<Statement>,
71 :
72 : /// Cache of types already looked up.
73 : types: HashMap<Oid, Type>,
74 : }
75 :
76 : pub struct InnerClient {
77 : sender: mpsc::UnboundedSender<Request>,
78 : cached_typeinfo: Mutex<CachedTypeInfo>,
79 :
80 : /// A buffer to use when writing out postgres commands.
81 : buffer: Mutex<BytesMut>,
82 : }
83 :
84 : impl InnerClient {
85 0 : pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
86 0 : let (sender, receiver) = mpsc::channel(1);
87 0 : let request = Request { messages, sender };
88 0 : self.sender.send(request).map_err(|_| Error::closed())?;
89 :
90 0 : Ok(Responses {
91 0 : receiver,
92 0 : cur: BackendMessages::empty(),
93 0 : })
94 0 : }
95 :
96 0 : pub fn typeinfo(&self) -> Option<Statement> {
97 0 : self.cached_typeinfo.lock().typeinfo.clone()
98 0 : }
99 :
100 0 : pub fn set_typeinfo(&self, statement: &Statement) {
101 0 : self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
102 0 : }
103 :
104 0 : pub fn typeinfo_composite(&self) -> Option<Statement> {
105 0 : self.cached_typeinfo.lock().typeinfo_composite.clone()
106 0 : }
107 :
108 0 : pub fn set_typeinfo_composite(&self, statement: &Statement) {
109 0 : self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
110 0 : }
111 :
112 0 : pub fn typeinfo_enum(&self) -> Option<Statement> {
113 0 : self.cached_typeinfo.lock().typeinfo_enum.clone()
114 0 : }
115 :
116 0 : pub fn set_typeinfo_enum(&self, statement: &Statement) {
117 0 : self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
118 0 : }
119 :
120 0 : pub fn type_(&self, oid: Oid) -> Option<Type> {
121 0 : self.cached_typeinfo.lock().types.get(&oid).cloned()
122 0 : }
123 :
124 0 : pub fn set_type(&self, oid: Oid, type_: &Type) {
125 0 : self.cached_typeinfo.lock().types.insert(oid, type_.clone());
126 0 : }
127 :
128 : /// Call the given function with a buffer to be used when writing out
129 : /// postgres commands.
130 0 : pub fn with_buf<F, R>(&self, f: F) -> R
131 0 : where
132 0 : F: FnOnce(&mut BytesMut) -> R,
133 0 : {
134 0 : let mut buffer = self.buffer.lock();
135 0 : let r = f(&mut buffer);
136 0 : buffer.clear();
137 0 : r
138 0 : }
139 : }
140 :
141 0 : #[derive(Clone, Serialize, Deserialize)]
142 : pub struct SocketConfig {
143 : pub host: Host,
144 : pub port: u16,
145 : pub connect_timeout: Option<Duration>,
146 : // pub keepalive: Option<KeepaliveConfig>,
147 : }
148 :
149 : /// An asynchronous PostgreSQL client.
150 : ///
151 : /// The client is one half of what is returned when a connection is established. Users interact with the database
152 : /// through this client object.
153 : pub struct Client {
154 : inner: Arc<InnerClient>,
155 :
156 : socket_config: SocketConfig,
157 : ssl_mode: SslMode,
158 : process_id: i32,
159 : secret_key: i32,
160 : }
161 :
162 : impl Client {
163 0 : pub(crate) fn new(
164 0 : sender: mpsc::UnboundedSender<Request>,
165 0 : socket_config: SocketConfig,
166 0 : ssl_mode: SslMode,
167 0 : process_id: i32,
168 0 : secret_key: i32,
169 0 : ) -> Client {
170 0 : Client {
171 0 : inner: Arc::new(InnerClient {
172 0 : sender,
173 0 : cached_typeinfo: Default::default(),
174 0 : buffer: Default::default(),
175 0 : }),
176 0 :
177 0 : socket_config,
178 0 : ssl_mode,
179 0 : process_id,
180 0 : secret_key,
181 0 : }
182 0 : }
183 :
184 : /// Returns process_id.
185 0 : pub fn get_process_id(&self) -> i32 {
186 0 : self.process_id
187 0 : }
188 :
189 0 : pub(crate) fn inner(&self) -> &Arc<InnerClient> {
190 0 : &self.inner
191 0 : }
192 :
193 : /// Creates a new prepared statement.
194 : ///
195 : /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
196 : /// which are set when executed. Prepared statements can only be used with the connection that created them.
197 0 : pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
198 0 : self.prepare_typed(query, &[]).await
199 0 : }
200 :
201 : /// Like `prepare`, but allows the types of query parameters to be explicitly specified.
202 : ///
203 : /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
204 : /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
205 0 : pub async fn prepare_typed(
206 0 : &self,
207 0 : query: &str,
208 0 : parameter_types: &[Type],
209 0 : ) -> Result<Statement, Error> {
210 0 : prepare::prepare(&self.inner, query, parameter_types).await
211 0 : }
212 :
213 : /// Executes a statement, returning a vector of the resulting rows.
214 : ///
215 : /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
216 : /// provided, 1-indexed.
217 : ///
218 : /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
219 : /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
220 : /// with the `prepare` method.
221 : ///
222 : /// # Panics
223 : ///
224 : /// Panics if the number of parameters provided does not match the number expected.
225 0 : pub async fn query<T>(
226 0 : &self,
227 0 : statement: &T,
228 0 : params: &[&(dyn ToSql + Sync)],
229 0 : ) -> Result<Vec<Row>, Error>
230 0 : where
231 0 : T: ?Sized + ToStatement,
232 0 : {
233 0 : self.query_raw(statement, slice_iter(params))
234 0 : .await?
235 0 : .try_collect()
236 0 : .await
237 0 : }
238 :
239 : /// The maximally flexible version of [`query`].
240 : ///
241 : /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
242 : /// provided, 1-indexed.
243 : ///
244 : /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
245 : /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
246 : /// with the `prepare` method.
247 : ///
248 : /// # Panics
249 : ///
250 : /// Panics if the number of parameters provided does not match the number expected.
251 : ///
252 : /// [`query`]: #method.query
253 0 : pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
254 0 : where
255 0 : T: ?Sized + ToStatement,
256 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
257 0 : I::IntoIter: ExactSizeIterator,
258 0 : {
259 0 : let statement = statement.__convert().into_statement(self).await?;
260 0 : query::query(&self.inner, statement, params).await
261 0 : }
262 :
263 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
264 : /// to save a roundtrip
265 0 : pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
266 0 : where
267 0 : S: AsRef<str>,
268 0 : I: IntoIterator<Item = Option<S>>,
269 0 : I::IntoIter: ExactSizeIterator,
270 0 : {
271 0 : query::query_txt(&self.inner, statement, params).await
272 0 : }
273 :
274 : /// Executes a statement, returning the number of rows modified.
275 : ///
276 : /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
277 : /// provided, 1-indexed.
278 : ///
279 : /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
280 : /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
281 : /// with the `prepare` method.
282 : ///
283 : /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
284 : ///
285 : /// # Panics
286 : ///
287 : /// Panics if the number of parameters provided does not match the number expected.
288 0 : pub async fn execute<T>(
289 0 : &self,
290 0 : statement: &T,
291 0 : params: &[&(dyn ToSql + Sync)],
292 0 : ) -> Result<u64, Error>
293 0 : where
294 0 : T: ?Sized + ToStatement,
295 0 : {
296 0 : self.execute_raw(statement, slice_iter(params)).await
297 0 : }
298 :
299 : /// The maximally flexible version of [`execute`].
300 : ///
301 : /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
302 : /// provided, 1-indexed.
303 : ///
304 : /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
305 : /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
306 : /// with the `prepare` method.
307 : ///
308 : /// # Panics
309 : ///
310 : /// Panics if the number of parameters provided does not match the number expected.
311 : ///
312 : /// [`execute`]: #method.execute
313 0 : pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
314 0 : where
315 0 : T: ?Sized + ToStatement,
316 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
317 0 : I::IntoIter: ExactSizeIterator,
318 0 : {
319 0 : let statement = statement.__convert().into_statement(self).await?;
320 0 : query::execute(self.inner(), statement, params).await
321 0 : }
322 :
323 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
324 : ///
325 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
326 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
327 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
328 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
329 : /// or a row of data. This preserves the framing between the separate statements in the request.
330 : ///
331 : /// # Warning
332 : ///
333 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
334 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
335 : /// them to this method!
336 0 : pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
337 0 : self.simple_query_raw(query).await?.try_collect().await
338 0 : }
339 :
340 0 : pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
341 0 : simple_query::simple_query(self.inner(), query).await
342 0 : }
343 :
344 : /// Executes a sequence of SQL statements using the simple query protocol.
345 : ///
346 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
347 : /// point. This is intended for use when, for example, initializing a database schema.
348 : ///
349 : /// # Warning
350 : ///
351 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
352 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
353 : /// them to this method!
354 0 : pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
355 0 : simple_query::batch_execute(self.inner(), query).await
356 0 : }
357 :
358 : /// Begins a new database transaction.
359 : ///
360 : /// The transaction will roll back by default - use the `commit` method to commit it.
361 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
362 : struct RollbackIfNotDone<'me> {
363 : client: &'me Client,
364 : done: bool,
365 : }
366 :
367 : impl Drop for RollbackIfNotDone<'_> {
368 0 : fn drop(&mut self) {
369 0 : if self.done {
370 0 : return;
371 0 : }
372 0 :
373 0 : let buf = self.client.inner().with_buf(|buf| {
374 0 : frontend::query("ROLLBACK", buf).unwrap();
375 0 : buf.split().freeze()
376 0 : });
377 0 : let _ = self
378 0 : .client
379 0 : .inner()
380 0 : .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
381 0 : }
382 : }
383 :
384 : // This is done, as `Future` created by this method can be dropped after
385 : // `RequestMessages` is synchronously send to the `Connection` by
386 : // `batch_execute()`, but before `Responses` is asynchronously polled to
387 : // completion. In that case `Transaction` won't be created and thus
388 : // won't be rolled back.
389 : {
390 0 : let mut cleaner = RollbackIfNotDone {
391 0 : client: self,
392 0 : done: false,
393 0 : };
394 0 : self.batch_execute("BEGIN").await?;
395 0 : cleaner.done = true;
396 0 : }
397 0 :
398 0 : Ok(Transaction::new(self))
399 0 : }
400 :
401 : /// Returns a builder for a transaction with custom settings.
402 : ///
403 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
404 : /// attributes.
405 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
406 0 : TransactionBuilder::new(self)
407 0 : }
408 :
409 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
410 : /// connection associated with this client.
411 0 : pub fn cancel_token(&self) -> CancelToken {
412 0 : CancelToken {
413 0 : socket_config: Some(self.socket_config.clone()),
414 0 : ssl_mode: self.ssl_mode,
415 0 : process_id: self.process_id,
416 0 : secret_key: self.secret_key,
417 0 : }
418 0 : }
419 :
420 : /// Query for type information
421 0 : pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
422 0 : crate::prepare::get_type(&self.inner, oid).await
423 0 : }
424 :
425 : /// Determines if the connection to the server has already closed.
426 : ///
427 : /// In that case, all future queries will fail.
428 0 : pub fn is_closed(&self) -> bool {
429 0 : self.inner.sender.is_closed()
430 0 : }
431 : }
432 :
433 : impl fmt::Debug for Client {
434 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 0 : f.debug_struct("Client").finish()
436 0 : }
437 : }
|