Line data Source code
1 : use std::fmt;
2 : use std::marker::PhantomPinned;
3 : use std::pin::Pin;
4 : use std::sync::Arc;
5 : use std::task::{Context, Poll};
6 :
7 : use bytes::{BufMut, Bytes, BytesMut};
8 : use fallible_iterator::FallibleIterator;
9 : use futures_util::{Stream, ready};
10 : use log::{Level, debug, log_enabled};
11 : use pin_project_lite::pin_project;
12 : use postgres_protocol2::message::backend::Message;
13 : use postgres_protocol2::message::frontend;
14 : use postgres_types2::{Format, ToSql, Type};
15 :
16 : use crate::client::{InnerClient, Responses};
17 : use crate::codec::FrontendMessage;
18 : use crate::connection::RequestMessages;
19 : use crate::types::IsNull;
20 : use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
21 :
22 : struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
23 :
24 : impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
25 0 : fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 0 : f.debug_list().entries(self.0.iter()).finish()
27 0 : }
28 : }
29 :
30 0 : pub async fn query<'a, I>(
31 0 : client: &InnerClient,
32 0 : statement: Statement,
33 0 : params: I,
34 0 : ) -> Result<RowStream, Error>
35 0 : where
36 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
37 0 : I::IntoIter: ExactSizeIterator,
38 0 : {
39 0 : let buf = if log_enabled!(Level::Debug) {
40 0 : let params = params.into_iter().collect::<Vec<_>>();
41 0 : debug!(
42 0 : "executing statement {} with parameters: {:?}",
43 0 : statement.name(),
44 0 : BorrowToSqlParamsDebug(params.as_slice()),
45 : );
46 0 : encode(client, &statement, params)?
47 : } else {
48 0 : encode(client, &statement, params)?
49 : };
50 0 : let responses = start(client, buf).await?;
51 0 : Ok(RowStream {
52 0 : statement,
53 0 : responses,
54 0 : command_tag: None,
55 0 : status: ReadyForQueryStatus::Unknown,
56 0 : output_format: Format::Binary,
57 0 : _p: PhantomPinned,
58 0 : })
59 0 : }
60 :
61 0 : pub async fn query_txt<S, I>(
62 0 : client: &Arc<InnerClient>,
63 0 : query: &str,
64 0 : params: I,
65 0 : ) -> Result<RowStream, Error>
66 0 : where
67 0 : S: AsRef<str>,
68 0 : I: IntoIterator<Item = Option<S>>,
69 0 : I::IntoIter: ExactSizeIterator,
70 0 : {
71 0 : let params = params.into_iter();
72 :
73 0 : let buf = client.with_buf(|buf| {
74 0 : frontend::parse(
75 0 : "", // unnamed prepared statement
76 0 : query, // query to parse
77 0 : std::iter::empty(), // give no type info
78 0 : buf,
79 0 : )
80 0 : .map_err(Error::encode)?;
81 0 : frontend::describe(b'S', "", buf).map_err(Error::encode)?;
82 : // Bind, pass params as text, retrieve as binary
83 : match frontend::bind(
84 0 : "", // empty string selects the unnamed portal
85 0 : "", // unnamed prepared statement
86 0 : std::iter::empty(), // all parameters use the default format (text)
87 0 : params,
88 0 : |param, buf| match param {
89 0 : Some(param) => {
90 0 : buf.put_slice(param.as_ref().as_bytes());
91 0 : Ok(postgres_protocol2::IsNull::No)
92 : }
93 0 : None => Ok(postgres_protocol2::IsNull::Yes),
94 0 : },
95 0 : Some(0), // all text
96 0 : buf,
97 : ) {
98 0 : Ok(()) => Ok(()),
99 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
100 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
101 0 : }?;
102 :
103 : // Execute
104 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
105 : // Sync
106 0 : frontend::sync(buf);
107 0 :
108 0 : Ok(buf.split().freeze())
109 0 : })?;
110 :
111 : // now read the responses
112 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
113 :
114 0 : match responses.next().await? {
115 0 : Message::ParseComplete => {}
116 0 : _ => return Err(Error::unexpected_message()),
117 : }
118 :
119 0 : let parameter_description = match responses.next().await? {
120 0 : Message::ParameterDescription(body) => body,
121 0 : _ => return Err(Error::unexpected_message()),
122 : };
123 :
124 0 : let row_description = match responses.next().await? {
125 0 : Message::RowDescription(body) => Some(body),
126 0 : Message::NoData => None,
127 0 : _ => return Err(Error::unexpected_message()),
128 : };
129 :
130 0 : match responses.next().await? {
131 0 : Message::BindComplete => {}
132 0 : _ => return Err(Error::unexpected_message()),
133 : }
134 :
135 0 : let mut parameters = vec![];
136 0 : let mut it = parameter_description.parameters();
137 0 : while let Some(oid) = it.next().map_err(Error::parse)? {
138 0 : let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
139 0 : parameters.push(type_);
140 0 : }
141 :
142 0 : let mut columns = vec![];
143 0 : if let Some(row_description) = row_description {
144 0 : let mut it = row_description.fields();
145 0 : while let Some(field) = it.next().map_err(Error::parse)? {
146 0 : let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
147 0 : let column = Column::new(field.name().to_string(), type_, field);
148 0 : columns.push(column);
149 0 : }
150 0 : }
151 :
152 0 : Ok(RowStream {
153 0 : statement: Statement::new_anonymous(parameters, columns),
154 0 : responses,
155 0 : command_tag: None,
156 0 : status: ReadyForQueryStatus::Unknown,
157 0 : output_format: Format::Text,
158 0 : _p: PhantomPinned,
159 0 : })
160 0 : }
161 :
162 0 : async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
163 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
164 :
165 0 : match responses.next().await? {
166 0 : Message::BindComplete => {}
167 0 : _ => return Err(Error::unexpected_message()),
168 : }
169 :
170 0 : Ok(responses)
171 0 : }
172 :
173 0 : pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
174 0 : where
175 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
176 0 : I::IntoIter: ExactSizeIterator,
177 0 : {
178 0 : client.with_buf(|buf| {
179 0 : encode_bind(statement, params, "", buf)?;
180 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
181 0 : frontend::sync(buf);
182 0 : Ok(buf.split().freeze())
183 0 : })
184 0 : }
185 :
186 0 : pub fn encode_bind<'a, I>(
187 0 : statement: &Statement,
188 0 : params: I,
189 0 : portal: &str,
190 0 : buf: &mut BytesMut,
191 0 : ) -> Result<(), Error>
192 0 : where
193 0 : I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
194 0 : I::IntoIter: ExactSizeIterator,
195 0 : {
196 0 : let param_types = statement.params();
197 0 : let params = params.into_iter();
198 0 :
199 0 : assert!(
200 0 : param_types.len() == params.len(),
201 0 : "expected {} parameters but got {}",
202 0 : param_types.len(),
203 0 : params.len()
204 : );
205 :
206 0 : let (param_formats, params): (Vec<_>, Vec<_>) = params
207 0 : .zip(param_types.iter())
208 0 : .map(|(p, ty)| (p.encode_format(ty) as i16, p))
209 0 : .unzip();
210 0 :
211 0 : let params = params.into_iter();
212 0 :
213 0 : let mut error_idx = 0;
214 0 : let r = frontend::bind(
215 0 : portal,
216 0 : statement.name(),
217 0 : param_formats,
218 0 : params.zip(param_types).enumerate(),
219 0 : |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
220 0 : Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
221 0 : Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
222 0 : Err(e) => {
223 0 : error_idx = idx;
224 0 : Err(e)
225 : }
226 0 : },
227 0 : Some(1),
228 0 : buf,
229 0 : );
230 0 : match r {
231 0 : Ok(()) => Ok(()),
232 0 : Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
233 0 : Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
234 : }
235 0 : }
236 :
237 : pin_project! {
238 : /// A stream of table rows.
239 : pub struct RowStream {
240 : statement: Statement,
241 : responses: Responses,
242 : command_tag: Option<String>,
243 : output_format: Format,
244 : status: ReadyForQueryStatus,
245 : #[pin]
246 : _p: PhantomPinned,
247 : }
248 : }
249 :
250 : impl Stream for RowStream {
251 : type Item = Result<Row, Error>;
252 :
253 0 : fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254 0 : let this = self.project();
255 : loop {
256 0 : match ready!(this.responses.poll_next(cx)?) {
257 0 : Message::DataRow(body) => {
258 0 : return Poll::Ready(Some(Ok(Row::new(
259 0 : this.statement.clone(),
260 0 : body,
261 0 : *this.output_format,
262 0 : )?)));
263 : }
264 0 : Message::EmptyQueryResponse | Message::PortalSuspended => {}
265 0 : Message::CommandComplete(body) => {
266 0 : if let Ok(tag) = body.tag() {
267 0 : *this.command_tag = Some(tag.to_string());
268 0 : }
269 : }
270 0 : Message::ReadyForQuery(status) => {
271 0 : *this.status = status.into();
272 0 : return Poll::Ready(None);
273 : }
274 0 : _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
275 : }
276 : }
277 0 : }
278 : }
279 :
280 : impl RowStream {
281 : /// Returns information about the columns of data in the row.
282 0 : pub fn columns(&self) -> &[Column] {
283 0 : self.statement.columns()
284 0 : }
285 :
286 : /// Returns the command tag of this query.
287 : ///
288 : /// This is only available after the stream has been exhausted.
289 0 : pub fn command_tag(&self) -> Option<String> {
290 0 : self.command_tag.clone()
291 0 : }
292 :
293 : /// Returns if the connection is ready for querying, with the status of the connection.
294 : ///
295 : /// This might be available only after the stream has been exhausted.
296 0 : pub fn ready_status(&self) -> ReadyForQueryStatus {
297 0 : self.status
298 0 : }
299 : }
|