diesel/expression/functions/aggregate_expressions/
aggregate_filter.rs

1use super::aggregate_order::NoOrder;
2use super::over_clause::ValidAggregateFilterForWindow;
3use super::prefix::NoPrefix;
4use super::AggregateExpression;
5use super::IsAggregateFunction;
6use super::NoWindow;
7use crate::backend::{sql_dialect, Backend, SqlDialect};
8use crate::query_builder::where_clause::NoWhereClause;
9use crate::query_builder::where_clause::WhereAnd;
10use crate::query_builder::QueryFragment;
11use crate::query_builder::{AstPass, QueryId};
12use crate::sql_types::BoolOrNullableBool;
13use crate::Expression;
14use crate::QueryResult;
15
16empty_clause!(NoFilter);
17
18#[derive(QueryId, Copy, Clone, Debug)]
19pub struct Filter<P>(P);
20
21impl<P, DB> QueryFragment<DB> for Filter<P>
22where
23    Self: QueryFragment<DB, DB::AggregateFunctionExpressions>,
24    DB: Backend,
25{
26    fn walk_ast<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
27        <Self as QueryFragment<DB, DB::AggregateFunctionExpressions>>::walk_ast(self, pass)
28    }
29}
30
31impl<P, DB>
32    QueryFragment<
33        DB,
34        sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions,
35    > for Filter<P>
36where
37    P: QueryFragment<DB>,
38    DB: Backend + SqlDialect<AggregateFunctionExpressions = sql_dialect::aggregate_function_expressions::PostgresLikeAggregateFunctionExpressions>,
39{
40    fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
41        pass.push_sql(" FILTER (");
42        self.0.walk_ast(pass.reborrow())?;
43        pass.push_sql(")");
44        Ok(())
45    }
46}
47
48pub trait FilterDsl<P> {
49    type Output;
50
51    fn filter(self, f: P) -> Self::Output;
52}
53
54impl<P, T, ST> FilterDsl<P> for T
55where
56    T: IsAggregateFunction,
57    P: Expression<SqlType = ST>,
58    ST: BoolOrNullableBool,
59{
60    type Output =
61        AggregateExpression<T, NoPrefix, NoOrder, Filter<<NoWhereClause as WhereAnd<P>>::Output>>;
62
63    fn filter(self, f: P) -> Self::Output {
64        AggregateExpression {
65            prefix: NoPrefix,
66            function: self,
67            order: NoOrder,
68            filter: Filter(NoWhereClause.and(f)),
69            window: NoWindow,
70        }
71    }
72}
73
74impl<Fn, P, Prefix, Order, F, Window, ST> FilterDsl<P>
75    for AggregateExpression<Fn, Prefix, Order, Filter<F>, Window>
76where
77    P: Expression<SqlType = ST>,
78    ST: BoolOrNullableBool,
79    F: WhereAnd<P>,
80    Filter<<F as WhereAnd<P>>::Output>: ValidAggregateFilterForWindow<Fn, Window>,
81{
82    type Output =
83        AggregateExpression<Fn, Prefix, Order, Filter<<F as WhereAnd<P>>::Output>, Window>;
84
85    fn filter(self, f: P) -> Self::Output {
86        AggregateExpression {
87            prefix: self.prefix,
88            function: self.function,
89            order: self.order,
90            filter: Filter(WhereAnd::<P>::and(self.filter.0, f)),
91            window: self.window,
92        }
93    }
94}
95
96impl<Fn, P, Prefix, Order, Window, ST> FilterDsl<P>
97    for AggregateExpression<Fn, Prefix, Order, NoFilter, Window>
98where
99    P: Expression<SqlType = ST>,
100    ST: BoolOrNullableBool,
101    NoWhereClause: WhereAnd<P>,
102    Filter<<NoWhereClause as WhereAnd<P>>::Output>: ValidAggregateFilterForWindow<Fn, Window>,
103{
104    type Output = AggregateExpression<
105        Fn,
106        Prefix,
107        Order,
108        Filter<<NoWhereClause as WhereAnd<P>>::Output>,
109        Window,
110    >;
111
112    fn filter(self, f: P) -> Self::Output {
113        AggregateExpression {
114            prefix: self.prefix,
115            function: self.function,
116            order: self.order,
117            filter: Filter(WhereAnd::<P>::and(NoWhereClause, f)),
118            window: self.window,
119        }
120    }
121}