diesel/pg/types/
ranges.rs

1use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
2use std::collections::Bound;
3use std::error::Error;
4use std::io::Write;
5
6use crate::deserialize::{self, Defaultable, FromSql, Queryable};
7use crate::expression::bound::Bound as SqlBound;
8use crate::expression::AsExpression;
9use crate::pg::{Pg, PgTypeMetadata, PgValue};
10use crate::query_builder::bind_collector::ByteWrapper;
11use crate::serialize::{self, IsNull, Output, ToSql};
12use crate::sql_types::*;
13
14// https://github.com/postgres/postgres/blob/113b0045e20d40f726a0a30e33214455e4f1385e/src/include/utils/rangetypes.h#L35-L43
15bitflags::bitflags! {
16    struct RangeFlags: u8 {
17        const EMPTY = 0x01;
18        const LB_INC = 0x02;
19        const UB_INC = 0x04;
20        const LB_INF = 0x08;
21        const UB_INF = 0x10;
22        const LB_NULL = 0x20;
23        const UB_NULL = 0x40;
24        const CONTAIN_EMPTY = 0x80;
25    }
26}
27
28macro_rules! range_as_expression {
29    ($ty:ty; $sql_type:ty) => {
30        #[cfg(feature = "postgres_backend")]
31        // this simplifies the macro implementation
32        // as some macro calls use this lifetime
33        #[allow(clippy::extra_unused_lifetimes)]
34        impl<'a, ST: 'static, T> AsExpression<$sql_type> for $ty {
35            type Expression = SqlBound<$sql_type, Self>;
36
37            fn as_expression(self) -> Self::Expression {
38                SqlBound::new(self)
39            }
40        }
41    };
42}
43
44range_as_expression!((Bound<T>, Bound<T>); Range<ST>);
45range_as_expression!(&'a (Bound<T>, Bound<T>); Range<ST>);
46range_as_expression!((Bound<T>, Bound<T>); Nullable<Range<ST>>);
47range_as_expression!(&'a (Bound<T>, Bound<T>); Nullable<Range<ST>>);
48
49range_as_expression!(std::ops::Range<T>; Range<ST>);
50range_as_expression!(&'a std::ops::Range<T>; Range<ST>);
51range_as_expression!(std::ops::Range<T>; Nullable<Range<ST>>);
52range_as_expression!(&'a std::ops::Range<T>; Nullable<Range<ST>>);
53
54range_as_expression!(std::ops::RangeInclusive<T>; Range<ST>);
55range_as_expression!(&'a std::ops::RangeInclusive<T>; Range<ST>);
56range_as_expression!(std::ops::RangeInclusive<T>; Nullable<Range<ST>>);
57range_as_expression!(&'a std::ops::RangeInclusive<T>; Nullable<Range<ST>>);
58
59range_as_expression!(std::ops::RangeToInclusive<T>; Range<ST>);
60range_as_expression!(&'a std::ops::RangeToInclusive<T>; Range<ST>);
61range_as_expression!(std::ops::RangeToInclusive<T>; Nullable<Range<ST>>);
62range_as_expression!(&'a std::ops::RangeToInclusive<T>; Nullable<Range<ST>>);
63
64range_as_expression!(std::ops::RangeFrom<T>; Range<ST>);
65range_as_expression!(&'a std::ops::RangeFrom<T>; Range<ST>);
66range_as_expression!(std::ops::RangeFrom<T>; Nullable<Range<ST>>);
67range_as_expression!(&'a std::ops::RangeFrom<T>; Nullable<Range<ST>>);
68
69range_as_expression!(std::ops::RangeTo<T>; Range<ST>);
70range_as_expression!(&'a std::ops::RangeTo<T>; Range<ST>);
71range_as_expression!(std::ops::RangeTo<T>; Nullable<Range<ST>>);
72range_as_expression!(&'a std::ops::RangeTo<T>; Nullable<Range<ST>>);
73
74#[cfg(feature = "postgres_backend")]
75impl<T, ST> FromSql<Range<ST>, Pg> for (Bound<T>, Bound<T>)
76where
77    T: FromSql<ST, Pg> + Defaultable,
78{
79    fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
80        let mut bytes = value.as_bytes();
81        let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
82        let mut lower_bound = Bound::Unbounded;
83        let mut upper_bound = Bound::Unbounded;
84
85        if flags.contains(RangeFlags::EMPTY) {
86            lower_bound = Bound::Excluded(T::default_value());
87        } else if !flags.contains(RangeFlags::LB_INF) {
88            let elem_size = bytes.read_i32::<NetworkEndian>()?;
89            let (elem_bytes, new_bytes) = bytes.split_at(elem_size.try_into()?);
90            bytes = new_bytes;
91            let value = T::from_sql(PgValue::new_internal(elem_bytes, &value))?;
92
93            lower_bound = if flags.contains(RangeFlags::LB_INC) {
94                Bound::Included(value)
95            } else {
96                Bound::Excluded(value)
97            };
98        }
99
100        if flags.contains(RangeFlags::EMPTY) {
101            upper_bound = Bound::Excluded(T::default_value());
102        } else if !flags.contains(RangeFlags::UB_INF) {
103            let _size = bytes.read_i32::<NetworkEndian>()?;
104            let value = T::from_sql(PgValue::new_internal(bytes, &value))?;
105
106            upper_bound = if flags.contains(RangeFlags::UB_INC) {
107                Bound::Included(value)
108            } else {
109                Bound::Excluded(value)
110            };
111        }
112
113        Ok((lower_bound, upper_bound))
114    }
115}
116
117#[cfg(feature = "postgres_backend")]
118impl<T, ST> Queryable<Range<ST>, Pg> for (Bound<T>, Bound<T>)
119where
120    T: FromSql<ST, Pg> + Defaultable,
121{
122    type Row = Self;
123
124    fn build(row: Self) -> deserialize::Result<Self> {
125        Ok(row)
126    }
127}
128
129#[cfg(feature = "postgres_backend")]
130fn to_sql<ST, T>(
131    start: Bound<&T>,
132    end: Bound<&T>,
133    out: &mut Output<'_, '_, Pg>,
134) -> serialize::Result
135where
136    T: ToSql<ST, Pg>,
137{
138    let mut flags = match start {
139        Bound::Included(_) => RangeFlags::LB_INC,
140        Bound::Excluded(_) => RangeFlags::empty(),
141        Bound::Unbounded => RangeFlags::LB_INF,
142    };
143
144    flags |= match end {
145        Bound::Included(_) => RangeFlags::UB_INC,
146        Bound::Excluded(_) => RangeFlags::empty(),
147        Bound::Unbounded => RangeFlags::UB_INF,
148    };
149
150    out.write_u8(flags.bits())?;
151
152    let mut buffer = Vec::new();
153
154    match start {
155        Bound::Included(ref value) | Bound::Excluded(ref value) => {
156            {
157                let mut inner_buffer = Output::new(ByteWrapper(&mut buffer), out.metadata_lookup());
158                value.to_sql(&mut inner_buffer)?;
159            }
160            out.write_u32::<NetworkEndian>(buffer.len().try_into()?)?;
161            out.write_all(&buffer)?;
162            buffer.clear();
163        }
164        Bound::Unbounded => {}
165    }
166
167    match end {
168        Bound::Included(ref value) | Bound::Excluded(ref value) => {
169            {
170                let mut inner_buffer = Output::new(ByteWrapper(&mut buffer), out.metadata_lookup());
171                value.to_sql(&mut inner_buffer)?;
172            }
173            out.write_u32::<NetworkEndian>(buffer.len().try_into()?)?;
174            out.write_all(&buffer)?;
175        }
176        Bound::Unbounded => {}
177    }
178
179    Ok(IsNull::No)
180}
181
182#[cfg(feature = "postgres_backend")]
183impl<ST, T> ToSql<Range<ST>, Pg> for (Bound<T>, Bound<T>)
184where
185    T: ToSql<ST, Pg>,
186{
187    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
188        to_sql(self.0.as_ref(), self.1.as_ref(), out)
189    }
190}
191
192use std::ops::RangeBounds;
193macro_rules! range_std_to_sql {
194    ($ty:ty) => {
195        #[cfg(feature = "postgres_backend")]
196        impl<ST, T> ToSql<Range<ST>, Pg> for $ty
197        where
198            ST: 'static,
199            T: ToSql<ST, Pg>,
200        {
201            fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
202                to_sql(self.start_bound(), self.end_bound(), out)
203            }
204        }
205    };
206}
207
208range_std_to_sql!(std::ops::Range<T>);
209range_std_to_sql!(std::ops::RangeInclusive<T>);
210range_std_to_sql!(std::ops::RangeFrom<T>);
211range_std_to_sql!(std::ops::RangeTo<T>);
212range_std_to_sql!(std::ops::RangeToInclusive<T>);
213
214macro_rules! range_to_sql_nullable {
215    ($ty:ty) => {
216        #[cfg(feature = "postgres_backend")]
217        impl<ST, T> ToSql<Nullable<Range<ST>>, Pg> for $ty
218        where
219            ST: 'static,
220            $ty: ToSql<Range<ST>, Pg>,
221        {
222            fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
223                ToSql::<Range<ST>, Pg>::to_sql(self, out)
224            }
225        }
226    };
227}
228range_to_sql_nullable!((Bound<T>, Bound<T>));
229range_to_sql_nullable!(std::ops::Range<T>);
230range_to_sql_nullable!(std::ops::RangeInclusive<T>);
231range_to_sql_nullable!(std::ops::RangeFrom<T>);
232range_to_sql_nullable!(std::ops::RangeTo<T>);
233range_to_sql_nullable!(std::ops::RangeToInclusive<T>);
234
235#[cfg(feature = "postgres_backend")]
236impl HasSqlType<Int4range> for Pg {
237    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
238        PgTypeMetadata::new(3904, 3905)
239    }
240}
241
242#[cfg(feature = "postgres_backend")]
243impl HasSqlType<Numrange> for Pg {
244    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
245        PgTypeMetadata::new(3906, 3907)
246    }
247}
248
249impl HasSqlType<Tsrange> for Pg {
250    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
251        PgTypeMetadata::new(3908, 3909)
252    }
253}
254
255#[cfg(feature = "postgres_backend")]
256impl HasSqlType<Tstzrange> for Pg {
257    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
258        PgTypeMetadata::new(3910, 3911)
259    }
260}
261
262#[cfg(feature = "postgres_backend")]
263impl HasSqlType<Daterange> for Pg {
264    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
265        PgTypeMetadata::new(3912, 3913)
266    }
267}
268
269#[cfg(feature = "postgres_backend")]
270impl HasSqlType<Int8range> for Pg {
271    fn metadata(_: &mut Self::MetadataLookup) -> PgTypeMetadata {
272        PgTypeMetadata::new(3926, 3927)
273    }
274}
275
276#[cfg(feature = "postgres_backend")]
277impl ToSql<RangeBoundEnum, Pg> for RangeBound {
278    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
279        let literal = match self {
280            Self::LowerBoundInclusiveUpperBoundInclusive => "[]",
281            Self::LowerBoundInclusiveUpperBoundExclusive => "[)",
282            Self::LowerBoundExclusiveUpperBoundInclusive => "(]",
283            Self::LowerBoundExclusiveUpperBoundExclusive => "()",
284        };
285        out.write_all(literal.as_bytes())
286            .map(|_| IsNull::No)
287            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
288    }
289}