Line data Source code
1 : use std::collections::HashMap;
2 : use std::fmt;
3 : use std::net::IpAddr;
4 : use std::sync::Arc;
5 : use std::task::{Context, Poll};
6 : use std::time::Duration;
7 :
8 : use bytes::BytesMut;
9 : use fallible_iterator::FallibleIterator;
10 : use futures_util::{TryStreamExt, future, ready};
11 : use parking_lot::Mutex;
12 : use postgres_protocol2::message::backend::Message;
13 : use postgres_protocol2::message::frontend;
14 : use serde::{Deserialize, Serialize};
15 : use tokio::sync::mpsc;
16 :
17 : use crate::codec::{BackendMessages, FrontendMessage};
18 : use crate::config::{Host, SslMode};
19 : use crate::connection::{Request, RequestMessages};
20 : use crate::query::RowStream;
21 : use crate::simple_query::SimpleQueryStream;
22 : use crate::types::{Oid, Type};
23 : use crate::{
24 : CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Statement, Transaction,
25 : TransactionBuilder, query, simple_query,
26 : };
27 :
28 : pub struct Responses {
29 : receiver: mpsc::Receiver<BackendMessages>,
30 : cur: BackendMessages,
31 : }
32 :
33 : impl Responses {
34 0 : pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
35 : loop {
36 0 : match self.cur.next().map_err(Error::parse)? {
37 0 : Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
38 0 : Some(message) => return Poll::Ready(Ok(message)),
39 0 : None => {}
40 : }
41 :
42 0 : match ready!(self.receiver.poll_recv(cx)) {
43 0 : Some(messages) => self.cur = messages,
44 0 : None => return Poll::Ready(Err(Error::closed())),
45 : }
46 : }
47 0 : }
48 :
49 0 : pub async fn next(&mut self) -> Result<Message, Error> {
50 0 : future::poll_fn(|cx| self.poll_next(cx)).await
51 0 : }
52 : }
53 :
54 : /// A cache of type info and prepared statements for fetching type info
55 : /// (corresponding to the queries in the [crate::prepare] module).
56 : #[derive(Default)]
57 : pub(crate) struct CachedTypeInfo {
58 : /// A statement for basic information for a type from its
59 : /// OID. Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_QUERY) (or its
60 : /// fallback).
61 : pub(crate) typeinfo: Option<Statement>,
62 :
63 : /// Cache of types already looked up.
64 : pub(crate) types: HashMap<Oid, Type>,
65 : }
66 :
67 : pub struct InnerClient {
68 : sender: mpsc::UnboundedSender<Request>,
69 :
70 : /// A buffer to use when writing out postgres commands.
71 : buffer: Mutex<BytesMut>,
72 : }
73 :
74 : impl InnerClient {
75 0 : pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
76 0 : let (sender, receiver) = mpsc::channel(1);
77 0 : let request = Request { messages, sender };
78 0 : self.sender.send(request).map_err(|_| Error::closed())?;
79 :
80 0 : Ok(Responses {
81 0 : receiver,
82 0 : cur: BackendMessages::empty(),
83 0 : })
84 0 : }
85 :
86 : /// Call the given function with a buffer to be used when writing out
87 : /// postgres commands.
88 0 : pub fn with_buf<F, R>(&self, f: F) -> R
89 0 : where
90 0 : F: FnOnce(&mut BytesMut) -> R,
91 0 : {
92 0 : let mut buffer = self.buffer.lock();
93 0 : let r = f(&mut buffer);
94 0 : buffer.clear();
95 0 : r
96 0 : }
97 : }
98 :
99 0 : #[derive(Clone, Serialize, Deserialize)]
100 : pub struct SocketConfig {
101 : pub host_addr: Option<IpAddr>,
102 : pub host: Host,
103 : pub port: u16,
104 : pub connect_timeout: Option<Duration>,
105 : }
106 :
107 : /// An asynchronous PostgreSQL client.
108 : ///
109 : /// The client is one half of what is returned when a connection is established. Users interact with the database
110 : /// through this client object.
111 : pub struct Client {
112 : inner: Arc<InnerClient>,
113 : cached_typeinfo: CachedTypeInfo,
114 :
115 : socket_config: SocketConfig,
116 : ssl_mode: SslMode,
117 : process_id: i32,
118 : secret_key: i32,
119 : }
120 :
121 : impl Client {
122 0 : pub(crate) fn new(
123 0 : sender: mpsc::UnboundedSender<Request>,
124 0 : socket_config: SocketConfig,
125 0 : ssl_mode: SslMode,
126 0 : process_id: i32,
127 0 : secret_key: i32,
128 0 : ) -> Client {
129 0 : Client {
130 0 : inner: Arc::new(InnerClient {
131 0 : sender,
132 0 : buffer: Default::default(),
133 0 : }),
134 0 : cached_typeinfo: Default::default(),
135 0 :
136 0 : socket_config,
137 0 : ssl_mode,
138 0 : process_id,
139 0 : secret_key,
140 0 : }
141 0 : }
142 :
143 : /// Returns process_id.
144 0 : pub fn get_process_id(&self) -> i32 {
145 0 : self.process_id
146 0 : }
147 :
148 0 : pub(crate) fn inner(&self) -> &Arc<InnerClient> {
149 0 : &self.inner
150 0 : }
151 :
152 : /// Pass text directly to the Postgres backend to allow it to sort out typing itself and
153 : /// to save a roundtrip
154 0 : pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
155 0 : where
156 0 : S: AsRef<str>,
157 0 : I: IntoIterator<Item = Option<S>>,
158 0 : I::IntoIter: ExactSizeIterator,
159 0 : {
160 0 : query::query_txt(&self.inner, statement, params).await
161 0 : }
162 :
163 : /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
164 : ///
165 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
166 : /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
167 : /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
168 : /// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
169 : /// or a row of data. This preserves the framing between the separate statements in the request.
170 : ///
171 : /// # Warning
172 : ///
173 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
174 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
175 : /// them to this method!
176 0 : pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
177 0 : self.simple_query_raw(query).await?.try_collect().await
178 0 : }
179 :
180 0 : pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
181 0 : simple_query::simple_query(self.inner(), query).await
182 0 : }
183 :
184 : /// Executes a sequence of SQL statements using the simple query protocol.
185 : ///
186 : /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
187 : /// point. This is intended for use when, for example, initializing a database schema.
188 : ///
189 : /// # Warning
190 : ///
191 : /// Prepared statements should be use for any query which contains user-specified data, as they provided the
192 : /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
193 : /// them to this method!
194 0 : pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
195 0 : simple_query::batch_execute(self.inner(), query).await
196 0 : }
197 :
198 0 : pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
199 0 : // clear the prepared statements that are about to be nuked from the postgres session
200 0 :
201 0 : self.cached_typeinfo.typeinfo = None;
202 0 :
203 0 : self.batch_execute("discard all").await
204 0 : }
205 :
206 : /// Begins a new database transaction.
207 : ///
208 : /// The transaction will roll back by default - use the `commit` method to commit it.
209 0 : pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
210 : struct RollbackIfNotDone<'me> {
211 : client: &'me Client,
212 : done: bool,
213 : }
214 :
215 : impl Drop for RollbackIfNotDone<'_> {
216 0 : fn drop(&mut self) {
217 0 : if self.done {
218 0 : return;
219 0 : }
220 0 :
221 0 : let buf = self.client.inner().with_buf(|buf| {
222 0 : frontend::query("ROLLBACK", buf).unwrap();
223 0 : buf.split().freeze()
224 0 : });
225 0 : let _ = self
226 0 : .client
227 0 : .inner()
228 0 : .send(RequestMessages::Single(FrontendMessage::Raw(buf)));
229 0 : }
230 : }
231 :
232 : // This is done, as `Future` created by this method can be dropped after
233 : // `RequestMessages` is synchronously send to the `Connection` by
234 : // `batch_execute()`, but before `Responses` is asynchronously polled to
235 : // completion. In that case `Transaction` won't be created and thus
236 : // won't be rolled back.
237 : {
238 0 : let mut cleaner = RollbackIfNotDone {
239 0 : client: self,
240 0 : done: false,
241 0 : };
242 0 : self.batch_execute("BEGIN").await?;
243 0 : cleaner.done = true;
244 0 : }
245 0 :
246 0 : Ok(Transaction::new(self))
247 0 : }
248 :
249 : /// Returns a builder for a transaction with custom settings.
250 : ///
251 : /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
252 : /// attributes.
253 0 : pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
254 0 : TransactionBuilder::new(self)
255 0 : }
256 :
257 : /// Constructs a cancellation token that can later be used to request cancellation of a query running on the
258 : /// connection associated with this client.
259 0 : pub fn cancel_token(&self) -> CancelToken {
260 0 : CancelToken {
261 0 : socket_config: Some(self.socket_config.clone()),
262 0 : ssl_mode: self.ssl_mode,
263 0 : process_id: self.process_id,
264 0 : secret_key: self.secret_key,
265 0 : }
266 0 : }
267 :
268 : /// Query for type information
269 0 : pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result<Type, Error> {
270 0 : crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await
271 0 : }
272 :
273 : /// Determines if the connection to the server has already closed.
274 : ///
275 : /// In that case, all future queries will fail.
276 0 : pub fn is_closed(&self) -> bool {
277 0 : self.inner.sender.is_closed()
278 0 : }
279 : }
280 :
281 : impl fmt::Debug for Client {
282 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283 0 : f.debug_struct("Client").finish()
284 0 : }
285 : }
|