Line data Source code
1 : use std::future::Future;
2 : use std::pin::Pin;
3 : use std::sync::Arc;
4 :
5 : use bytes::Bytes;
6 : use fallible_iterator::FallibleIterator;
7 : use futures_util::{TryStreamExt, pin_mut};
8 : use log::debug;
9 : use postgres_protocol2::message::backend::Message;
10 : use postgres_protocol2::message::frontend;
11 :
12 : use crate::client::{CachedTypeInfo, InnerClient};
13 : use crate::codec::FrontendMessage;
14 : use crate::connection::RequestMessages;
15 : use crate::types::{Kind, Oid, Type};
16 : use crate::{Column, Error, Statement, query, slice_iter};
17 :
18 : pub(crate) const TYPEINFO_QUERY: &str = "\
19 : SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
20 : FROM pg_catalog.pg_type t
21 : LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
22 : INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
23 : WHERE t.oid = $1
24 : ";
25 :
26 0 : async fn prepare_typecheck(
27 0 : client: &Arc<InnerClient>,
28 0 : name: &'static str,
29 0 : query: &str,
30 0 : types: &[Type],
31 0 : ) -> Result<Statement, Error> {
32 0 : let buf = encode(client, name, query, types)?;
33 0 : let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
34 :
35 0 : match responses.next().await? {
36 0 : Message::ParseComplete => {}
37 0 : _ => return Err(Error::unexpected_message()),
38 : }
39 :
40 0 : let parameter_description = match responses.next().await? {
41 0 : Message::ParameterDescription(body) => body,
42 0 : _ => return Err(Error::unexpected_message()),
43 : };
44 :
45 0 : let row_description = match responses.next().await? {
46 0 : Message::RowDescription(body) => Some(body),
47 0 : Message::NoData => None,
48 0 : _ => return Err(Error::unexpected_message()),
49 : };
50 :
51 0 : let mut parameters = vec![];
52 0 : let mut it = parameter_description.parameters();
53 0 : while let Some(oid) = it.next().map_err(Error::parse)? {
54 0 : let type_ = Type::from_oid(oid).ok_or_else(Error::unexpected_message)?;
55 0 : parameters.push(type_);
56 : }
57 :
58 0 : let mut columns = vec![];
59 0 : if let Some(row_description) = row_description {
60 0 : let mut it = row_description.fields();
61 0 : while let Some(field) = it.next().map_err(Error::parse)? {
62 0 : let type_ = Type::from_oid(field.type_oid()).ok_or_else(Error::unexpected_message)?;
63 0 : let column = Column::new(field.name().to_string(), type_, field);
64 0 : columns.push(column);
65 : }
66 0 : }
67 :
68 0 : Ok(Statement::new(client, name, parameters, columns))
69 0 : }
70 :
71 0 : fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
72 0 : if types.is_empty() {
73 0 : debug!("preparing query {}: {}", name, query);
74 : } else {
75 0 : debug!("preparing query {} with types {:?}: {}", name, types, query);
76 : }
77 :
78 0 : client.with_buf(|buf| {
79 0 : frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
80 0 : frontend::describe(b'S', name, buf).map_err(Error::encode)?;
81 0 : frontend::sync(buf);
82 0 : Ok(buf.split().freeze())
83 0 : })
84 0 : }
85 :
86 0 : pub async fn get_type(
87 0 : client: &Arc<InnerClient>,
88 0 : typecache: &mut CachedTypeInfo,
89 0 : oid: Oid,
90 0 : ) -> Result<Type, Error> {
91 0 : if let Some(type_) = Type::from_oid(oid) {
92 0 : return Ok(type_);
93 0 : }
94 :
95 0 : if let Some(type_) = typecache.types.get(&oid) {
96 0 : return Ok(type_.clone());
97 0 : };
98 :
99 0 : let stmt = typeinfo_statement(client, typecache).await?;
100 :
101 0 : let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
102 0 : pin_mut!(rows);
103 :
104 0 : let row = match rows.try_next().await? {
105 0 : Some(row) => row,
106 0 : None => return Err(Error::unexpected_message()),
107 : };
108 :
109 0 : let name: String = row.try_get(0)?;
110 0 : let type_: i8 = row.try_get(1)?;
111 0 : let elem_oid: Oid = row.try_get(2)?;
112 0 : let rngsubtype: Option<Oid> = row.try_get(3)?;
113 0 : let basetype: Oid = row.try_get(4)?;
114 0 : let schema: String = row.try_get(5)?;
115 0 : let relid: Oid = row.try_get(6)?;
116 :
117 0 : let kind = if type_ == b'e' as i8 {
118 0 : Kind::Enum
119 0 : } else if type_ == b'p' as i8 {
120 0 : Kind::Pseudo
121 0 : } else if basetype != 0 {
122 0 : Kind::Domain(basetype)
123 0 : } else if elem_oid != 0 {
124 0 : let type_ = get_type_rec(client, typecache, elem_oid).await?;
125 0 : Kind::Array(type_)
126 0 : } else if relid != 0 {
127 0 : Kind::Composite(relid)
128 0 : } else if let Some(rngsubtype) = rngsubtype {
129 0 : let type_ = get_type_rec(client, typecache, rngsubtype).await?;
130 0 : Kind::Range(type_)
131 : } else {
132 0 : Kind::Simple
133 : };
134 :
135 0 : let type_ = Type::new(name, oid, kind, schema);
136 0 : typecache.types.insert(oid, type_.clone());
137 0 :
138 0 : Ok(type_)
139 0 : }
140 :
141 0 : fn get_type_rec<'a>(
142 0 : client: &'a Arc<InnerClient>,
143 0 : typecache: &'a mut CachedTypeInfo,
144 0 : oid: Oid,
145 0 : ) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
146 0 : Box::pin(get_type(client, typecache, oid))
147 0 : }
148 :
149 0 : async fn typeinfo_statement(
150 0 : client: &Arc<InnerClient>,
151 0 : typecache: &mut CachedTypeInfo,
152 0 : ) -> Result<Statement, Error> {
153 0 : if let Some(stmt) = &typecache.typeinfo {
154 0 : return Ok(stmt.clone());
155 0 : }
156 0 :
157 0 : let typeinfo = "neon_proxy_typeinfo";
158 0 : let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?;
159 :
160 0 : typecache.typeinfo = Some(stmt.clone());
161 0 : Ok(stmt)
162 0 : }
|