diesel/pg/types/
record.rs

1use byteorder::*;
2use std::io::Write;
3use std::num::NonZeroU32;
4
5use crate::deserialize::{self, FromSql, Queryable};
6use crate::expression::{
7    AppearsOnTable, AsExpression, Expression, SelectableExpression, TypedExpressionType,
8    ValidGrouping,
9};
10use crate::pg::{Pg, PgValue};
11use crate::query_builder::bind_collector::ByteWrapper;
12use crate::query_builder::{AstPass, QueryFragment, QueryId};
13use crate::result::QueryResult;
14use crate::serialize::{self, IsNull, Output, ToSql, WriteTuple};
15use crate::sql_types::{HasSqlType, Record, SqlType};
16
17macro_rules! tuple_impls {
18    ($(
19        $Tuple:tt {
20            $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
21        }
22    )+) => {$(
23        #[cfg(feature = "postgres_backend")]
24        impl<$($T,)+ $($ST,)+> FromSql<Record<($($ST,)+)>, Pg> for ($($T,)+)
25        where
26            $($T: FromSql<$ST, Pg>,)+
27        {
28            // Yes, we're relying on the order of evaluation of subexpressions
29            // but the only other option would be to use `mem::uninitialized`
30            // and `ptr::write`.
31            #[allow(clippy::mixed_read_write_in_expression)]
32            fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
33                let mut bytes = value.as_bytes();
34                let num_elements = bytes.read_i32::<NetworkEndian>()?;
35
36                if num_elements != $Tuple {
37                    return Err(format!(
38                        "Expected a tuple of {} elements, got {}",
39                        $Tuple,
40                        num_elements,
41                    ).into());
42                }
43
44                let result = ($({
45                    // We could in theory validate the OID here, but that
46                    // ignores cases like text vs varchar where the
47                    // representation is the same and we don't care which we
48                    // got.
49                    let oid = NonZeroU32::new(bytes.read_u32::<NetworkEndian>()?).expect("Oid's aren't zero");
50                    let num_bytes = bytes.read_i32::<NetworkEndian>()?;
51
52                    if num_bytes == -1 {
53                        $T::from_nullable_sql(None)?
54                    } else {
55                        let (elem_bytes, new_bytes) = bytes.split_at(num_bytes.try_into()?);
56                        bytes = new_bytes;
57                        $T::from_sql(PgValue::new_internal(
58                            elem_bytes,
59                            &oid,
60                        ))?
61                    }
62                },)+);
63
64                if bytes.is_empty() {
65                    Ok(result)
66                } else {
67                    Err("Received too many bytes. This tuple likely contains \
68                        an element of the wrong SQL type.".into())
69                }
70            }
71        }
72
73        #[cfg(feature = "postgres_backend")]
74        impl<$($T,)+ $($ST,)+> Queryable<Record<($($ST,)+)>, Pg> for ($($T,)+)
75        where Self: FromSql<Record<($($ST,)+)>, Pg>
76        {
77            type Row = Self;
78
79            fn build(row: Self::Row) -> deserialize::Result<Self> {
80                Ok(row)
81            }
82        }
83
84        #[cfg(feature = "postgres_backend")]
85        impl<$($T,)+ $($ST,)+> AsExpression<Record<($($ST,)+)>> for ($($T,)+)
86        where
87            $($ST: SqlType + TypedExpressionType,)+
88            $($T: AsExpression<$ST>,)+
89            PgTuple<($($T::Expression,)+)>: Expression<SqlType = Record<($($ST,)+)>>,
90        {
91            type Expression = PgTuple<($($T::Expression,)+)>;
92
93            fn as_expression(self) -> Self::Expression {
94                PgTuple(($(
95                    self.$idx.as_expression(),
96                )+))
97            }
98        }
99
100        #[cfg(feature = "postgres_backend")]
101        impl<$($T,)+ $($ST,)+> WriteTuple<($($ST,)+)> for ($($T,)+)
102        where
103            $($T: ToSql<$ST, Pg>,)+
104            $(Pg: HasSqlType<$ST>),+
105        {
106            fn write_tuple(&self, out: &mut Output<'_, '_, Pg>) -> serialize::Result {
107                let mut buffer = Vec::new();
108                out.write_i32::<NetworkEndian>($Tuple)?;
109
110                $(
111                    let oid = <Pg as HasSqlType<$ST>>::metadata(out.metadata_lookup()).oid()?;
112                    out.write_u32::<NetworkEndian>(oid)?;
113                    let is_null = {
114                        let mut temp_buffer = Output::new(ByteWrapper(&mut buffer), out.metadata_lookup());
115                        let is_null = self.$idx.to_sql(&mut temp_buffer)?;
116                        is_null
117                    };
118
119                    if let IsNull::No = is_null {
120                        out.write_i32::<NetworkEndian>(buffer.len().try_into()?)?;
121                        out.write_all(&buffer)?;
122                        buffer.clear();
123                    } else {
124                        out.write_i32::<NetworkEndian>(-1)?;
125                    }
126                )+
127
128                Ok(IsNull::No)
129            }
130        }
131    )+}
132}
133
134diesel_derives::__diesel_for_each_tuple!(tuple_impls);
135
136#[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
137pub struct PgTuple<T>(T);
138
139impl<T> QueryFragment<Pg> for PgTuple<T>
140where
141    T: QueryFragment<Pg>,
142{
143    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
144        out.push_sql("ROW(");
145        self.0.walk_ast(out.reborrow())?;
146        out.push_sql(")");
147        Ok(())
148    }
149}
150
151impl<T> Expression for PgTuple<T>
152where
153    T: Expression,
154    T::SqlType: 'static,
155{
156    type SqlType = Record<T::SqlType>;
157}
158
159impl<T, QS> SelectableExpression<QS> for PgTuple<T>
160where
161    T: SelectableExpression<QS>,
162    Self: AppearsOnTable<QS>,
163{
164}
165
166impl<T, QS> AppearsOnTable<QS> for PgTuple<T>
167where
168    T: AppearsOnTable<QS>,
169    Self: Expression,
170{
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::dsl::sql;
177    use crate::prelude::*;
178    use crate::sql_types::*;
179    use crate::test_helpers::*;
180
181    #[diesel_test_helper::test]
182    fn record_deserializes_correctly() {
183        let conn = &mut pg_connection();
184
185        let tup =
186            sql::<Record<(Integer, Text)>>("SELECT (1, 'hi')").get_result::<(i32, String)>(conn);
187        assert_eq!(Ok((1, String::from("hi"))), tup);
188
189        let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("SELECT ((2, 'bye'), 3)")
190            .get_result::<((i32, String), i32)>(conn);
191        assert_eq!(Ok(((2, String::from("bye")), 3)), tup);
192
193        let tup = sql::<
194            Record<(
195                Record<(Nullable<Integer>, Nullable<Text>)>,
196                Nullable<Integer>,
197            )>,
198        >("SELECT ((4, NULL), NULL)")
199        .get_result::<((Option<i32>, Option<String>), Option<i32>)>(conn);
200        assert_eq!(Ok(((Some(4), None), None)), tup);
201
202        let tup = sql::<Record<(Integer,)>>("SELECT ROW(1)").get_result::<(i32,)>(conn);
203        assert_eq!(Ok((1,)), tup);
204    }
205
206    #[diesel_test_helper::test]
207    fn record_kinda_sorta_not_really_serializes_correctly() {
208        let conn = &mut pg_connection();
209
210        let tup = sql::<Record<(Integer, Text)>>("(1, 'hi')");
211        let res = crate::select(tup.eq((1, "hi"))).get_result(conn);
212        assert_eq!(Ok(true), res);
213
214        let tup = sql::<Record<(Record<(Integer, Text)>, Integer)>>("((2, 'bye'::text), 3)");
215        let res = crate::select(tup.eq(((2, "bye"), 3))).get_result(conn);
216        assert_eq!(Ok(true), res);
217
218        let tup = sql::<Record<(Integer,)>>("ROW(3)");
219        let res = crate::select(tup.eq((3,))).get_result(conn);
220        assert_eq!(Ok(true), res);
221
222        let tup = sql::<
223            Record<(
224                Record<(Nullable<Integer>, Nullable<Text>)>,
225                Nullable<Integer>,
226            )>,
227        >("((4, NULL::text), NULL::int4)");
228        let res = crate::select(tup.is_not_distinct_from(((Some(4), None::<&str>), None::<i32>)))
229            .get_result(conn);
230        assert_eq!(Ok(true), res);
231    }
232
233    #[diesel_test_helper::test]
234    fn serializing_named_composite_types() {
235        #[derive(SqlType, QueryId, Debug, Clone, Copy)]
236        #[diesel(postgres_type(name = "my_type"))]
237        struct MyType;
238
239        #[derive(Debug, AsExpression)]
240        #[diesel(sql_type = MyType)]
241        struct MyStruct<'a>(i32, &'a str);
242
243        impl ToSql<MyType, Pg> for MyStruct<'_> {
244            fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
245                WriteTuple::<(Integer, Text)>::write_tuple(&(self.0, self.1), out)
246            }
247        }
248
249        let conn = &mut pg_connection();
250
251        crate::sql_query("CREATE TYPE my_type AS (i int4, t text)")
252            .execute(conn)
253            .unwrap();
254        let sql = sql::<Bool>("(1, 'hi')::my_type = ").bind::<MyType, _>(MyStruct(1, "hi"));
255        let res = crate::select(sql).get_result(conn);
256        assert_eq!(Ok(true), res);
257    }
258}