Line data Source code
1 : use crate::client::{InnerClient, Responses};
2 : use crate::codec::FrontendMessage;
3 : use crate::connection::RequestMessages;
4 : use crate::types::IsNull;
5 : use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
6 : use bytes::{BufMut, Bytes, BytesMut};
7 : use fallible_iterator::FallibleIterator;
8 : use futures_util::{ready, Stream};
9 : use log::{debug, log_enabled, Level};
10 : use pin_project_lite::pin_project;
11 : use postgres_protocol2::message::backend::Message;
12 : use postgres_protocol2::message::frontend;
13 : use postgres_types2::{Format, ToSql, Type};
14 : use std::fmt;
15 : use std::marker::PhantomPinned;
16 : use std::pin::Pin;
17 : use std::sync::Arc;
18 : use std::task::{Context, Poll};
19 :
20 : struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
21 :
22 : impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
23 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 0 : f.debug_list().entries(self.0.iter()).finish()
25 0 : }
26 : }
27 :
28 0 : pub async fn query<'a, I>(
29 0 : client: &InnerClient,
30 0 : statement: Statement,
31 0 : params: I,
32 0 : ) -> Result<RowStream, Error>
33 0 : where
34 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
35 0 : I::IntoIter: ExactSizeIterator,
36 0 : {
37 0 : let buf = if log_enabled!(Level::Debug) {
38 0 : let params = params.into_iter().collect::<Vec<_>>();
39 0 : debug!(
40 0 : "executing statement {} with parameters: {:?}",
41 0 : statement.name(),
42 0 : BorrowToSqlParamsDebug(params.as_slice()),
43 : );
44 0 : encode(client, &statement, params)?
45 : } else {
46 0 : encode(client, &statement, params)?
47 : };
48 0 : let responses = start(client, buf).await?;
49 0 : Ok(RowStream {
50 0 : statement,
51 0 : responses,
52 0 : command_tag: None,
53 0 : status: ReadyForQueryStatus::Unknown,
54 0 : output_format: Format::Binary,
55 0 : _p: PhantomPinned,
56 0 : })
57 0 : }
58 :
59 0 : pub async fn query_txt<S, I>(
60 0 : client: &Arc<InnerClient>,
61 0 : query: &str,
62 0 : params: I,
63 0 : ) -> Result<RowStream, Error>
64 0 : where
65 0 : S: AsRef<str>,
66 0 : I: IntoIterator<Item = Option<S>>,
67 0 : I::IntoIter: ExactSizeIterator,
68 0 : {
69 0 : let params = params.into_iter();
70 :
71 0 : let buf = client.with_buf(|buf| {
72 0 : frontend::parse(
73 0 : "", // unnamed prepared statement
74 0 : query, // query to parse
75 0 : std::iter::empty(), // give no type info
76 0 : buf,
77 0 : )
78 0 : .map_err(Error::encode)?;
79 0 : frontend::describe(b'S', "", buf).map_err(Error::encode)?;
80 : // Bind, pass params as text, retrieve as binary
81 : match frontend::bind(
82 0 : "", // empty string selects the unnamed portal
83 0 : "", // unnamed prepared statement
84 0 : std::iter::empty(), // all parameters use the default format (text)
85 0 : params,
86 0 : |param, buf| match param {
87 0 : Some(param) => {
88 0 : buf.put_slice(param.as_ref().as_bytes());
89 0 : Ok(postgres_protocol2::IsNull::No)
90 : }
91 0 : None => Ok(postgres_protocol2::IsNull::Yes),
92 0 : },
93 0 : Some(0), // all text
94 0 : buf,
95 : ) {
96 0 : Ok(()) => Ok(()),
97 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
98 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
99 0 : }?;
100 :
101 : // Execute
102 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
103 : // Sync
104 0 : frontend::sync(buf);
105 0 :
106 0 : Ok(buf.split().freeze())
107 0 : })?;
108 :
109 : // now read the responses
110 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
111 :
112 0 : match responses.next().await? {
113 0 : Message::ParseComplete => {}
114 0 : _ => return Err(Error::unexpected_message()),
115 : }
116 :
117 0 : let parameter_description = match responses.next().await? {
118 0 : Message::ParameterDescription(body) => body,
119 0 : _ => return Err(Error::unexpected_message()),
120 : };
121 :
122 0 : let row_description = match responses.next().await? {
123 0 : Message::RowDescription(body) => Some(body),
124 0 : Message::NoData => None,
125 0 : _ => return Err(Error::unexpected_message()),
126 : };
127 :
128 0 : match responses.next().await? {
129 0 : Message::BindComplete => {}
130 0 : _ => return Err(Error::unexpected_message()),
131 : }
132 :
133 0 : let mut parameters = vec![];
134 0 : let mut it = parameter_description.parameters();
135 0 : while let Some(oid) = it.next().map_err(Error::parse)? {
136 0 : let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
137 0 : parameters.push(type_);
138 0 : }
139 :
140 0 : let mut columns = vec![];
141 0 : if let Some(row_description) = row_description {
142 0 : let mut it = row_description.fields();
143 0 : while let Some(field) = it.next().map_err(Error::parse)? {
144 0 : let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
145 0 : let column = Column::new(field.name().to_string(), type_, field);
146 0 : columns.push(column);
147 0 : }
148 0 : }
149 :
150 0 : Ok(RowStream {
151 0 : statement: Statement::new_anonymous(parameters, columns),
152 0 : responses,
153 0 : command_tag: None,
154 0 : status: ReadyForQueryStatus::Unknown,
155 0 : output_format: Format::Text,
156 0 : _p: PhantomPinned,
157 0 : })
158 0 : }
159 :
160 0 : async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
161 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
162 :
163 0 : match responses.next().await? {
164 0 : Message::BindComplete => {}
165 0 : _ => return Err(Error::unexpected_message()),
166 : }
167 :
168 0 : Ok(responses)
169 0 : }
170 :
171 0 : pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
172 0 : where
173 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
174 0 : I::IntoIter: ExactSizeIterator,
175 0 : {
176 0 : client.with_buf(|buf| {
177 0 : encode_bind(statement, params, "", buf)?;
178 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
179 0 : frontend::sync(buf);
180 0 : Ok(buf.split().freeze())
181 0 : })
182 0 : }
183 :
184 0 : pub fn encode_bind<'a, I>(
185 0 : statement: &Statement,
186 0 : params: I,
187 0 : portal: &str,
188 0 : buf: &mut BytesMut,
189 0 : ) -> Result<(), Error>
190 0 : where
191 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
192 0 : I::IntoIter: ExactSizeIterator,
193 0 : {
194 0 : let param_types = statement.params();
195 0 : let params = params.into_iter();
196 0 :
197 0 : assert!(
198 0 : param_types.len() == params.len(),
199 0 : "expected {} parameters but got {}",
200 0 : param_types.len(),
201 0 : params.len()
202 : );
203 :
204 0 : let (param_formats, params): (Vec<_>, Vec<_>) = params
205 0 : .zip(param_types.iter())
206 0 : .map(|(p, ty)| (p.encode_format(ty) as i16, p))
207 0 : .unzip();
208 0 :
209 0 : let params = params.into_iter();
210 0 :
211 0 : let mut error_idx = 0;
212 0 : let r = frontend::bind(
213 0 : portal,
214 0 : statement.name(),
215 0 : param_formats,
216 0 : params.zip(param_types).enumerate(),
217 0 : |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
218 0 : Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
219 0 : Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
220 0 : Err(e) => {
221 0 : error_idx = idx;
222 0 : Err(e)
223 : }
224 0 : },
225 0 : Some(1),
226 0 : buf,
227 0 : );
228 0 : match r {
229 0 : Ok(()) => Ok(()),
230 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
231 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
232 : }
233 0 : }
234 :
235 : pin_project! {
236 : /// A stream of table rows.
237 : pub struct RowStream {
238 : statement: Statement,
239 : responses: Responses,
240 : command_tag: Option<String>,
241 : output_format: Format,
242 : status: ReadyForQueryStatus,
243 : #[pin]
244 : _p: PhantomPinned,
245 : }
246 : }
247 :
248 : impl Stream for RowStream {
249 : type Item = Result<Row, Error>;
250 :
251 0 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
252 0 : let this = self.project();
253 : loop {
254 0 : match ready!(this.responses.poll_next(cx)?) {
255 0 : Message::DataRow(body) => {
256 0 : return Poll::Ready(Some(Ok(Row::new(
257 0 : this.statement.clone(),
258 0 : body,
259 0 : *this.output_format,
260 0 : )?)))
261 : }
262 0 : Message::EmptyQueryResponse | Message::PortalSuspended => {}
263 0 : Message::CommandComplete(body) => {
264 0 : if let Ok(tag) = body.tag() {
265 0 : *this.command_tag = Some(tag.to_string());
266 0 : }
267 : }
268 0 : Message::ReadyForQuery(status) => {
269 0 : *this.status = status.into();
270 0 : return Poll::Ready(None);
271 : }
272 0 : _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
273 : }
274 : }
275 0 : }
276 : }
277 :
278 : impl RowStream {
279 : /// Returns information about the columns of data in the row.
280 0 : pub fn columns(&self) -> &[Column] {
281 0 : self.statement.columns()
282 0 : }
283 :
284 : /// Returns the command tag of this query.
285 : ///
286 : /// This is only available after the stream has been exhausted.
287 0 : pub fn command_tag(&self) -> Option<String> {
288 0 : self.command_tag.clone()
289 0 : }
290 :
291 : /// Returns if the connection is ready for querying, with the status of the connection.
292 : ///
293 : /// This might be available only after the stream has been exhausted.
294 0 : pub fn ready_status(&self) -> ReadyForQueryStatus {
295 0 : self.status
296 0 : }
297 : }
|