Skip to main content

diesel/pg/types/
array.rs

1use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
2use std::fmt;
3use std::io::Write;
4
5use crate::deserialize::{self, FromSql};
6use crate::pg::{Pg, PgTypeMetadata, PgValue};
7use crate::query_builder::bind_collector::ByteWrapper;
8use crate::serialize::{self, IsNull, Output, ToSql};
9use crate::sql_types::{Array, HasSqlType, Nullable};
10
11#[cfg(feature = "postgres_backend")]
12impl<T> HasSqlType<Array<T>> for Pg
13where
14    Pg: HasSqlType<T>,
15{
16    fn metadata(lookup: &mut Self::MetadataLookup) -> PgTypeMetadata {
17        match <Pg as HasSqlType<T>>::metadata(lookup).0 {
18            Ok(tpe) => PgTypeMetadata::new(tpe.array_oid, 0),
19            c @ Err(_) => PgTypeMetadata(c),
20        }
21    }
22}
23
24#[cfg(feature = "postgres_backend")]
25impl<T, ST> FromSql<Array<ST>, Pg> for Vec<T>
26where
27    T: FromSql<ST, Pg>,
28{
29    fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
30        let mut bytes = value.as_bytes();
31        let num_dimensions = bytes.read_i32::<NetworkEndian>()?;
32        let has_null = bytes.read_i32::<NetworkEndian>()? != 0;
33        let _oid = bytes.read_i32::<NetworkEndian>()?;
34
35        if num_dimensions == 0 {
36            return Ok(Vec::new());
37        }
38
39        let num_elements = bytes.read_i32::<NetworkEndian>()?;
40        let _lower_bound = bytes.read_i32::<NetworkEndian>()?;
41
42        if num_dimensions != 1 {
43            return Err("multi-dimensional arrays are not supported".into());
44        }
45
46        (0..num_elements)
47            .map(|_| -> deserialize::Result<_> {
48                let elem_size = bytes.read_i32::<NetworkEndian>()?;
49                if has_null && elem_size == -1 {
50                    T::from_nullable_sql(None)
51                } else {
52                    let (elem_bytes, new_bytes) = bytes
53                        .split_at_checked(elem_size.try_into()?)
54                        .ok_or_else(|| {
55                            ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("Invalid element byte count: Expected at least {1} bytes, but only {0} bytes were received",
                bytes.len(), elem_size))
    })format!(
56                                "Invalid element byte count: Expected at least {elem_size} bytes, but only {} bytes were received",
57                                bytes.len()
58                            )
59                        })?;
60                    bytes = new_bytes;
61                    T::from_sql(PgValue::new_internal(elem_bytes, &value))
62                }
63            })
64            .collect()
65    }
66}
67
68use crate::expression::bound::Bound;
69use crate::expression::AsExpression;
70
71macro_rules! array_as_expression {
72    ($ty:ty, $sql_type:ty) => {
73        #[cfg(feature = "postgres_backend")]
74        // this simplifies the macro implementation
75        // as some macro calls use this lifetime
76        #[allow(clippy::extra_unused_lifetimes)]
77        impl<'a, 'b, ST: 'static, T> AsExpression<$sql_type> for $ty {
78            type Expression = Bound<$sql_type, Self>;
79
80            fn as_expression(self) -> Self::Expression {
81                Bound::new(self)
82            }
83        }
84    };
85}
86
87#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a [T] {
    type Expression = Bound<Array<ST>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a [T], Array<ST>);
88#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a [T] {
    type Expression = Bound<Nullable<Array<ST>>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a [T], Nullable<Array<ST>>);
89#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a &'b [T] {
    type Expression = Bound<Array<ST>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b [T], Array<ST>);
90#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a &'b [T]
    {
    type Expression = Bound<Nullable<Array<ST>>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b [T], Nullable<Array<ST>>);
91#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for Vec<T> {
    type Expression = Bound<Array<ST>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(Vec<T>, Array<ST>);
92#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for Vec<T> {
    type Expression = Bound<Nullable<Array<ST>>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(Vec<T>, Nullable<Array<ST>>);
93#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a Vec<T> {
    type Expression = Bound<Array<ST>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a Vec<T>, Array<ST>);
94#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a Vec<T>
    {
    type Expression = Bound<Nullable<Array<ST>>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a Vec<T>, Nullable<Array<ST>>);
95#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a &'b Vec<T> {
    type Expression = Bound<Array<ST>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b Vec<T>, Array<ST>);
96#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for
    &'a &'b Vec<T> {
    type Expression = Bound<Nullable<Array<ST>>, Self>;
    fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b Vec<T>, Nullable<Array<ST>>);
97
98#[cfg(feature = "postgres_backend")]
99impl<ST, T> ToSql<Array<ST>, Pg> for [T]
100where
101    Pg: HasSqlType<ST>,
102    T: ToSql<ST, Pg>,
103{
104    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
105        let num_dimensions = 1;
106        out.write_i32::<NetworkEndian>(num_dimensions)?;
107        let flags = 0;
108        out.write_i32::<NetworkEndian>(flags)?;
109        let element_oid = Pg::metadata(out.metadata_lookup()).oid()?;
110        out.write_u32::<NetworkEndian>(element_oid)?;
111        out.write_i32::<NetworkEndian>(self.len().try_into()?)?;
112        let lower_bound = 1;
113        out.write_i32::<NetworkEndian>(lower_bound)?;
114
115        // This buffer is created outside of the loop to reuse the underlying memory allocation
116        // For most cases all array elements will have the same serialized size
117        let mut buffer = Vec::new();
118
119        for elem in self.iter() {
120            let is_null = {
121                let mut temp_buffer = Output::new(ByteWrapper(&mut buffer), out.metadata_lookup());
122                elem.to_sql(&mut temp_buffer)?
123            };
124
125            if let IsNull::No = is_null {
126                out.write_i32::<NetworkEndian>(buffer.len().try_into()?)?;
127                out.write_all(&buffer)?;
128                buffer.clear();
129            } else {
130                // https://github.com/postgres/postgres/blob/82f8107b92c9104ec9d9465f3f6a4c6dab4c124a/src/backend/utils/adt/arrayfuncs.c#L1461
131                out.write_i32::<NetworkEndian>(-1)?;
132            }
133        }
134
135        Ok(IsNull::No)
136    }
137}
138
139#[cfg(feature = "postgres_backend")]
140impl<ST, T> ToSql<Nullable<Array<ST>>, Pg> for [T]
141where
142    [T]: ToSql<Array<ST>, Pg>,
143    ST: 'static,
144{
145    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
146        ToSql::<Array<ST>, Pg>::to_sql(self, out)
147    }
148}
149
150#[cfg(feature = "postgres_backend")]
151impl<ST, T> ToSql<Array<ST>, Pg> for Vec<T>
152where
153    ST: 'static,
154    [T]: ToSql<Array<ST>, Pg>,
155    T: fmt::Debug,
156{
157    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
158        (self as &[T]).to_sql(out)
159    }
160}
161
162#[cfg(feature = "postgres_backend")]
163impl<ST, T> ToSql<Nullable<Array<ST>>, Pg> for Vec<T>
164where
165    ST: 'static,
166    Vec<T>: ToSql<Array<ST>, Pg>,
167{
168    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
169        ToSql::<Array<ST>, Pg>::to_sql(self, out)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use byteorder::{NetworkEndian, WriteBytesExt};
176
177    use crate::deserialize::FromSql;
178    use crate::pg::{Pg, PgValue};
179    use crate::sql_types::{Array, Integer};
180
181    #[test]
182    fn check_invalid_element_size_for_array() {
183        // check for the wrong element size
184        let mut value = Vec::<u8>::new();
185
186        // dimensions
187        value.write_i32::<NetworkEndian>(1).unwrap();
188        // has null
189        value.write_i32::<NetworkEndian>(0).unwrap();
190        // oid
191        value.write_i32::<NetworkEndian>(0).unwrap();
192        // num elements
193        value.write_i32::<NetworkEndian>(2).unwrap();
194        // lower bound
195        value.write_i32::<NetworkEndian>(0).unwrap();
196        // elem size element 1
197        value.write_i32::<NetworkEndian>(6).unwrap();
198        // the element itself
199        value.write_i32::<NetworkEndian>(42).unwrap();
200
201        let value = PgValue::for_test(&value);
202        let res = <Vec<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
203        assert!(res.is_err());
204        assert_eq!(
205            format!("{}", res.unwrap_err()),
206            "Invalid element byte count: Expected at least 6 bytes, but only 4 bytes were received",
207        );
208
209        // check for the wrong number of elements
210        let mut value = Vec::<u8>::new();
211
212        // dimensions
213        value.write_i32::<NetworkEndian>(1).unwrap();
214        // has null
215        value.write_i32::<NetworkEndian>(0).unwrap();
216        // oid
217        value.write_i32::<NetworkEndian>(0).unwrap();
218        // num elements
219        value.write_i32::<NetworkEndian>(2).unwrap();
220        // lower bound
221        value.write_i32::<NetworkEndian>(0).unwrap();
222        // elem size element 1
223        value.write_i32::<NetworkEndian>(4).unwrap();
224        // the element itself
225        value.write_i32::<NetworkEndian>(42).unwrap();
226
227        let value = PgValue::for_test(&value);
228        let res = <Vec<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
229        assert!(res.is_err());
230        assert_eq!(
231            format!("{}", res.unwrap_err()),
232            "failed to fill whole buffer"
233        );
234    }
235}