Line data Source code
1 : use bytes::BytesMut;
2 : use fallible_iterator::FallibleIterator;
3 : use postgres_protocol2::IsNull;
4 : use postgres_protocol2::message::backend::{Message, RowDescriptionBody};
5 : use postgres_protocol2::message::frontend;
6 : use postgres_protocol2::types::oid_to_sql;
7 : use postgres_types2::Format;
8 :
9 : use crate::client::{CachedTypeInfo, PartialQuery, Responses};
10 : use crate::types::{Kind, Oid, Type};
11 : use crate::{Column, Error, Row, Statement};
12 :
13 : pub(crate) const TYPEINFO_QUERY: &str = "\
14 : SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
15 : FROM pg_catalog.pg_type t
16 : LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
17 : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
18 : WHERE t.oid = $1
19 : ";
20 :
21 : /// we need to make sure we close this prepared statement.
22 : struct CloseStmt<'a, 'b> {
23 : client: Option<&'a mut PartialQuery<'b>>,
24 : name: &'static str,
25 : }
26 :
27 : impl<'a> CloseStmt<'a, '_> {
28 0 : fn close(mut self) -> Result<&'a mut Responses, Error> {
29 0 : let client = self.client.take().unwrap();
30 0 : client.send_with_flush(|buf| {
31 0 : frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
32 0 : Ok(())
33 0 : })
34 0 : }
35 : }
36 :
37 : impl Drop for CloseStmt<'_, '_> {
38 0 : fn drop(&mut self) {
39 0 : if let Some(client) = self.client.take() {
40 0 : let _ = client.send_with_flush(|buf| {
41 0 : frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
42 0 : Ok(())
43 0 : });
44 0 : }
45 0 : }
46 : }
47 :
48 0 : async fn prepare_typecheck(
49 0 : client: &mut PartialQuery<'_>,
50 0 : name: &'static str,
51 0 : query: &str,
52 0 : ) -> Result<Statement, Error> {
53 0 : let responses = client.send_with_flush(|buf| {
54 0 : frontend::parse(name, query, [], buf).map_err(Error::encode)?;
55 0 : frontend::describe(b'S', name, buf).map_err(Error::encode)?;
56 0 : Ok(())
57 0 : })?;
58 :
59 0 : match responses.next().await? {
60 0 : Message::ParseComplete => {}
61 0 : _ => return Err(Error::unexpected_message()),
62 : }
63 :
64 0 : match responses.next().await? {
65 0 : Message::ParameterDescription(_) => {}
66 0 : _ => return Err(Error::unexpected_message()),
67 : };
68 :
69 0 : let row_description = match responses.next().await? {
70 0 : Message::RowDescription(body) => Some(body),
71 0 : Message::NoData => None,
72 0 : _ => return Err(Error::unexpected_message()),
73 : };
74 :
75 0 : let mut columns = vec![];
76 0 : if let Some(row_description) = row_description {
77 0 : let mut it = row_description.fields();
78 0 : while let Some(field) = it.next().map_err(Error::parse)? {
79 0 : let type_ = Type::from_oid(field.type_oid()).ok_or_else(Error::unexpected_message)?;
80 0 : let column = Column::new(field.name().to_string(), type_, field);
81 0 : columns.push(column);
82 : }
83 0 : }
84 :
85 0 : Ok(Statement::new(name, columns))
86 0 : }
87 :
88 0 : fn try_from_cache(typecache: &CachedTypeInfo, oid: Oid) -> Option<Type> {
89 0 : if let Some(type_) = Type::from_oid(oid) {
90 0 : return Some(type_);
91 0 : }
92 :
93 0 : if let Some(type_) = typecache.types.get(&oid) {
94 0 : return Some(type_.clone());
95 0 : };
96 :
97 0 : None
98 0 : }
99 :
100 0 : pub async fn parse_row_description(
101 0 : client: &mut PartialQuery<'_>,
102 0 : typecache: &mut CachedTypeInfo,
103 0 : row_description: Option<RowDescriptionBody>,
104 0 : ) -> Result<Vec<Column>, Error> {
105 0 : let mut columns = vec![];
106 :
107 0 : if let Some(row_description) = row_description {
108 0 : let mut it = row_description.fields();
109 0 : while let Some(field) = it.next().map_err(Error::parse)? {
110 0 : let type_ = try_from_cache(typecache, field.type_oid()).unwrap_or(Type::UNKNOWN);
111 0 : let column = Column::new(field.name().to_string(), type_, field);
112 0 : columns.push(column);
113 0 : }
114 0 : }
115 :
116 0 : let all_known = columns.iter().all(|c| c.type_ != Type::UNKNOWN);
117 0 : if all_known {
118 : // all known, return early.
119 0 : return Ok(columns);
120 0 : }
121 :
122 0 : let typeinfo = "neon_proxy_typeinfo";
123 :
124 : // make sure to close the typeinfo statement before exiting.
125 0 : let mut guard = CloseStmt {
126 0 : name: typeinfo,
127 0 : client: None,
128 0 : };
129 0 : let client = guard.client.insert(client);
130 :
131 : // get the typeinfo statement.
132 0 : let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY).await?;
133 :
134 0 : for column in &mut columns {
135 0 : column.type_ = get_type(client, typecache, &stmt, column.type_oid()).await?;
136 : }
137 :
138 : // cancel the close guard.
139 0 : let responses = guard.close()?;
140 :
141 0 : match responses.next().await? {
142 0 : Message::CloseComplete => {}
143 0 : _ => return Err(Error::unexpected_message()),
144 : }
145 :
146 0 : Ok(columns)
147 0 : }
148 :
149 0 : async fn get_type(
150 0 : client: &mut PartialQuery<'_>,
151 0 : typecache: &mut CachedTypeInfo,
152 0 : stmt: &Statement,
153 0 : mut oid: Oid,
154 0 : ) -> Result<Type, Error> {
155 0 : let mut stack = vec![];
156 0 : let mut type_ = loop {
157 0 : if let Some(type_) = try_from_cache(typecache, oid) {
158 0 : break type_;
159 0 : }
160 :
161 0 : let row = exec(client, stmt, oid).await?;
162 0 : if stack.len() > 8 {
163 0 : return Err(Error::unexpected_message());
164 0 : }
165 :
166 0 : let name: String = row.try_get(0)?;
167 0 : let type_: i8 = row.try_get(1)?;
168 0 : let elem_oid: Oid = row.try_get(2)?;
169 0 : let rngsubtype: Option<Oid> = row.try_get(3)?;
170 0 : let basetype: Oid = row.try_get(4)?;
171 0 : let schema: String = row.try_get(5)?;
172 0 : let relid: Oid = row.try_get(6)?;
173 :
174 0 : let kind = if type_ == b'e' as i8 {
175 0 : Kind::Enum
176 0 : } else if type_ == b'p' as i8 {
177 0 : Kind::Pseudo
178 0 : } else if basetype != 0 {
179 0 : Kind::Domain(basetype)
180 0 : } else if elem_oid != 0 {
181 0 : stack.push((name, oid, schema));
182 0 : oid = elem_oid;
183 0 : continue;
184 0 : } else if relid != 0 {
185 0 : Kind::Composite(relid)
186 0 : } else if let Some(rngsubtype) = rngsubtype {
187 0 : Kind::Range(rngsubtype)
188 : } else {
189 0 : Kind::Simple
190 : };
191 :
192 0 : let type_ = Type::new(name, oid, kind, schema);
193 0 : typecache.types.insert(oid, type_.clone());
194 0 : break type_;
195 : };
196 :
197 0 : while let Some((name, oid, schema)) = stack.pop() {
198 0 : type_ = Type::new(name, oid, Kind::Array(type_), schema);
199 0 : typecache.types.insert(oid, type_.clone());
200 0 : }
201 :
202 0 : Ok(type_)
203 0 : }
204 :
205 : /// exec the typeinfo statement returning one row.
206 0 : async fn exec(
207 0 : client: &mut PartialQuery<'_>,
208 0 : statement: &Statement,
209 0 : param: Oid,
210 0 : ) -> Result<Row, Error> {
211 0 : let responses = client.send_with_flush(|buf| {
212 0 : encode_bind(statement, param, "", buf);
213 0 : frontend::execute("", 0, buf).map_err(Error::encode)?;
214 0 : Ok(())
215 0 : })?;
216 :
217 0 : match responses.next().await? {
218 0 : Message::BindComplete => {}
219 0 : _ => return Err(Error::unexpected_message()),
220 : }
221 :
222 0 : let row = match responses.next().await? {
223 0 : Message::DataRow(body) => Row::new(statement.clone(), body, Format::Binary)?,
224 0 : _ => return Err(Error::unexpected_message()),
225 : };
226 :
227 0 : match responses.next().await? {
228 0 : Message::CommandComplete(_) => {}
229 0 : _ => return Err(Error::unexpected_message()),
230 : };
231 :
232 0 : Ok(row)
233 0 : }
234 :
235 0 : fn encode_bind(statement: &Statement, param: Oid, portal: &str, buf: &mut BytesMut) {
236 0 : frontend::bind(
237 0 : portal,
238 0 : statement.name(),
239 0 : [Format::Binary as i16],
240 0 : [param],
241 0 : |param, buf| {
242 0 : oid_to_sql(param, buf);
243 0 : Ok(IsNull::No)
244 0 : },
245 0 : [Format::Binary as i16],
246 0 : buf,
247 : )
248 0 : .unwrap();
249 0 : }
|