diesel/expression/functions/
aggregate_expressions.rs

1use crate::backend::Backend;
2use crate::expression::{AsExpression, ValidGrouping};
3use crate::query_builder::{AstPass, NotSpecialized, QueryFragment, QueryId};
4use crate::sql_types::Bool;
5use crate::{AppearsOnTable, Expression, QueryResult, SelectableExpression};
6
7macro_rules! empty_clause {
8    ($name: ident) => {
9        #[derive(Debug, Clone, Copy, QueryId)]
10        pub struct $name;
11
12        impl<DB> crate::query_builder::QueryFragment<DB> for $name
13        where
14            DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization,
15        {
16            fn walk_ast<'b>(
17                &'b self,
18                _pass: crate::query_builder::AstPass<'_, 'b, DB>,
19            ) -> crate::QueryResult<()> {
20                Ok(())
21            }
22        }
23    };
24}
25
26mod aggregate_filter;
27mod aggregate_order;
28pub(crate) mod frame_clause;
29mod over_clause;
30mod partition_by;
31mod prefix;
32
33use self::aggregate_filter::{FilterDsl, NoFilter};
34pub use self::aggregate_order::Order;
35use self::aggregate_order::{NoOrder, OrderAggregateDsl, OrderWindowDsl};
36use self::frame_clause::{FrameDsl, NoFrame};
37pub use self::over_clause::OverClause;
38use self::over_clause::{NoWindow, OverDsl};
39use self::partition_by::PartitionByDsl;
40use self::prefix::{AllDsl, DistinctDsl, NoPrefix};
41
42#[derive(QueryId, Debug)]
43pub struct AggregateExpression<
44    Fn,
45    Prefix = NoPrefix,
46    Order = NoOrder,
47    Filter = NoFilter,
48    Window = NoWindow,
49> {
50    prefix: Prefix,
51    function: Fn,
52    order: Order,
53    filter: Filter,
54    window: Window,
55}
56
57impl<Fn, Prefix, Order, Filter, Window, DB> QueryFragment<DB>
58    for AggregateExpression<Fn, Prefix, Order, Filter, Window>
59where
60    DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization,
61    Fn: FunctionFragment<DB>,
62    Prefix: QueryFragment<DB>,
63    Order: QueryFragment<DB>,
64    Filter: QueryFragment<DB>,
65    Window: QueryFragment<DB> + WindowFunctionFragment<Fn, DB>,
66{
67    fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
68        pass.push_sql(Fn::FUNCTION_NAME);
69        pass.push_sql("(");
70        self.prefix.walk_ast(pass.reborrow())?;
71        self.function.walk_arguments(pass.reborrow())?;
72        self.order.walk_ast(pass.reborrow())?;
73        pass.push_sql(")");
74        self.filter.walk_ast(pass.reborrow())?;
75        self.window.walk_ast(pass.reborrow())?;
76        Ok(())
77    }
78}
79
80impl<Fn, Prefix, Order, Filter, GB> ValidGrouping<GB>
81    for AggregateExpression<Fn, Prefix, Order, Filter>
82where
83    Fn: ValidGrouping<GB>,
84{
85    type IsAggregate = <Fn as ValidGrouping<GB>>::IsAggregate;
86}
87
88impl<Fn, Prefix, Order, Filter, GB, Partition, WindowOrder, Frame> ValidGrouping<GB>
89    for AggregateExpression<Fn, Prefix, Order, Filter, OverClause<Partition, WindowOrder, Frame>>
90where
91    Fn: IsWindowFunction,
92    Fn::ArgTypes: ValidGrouping<GB>,
93{
94    type IsAggregate = <Fn::ArgTypes as ValidGrouping<GB>>::IsAggregate;
95}
96
97impl<Fn, Prefix, Order, Filter, Window> Expression
98    for AggregateExpression<Fn, Prefix, Order, Filter, Window>
99where
100    Fn: Expression,
101{
102    type SqlType = <Fn as Expression>::SqlType;
103}
104
105impl<Fn, Prefix, Order, Filter, Window, QS> AppearsOnTable<QS>
106    for AggregateExpression<Fn, Prefix, Order, Filter, Window>
107where
108    Self: Expression,
109    Fn: AppearsOnTable<QS>,
110{
111}
112
113impl<Fn, Prefix, Order, Filter, Window, QS> SelectableExpression<QS>
114    for AggregateExpression<Fn, Prefix, Order, Filter, Window>
115where
116    Self: Expression,
117    Fn: SelectableExpression<QS>,
118{
119}
120
121/// A helper marker trait that this function is a window function
122/// This is only used to provide the gate the `WindowExpressionMethods`
123/// trait onto, not to check if the construct is valid for a given backend
124/// This check is postponed to building the query via `QueryFragment`
125/// (We have access to the DB type there)
126#[diagnostic::on_unimplemented(
127    message = "{Self} is not a window function",
128    label = "remove this function call to use `{Self}` as normal SQL function",
129    note = "try removing any method call to `WindowExpressionMethods` and use it as normal SQL function"
130)]
131pub trait IsWindowFunction {
132    /// A tuple of all arg types
133    type ArgTypes;
134}
135
136/// A helper marker trait that this function is a valid window function
137/// for the given backend
138/// this trait is used to transport information that
139/// a certain function can be used as window function for a specific
140/// backend
141/// We allow to specialize this function for different SQL dialects
142pub trait WindowFunctionFragment<Fn, DB: Backend, SP = NotSpecialized> {}
143
144/// A helper marker trait that this function as a aggregate function
145/// This is only used to provide the gate the `AggregateExpressionMethods`
146/// trait onto, not to check if the construct is valid for a given backend
147/// This check is postponed to building the query via `QueryFragment`
148/// (We have access to the DB type there)
149pub trait IsAggregateFunction {}
150
151/// A specialized QueryFragment helper trait that allows us to walk the function name
152/// and the function arguments in separate steps
153pub trait FunctionFragment<DB: Backend> {
154    /// The name of the sql function
155    const FUNCTION_NAME: &'static str;
156
157    /// Walk the function argument part (everything between ())
158    fn walk_arguments<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>;
159}
160
161/// Expression methods to build aggregate function expressions
162pub trait AggregateExpressionMethods: Sized {
163    /// `DISTINCT` modifier for aggregate functions
164    ///
165    /// This modifies the aggregate function call to only
166    /// include distinct items
167    ///
168    /// # Example
169    ///
170    /// ```rust
171    /// # include!("../../doctest_setup.rs");
172    /// #
173    /// # fn main() {
174    /// #     run_test().unwrap();
175    /// # }
176    /// #
177    /// # fn run_test() -> QueryResult<()> {
178    /// #     use schema::posts::dsl::*;
179    /// #     use diesel::dsl;
180    /// #     let connection = &mut establish_connection();
181    /// let without_distinct = posts
182    ///     .select(dsl::count(user_id))
183    ///     .get_result::<i64>(connection)?;
184    /// let with_distinct = posts
185    ///     .select(dsl::count(user_id).aggregate_distinct())
186    ///     .get_result::<i64>(connection)?;
187    ///
188    /// assert_eq!(3, without_distinct);
189    /// assert_eq!(2, with_distinct);
190    /// #     Ok(())
191    /// # }
192    /// ```
193    fn aggregate_distinct(self) -> self::dsl::AggregateDistinct<Self>
194    where
195        Self: DistinctDsl,
196    {
197        <Self as DistinctDsl>::distinct(self)
198    }
199
200    /// `ALL` modifier for aggregate functions
201    ///
202    /// This modifies the aggregate function call to include
203    /// all items. This is the default behaviour.
204    ///
205    /// # Example
206    ///
207    /// ```rust
208    /// # include!("../../doctest_setup.rs");
209    /// #
210    /// # fn main() {
211    /// #     run_test().unwrap();
212    /// # }
213    /// #
214    /// # fn run_test() -> QueryResult<()> {
215    /// #     use schema::posts::dsl::*;
216    /// #     use diesel::dsl;
217    /// #     let connection = &mut establish_connection();
218    /// let without_all = posts
219    ///     .select(dsl::count(user_id))
220    ///     .get_result::<i64>(connection)?;
221    /// let with_all = posts
222    ///     .select(dsl::count(user_id).aggregate_all())
223    ///     .get_result::<i64>(connection)?;
224    ///
225    /// assert_eq!(3, without_all);
226    /// assert_eq!(3, with_all);
227    /// #     Ok(())
228    /// # }
229    /// ```
230    fn aggregate_all(self) -> self::dsl::AggregateAll<Self>
231    where
232        Self: AllDsl,
233    {
234        <Self as AllDsl>::all(self)
235    }
236
237    /// Add an aggregate function filter
238    ///
239    /// This function modifies an aggregate function
240    /// call to use only items matching the provided
241    /// filter
242    ///
243    /// # Example
244    ///
245    /// ```rust
246    /// # include!("../../doctest_setup.rs");
247    /// #
248    /// # fn main() {
249    /// #     #[cfg(not(feature = "mysql"))]
250    /// #     run_test().unwrap();
251    /// # }
252    /// #
253    /// # #[cfg(not(feature = "mysql"))]
254    /// # fn run_test() -> QueryResult<()> {
255    /// #     use schema::posts::dsl::*;
256    /// #     use diesel::dsl;
257    /// #     let connection = &mut establish_connection();
258    /// let without_filter = posts
259    ///     .select(dsl::count(user_id))
260    ///     .get_result::<i64>(connection)?;
261    /// let with_filter = posts
262    ///     .select(dsl::count(user_id).aggregate_filter(title.like("%first post%")))
263    ///     .get_result::<i64>(connection)?;
264    ///
265    /// assert_eq!(3, without_filter);
266    /// assert_eq!(2, with_filter);
267    /// #     Ok(())
268    /// # }
269    /// ```
270    fn aggregate_filter<P>(self, f: P) -> self::dsl::AggregateFilter<Self, P>
271    where
272        P: AsExpression<Bool>,
273        Self: FilterDsl<P::Expression>,
274    {
275        <Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
276    }
277
278    /// Add an aggregate function order
279    ///
280    /// This function orders the items passed into an
281    /// aggregate function
282    ///
283    /// For sqlite this is only supported starting with SQLite 3.44
284    ///
285    /// # Example
286    ///
287    /// ```rust
288    /// # include!("../../doctest_setup.rs");
289    /// #
290    /// # fn main() {
291    /// #     #[cfg(not(feature = "mysql"))]
292    /// #     run_test().unwrap();
293    /// # }
294    /// #
295    /// # #[cfg(not(feature = "mysql"))]
296    /// # fn run_test() -> QueryResult<()> {
297    /// #     use schema::posts::dsl::*;
298    /// #     use diesel::dsl;
299    /// #     let connection = &mut establish_connection();
300    /// #     #[cfg(feature = "sqlite")]
301    /// #     assert_version!(connection, 3, 44, 0);
302    /// // This example is not meaningful yet,
303    /// // modify it as soon as we support more
304    /// // meaningful functions here
305    /// let res = posts
306    ///     .select(dsl::count(user_id).aggregate_order(title))
307    ///     .get_result::<i64>(connection)?;
308    /// assert_eq!(3, res);
309    /// #     Ok(())
310    /// # }
311    /// ```
312    fn aggregate_order<O>(self, o: O) -> self::dsl::AggregateOrder<Self, O>
313    where
314        Self: OrderAggregateDsl<O>,
315    {
316        <Self as OrderAggregateDsl<O>>::order(self, o)
317    }
318}
319
320impl<T> AggregateExpressionMethods for T {}
321
322/// Methods to construct a window function call
323pub trait WindowExpressionMethods: Sized {
324    /// Turn a function call into a window function call
325    ///
326    /// This function turns a ordinary SQL function call
327    /// into a window function call by adding an empty `OVER ()`
328    /// clause
329    ///
330    /// # Example
331    ///
332    /// ```rust
333    /// # include!("../../doctest_setup.rs");
334    /// #
335    /// # fn main() {
336    /// #     run_test().unwrap();
337    /// # }
338    /// #
339    /// # fn run_test() -> QueryResult<()> {
340    /// #     use schema::posts::dsl::*;
341    /// #     use diesel::dsl;
342    /// #     let connection = &mut establish_connection();
343    /// let res = posts
344    ///     .select(dsl::count(user_id).over())
345    ///     .load::<i64>(connection)?;
346    /// assert_eq!(vec![3, 3, 3], res);
347    /// #     Ok(())
348    /// # }
349    /// ```
350    fn over(self) -> self::dsl::Over<Self>
351    where
352        Self: OverDsl,
353    {
354        <Self as OverDsl>::over(self)
355    }
356
357    /// Add a filter to the current window function
358    ///
359    ///
360    /// # Example
361    ///
362    /// ```rust
363    /// # include!("../../doctest_setup.rs");
364    /// #
365    /// # fn main() {
366    /// #     #[cfg(not(feature = "mysql"))]
367    /// #     run_test().unwrap();
368    /// # }
369    /// #
370    /// # #[cfg(not(feature = "mysql"))]
371    /// # fn run_test() -> QueryResult<()> {
372    /// #     use schema::posts::dsl::*;
373    /// #     use diesel::dsl;
374    /// #     let connection = &mut establish_connection();
375    /// let res = posts
376    ///     .select(dsl::count(user_id).window_filter(user_id.eq(1)))
377    ///     .load::<i64>(connection)?;
378    /// assert_eq!(vec![2], res);
379    /// #     Ok(())
380    /// # }
381    /// ```
382    fn window_filter<P>(self, f: P) -> self::dsl::WindowFilter<Self, P>
383    where
384        P: AsExpression<Bool>,
385        Self: FilterDsl<P::Expression>,
386    {
387        <Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
388    }
389
390    /// Add a partition clause to the current window function
391    ///
392    /// This function adds a `PARTITION BY` clause to your window function call
393    ///
394    /// # Example
395    ///
396    /// ```rust
397    /// # include!("../../doctest_setup.rs");
398    /// #
399    /// # fn main() {
400    /// #     run_test().unwrap();
401    /// # }
402    /// #
403    /// # fn run_test() -> QueryResult<()> {
404    /// #     use schema::posts::dsl::*;
405    /// #     use diesel::dsl;
406    /// #     let connection = &mut establish_connection();
407    /// let res = posts
408    ///     .select(dsl::count(user_id).partition_by(user_id))
409    ///     .load::<i64>(connection)?;
410    /// assert_eq!(vec![2, 2, 1], res);
411    /// #     Ok(())
412    /// # }
413    /// ```
414    fn partition_by<E>(self, expr: E) -> self::dsl::PartitionBy<Self, E>
415    where
416        Self: PartitionByDsl<E>,
417    {
418        <Self as PartitionByDsl<E>>::partition_by(self, expr)
419    }
420
421    /// Add a order clause to the current window function
422    ///
423    /// Add a `ORDER BY` clause to your window function call
424    ///
425    /// # Example
426    ///
427    /// ```rust
428    /// # include!("../../doctest_setup.rs");
429    /// #
430    /// # fn main() {
431    /// #     run_test().unwrap();
432    /// # }
433    /// #
434    /// # fn run_test() -> QueryResult<()> {
435    /// #     use schema::posts::dsl::*;
436    /// #     use diesel::dsl;
437    /// #     let connection = &mut establish_connection();
438    /// let res = posts
439    ///     .select(dsl::first_value(user_id).window_order(title))
440    ///     .load::<i32>(connection)?;
441    /// assert_eq!(vec![1, 1, 1], res);
442    /// #     Ok(())
443    /// # }
444    /// ```
445    fn window_order<E>(self, expr: E) -> self::dsl::WindowOrder<Self, E>
446    where
447        Self: OrderWindowDsl<E>,
448    {
449        <Self as OrderWindowDsl<E>>::order(self, expr)
450    }
451
452    /// Add a frame clause to the current window function
453    ///
454    /// This function adds a frame clause to your window function call.
455    /// Accepts the following items:
456    ///
457    /// * [`dsl::frame::Groups`](crate::dsl::frame::Groups)
458    /// * [`dsl::frame::Rows`](crate::dsl::frame::Rows)
459    /// * [`dsl::frame::Range`](crate::dsl::frame::Range)
460    ///
461    /// # Example
462    ///
463    /// ```rust
464    /// # include!("../../doctest_setup.rs");
465    /// #
466    /// # fn main() {
467    /// #     run_test().unwrap();
468    /// # }
469    /// #
470    /// # fn run_test() -> QueryResult<()> {
471    /// #     use schema::posts::dsl::*;
472    /// #     use diesel::dsl;
473    /// #     let connection = &mut establish_connection();
474    /// let res = posts
475    ///     .select(
476    ///         dsl::count(user_id).frame_by(dsl::frame::Rows.frame_start_with(dsl::frame::CurrentRow)),
477    ///     )
478    ///     .load::<i64>(connection)?;
479    /// assert_eq!(vec![1, 1, 1], res);
480    /// #     Ok(())
481    /// # }
482    /// ```
483    fn frame_by<E>(self, expr: E) -> self::dsl::FrameBy<Self, E>
484    where
485        Self: FrameDsl<E>,
486    {
487        <Self as FrameDsl<E>>::frame(self, expr)
488    }
489}
490
491impl<T> WindowExpressionMethods for T {}
492
493pub(super) mod dsl {
494    #[cfg(doc)]
495    use super::frame_clause::{FrameBoundDsl, FrameClauseDsl};
496    use super::*;
497
498    /// Return type of [`WindowExpressionMethods::over`]
499    pub type Over<Fn> = <Fn as OverDsl>::Output;
500
501    /// Return type of [`WindowExpressionMethods::window_filter`]
502    pub type WindowFilter<Fn, P> = <Fn as FilterDsl<crate::dsl::AsExprOf<P, Bool>>>::Output;
503
504    /// Return type of [`WindowExpressionMethods::partition_by`]
505    pub type PartitionBy<Fn, E> = <Fn as PartitionByDsl<E>>::Output;
506
507    /// Return type of [`WindowExpressionMethods::window_order`]
508    pub type WindowOrder<Fn, E> = <Fn as OrderWindowDsl<E>>::Output;
509
510    /// Return type of [`WindowExpressionMethods::frame_by`]
511    pub type FrameBy<Fn, E> = <Fn as FrameDsl<E>>::Output;
512
513    /// Return type of [`AggregateExpressionMethods::aggregate_distinct`]
514    pub type AggregateDistinct<Fn> = <Fn as DistinctDsl>::Output;
515
516    /// Return type of [`AggregateExpressionMethods::aggregate_all`]
517    pub type AggregateAll<Fn> = <Fn as AllDsl>::Output;
518
519    /// Return type of [`AggregateExpressionMethods::aggregate_filter`]
520    pub type AggregateFilter<Fn, P> = <Fn as FilterDsl<crate::dsl::AsExprOf<P, Bool>>>::Output;
521
522    /// Return type of [`AggregateExpressionMethods::aggregate_order`]
523    pub type AggregateOrder<Fn, O> = <Fn as OrderAggregateDsl<O>>::Output;
524
525    /// Return type of [`FrameClauseDsl::frame_start_with`]
526    pub type FrameStartWith<S, T> = self::frame_clause::StartFrame<S, T>;
527
528    /// Return type of [`FrameClauseDsl::frame_start_with_exclusion`]
529    pub type FrameStartWithExclusion<S, T, E> = self::frame_clause::StartFrame<S, T, E>;
530
531    /// Return type of [`FrameClauseDsl::frame_between`]
532    pub type FrameBetween<S, E1, E2> = self::frame_clause::BetweenFrame<S, E1, E2>;
533
534    /// Return type of [`FrameClauseDsl::frame_between_with_exclusion`]
535    pub type FrameBetweenWithExclusion<S, E1, E2, E> =
536        self::frame_clause::BetweenFrame<S, E1, E2, E>;
537
538    /// Return type of [`FrameBoundDsl::preceding`]
539    pub type Preceding<I> = self::frame_clause::OffsetPreceding<I>;
540
541    /// Return type of [`FrameBoundDsl::following`]
542    pub type Following<I> = self::frame_clause::OffsetFollowing<I>;
543}