1use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
2use core::fmt;
3use std::io::Write;
4
5use crate::deserialize::{self, FromSql, FromSqlRow};
6use crate::pg::{Pg, PgTypeMetadata, PgValue};
7use crate::query_builder::bind_collector::ByteWrapper;
8use crate::serialize::{self, IsNull, Output, ToSql};
9use crate::sql_types::{Array, HasSqlType, Nullable};
10
11#[cfg(feature = "postgres_backend")]
12#[derive(#[automatically_derived]
impl<T: ::core::fmt::Debug> ::core::fmt::Debug for NdArray<T> {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::debug_struct_field2_finish(f, "NdArray",
"dims", &self.dims, "data", &&self.data)
}
}Debug, #[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for NdArray<T> {
#[inline]
fn clone(&self) -> NdArray<T> {
NdArray {
dims: ::core::clone::Clone::clone(&self.dims),
data: ::core::clone::Clone::clone(&self.data),
}
}
}Clone, #[automatically_derived]
impl<T: ::core::cmp::PartialEq> ::core::cmp::PartialEq for NdArray<T> {
#[inline]
fn eq(&self, other: &NdArray<T>) -> bool {
self.dims == other.dims && self.data == other.data
}
}PartialEq, #[automatically_derived]
impl<T: ::core::cmp::Eq> ::core::cmp::Eq for NdArray<T> {
#[inline]
#[doc(hidden)]
#[coverage(off)]
fn assert_fields_are_eq(&self) {
let _: ::core::cmp::AssertParamIsEq<Vec<usize>>;
let _: ::core::cmp::AssertParamIsEq<Vec<T>>;
}
}Eq, #[automatically_derived]
impl<T: ::core::hash::Hash> ::core::hash::Hash for NdArray<T> {
#[inline]
fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) {
::core::hash::Hash::hash(&self.dims, state);
::core::hash::Hash::hash(&self.data, state)
}
}Hash, #[automatically_derived]
impl<T: ::core::cmp::PartialOrd> ::core::cmp::PartialOrd for NdArray<T> {
#[inline]
fn partial_cmp(&self, other: &NdArray<T>)
-> ::core::option::Option<::core::cmp::Ordering> {
match ::core::cmp::PartialOrd::partial_cmp(&self.dims, &other.dims) {
::core::option::Option::Some(::core::cmp::Ordering::Equal) =>
::core::cmp::PartialOrd::partial_cmp(&self.data, &other.data),
cmp => cmp,
}
}
}PartialOrd, #[automatically_derived]
impl<T: ::core::cmp::Ord> ::core::cmp::Ord for NdArray<T> {
#[inline]
fn cmp(&self, other: &NdArray<T>) -> ::core::cmp::Ordering {
match ::core::cmp::Ord::cmp(&self.dims, &other.dims) {
::core::cmp::Ordering::Equal =>
::core::cmp::Ord::cmp(&self.data, &other.data),
cmp => cmp,
}
}
}Ord, const _: () =
{
use diesel;
impl<'__expr, T> diesel::expression::AsExpression<Array<T>> for
&'__expr NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<Array<T>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<Array<T>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
#[diagnostic::do_not_recommend]
impl<'__expr, T>
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>
for &'__expr NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<diesel::sql_types::Nullable<Array<T>>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
#[diagnostic::do_not_recommend]
impl<'__expr, '__expr2, T> diesel::expression::AsExpression<Array<T>>
for &'__expr2 &'__expr NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<Array<T>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<Array<T>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
#[diagnostic::do_not_recommend]
impl<'__expr, '__expr2, T>
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>
for &'__expr2 &'__expr NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<diesel::sql_types::Nullable<Array<T>>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
impl<T, __DB>
diesel::serialize::ToSql<diesel::sql_types::Nullable<Array<T>>,
__DB> for NdArray<T> where __DB: diesel::backend::Backend,
Self: diesel::serialize::ToSql<Array<T>, __DB> {
fn to_sql<'__b>(&'__b self,
out: &mut diesel::serialize::Output<'__b, '_, __DB>)
-> diesel::serialize::Result {
diesel::serialize::ToSql::<Array<T>, __DB>::to_sql(self, out)
}
}
impl<T> diesel::expression::AsExpression<Array<T>> for NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<Array<T>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<Array<T>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
impl<T>
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>
for NdArray<T> {
type Expression =
diesel::internal::derives::as_expression::Bound<diesel::sql_types::Nullable<Array<T>>,
Self>;
fn as_expression(self)
->
<Self as
diesel::expression::AsExpression<diesel::sql_types::Nullable<Array<T>>>>::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
};AsExpression, const _: () =
{
use diesel;
impl<T, __DB, __ST> diesel::deserialize::Queryable<__ST, __DB> for
NdArray<T> where __DB: diesel::backend::Backend,
__ST: diesel::sql_types::SingleValue,
Self: diesel::deserialize::FromSql<__ST, __DB> {
type Row = Self;
fn build(row: Self) -> diesel::deserialize::Result<Self> {
diesel::deserialize::Result::Ok(row)
}
}
};FromSqlRow)]
13#[diesel(sql_type = Array<T>)]
14pub struct NdArray<T> {
18 pub dims: Vec<usize>,
20 pub data: Vec<T>,
28}
29
30#[cfg(feature = "postgres_backend")]
31impl<T> HasSqlType<Array<T>> for Pg
32where
33 Pg: HasSqlType<T>,
34{
35 fn metadata(lookup: &mut Self::MetadataLookup) -> PgTypeMetadata {
36 match <Pg as HasSqlType<T>>::metadata(lookup).0 {
37 Ok(tpe) => PgTypeMetadata::new(tpe.array_oid, 0),
38 c @ Err(_) => PgTypeMetadata(c),
39 }
40 }
41}
42
43#[cfg(feature = "postgres_backend")]
44impl<T, ST> FromSql<Array<ST>, Pg> for Vec<T>
45where
46 T: FromSql<ST, Pg>,
47{
48 fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
49 let mut bytes = value.as_bytes();
50 let num_dimensions = bytes.read_i32::<NetworkEndian>()?;
51 let has_null = bytes.read_i32::<NetworkEndian>()? != 0;
52 let _oid = bytes.read_i32::<NetworkEndian>()?;
53
54 if num_dimensions == 0 {
55 return Ok(Vec::new());
56 }
57
58 let num_elements = bytes.read_i32::<NetworkEndian>()?;
59 let _lower_bound = bytes.read_i32::<NetworkEndian>()?;
60
61 if num_dimensions != 1 {
62 return Err("multi-dimensional arrays are not supported".into());
63 }
64
65 (0..num_elements)
66 .map(|_| -> deserialize::Result<_> {
67 let elem_size = bytes.read_i32::<NetworkEndian>()?;
68 if has_null && elem_size == -1 {
69 T::from_nullable_sql(None)
70 } else {
71 let (elem_bytes, new_bytes) = bytes
72 .split_at_checked(elem_size.try_into()?)
73 .ok_or_else(|| {
74 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("Invalid element byte count: Expected at least {1} bytes, but only {0} bytes were received",
bytes.len(), elem_size))
})format!(
75 "Invalid element byte count: Expected at least {elem_size} bytes, but only {} bytes were received",
76 bytes.len()
77 )
78 })?;
79 bytes = new_bytes;
80 T::from_sql(PgValue::new_internal(elem_bytes, &value))
81 }
82 })
83 .collect()
84 }
85}
86
87#[cfg(feature = "postgres_backend")]
88impl<T, ST> FromSql<Array<ST>, Pg> for NdArray<T>
89where
90 T: FromSql<ST, Pg>,
91{
92 fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
93 let mut bytes = value.as_bytes();
94 let num_dimensions = bytes.read_i32::<NetworkEndian>()?;
95 let has_null = bytes.read_i32::<NetworkEndian>()? != 0;
96 let _oid = bytes.read_i32::<NetworkEndian>()?;
97
98 if num_dimensions == 0 {
99 return Ok(NdArray {
100 dims: Vec::new(),
101 data: Vec::new(),
102 });
103 }
104
105 let num_dims: usize = num_dimensions
106 .try_into()
107 .map_err(|_| "number of dimensions must be positive")?;
108
109 let dims = (0..num_dims)
110 .map(|_| {
111 let num_elements = bytes.read_i32::<NetworkEndian>()?;
112 let _lower_bound = bytes.read_i32::<NetworkEndian>()?;
113
114 let dim: usize = num_elements
115 .try_into()
116 .map_err(|_| "array dimension length must be positive")?;
117 Ok(dim)
118 })
119 .collect::<deserialize::Result<Vec<_>>>()?;
120
121 let max_dim = dims
122 .iter()
123 .try_fold(1_usize, |a, b| a.checked_mul(*b))
124 .ok_or("Overflow while deserializing package size")?;
125
126 let data = (0..max_dim)
127 .map(|_| -> deserialize::Result<T> {
128 let elem_size = bytes.read_i32::<NetworkEndian>()?;
129 if has_null && elem_size == -1 {
130 T::from_nullable_sql(None)
131 } else {
132 let (elem_bytes, new_bytes) = bytes
133 .split_at_checked(elem_size.try_into()?)
134 .ok_or_else(|| {
135 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("Invalid element byte count: Expected at least {1} bytes, but only {0} bytes were received",
bytes.len(), elem_size))
})format!(
136 "Invalid element byte count: Expected at least {elem_size} bytes, but only {} bytes were received",
137 bytes.len()
138 )
139 })?;
140 bytes = new_bytes;
141 T::from_sql(PgValue::new_internal(elem_bytes, &value))
142 }
143 })
144 .collect::<deserialize::Result<Vec<T>>>()?;
145 Ok(NdArray { dims, data })
146 }
147}
148
149use crate::expression::AsExpression;
150use crate::expression::bound::Bound;
151
152macro_rules! array_as_expression {
153 ($ty:ty, $sql_type:ty) => {
154 #[cfg(feature = "postgres_backend")]
155 #[allow(clippy::extra_unused_lifetimes)]
158 impl<'a, 'b, ST: 'static, T> AsExpression<$sql_type> for $ty {
159 type Expression = Bound<$sql_type, Self>;
160
161 fn as_expression(self) -> Self::Expression {
162 Bound::new(self)
163 }
164 }
165 };
166}
167
168#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a [T] {
type Expression = Bound<Array<ST>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a [T], Array<ST>);
169#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a [T] {
type Expression = Bound<Nullable<Array<ST>>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a [T], Nullable<Array<ST>>);
170#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a &'b [T] {
type Expression = Bound<Array<ST>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b [T], Array<ST>);
171#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a &'b [T]
{
type Expression = Bound<Nullable<Array<ST>>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b [T], Nullable<Array<ST>>);
172#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for Vec<T> {
type Expression = Bound<Array<ST>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(Vec<T>, Array<ST>);
173#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for Vec<T> {
type Expression = Bound<Nullable<Array<ST>>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(Vec<T>, Nullable<Array<ST>>);
174#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a Vec<T> {
type Expression = Bound<Array<ST>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a Vec<T>, Array<ST>);
175#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for &'a Vec<T>
{
type Expression = Bound<Nullable<Array<ST>>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a Vec<T>, Nullable<Array<ST>>);
176#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Array<ST>> for &'a &'b Vec<T> {
type Expression = Bound<Array<ST>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b Vec<T>, Array<ST>);
177#[allow(clippy :: extra_unused_lifetimes)]
impl<'a, 'b, ST: 'static, T> AsExpression<Nullable<Array<ST>>> for
&'a &'b Vec<T> {
type Expression = Bound<Nullable<Array<ST>>, Self>;
fn as_expression(self) -> Self::Expression { Bound::new(self) }
}array_as_expression!(&'a &'b Vec<T>, Nullable<Array<ST>>);
178
179#[cfg(feature = "postgres_backend")]
180impl<ST, T> ToSql<Array<ST>, Pg> for [T]
181where
182 Pg: HasSqlType<ST>,
183 T: ToSql<ST, Pg>,
184{
185 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
186 let num_dimensions = 1;
187 out.write_i32::<NetworkEndian>(num_dimensions)?;
188 let flags = 0;
189 out.write_i32::<NetworkEndian>(flags)?;
190 let element_oid = Pg::metadata(out.metadata_lookup()).oid()?;
191 out.write_u32::<NetworkEndian>(element_oid)?;
192 out.write_i32::<NetworkEndian>(self.len().try_into()?)?;
193 let lower_bound = 1;
194 out.write_i32::<NetworkEndian>(lower_bound)?;
195
196 let mut buffer = Vec::new();
199
200 for elem in self.iter() {
201 let is_null = {
202 let mut temp_buffer = Output::new(ByteWrapper(&mut buffer), out.metadata_lookup());
203 elem.to_sql(&mut temp_buffer)?
204 };
205
206 if let IsNull::No = is_null {
207 out.write_i32::<NetworkEndian>(buffer.len().try_into()?)?;
208 out.write_all(&buffer)?;
209 buffer.clear();
210 } else {
211 out.write_i32::<NetworkEndian>(-1)?;
213 }
214 }
215
216 Ok(IsNull::No)
217 }
218}
219
220#[cfg(feature = "postgres_backend")]
221impl<ST, T> ToSql<Nullable<Array<ST>>, Pg> for [T]
222where
223 [T]: ToSql<Array<ST>, Pg>,
224 ST: 'static,
225{
226 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
227 ToSql::<Array<ST>, Pg>::to_sql(self, out)
228 }
229}
230
231#[cfg(feature = "postgres_backend")]
232impl<ST, T> ToSql<Array<ST>, Pg> for Vec<T>
233where
234 ST: 'static,
235 [T]: ToSql<Array<ST>, Pg>,
236 T: fmt::Debug,
237{
238 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
239 (self as &[T]).to_sql(out)
240 }
241}
242
243#[cfg(feature = "postgres_backend")]
244impl<ST, T> ToSql<Nullable<Array<ST>>, Pg> for Vec<T>
245where
246 ST: 'static,
247 Vec<T>: ToSql<Array<ST>, Pg>,
248{
249 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
250 ToSql::<Array<ST>, Pg>::to_sql(self, out)
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use byteorder::{NetworkEndian, WriteBytesExt};
257
258 use crate::data_types::NdArray;
259 use crate::deserialize::FromSql;
260 use crate::pg::{Pg, PgValue};
261 use crate::sql_types::{Array, Integer};
262
263 #[test]
264 fn check_invalid_element_size_for_array() {
265 let mut value = Vec::<u8>::new();
267
268 value.write_i32::<NetworkEndian>(1).unwrap();
270 value.write_i32::<NetworkEndian>(0).unwrap();
272 value.write_i32::<NetworkEndian>(0).unwrap();
274 value.write_i32::<NetworkEndian>(2).unwrap();
276 value.write_i32::<NetworkEndian>(0).unwrap();
278 value.write_i32::<NetworkEndian>(6).unwrap();
280 value.write_i32::<NetworkEndian>(42).unwrap();
282
283 let value = PgValue::for_test(&value);
284 let res = <Vec<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
285 assert!(res.is_err());
286 assert_eq!(
287 format!("{}", res.unwrap_err()),
288 "Invalid element byte count: Expected at least 6 bytes, but only 4 bytes were received",
289 );
290
291 let mut value = Vec::<u8>::new();
293
294 value.write_i32::<NetworkEndian>(1).unwrap();
296 value.write_i32::<NetworkEndian>(0).unwrap();
298 value.write_i32::<NetworkEndian>(0).unwrap();
300 value.write_i32::<NetworkEndian>(2).unwrap();
302 value.write_i32::<NetworkEndian>(0).unwrap();
304 value.write_i32::<NetworkEndian>(4).unwrap();
306 value.write_i32::<NetworkEndian>(42).unwrap();
308
309 let value = PgValue::for_test(&value);
310 let res = <Vec<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
311 assert!(res.is_err());
312 assert_eq!(
313 format!("{}", res.unwrap_err()),
314 "failed to fill whole buffer"
315 );
316 }
317
318 #[test]
319 fn check_invalid_element_size_for_multidimensional_array() {
320 let mut value = Vec::<u8>::new();
322
323 value.write_i32::<NetworkEndian>(1).unwrap();
325 value.write_i32::<NetworkEndian>(0).unwrap();
327 value.write_i32::<NetworkEndian>(0).unwrap();
329 value.write_i32::<NetworkEndian>(2).unwrap();
331 value.write_i32::<NetworkEndian>(0).unwrap();
333 value.write_i32::<NetworkEndian>(6).unwrap();
335 value.write_i32::<NetworkEndian>(42).unwrap();
337
338 let value = PgValue::for_test(&value);
339 let res = <NdArray<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
340 assert!(res.is_err());
341 assert_eq!(
342 format!("{}", res.unwrap_err()),
343 "Invalid element byte count: Expected at least 6 bytes, but only 4 bytes were received",
344 );
345
346 let mut value = Vec::<u8>::new();
348
349 value.write_i32::<NetworkEndian>(1).unwrap();
351 value.write_i32::<NetworkEndian>(0).unwrap();
353 value.write_i32::<NetworkEndian>(0).unwrap();
355 value.write_i32::<NetworkEndian>(2).unwrap();
357 value.write_i32::<NetworkEndian>(0).unwrap();
359 value.write_i32::<NetworkEndian>(4).unwrap();
361 value.write_i32::<NetworkEndian>(42).unwrap();
363
364 let value = PgValue::for_test(&value);
365 let res = <NdArray<i32> as FromSql<Array<Integer>, Pg>>::from_sql(value);
366 assert!(res.is_err());
367 assert_eq!(
368 format!("{}", res.unwrap_err()),
369 "failed to fill whole buffer"
370 );
371 }
372}