diesel/expression/
array_comparison.rs

1//! This module contains the query dsl node definitions
2//! for array comparison operations like `IN` and `NOT IN`
3
4use super::expression_types::NotSelectable;
5use crate::backend::{sql_dialect, Backend, SqlDialect};
6use crate::expression::subselect::Subselect;
7use crate::expression::{
8    AppearsOnTable, AsExpression, Expression, SelectableExpression, TypedExpressionType,
9    ValidGrouping,
10};
11use crate::query_builder::combination_clause::CombinationClause;
12use crate::query_builder::{
13    AstPass, BoxedSelectStatement, QueryFragment, QueryId, SelectQuery, SelectStatement,
14};
15use crate::result::QueryResult;
16use crate::serialize::ToSql;
17use crate::sql_types::{self, HasSqlType, SingleValue, SqlType};
18use std::marker::PhantomData;
19
20/// Query dsl node that represents a `left IN (values)`
21/// expression
22///
23/// Third party backend can customize the [`QueryFragment`]
24/// implementation of this query dsl node via
25/// [`SqlDialect::ArrayComparison`]. A customized implementation
26/// is expected to provide the same semantics as an ANSI SQL
27/// `IN` expression.
28///
29/// The postgres backend provided a specialized implementation
30/// by using `left = ANY(values)` as optimized variant instead
31/// if this is possible. For cases where this is not possible
32/// like for example if values is a vector of arrays we
33/// generate an ordinary `IN` expression instead.
34#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
35#[non_exhaustive]
36pub struct In<T, U> {
37    /// The expression on the left side of the `IN` keyword
38    pub left: T,
39    /// The values clause of the `IN` expression
40    pub values: U,
41}
42
43/// Query dsl node that represents a `left NOT IN (values)`
44/// expression
45///
46/// Third party backend can customize the [`QueryFragment`]
47/// implementation of this query dsl node via
48/// [`SqlDialect::ArrayComparison`]. A customized implementation
49/// is expected to provide the same semantics as an ANSI SQL
50/// `NOT IN` expression.0
51///
52/// The postgres backend provided a specialized implementation
53/// by using `left != ALL(values)` as optimized variant instead
54/// if this is possible. For cases where this is not possible
55/// like for example if values is a vector of arrays we
56/// generate a ordinary `NOT IN` expression instead
57#[derive(Debug, Copy, Clone, QueryId, ValidGrouping)]
58#[non_exhaustive]
59pub struct NotIn<T, U> {
60    /// The expression on the left side of the `NOT IN` keyword
61    pub left: T,
62    /// The values clause of the `NOT IN` expression
63    pub values: U,
64}
65
66impl<T, U> In<T, U> {
67    pub(crate) fn new(left: T, values: U) -> Self {
68        In { left, values }
69    }
70
71    pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
72    where
73        DB: Backend,
74        T: QueryFragment<DB>,
75        U: QueryFragment<DB> + InExpression,
76    {
77        if self.values.is_empty() {
78            out.push_sql("1=0");
79        } else {
80            self.left.walk_ast(out.reborrow())?;
81            out.push_sql(" IN (");
82            self.values.walk_ast(out.reborrow())?;
83            out.push_sql(")");
84        }
85        Ok(())
86    }
87}
88
89impl<T, U> NotIn<T, U> {
90    pub(crate) fn new(left: T, values: U) -> Self {
91        NotIn { left, values }
92    }
93
94    pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
95    where
96        DB: Backend,
97        T: QueryFragment<DB>,
98        U: QueryFragment<DB> + InExpression,
99    {
100        if self.values.is_empty() {
101            out.push_sql("1=1");
102        } else {
103            self.left.walk_ast(out.reborrow())?;
104            out.push_sql(" NOT IN (");
105            self.values.walk_ast(out.reborrow())?;
106            out.push_sql(")");
107        }
108        Ok(())
109    }
110}
111
112impl<T, U> Expression for In<T, U>
113where
114    T: Expression,
115    U: InExpression<SqlType = T::SqlType>,
116    T::SqlType: SqlType,
117    sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>:
118        sql_types::MaybeNullableType<sql_types::Bool>,
119{
120    type SqlType = sql_types::is_nullable::MaybeNullable<
121        sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>,
122        sql_types::Bool,
123    >;
124}
125
126impl<T, U> Expression for NotIn<T, U>
127where
128    T: Expression,
129    U: InExpression<SqlType = T::SqlType>,
130    T::SqlType: SqlType,
131    sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>:
132        sql_types::MaybeNullableType<sql_types::Bool>,
133{
134    type SqlType = sql_types::is_nullable::MaybeNullable<
135        sql_types::is_nullable::IsSqlTypeNullable<T::SqlType>,
136        sql_types::Bool,
137    >;
138}
139
140impl<T, U, DB> QueryFragment<DB> for In<T, U>
141where
142    DB: Backend,
143    Self: QueryFragment<DB, DB::ArrayComparison>,
144{
145    fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
146        <Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
147    }
148}
149
150impl<T, U, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison> for In<T, U>
151where
152    DB: Backend
153        + SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
154    T: QueryFragment<DB>,
155    U: QueryFragment<DB> + InExpression,
156{
157    fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
158        self.walk_ansi_ast(out)
159    }
160}
161
162impl<T, U, DB> QueryFragment<DB> for NotIn<T, U>
163where
164    DB: Backend,
165    Self: QueryFragment<DB, DB::ArrayComparison>,
166{
167    fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
168        <Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
169    }
170}
171
172impl<T, U, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison>
173    for NotIn<T, U>
174where
175    DB: Backend
176        + SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
177    T: QueryFragment<DB>,
178    U: QueryFragment<DB> + InExpression,
179{
180    fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
181        self.walk_ansi_ast(out)
182    }
183}
184
185impl_selectable_expression!(In<T, U>);
186impl_selectable_expression!(NotIn<T, U>);
187
188/// This trait describes how a type is transformed to the
189/// `IN (values)` value expression
190///
191/// Diesel provided several implementations here:
192///
193///  - An implementation for any [`Iterator`] over values
194///    that implement [`AsExpression<ST>`] for the corresponding
195///    sql type ST. The corresponding values clause will contain
196///    bind statements for each individual value.
197///  - An implementation for select statements, that returns
198///    a single field. The corresponding values clause will contain
199///    the sub query.
200///
201///  This trait is exposed for custom third party backends so
202///  that they can restrict the [`QueryFragment`] implementations
203///  for [`In`] and [`NotIn`].
204pub trait AsInExpression<T: SqlType> {
205    /// Type of the expression returned by [AsInExpression::as_in_expression]
206    type InExpression: InExpression<SqlType = T>;
207
208    /// Construct the diesel query dsl representation of
209    /// the `IN (values)` clause for the given type
210    #[allow(clippy::wrong_self_convention)]
211    // That's a public api, we cannot just change it to
212    // appease clippy
213    fn as_in_expression(self) -> Self::InExpression;
214}
215
216impl<I, T, ST> AsInExpression<ST> for I
217where
218    I: IntoIterator<Item = T>,
219    T: AsExpression<ST>,
220    ST: SqlType + TypedExpressionType,
221{
222    type InExpression = Many<ST, T>;
223
224    fn as_in_expression(self) -> Self::InExpression {
225        Many {
226            values: self.into_iter().collect(),
227            p: PhantomData,
228        }
229    }
230}
231
232/// A marker trait that identifies query fragments that can be used in `IN(...)` and `NOT IN(...)`
233/// clauses, (or `= ANY (...)` clauses on the Postgres backend)
234///
235/// These can be wrapped in [`In`] or [`NotIn`] query dsl nodes
236pub trait InExpression {
237    /// The SQL type of the inner values, which should be the same as the left of the `IN` or
238    /// `NOT IN` clause
239    type SqlType: SqlType;
240
241    /// Returns `true` if self represents an empty collection
242    /// Otherwise `false` is returned.
243    fn is_empty(&self) -> bool;
244
245    /// Returns `true` if the values clause represents
246    /// bind values and each bind value is a postgres array type
247    fn is_array(&self) -> bool;
248}
249
250impl<ST, F, S, D, W, O, LOf, G, H, LC> AsInExpression<ST>
251    for SelectStatement<F, S, D, W, O, LOf, G, H, LC>
252where
253    ST: SqlType,
254    Subselect<Self, ST>: Expression<SqlType = ST>,
255    Self: SelectQuery<SqlType = ST>,
256{
257    type InExpression = Subselect<Self, ST>;
258
259    fn as_in_expression(self) -> Self::InExpression {
260        Subselect::new(self)
261    }
262}
263
264impl<'a, ST, QS, DB, GB> AsInExpression<ST> for BoxedSelectStatement<'a, ST, QS, DB, GB>
265where
266    ST: SqlType,
267    Subselect<BoxedSelectStatement<'a, ST, QS, DB, GB>, ST>: Expression<SqlType = ST>,
268{
269    type InExpression = Subselect<Self, ST>;
270
271    fn as_in_expression(self) -> Self::InExpression {
272        Subselect::new(self)
273    }
274}
275
276impl<ST, Combinator, Rule, Source, Rhs> AsInExpression<ST>
277    for CombinationClause<Combinator, Rule, Source, Rhs>
278where
279    ST: SqlType,
280    Self: SelectQuery<SqlType = ST>,
281    Subselect<Self, ST>: Expression<SqlType = ST>,
282{
283    type InExpression = Subselect<Self, ST>;
284
285    fn as_in_expression(self) -> Self::InExpression {
286        Subselect::new(self)
287    }
288}
289
290/// Query dsl node for the `values` part of an `IN (values)` clause
291/// containing a variable number of bind values.
292///
293/// Third party backend can customize the [`QueryFragment`]
294/// implementation of this query dsl node via
295/// [`SqlDialect::ArrayComparison`]. The default
296/// implementation does generate one bind per value
297/// in the `values` field.
298///
299/// Diesel provides an optimized implementation for Postgresql
300/// like database systems that bind all values with one
301/// bind value of the type `Array<ST>` instead.
302#[derive(Debug, Clone)]
303pub struct Many<ST, I> {
304    /// The values contained in the `IN (values)` clause
305    pub values: Vec<I>,
306    p: PhantomData<ST>,
307}
308
309impl<ST, I, GB> ValidGrouping<GB> for Many<ST, I>
310where
311    ST: SingleValue,
312    I: AsExpression<ST>,
313    I::Expression: ValidGrouping<GB>,
314{
315    type IsAggregate = <I::Expression as ValidGrouping<GB>>::IsAggregate;
316}
317
318impl<ST, I> Expression for Many<ST, I>
319where
320    ST: TypedExpressionType,
321{
322    // Comma-ed fake expressions are not usable directly in SQL
323    // This is only implemented so that we can use the usual SelectableExpression & co traits
324    // as constraints for the same implementations on [`In`] and [`NotIn`]
325    type SqlType = NotSelectable;
326}
327
328impl<ST, I> InExpression for Many<ST, I>
329where
330    ST: SqlType,
331{
332    type SqlType = ST;
333
334    fn is_empty(&self) -> bool {
335        self.values.is_empty()
336    }
337
338    fn is_array(&self) -> bool {
339        ST::IS_ARRAY
340    }
341}
342
343impl<ST, I, QS> SelectableExpression<QS> for Many<ST, I>
344where
345    Many<ST, I>: AppearsOnTable<QS>,
346    ST: SingleValue,
347    I: AsExpression<ST>,
348    <I as AsExpression<ST>>::Expression: SelectableExpression<QS>,
349{
350}
351
352impl<ST, I, QS> AppearsOnTable<QS> for Many<ST, I>
353where
354    Many<ST, I>: Expression,
355    I: AsExpression<ST>,
356    ST: SingleValue,
357    <I as AsExpression<ST>>::Expression: SelectableExpression<QS>,
358{
359}
360
361impl<ST, I, DB> QueryFragment<DB> for Many<ST, I>
362where
363    Self: QueryFragment<DB, DB::ArrayComparison>,
364    DB: Backend,
365{
366    fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
367        <Self as QueryFragment<DB, DB::ArrayComparison>>::walk_ast(self, pass)
368    }
369}
370
371impl<ST, I, DB> QueryFragment<DB, sql_dialect::array_comparison::AnsiSqlArrayComparison>
372    for Many<ST, I>
373where
374    DB: Backend
375        + HasSqlType<ST>
376        + SqlDialect<ArrayComparison = sql_dialect::array_comparison::AnsiSqlArrayComparison>,
377    ST: SingleValue,
378    I: ToSql<ST, DB>,
379{
380    fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> QueryResult<()> {
381        self.walk_ansi_ast(out)
382    }
383}
384
385impl<ST, I> Many<ST, I> {
386    pub(crate) fn walk_ansi_ast<'b, DB>(&'b self, mut out: AstPass<'_, 'b, DB>) -> QueryResult<()>
387    where
388        DB: Backend + HasSqlType<ST>,
389        ST: SingleValue,
390        I: ToSql<ST, DB>,
391    {
392        out.unsafe_to_cache_prepared();
393        let mut first = true;
394        for value in &self.values {
395            if first {
396                first = false;
397            } else {
398                out.push_sql(", ");
399            }
400            out.push_bind_param(value)?;
401        }
402        Ok(())
403    }
404}
405
406impl<ST, I> QueryId for Many<ST, I> {
407    type QueryId = ();
408
409    const HAS_STATIC_QUERY_ID: bool = false;
410}