Line data Source code
1 : use crate::client::{InnerClient, Responses};
2 : use crate::codec::FrontendMessage;
3 : use crate::connection::RequestMessages;
4 : use crate::types::IsNull;
5 : use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
6 : use bytes::{BufMut, Bytes, BytesMut};
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{ready, Stream};
9 : use log::{debug, log_enabled, Level};
10 : use pin_project_lite::pin_project;
11 : use postgres_protocol2::message::backend::Message;
12 : use postgres_protocol2::message::frontend;
13 : use postgres_types2::{Format, ToSql, Type};
14 : use std::fmt;
15 : use std::marker::PhantomPinned;
16 : use std::pin::Pin;
17 : use std::sync::Arc;
18 : use std::task::{Context, Poll};
19 :
20 : struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
21 :
22 : impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
23 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 0 : f.debug_list().entries(self.0.iter()).finish()
25 0 : }
26 : }
27 :
28 0 : pub async fn query<'a, I>(
29 0 : client: &InnerClient,
30 0 : statement: Statement,
31 0 : params: I,
32 0 : ) -> Result<RowStream, Error>
33 0 : where
34 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
35 0 : I::IntoIter: ExactSizeIterator,
36 0 : {
37 0 : let buf = if log_enabled!(Level::Debug) {
38 0 : let params = params.into_iter().collect::<Vec<_>>();
39 0 : debug!(
40 0 : "executing statement {} with parameters: {:?}",
41 0 : statement.name(),
42 0 : BorrowToSqlParamsDebug(params.as_slice()),
43 : );
44 0 : encode(client, &statement, params)?
45 : } else {
46 0 : encode(client, &statement, params)?
47 : };
48 0 : let responses = start(client, buf).await?;
49 0 : Ok(RowStream {
50 0 : statement,
51 0 : responses,
52 0 : command_tag: None,
53 0 : status: ReadyForQueryStatus::Unknown,
54 0 : output_format: Format::Binary,
55 0 : _p: PhantomPinned,
56 0 : })
57 0 : }
58 :
59 0 : pub async fn query_txt<S, I>(
60 0 : client: &Arc<InnerClient>,
61 0 : query: &str,
62 0 : params: I,
63 0 : ) -> Result<RowStream, Error>
64 0 : where
65 0 : S: AsRef<str>,
66 0 : I: IntoIterator<Item = Option<S>>,
67 0 : I::IntoIter: ExactSizeIterator,
68 0 : {
69 0 : let params = params.into_iter();
70 :
71 0 : let buf = client.with_buf(|buf| {
72 0 : frontend::parse(
73 0 : "", // unnamed prepared statement
74 0 : query, // query to parse
75 0 : std::iter::empty(), // give no type info
76 0 : buf,
77 0 : )
78 0 : .map_err(Error::encode)?;
79 0 : frontend::describe(b'S', "", buf).map_err(Error::encode)?;
80 : // Bind, pass params as text, retrieve as binary
81 : match frontend::bind(
82 0 : "", // empty string selects the unnamed portal
83 0 : "", // unnamed prepared statement
84 0 : std::iter::empty(), // all parameters use the default format (text)
85 0 : params,
86 0 : |param, buf| match param {
87 0 : Some(param) => {
88 0 : buf.put_slice(param.as_ref().as_bytes());
89 0 : Ok(postgres_protocol2::IsNull::No)
90 : }
91 0 : None => Ok(postgres_protocol2::IsNull::Yes),
92 0 : },
93 0 : Some(0), // all text
94 0 : buf,
95 : ) {
96 0 : Ok(()) => Ok(()),
97 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
98 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
99 0 : }?;
100 :
101 : // Execute
102 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
103 : // Sync
104 0 : frontend::sync(buf);
105 0 :
106 0 : Ok(buf.split().freeze())
107 0 : })?;
108 :
109 : // now read the responses
110 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
111 :
112 0 : match responses.next().await? {
113 0 : Message::ParseComplete => {}
114 0 : _ => return Err(Error::unexpected_message()),
115 : }
116 :
117 0 : let parameter_description = match responses.next().await? {
118 0 : Message::ParameterDescription(body) => body,
119 0 : _ => return Err(Error::unexpected_message()),
120 : };
121 :
122 0 : let row_description = match responses.next().await? {
123 0 : Message::RowDescription(body) => Some(body),
124 0 : Message::NoData => None,
125 0 : _ => return Err(Error::unexpected_message()),
126 : };
127 :
128 0 : match responses.next().await? {
129 0 : Message::BindComplete => {}
130 0 : _ => return Err(Error::unexpected_message()),
131 : }
132 :
133 0 : let mut parameters = vec![];
134 0 : let mut it = parameter_description.parameters();
135 0 : while let Some(oid) = it.next().map_err(Error::parse)? {
136 0 : let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
137 0 : parameters.push(type_);
138 0 : }
139 :
140 0 : let mut columns = vec![];
141 0 : if let Some(row_description) = row_description {
142 0 : let mut it = row_description.fields();
143 0 : while let Some(field) = it.next().map_err(Error::parse)? {
144 0 : let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
145 0 : let column = Column::new(field.name().to_string(), type_, field);
146 0 : columns.push(column);
147 0 : }
148 0 : }
149 :
150 0 : Ok(RowStream {
151 0 : statement: Statement::new_anonymous(parameters, columns),
152 0 : responses,
153 0 : command_tag: None,
154 0 : status: ReadyForQueryStatus::Unknown,
155 0 : output_format: Format::Text,
156 0 : _p: PhantomPinned,
157 0 : })
158 0 : }
159 :
160 0 : pub async fn execute<'a, I>(
161 0 : client: &InnerClient,
162 0 : statement: Statement,
163 0 : params: I,
164 0 : ) -> Result<u64, Error>
165 0 : where
166 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
167 0 : I::IntoIter: ExactSizeIterator,
168 0 : {
169 0 : let buf = if log_enabled!(Level::Debug) {
170 0 : let params = params.into_iter().collect::<Vec<_>>();
171 0 : debug!(
172 0 : "executing statement {} with parameters: {:?}",
173 0 : statement.name(),
174 0 : BorrowToSqlParamsDebug(params.as_slice()),
175 : );
176 0 : encode(client, &statement, params)?
177 : } else {
178 0 : encode(client, &statement, params)?
179 : };
180 0 : let mut responses = start(client, buf).await?;
181 :
182 0 : let mut rows = 0;
183 : loop {
184 0 : match responses.next().await? {
185 0 : Message::DataRow(_) => {}
186 0 : Message::CommandComplete(body) => {
187 0 : rows = body
188 0 : .tag()
189 0 : .map_err(Error::parse)?
190 0 : .rsplit(' ')
191 0 : .next()
192 0 : .unwrap()
193 0 : .parse()
194 0 : .unwrap_or(0);
195 : }
196 0 : Message::EmptyQueryResponse => rows = 0,
197 0 : Message::ReadyForQuery(_) => return Ok(rows),
198 0 : _ => return Err(Error::unexpected_message()),
199 : }
200 : }
201 0 : }
202 :
203 0 : async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
204 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
205 :
206 0 : match responses.next().await? {
207 0 : Message::BindComplete => {}
208 0 : _ => return Err(Error::unexpected_message()),
209 : }
210 :
211 0 : Ok(responses)
212 0 : }
213 :
214 0 : pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
215 0 : where
216 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
217 0 : I::IntoIter: ExactSizeIterator,
218 0 : {
219 0 : client.with_buf(|buf| {
220 0 : encode_bind(statement, params, "", buf)?;
221 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
222 0 : frontend::sync(buf);
223 0 : Ok(buf.split().freeze())
224 0 : })
225 0 : }
226 :
227 0 : pub fn encode_bind<'a, I>(
228 0 : statement: &Statement,
229 0 : params: I,
230 0 : portal: &str,
231 0 : buf: &mut BytesMut,
232 0 : ) -> Result<(), Error>
233 0 : where
234 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
235 0 : I::IntoIter: ExactSizeIterator,
236 0 : {
237 0 : let param_types = statement.params();
238 0 : let params = params.into_iter();
239 0 :
240 0 : assert!(
241 0 : param_types.len() == params.len(),
242 0 : "expected {} parameters but got {}",
243 0 : param_types.len(),
244 0 : params.len()
245 : );
246 :
247 0 : let (param_formats, params): (Vec<_>, Vec<_>) = params
248 0 : .zip(param_types.iter())
249 0 : .map(|(p, ty)| (p.encode_format(ty) as i16, p))
250 0 : .unzip();
251 0 :
252 0 : let params = params.into_iter();
253 0 :
254 0 : let mut error_idx = 0;
255 0 : let r = frontend::bind(
256 0 : portal,
257 0 : statement.name(),
258 0 : param_formats,
259 0 : params.zip(param_types).enumerate(),
260 0 : |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
261 0 : Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
262 0 : Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
263 0 : Err(e) => {
264 0 : error_idx = idx;
265 0 : Err(e)
266 : }
267 0 : },
268 0 : Some(1),
269 0 : buf,
270 0 : );
271 0 : match r {
272 0 : Ok(()) => Ok(()),
273 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
274 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
275 : }
276 0 : }
277 :
278 : pin_project! {
279 : /// A stream of table rows.
280 : pub struct RowStream {
281 : statement: Statement,
282 : responses: Responses,
283 : command_tag: Option<String>,
284 : output_format: Format,
285 : status: ReadyForQueryStatus,
286 : #[pin]
287 : _p: PhantomPinned,
288 : }
289 : }
290 :
291 : impl Stream for RowStream {
292 : type Item = Result<Row, Error>;
293 :
294 0 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
295 0 : let this = self.project();
296 : loop {
297 0 : match ready!(this.responses.poll_next(cx)?) {
298 0 : Message::DataRow(body) => {
299 0 : return Poll::Ready(Some(Ok(Row::new(
300 0 : this.statement.clone(),
301 0 : body,
302 0 : *this.output_format,
303 0 : )?)))
304 : }
305 0 : Message::EmptyQueryResponse | Message::PortalSuspended => {}
306 0 : Message::CommandComplete(body) => {
307 0 : if let Ok(tag) = body.tag() {
308 0 : *this.command_tag = Some(tag.to_string());
309 0 : }
310 : }
311 0 : Message::ReadyForQuery(status) => {
312 0 : *this.status = status.into();
313 0 : return Poll::Ready(None);
314 : }
315 0 : _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
316 : }
317 : }
318 0 : }
319 : }
320 :
321 : impl RowStream {
322 : /// Returns information about the columns of data in the row.
323 0 : pub fn columns(&self) -> &[Column] {
324 0 : self.statement.columns()
325 0 : }
326 :
327 : /// Returns the command tag of this query.
328 : ///
329 : /// This is only available after the stream has been exhausted.
330 0 : pub fn command_tag(&self) -> Option<String> {
331 0 : self.command_tag.clone()
332 0 : }
333 :
334 : /// Returns if the connection is ready for querying, with the status of the connection.
335 : ///
336 : /// This might be available only after the stream has been exhausted.
337 0 : pub fn ready_status(&self) -> ReadyForQueryStatus {
338 0 : self.status
339 0 : }
340 : }
|