Line data Source code
1 : use crate::client::InnerClient;
2 : use crate::codec::FrontendMessage;
3 : use crate::connection::RequestMessages;
4 : use crate::error::SqlState;
5 : use crate::types::{Field, Kind, Oid, Type};
6 : use crate::{query, slice_iter};
7 : use crate::{Column, Error, Statement};
8 : use bytes::Bytes;
9 : use fallible_iterator::FallibleIterator;
10 : use futures_util::{pin_mut, TryStreamExt};
11 : use log::debug;
12 : use postgres_protocol2::message::backend::Message;
13 : use postgres_protocol2::message::frontend;
14 : use std::future::Future;
15 : use std::pin::Pin;
16 : use std::sync::atomic::{AtomicUsize, Ordering};
17 : use std::sync::Arc;
18 :
19 : pub(crate) const TYPEINFO_QUERY: &str = "\
20 : SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
21 : FROM pg_catalog.pg_type t
22 : LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
23 : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
24 : WHERE t.oid = $1
25 : ";
26 :
27 : // Range types weren't added until Postgres 9.2, so pg_range may not exist
28 : const TYPEINFO_FALLBACK_QUERY: &str = "\
29 : SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
30 : FROM pg_catalog.pg_type t
31 : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
32 : WHERE t.oid = $1
33 : ";
34 :
35 : const TYPEINFO_ENUM_QUERY: &str = "\
36 : SELECT enumlabel
37 : FROM pg_catalog.pg_enum
38 : WHERE enumtypid = $1
39 : ORDER BY enumsortorder
40 : ";
41 :
42 : // Postgres 9.0 didn't have enumsortorder
43 : const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
44 : SELECT enumlabel
45 : FROM pg_catalog.pg_enum
46 : WHERE enumtypid = $1
47 : ORDER BY oid
48 : ";
49 :
50 : pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\
51 : SELECT attname, atttypid
52 : FROM pg_catalog.pg_attribute
53 : WHERE attrelid = $1
54 : AND NOT attisdropped
55 : AND attnum > 0
56 : ORDER BY attnum
57 : ";
58 :
59 : static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
60 :
61 0 : pub async fn prepare(
62 0 : client: &Arc<InnerClient>,
63 0 : query: &str,
64 0 : types: &[Type],
65 0 : ) -> Result<Statement, Error> {
66 0 : let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
67 0 : let buf = encode(client, &name, query, types)?;
68 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
69 :
70 0 : match responses.next().await? {
71 0 : Message::ParseComplete => {}
72 0 : _ => return Err(Error::unexpected_message()),
73 : }
74 :
75 0 : let parameter_description = match responses.next().await? {
76 0 : Message::ParameterDescription(body) => body,
77 0 : _ => return Err(Error::unexpected_message()),
78 : };
79 :
80 0 : let row_description = match responses.next().await? {
81 0 : Message::RowDescription(body) => Some(body),
82 0 : Message::NoData => None,
83 0 : _ => return Err(Error::unexpected_message()),
84 : };
85 :
86 0 : let mut parameters = vec![];
87 0 : let mut it = parameter_description.parameters();
88 0 : while let Some(oid) = it.next().map_err(Error::parse)? {
89 0 : let type_ = get_type(client, oid).await?;
90 0 : parameters.push(type_);
91 : }
92 :
93 0 : let mut columns = vec![];
94 0 : if let Some(row_description) = row_description {
95 0 : let mut it = row_description.fields();
96 0 : while let Some(field) = it.next().map_err(Error::parse)? {
97 0 : let type_ = get_type(client, field.type_oid()).await?;
98 0 : let column = Column::new(field.name().to_string(), type_, field);
99 0 : columns.push(column);
100 : }
101 0 : }
102 :
103 0 : Ok(Statement::new(client, name, parameters, columns))
104 0 : }
105 :
106 0 : fn prepare_rec<'a>(
107 0 : client: &'a Arc<InnerClient>,
108 0 : query: &'a str,
109 0 : types: &'a [Type],
110 0 : ) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
111 0 : Box::pin(prepare(client, query, types))
112 0 : }
113 :
114 0 : fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
115 0 : if types.is_empty() {
116 0 : debug!("preparing query {}: {}", name, query);
117 : } else {
118 0 : debug!("preparing query {} with types {:?}: {}", name, types, query);
119 : }
120 :
121 0 : client.with_buf(|buf| {
122 0 : frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
123 0 : frontend::describe(b'S', name, buf).map_err(Error::encode)?;
124 0 : frontend::sync(buf);
125 0 : Ok(buf.split().freeze())
126 0 : })
127 0 : }
128 :
129 0 : pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130 0 : if let Some(type_) = Type::from_oid(oid) {
131 0 : return Ok(type_);
132 0 : }
133 :
134 0 : if let Some(type_) = client.type_(oid) {
135 0 : return Ok(type_);
136 0 : }
137 :
138 0 : let stmt = typeinfo_statement(client).await?;
139 :
140 0 : let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
141 0 : pin_mut!(rows);
142 :
143 0 : let row = match rows.try_next().await? {
144 0 : Some(row) => row,
145 0 : None => return Err(Error::unexpected_message()),
146 : };
147 :
148 0 : let name: String = row.try_get(0)?;
149 0 : let type_: i8 = row.try_get(1)?;
150 0 : let elem_oid: Oid = row.try_get(2)?;
151 0 : let rngsubtype: Option<Oid> = row.try_get(3)?;
152 0 : let basetype: Oid = row.try_get(4)?;
153 0 : let schema: String = row.try_get(5)?;
154 0 : let relid: Oid = row.try_get(6)?;
155 :
156 0 : let kind = if type_ == b'e' as i8 {
157 0 : let variants = get_enum_variants(client, oid).await?;
158 0 : Kind::Enum(variants)
159 0 : } else if type_ == b'p' as i8 {
160 0 : Kind::Pseudo
161 0 : } else if basetype != 0 {
162 0 : let type_ = get_type_rec(client, basetype).await?;
163 0 : Kind::Domain(type_)
164 0 : } else if elem_oid != 0 {
165 0 : let type_ = get_type_rec(client, elem_oid).await?;
166 0 : Kind::Array(type_)
167 0 : } else if relid != 0 {
168 0 : let fields = get_composite_fields(client, relid).await?;
169 0 : Kind::Composite(fields)
170 0 : } else if let Some(rngsubtype) = rngsubtype {
171 0 : let type_ = get_type_rec(client, rngsubtype).await?;
172 0 : Kind::Range(type_)
173 : } else {
174 0 : Kind::Simple
175 : };
176 :
177 0 : let type_ = Type::new(name, oid, kind, schema);
178 0 : client.set_type(oid, &type_);
179 0 :
180 0 : Ok(type_)
181 0 : }
182 :
183 0 : fn get_type_rec<'a>(
184 0 : client: &'a Arc<InnerClient>,
185 0 : oid: Oid,
186 0 : ) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
187 0 : Box::pin(get_type(client, oid))
188 0 : }
189 :
190 0 : async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
191 0 : if let Some(stmt) = client.typeinfo() {
192 0 : return Ok(stmt);
193 0 : }
194 :
195 0 : let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
196 0 : Ok(stmt) => stmt,
197 0 : Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
198 0 : prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
199 : }
200 0 : Err(e) => return Err(e),
201 : };
202 :
203 0 : client.set_typeinfo(&stmt);
204 0 : Ok(stmt)
205 0 : }
206 :
207 0 : async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
208 0 : let stmt = typeinfo_enum_statement(client).await?;
209 :
210 0 : query::query(client, stmt, slice_iter(&[&oid]))
211 0 : .await?
212 0 : .and_then(|row| async move { row.try_get(0) })
213 0 : .try_collect()
214 0 : .await
215 0 : }
216 :
217 0 : async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
218 0 : if let Some(stmt) = client.typeinfo_enum() {
219 0 : return Ok(stmt);
220 0 : }
221 :
222 0 : let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
223 0 : Ok(stmt) => stmt,
224 0 : Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
225 0 : prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
226 : }
227 0 : Err(e) => return Err(e),
228 : };
229 :
230 0 : client.set_typeinfo_enum(&stmt);
231 0 : Ok(stmt)
232 0 : }
233 :
234 0 : async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
235 0 : let stmt = typeinfo_composite_statement(client).await?;
236 :
237 0 : let rows = query::query(client, stmt, slice_iter(&[&oid]))
238 0 : .await?
239 0 : .try_collect::<Vec<_>>()
240 0 : .await?;
241 :
242 0 : let mut fields = vec![];
243 0 : for row in rows {
244 0 : let name = row.try_get(0)?;
245 0 : let oid = row.try_get(1)?;
246 0 : let type_ = get_type_rec(client, oid).await?;
247 0 : fields.push(Field::new(name, type_));
248 : }
249 :
250 0 : Ok(fields)
251 0 : }
252 :
253 0 : async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
254 0 : if let Some(stmt) = client.typeinfo_composite() {
255 0 : return Ok(stmt);
256 0 : }
257 :
258 0 : let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
259 :
260 0 : client.set_typeinfo_composite(&stmt);
261 0 : Ok(stmt)
262 0 : }
|