diesel/expression/
case_when.rs

1use crate::expression::grouped::Grouped;
2use crate::expression::{helper_types, Expression};
3use crate::sql_types::{BoolOrNullableBool, SqlType};
4use diesel_derives::{DieselNumericOps, QueryId, ValidGrouping};
5
6use super::{AsExpression, TypedExpressionType};
7
8/// Creates a SQL `CASE WHEN ... END` expression
9///
10/// # Example
11///
12/// ```
13/// # include!("../doctest_setup.rs");
14/// #
15/// # fn main() {
16/// #     use schema::users::dsl::*;
17/// #     let connection = &mut establish_connection();
18/// use diesel::dsl::case_when;
19///
20/// let users_with_name: Vec<(i32, Option<i32>)> = users
21///     .select((id, case_when(name.eq("Sean"), id)))
22///     .load(connection)
23///     .unwrap();
24///
25/// assert_eq!(&[(1, Some(1)), (2, None)], users_with_name.as_slice());
26/// # }
27/// ```
28///
29/// # `ELSE` clause
30/// ```
31/// # include!("../doctest_setup.rs");
32/// #
33/// # fn main() {
34/// #     use schema::users::dsl::*;
35/// #     let connection = &mut establish_connection();
36/// use diesel::dsl::case_when;
37///
38/// let users_with_name: Vec<(i32, i32)> = users
39///     .select((id, case_when(name.eq("Sean"), id).otherwise(0)))
40///     .load(connection)
41///     .unwrap();
42///
43/// assert_eq!(&[(1, 1), (2, 0)], users_with_name.as_slice());
44/// # }
45/// ```
46///
47/// Note that the SQL types of the `case_when` and `else` expressions should
48/// be equal. This includes whether they are wrapped in
49/// [`Nullable`](crate::sql_types::Nullable), so you may need to call
50/// [`nullable`](crate::expression_methods::NullableExpressionMethods::nullable)
51/// on one of them.
52///
53/// # More `WHEN` branches
54/// ```
55/// # include!("../doctest_setup.rs");
56/// #
57/// # fn main() {
58/// #     use schema::users::dsl::*;
59/// #     let connection = &mut establish_connection();
60/// use diesel::dsl::case_when;
61///
62/// let users_with_name: Vec<(i32, Option<i32>)> = users
63///     .select((id, case_when(name.eq("Sean"), id).when(name.eq("Tess"), 2)))
64///     .load(connection)
65///     .unwrap();
66///
67/// assert_eq!(&[(1, Some(1)), (2, Some(2))], users_with_name.as_slice());
68/// # }
69/// ```
70pub fn case_when<C, T, ST>(condition: C, if_true: T) -> helper_types::case_when<C, T, ST>
71where
72    C: Expression,
73    <C as Expression>::SqlType: BoolOrNullableBool,
74    T: AsExpression<ST>,
75    ST: SqlType + TypedExpressionType,
76{
77    CaseWhen {
78        whens: CaseWhenConditionsLeaf {
79            when: Grouped(condition),
80            then: Grouped(if_true.as_expression()),
81        },
82        else_expr: NoElseExpression,
83    }
84}
85
86/// A SQL `CASE WHEN ... END` expression
87#[derive(Debug, Clone, Copy, QueryId, DieselNumericOps, ValidGrouping)]
88pub struct CaseWhen<Whens, E> {
89    whens: Whens,
90    else_expr: E,
91}
92
93impl<Whens, E> CaseWhen<Whens, E> {
94    /// Add an additional `WHEN ... THEN ...` branch to the `CASE` expression
95    ///
96    /// See the [`case_when`] documentation for more details.
97    pub fn when<C, T>(self, condition: C, if_true: T) -> helper_types::When<Self, C, T>
98    where
99        Self: CaseWhenTypesExtractor<Whens = Whens, Else = E>,
100        C: Expression,
101        <C as Expression>::SqlType: BoolOrNullableBool,
102        T: AsExpression<<Self as CaseWhenTypesExtractor>::OutputExpressionSpecifiedSqlType>,
103    {
104        CaseWhen {
105            whens: CaseWhenConditionsIntermediateNode {
106                first_whens: self.whens,
107                last_when: CaseWhenConditionsLeaf {
108                    when: Grouped(condition),
109                    then: Grouped(if_true.as_expression()),
110                },
111            },
112            else_expr: self.else_expr,
113        }
114    }
115}
116
117impl<Whens> CaseWhen<Whens, NoElseExpression> {
118    /// Sets the `ELSE` branch of the `CASE` expression
119    ///
120    /// It is named this way because `else` is a reserved keyword in Rust
121    ///
122    /// See the [`case_when`] documentation for more details.
123    pub fn otherwise<E>(self, if_no_other_branch_matched: E) -> helper_types::Otherwise<Self, E>
124    where
125        Self: CaseWhenTypesExtractor<Whens = Whens, Else = NoElseExpression>,
126        E: AsExpression<<Self as CaseWhenTypesExtractor>::OutputExpressionSpecifiedSqlType>,
127    {
128        CaseWhen {
129            whens: self.whens,
130            else_expr: ElseExpression {
131                expr: Grouped(if_no_other_branch_matched.as_expression()),
132            },
133        }
134    }
135}
136
137pub(crate) use non_public_types::*;
138mod non_public_types {
139    use super::CaseWhen;
140
141    use diesel_derives::{QueryId, ValidGrouping};
142
143    use crate::expression::{
144        AppearsOnTable, Expression, SelectableExpression, TypedExpressionType,
145    };
146    use crate::query_builder::{AstPass, QueryFragment};
147    use crate::query_source::aliasing;
148    use crate::sql_types::{BoolOrNullableBool, IntoNullable, SqlType};
149
150    #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
151    pub struct CaseWhenConditionsLeaf<W, T> {
152        pub(super) when: W,
153        pub(super) then: T,
154    }
155
156    #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
157    pub struct CaseWhenConditionsIntermediateNode<W, T, Whens> {
158        pub(super) first_whens: Whens,
159        pub(super) last_when: CaseWhenConditionsLeaf<W, T>,
160    }
161
162    pub trait CaseWhenConditions {
163        type OutputExpressionSpecifiedSqlType: SqlType + TypedExpressionType;
164    }
165    impl<W, T: Expression> CaseWhenConditions for CaseWhenConditionsLeaf<W, T>
166    where
167        <T as Expression>::SqlType: SqlType + TypedExpressionType,
168    {
169        type OutputExpressionSpecifiedSqlType = T::SqlType;
170    }
171    // This intentionally doesn't re-check inner `Whens` here, because this trait is
172    // only used to allow expression SQL type inference for `.when` calls so we
173    // want to make it as lightweight as possible for fast compilation. Actual
174    // guarantees are provided by the other implementations below
175    impl<W, T: Expression, Whens> CaseWhenConditions for CaseWhenConditionsIntermediateNode<W, T, Whens>
176    where
177        <T as Expression>::SqlType: SqlType + TypedExpressionType,
178    {
179        type OutputExpressionSpecifiedSqlType = T::SqlType;
180    }
181
182    #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
183    pub struct NoElseExpression;
184    #[derive(Debug, Clone, Copy, QueryId, ValidGrouping)]
185    pub struct ElseExpression<E> {
186        pub(super) expr: E,
187    }
188
189    /// Largely internal trait used to define the [`When`] and [`Otherwise`]
190    /// type aliases
191    ///
192    /// It should typically not be needed in user code unless writing extremely
193    /// generic functions
194    pub trait CaseWhenTypesExtractor {
195        /// The
196        /// This may not be the actual output expression type: if there is no
197        /// `else` it will be made `Nullable`
198        type OutputExpressionSpecifiedSqlType: SqlType + TypedExpressionType;
199        type Whens;
200        type Else;
201    }
202    impl<Whens, E> CaseWhenTypesExtractor for CaseWhen<Whens, E>
203    where
204        Whens: CaseWhenConditions,
205    {
206        type OutputExpressionSpecifiedSqlType = Whens::OutputExpressionSpecifiedSqlType;
207        type Whens = Whens;
208        type Else = E;
209    }
210
211    impl<W, T, QS> SelectableExpression<QS> for CaseWhen<CaseWhenConditionsLeaf<W, T>, NoElseExpression>
212    where
213        CaseWhen<CaseWhenConditionsLeaf<W, T>, NoElseExpression>: AppearsOnTable<QS>,
214        W: SelectableExpression<QS>,
215        T: SelectableExpression<QS>,
216    {
217    }
218
219    impl<W, T, E, QS> SelectableExpression<QS>
220        for CaseWhen<CaseWhenConditionsLeaf<W, T>, ElseExpression<E>>
221    where
222        CaseWhen<CaseWhenConditionsLeaf<W, T>, ElseExpression<E>>: AppearsOnTable<QS>,
223        W: SelectableExpression<QS>,
224        T: SelectableExpression<QS>,
225        E: SelectableExpression<QS>,
226    {
227    }
228
229    impl<W, T, Whens, E, QS> SelectableExpression<QS>
230        for CaseWhen<CaseWhenConditionsIntermediateNode<W, T, Whens>, E>
231    where
232        Self: AppearsOnTable<QS>,
233        W: SelectableExpression<QS>,
234        T: SelectableExpression<QS>,
235        CaseWhen<Whens, E>: SelectableExpression<QS>,
236    {
237    }
238
239    impl<W, T, QS> AppearsOnTable<QS> for CaseWhen<CaseWhenConditionsLeaf<W, T>, NoElseExpression>
240    where
241        CaseWhen<CaseWhenConditionsLeaf<W, T>, NoElseExpression>: Expression,
242        W: AppearsOnTable<QS>,
243        T: AppearsOnTable<QS>,
244    {
245    }
246
247    impl<W, T, E, QS> AppearsOnTable<QS> for CaseWhen<CaseWhenConditionsLeaf<W, T>, ElseExpression<E>>
248    where
249        CaseWhen<CaseWhenConditionsLeaf<W, T>, ElseExpression<E>>: Expression,
250        W: AppearsOnTable<QS>,
251        T: AppearsOnTable<QS>,
252        E: AppearsOnTable<QS>,
253    {
254    }
255
256    impl<W, T, Whens, E, QS> AppearsOnTable<QS>
257        for CaseWhen<CaseWhenConditionsIntermediateNode<W, T, Whens>, E>
258    where
259        Self: Expression,
260        W: AppearsOnTable<QS>,
261        T: AppearsOnTable<QS>,
262        CaseWhen<Whens, E>: AppearsOnTable<QS>,
263    {
264    }
265
266    impl<W, T> Expression for CaseWhen<CaseWhenConditionsLeaf<W, T>, NoElseExpression>
267    where
268        W: Expression,
269        <W as Expression>::SqlType: BoolOrNullableBool,
270        T: Expression,
271        <T as Expression>::SqlType: IntoNullable,
272        <<T as Expression>::SqlType as IntoNullable>::Nullable: SqlType + TypedExpressionType,
273    {
274        type SqlType = <<T as Expression>::SqlType as IntoNullable>::Nullable;
275    }
276    impl<W, T, E> Expression for CaseWhen<CaseWhenConditionsLeaf<W, T>, ElseExpression<E>>
277    where
278        W: Expression,
279        <W as Expression>::SqlType: BoolOrNullableBool,
280        T: Expression,
281    {
282        type SqlType = T::SqlType;
283    }
284    impl<W, T, Whens, E> Expression for CaseWhen<CaseWhenConditionsIntermediateNode<W, T, Whens>, E>
285    where
286        CaseWhen<CaseWhenConditionsLeaf<W, T>, E>: Expression,
287        CaseWhen<Whens, E>: Expression<
288            SqlType = <CaseWhen<CaseWhenConditionsLeaf<W, T>, E> as Expression>::SqlType,
289        >,
290    {
291        type SqlType = <CaseWhen<CaseWhenConditionsLeaf<W, T>, E> as Expression>::SqlType;
292    }
293
294    impl<Whens, E, DB> QueryFragment<DB> for CaseWhen<Whens, E>
295    where
296        DB: crate::backend::Backend,
297        Whens: QueryFragment<DB>,
298        E: QueryFragment<DB>,
299    {
300        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> {
301            out.push_sql("CASE");
302            self.whens.walk_ast(out.reborrow())?;
303            self.else_expr.walk_ast(out.reborrow())?;
304            out.push_sql(" END");
305            Ok(())
306        }
307    }
308
309    impl<W, T, DB> QueryFragment<DB> for CaseWhenConditionsLeaf<W, T>
310    where
311        DB: crate::backend::Backend,
312        W: QueryFragment<DB>,
313        T: QueryFragment<DB>,
314    {
315        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> {
316            out.push_sql(" WHEN ");
317            self.when.walk_ast(out.reborrow())?;
318            out.push_sql(" THEN ");
319            self.then.walk_ast(out.reborrow())?;
320            Ok(())
321        }
322    }
323
324    impl<W, T, Whens, DB> QueryFragment<DB> for CaseWhenConditionsIntermediateNode<W, T, Whens>
325    where
326        DB: crate::backend::Backend,
327        Whens: QueryFragment<DB>,
328        W: QueryFragment<DB>,
329        T: QueryFragment<DB>,
330    {
331        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::QueryResult<()> {
332            self.first_whens.walk_ast(out.reborrow())?;
333            self.last_when.walk_ast(out.reborrow())?;
334            Ok(())
335        }
336    }
337
338    impl<DB> QueryFragment<DB> for NoElseExpression
339    where
340        DB: crate::backend::Backend,
341    {
342        fn walk_ast<'b>(&'b self, out: AstPass<'_, 'b, DB>) -> crate::result::QueryResult<()> {
343            let _ = out;
344            Ok(())
345        }
346    }
347    impl<E, DB> QueryFragment<DB> for ElseExpression<E>
348    where
349        E: QueryFragment<DB>,
350        DB: crate::backend::Backend,
351    {
352        fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, DB>) -> crate::result::QueryResult<()> {
353            out.push_sql(" ELSE ");
354            self.expr.walk_ast(out.reborrow())?;
355            Ok(())
356        }
357    }
358
359    impl<S, Conditions, E> aliasing::FieldAliasMapper<S> for CaseWhen<Conditions, E>
360    where
361        S: aliasing::AliasSource,
362        Conditions: aliasing::FieldAliasMapper<S>,
363        E: aliasing::FieldAliasMapper<S>,
364    {
365        type Out = CaseWhen<
366            <Conditions as aliasing::FieldAliasMapper<S>>::Out,
367            <E as aliasing::FieldAliasMapper<S>>::Out,
368        >;
369        fn map(self, alias: &aliasing::Alias<S>) -> Self::Out {
370            CaseWhen {
371                whens: self.whens.map(alias),
372                else_expr: self.else_expr.map(alias),
373            }
374        }
375    }
376
377    impl<S, W, T> aliasing::FieldAliasMapper<S> for CaseWhenConditionsLeaf<W, T>
378    where
379        S: aliasing::AliasSource,
380        W: aliasing::FieldAliasMapper<S>,
381        T: aliasing::FieldAliasMapper<S>,
382    {
383        type Out = CaseWhenConditionsLeaf<
384            <W as aliasing::FieldAliasMapper<S>>::Out,
385            <T as aliasing::FieldAliasMapper<S>>::Out,
386        >;
387        fn map(self, alias: &aliasing::Alias<S>) -> Self::Out {
388            CaseWhenConditionsLeaf {
389                when: self.when.map(alias),
390                then: self.then.map(alias),
391            }
392        }
393    }
394
395    impl<S, W, T, Whens> aliasing::FieldAliasMapper<S>
396        for CaseWhenConditionsIntermediateNode<W, T, Whens>
397    where
398        S: aliasing::AliasSource,
399        W: aliasing::FieldAliasMapper<S>,
400        T: aliasing::FieldAliasMapper<S>,
401        Whens: aliasing::FieldAliasMapper<S>,
402    {
403        type Out = CaseWhenConditionsIntermediateNode<
404            <W as aliasing::FieldAliasMapper<S>>::Out,
405            <T as aliasing::FieldAliasMapper<S>>::Out,
406            <Whens as aliasing::FieldAliasMapper<S>>::Out,
407        >;
408        fn map(self, alias: &aliasing::Alias<S>) -> Self::Out {
409            CaseWhenConditionsIntermediateNode {
410                first_whens: self.first_whens.map(alias),
411                last_when: self.last_when.map(alias),
412            }
413        }
414    }
415}