diesel/expression/functions/aggregate_expressions.rs
1use crate::backend::Backend;
2use crate::expression::{AsExpression, ValidGrouping};
3use crate::query_builder::{AstPass, NotSpecialized, QueryFragment, QueryId};
4use crate::sql_types::Bool;
5use crate::{AppearsOnTable, Expression, QueryResult, SelectableExpression};
6
7macro_rules! empty_clause {
8 ($name: ident) => {
9 #[derive(Debug, Clone, Copy, QueryId)]
10 pub struct $name;
11
12 impl<DB> crate::query_builder::QueryFragment<DB> for $name
13 where
14 DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization,
15 {
16 fn walk_ast<'b>(
17 &'b self,
18 _pass: crate::query_builder::AstPass<'_, 'b, DB>,
19 ) -> crate::QueryResult<()> {
20 Ok(())
21 }
22 }
23 };
24}
25
26mod aggregate_filter;
27mod aggregate_order;
28pub(crate) mod frame_clause;
29mod over_clause;
30mod partition_by;
31mod prefix;
32
33use self::aggregate_filter::{FilterDsl, NoFilter};
34pub use self::aggregate_order::Order;
35use self::aggregate_order::{NoOrder, OrderAggregateDsl, OrderWindowDsl};
36use self::frame_clause::{FrameDsl, NoFrame};
37pub use self::over_clause::OverClause;
38use self::over_clause::{NoWindow, OverDsl};
39use self::partition_by::PartitionByDsl;
40use self::prefix::{AllDsl, DistinctDsl, NoPrefix};
41
42#[derive(QueryId, Debug)]
43pub struct AggregateExpression<
44 Fn,
45 Prefix = NoPrefix,
46 Order = NoOrder,
47 Filter = NoFilter,
48 Window = NoWindow,
49> {
50 prefix: Prefix,
51 function: Fn,
52 order: Order,
53 filter: Filter,
54 window: Window,
55}
56
57impl<Fn, Prefix, Order, Filter, Window, DB> QueryFragment<DB>
58 for AggregateExpression<Fn, Prefix, Order, Filter, Window>
59where
60 DB: crate::backend::Backend + crate::backend::DieselReserveSpecialization,
61 Fn: FunctionFragment<DB>,
62 Prefix: QueryFragment<DB>,
63 Order: QueryFragment<DB>,
64 Filter: QueryFragment<DB>,
65 Window: QueryFragment<DB> + WindowFunctionFragment<Fn, DB>,
66{
67 fn walk_ast<'b>(&'b self, mut pass: AstPass<'_, 'b, DB>) -> QueryResult<()> {
68 pass.push_sql(Fn::FUNCTION_NAME);
69 pass.push_sql("(");
70 self.prefix.walk_ast(pass.reborrow())?;
71 self.function.walk_arguments(pass.reborrow())?;
72 self.order.walk_ast(pass.reborrow())?;
73 pass.push_sql(")");
74 self.filter.walk_ast(pass.reborrow())?;
75 self.window.walk_ast(pass.reborrow())?;
76 Ok(())
77 }
78}
79
80impl<Fn, Prefix, Order, Filter, GB> ValidGrouping<GB>
81 for AggregateExpression<Fn, Prefix, Order, Filter>
82where
83 Fn: ValidGrouping<GB>,
84{
85 type IsAggregate = <Fn as ValidGrouping<GB>>::IsAggregate;
86}
87
88impl<Fn, Prefix, Order, Filter, GB, Partition, WindowOrder, Frame> ValidGrouping<GB>
89 for AggregateExpression<Fn, Prefix, Order, Filter, OverClause<Partition, WindowOrder, Frame>>
90where
91 Fn: IsWindowFunction,
92 Fn::ArgTypes: ValidGrouping<GB>,
93{
94 type IsAggregate = <Fn::ArgTypes as ValidGrouping<GB>>::IsAggregate;
95}
96
97impl<Fn, Prefix, Order, Filter, Window> Expression
98 for AggregateExpression<Fn, Prefix, Order, Filter, Window>
99where
100 Fn: Expression,
101{
102 type SqlType = <Fn as Expression>::SqlType;
103}
104
105impl<Fn, Prefix, Order, Filter, Window, QS> AppearsOnTable<QS>
106 for AggregateExpression<Fn, Prefix, Order, Filter, Window>
107where
108 Self: Expression,
109 Fn: AppearsOnTable<QS>,
110{
111}
112
113impl<Fn, Prefix, Order, Filter, Window, QS> SelectableExpression<QS>
114 for AggregateExpression<Fn, Prefix, Order, Filter, Window>
115where
116 Self: Expression,
117 Fn: SelectableExpression<QS>,
118{
119}
120
121/// A helper marker trait that this function is a window function
122/// This is only used to provide the gate the `WindowExpressionMethods`
123/// trait onto, not to check if the construct is valid for a given backend
124/// This check is postponed to building the query via `QueryFragment`
125/// (We have access to the DB type there)
126#[diagnostic::on_unimplemented(
127 message = "{Self} is not a window function",
128 label = "remove this function call to use `{Self}` as normal SQL function",
129 note = "try removing any method call to `WindowExpressionMethods` and use it as normal SQL function"
130)]
131pub trait IsWindowFunction {
132 /// A tuple of all arg types
133 type ArgTypes;
134}
135
136/// A helper marker trait that this function is a valid window function
137/// for the given backend
138/// this trait is used to transport information that
139/// a certain function can be used as window function for a specific
140/// backend
141/// We allow to specialize this function for different SQL dialects
142pub trait WindowFunctionFragment<Fn, DB: Backend, SP = NotSpecialized> {}
143
144/// A helper marker trait that this function as a aggregate function
145/// This is only used to provide the gate the `AggregateExpressionMethods`
146/// trait onto, not to check if the construct is valid for a given backend
147/// This check is postponed to building the query via `QueryFragment`
148/// (We have access to the DB type there)
149pub trait IsAggregateFunction {}
150
151/// A specialized QueryFragment helper trait that allows us to walk the function name
152/// and the function arguments in separate steps
153pub trait FunctionFragment<DB: Backend> {
154 /// The name of the sql function
155 const FUNCTION_NAME: &'static str;
156
157 /// Walk the function argument part (everything between ())
158 fn walk_arguments<'b>(&'b self, pass: AstPass<'_, 'b, DB>) -> QueryResult<()>;
159}
160
161/// Expression methods to build aggregate function expressions
162pub trait AggregateExpressionMethods: Sized {
163 /// `DISTINCT` modifier for aggregate functions
164 ///
165 /// This modifies the aggregate function call to only
166 /// include distinct items
167 ///
168 /// # Example
169 ///
170 /// ```rust
171 /// # include!("../../doctest_setup.rs");
172 /// #
173 /// # fn main() {
174 /// # run_test().unwrap();
175 /// # }
176 /// #
177 /// # fn run_test() -> QueryResult<()> {
178 /// # use schema::posts::dsl::*;
179 /// # use diesel::dsl;
180 /// # let connection = &mut establish_connection();
181 /// let without_distinct = posts
182 /// .select(dsl::count(user_id))
183 /// .get_result::<i64>(connection)?;
184 /// let with_distinct = posts
185 /// .select(dsl::count(user_id).aggregate_distinct())
186 /// .get_result::<i64>(connection)?;
187 ///
188 /// assert_eq!(3, without_distinct);
189 /// assert_eq!(2, with_distinct);
190 /// # Ok(())
191 /// # }
192 /// ```
193 fn aggregate_distinct(self) -> self::dsl::AggregateDistinct<Self>
194 where
195 Self: DistinctDsl,
196 {
197 <Self as DistinctDsl>::distinct(self)
198 }
199
200 /// `ALL` modifier for aggregate functions
201 ///
202 /// This modifies the aggregate function call to include
203 /// all items. This is the default behaviour.
204 ///
205 /// # Example
206 ///
207 /// ```rust
208 /// # include!("../../doctest_setup.rs");
209 /// #
210 /// # fn main() {
211 /// # run_test().unwrap();
212 /// # }
213 /// #
214 /// # fn run_test() -> QueryResult<()> {
215 /// # use schema::posts::dsl::*;
216 /// # use diesel::dsl;
217 /// # let connection = &mut establish_connection();
218 /// let without_all = posts
219 /// .select(dsl::count(user_id))
220 /// .get_result::<i64>(connection)?;
221 /// let with_all = posts
222 /// .select(dsl::count(user_id).aggregate_all())
223 /// .get_result::<i64>(connection)?;
224 ///
225 /// assert_eq!(3, without_all);
226 /// assert_eq!(3, with_all);
227 /// # Ok(())
228 /// # }
229 /// ```
230 fn aggregate_all(self) -> self::dsl::AggregateAll<Self>
231 where
232 Self: AllDsl,
233 {
234 <Self as AllDsl>::all(self)
235 }
236
237 /// Add an aggregate function filter
238 ///
239 /// This function modifies an aggregate function
240 /// call to use only items matching the provided
241 /// filter
242 ///
243 /// # Example
244 ///
245 /// ```rust
246 /// # include!("../../doctest_setup.rs");
247 /// #
248 /// # fn main() {
249 /// # #[cfg(not(feature = "mysql"))]
250 /// # run_test().unwrap();
251 /// # }
252 /// #
253 /// # #[cfg(not(feature = "mysql"))]
254 /// # fn run_test() -> QueryResult<()> {
255 /// # use schema::posts::dsl::*;
256 /// # use diesel::dsl;
257 /// # let connection = &mut establish_connection();
258 /// let without_filter = posts
259 /// .select(dsl::count(user_id))
260 /// .get_result::<i64>(connection)?;
261 /// let with_filter = posts
262 /// .select(dsl::count(user_id).aggregate_filter(title.like("%first post%")))
263 /// .get_result::<i64>(connection)?;
264 ///
265 /// assert_eq!(3, without_filter);
266 /// assert_eq!(2, with_filter);
267 /// # Ok(())
268 /// # }
269 /// ```
270 fn aggregate_filter<P>(self, f: P) -> self::dsl::AggregateFilter<Self, P>
271 where
272 P: AsExpression<Bool>,
273 Self: FilterDsl<P::Expression>,
274 {
275 <Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
276 }
277
278 /// Add an aggregate function order
279 ///
280 /// This function orders the items passed into an
281 /// aggregate function
282 ///
283 /// For sqlite this is only supported starting with SQLite 3.44
284 ///
285 /// # Example
286 ///
287 /// ```rust
288 /// # include!("../../doctest_setup.rs");
289 /// #
290 /// # fn main() {
291 /// # #[cfg(not(feature = "mysql"))]
292 /// # run_test().unwrap();
293 /// # }
294 /// #
295 /// # #[cfg(not(feature = "mysql"))]
296 /// # fn run_test() -> QueryResult<()> {
297 /// # use schema::posts::dsl::*;
298 /// # use diesel::dsl;
299 /// # let connection = &mut establish_connection();
300 /// # #[cfg(feature = "sqlite")]
301 /// # assert_version!(connection, 3, 44, 0);
302 /// // This example is not meaningful yet,
303 /// // modify it as soon as we support more
304 /// // meaningful functions here
305 /// let res = posts
306 /// .select(dsl::count(user_id).aggregate_order(title))
307 /// .get_result::<i64>(connection)?;
308 /// assert_eq!(3, res);
309 /// # Ok(())
310 /// # }
311 /// ```
312 fn aggregate_order<O>(self, o: O) -> self::dsl::AggregateOrder<Self, O>
313 where
314 Self: OrderAggregateDsl<O>,
315 {
316 <Self as OrderAggregateDsl<O>>::order(self, o)
317 }
318}
319
320impl<T> AggregateExpressionMethods for T {}
321
322/// Methods to construct a window function call
323pub trait WindowExpressionMethods: Sized {
324 /// Turn a function call into a window function call
325 ///
326 /// This function turns a ordinary SQL function call
327 /// into a window function call by adding an empty `OVER ()`
328 /// clause
329 ///
330 /// # Example
331 ///
332 /// ```rust
333 /// # include!("../../doctest_setup.rs");
334 /// #
335 /// # fn main() {
336 /// # run_test().unwrap();
337 /// # }
338 /// #
339 /// # fn run_test() -> QueryResult<()> {
340 /// # use schema::posts::dsl::*;
341 /// # use diesel::dsl;
342 /// # let connection = &mut establish_connection();
343 /// let res = posts
344 /// .select(dsl::count(user_id).over())
345 /// .load::<i64>(connection)?;
346 /// assert_eq!(vec![3, 3, 3], res);
347 /// # Ok(())
348 /// # }
349 /// ```
350 fn over(self) -> self::dsl::Over<Self>
351 where
352 Self: OverDsl,
353 {
354 <Self as OverDsl>::over(self)
355 }
356
357 /// Add a filter to the current window function
358 ///
359 ///
360 /// # Example
361 ///
362 /// ```rust
363 /// # include!("../../doctest_setup.rs");
364 /// #
365 /// # fn main() {
366 /// # #[cfg(not(feature = "mysql"))]
367 /// # run_test().unwrap();
368 /// # }
369 /// #
370 /// # #[cfg(not(feature = "mysql"))]
371 /// # fn run_test() -> QueryResult<()> {
372 /// # use schema::posts::dsl::*;
373 /// # use diesel::dsl;
374 /// # let connection = &mut establish_connection();
375 /// let res = posts
376 /// .select(dsl::count(user_id).window_filter(user_id.eq(1)))
377 /// .load::<i64>(connection)?;
378 /// assert_eq!(vec![2], res);
379 /// # Ok(())
380 /// # }
381 /// ```
382 fn window_filter<P>(self, f: P) -> self::dsl::WindowFilter<Self, P>
383 where
384 P: AsExpression<Bool>,
385 Self: FilterDsl<P::Expression>,
386 {
387 <Self as FilterDsl<P::Expression>>::filter(self, f.as_expression())
388 }
389
390 /// Add a partition clause to the current window function
391 ///
392 /// This function adds a `PARTITION BY` clause to your window function call
393 ///
394 /// # Example
395 ///
396 /// ```rust
397 /// # include!("../../doctest_setup.rs");
398 /// #
399 /// # fn main() {
400 /// # run_test().unwrap();
401 /// # }
402 /// #
403 /// # fn run_test() -> QueryResult<()> {
404 /// # use schema::posts::dsl::*;
405 /// # use diesel::dsl;
406 /// # let connection = &mut establish_connection();
407 /// let res = posts
408 /// .select(dsl::count(user_id).partition_by(user_id))
409 /// .load::<i64>(connection)?;
410 /// assert_eq!(vec![2, 2, 1], res);
411 /// # Ok(())
412 /// # }
413 /// ```
414 fn partition_by<E>(self, expr: E) -> self::dsl::PartitionBy<Self, E>
415 where
416 Self: PartitionByDsl<E>,
417 {
418 <Self as PartitionByDsl<E>>::partition_by(self, expr)
419 }
420
421 /// Add a order clause to the current window function
422 ///
423 /// Add a `ORDER BY` clause to your window function call
424 ///
425 /// # Example
426 ///
427 /// ```rust
428 /// # include!("../../doctest_setup.rs");
429 /// #
430 /// # fn main() {
431 /// # run_test().unwrap();
432 /// # }
433 /// #
434 /// # fn run_test() -> QueryResult<()> {
435 /// # use schema::posts::dsl::*;
436 /// # use diesel::dsl;
437 /// # let connection = &mut establish_connection();
438 /// let res = posts
439 /// .select(dsl::first_value(user_id).window_order(title))
440 /// .load::<i32>(connection)?;
441 /// assert_eq!(vec![1, 1, 1], res);
442 /// # Ok(())
443 /// # }
444 /// ```
445 fn window_order<E>(self, expr: E) -> self::dsl::WindowOrder<Self, E>
446 where
447 Self: OrderWindowDsl<E>,
448 {
449 <Self as OrderWindowDsl<E>>::order(self, expr)
450 }
451
452 /// Add a frame clause to the current window function
453 ///
454 /// This function adds a frame clause to your window function call.
455 /// Accepts the following items:
456 ///
457 /// * [`dsl::frame::Groups`](crate::dsl::frame::Groups)
458 /// * [`dsl::frame::Rows`](crate::dsl::frame::Rows)
459 /// * [`dsl::frame::Range`](crate::dsl::frame::Range)
460 ///
461 /// # Example
462 ///
463 /// ```rust
464 /// # include!("../../doctest_setup.rs");
465 /// #
466 /// # fn main() {
467 /// # run_test().unwrap();
468 /// # }
469 /// #
470 /// # fn run_test() -> QueryResult<()> {
471 /// # use schema::posts::dsl::*;
472 /// # use diesel::dsl;
473 /// # let connection = &mut establish_connection();
474 /// let res = posts
475 /// .select(
476 /// dsl::count(user_id).frame_by(dsl::frame::Rows.frame_start_with(dsl::frame::CurrentRow)),
477 /// )
478 /// .load::<i64>(connection)?;
479 /// assert_eq!(vec![1, 1, 1], res);
480 /// # Ok(())
481 /// # }
482 /// ```
483 fn frame_by<E>(self, expr: E) -> self::dsl::FrameBy<Self, E>
484 where
485 Self: FrameDsl<E>,
486 {
487 <Self as FrameDsl<E>>::frame(self, expr)
488 }
489}
490
491impl<T> WindowExpressionMethods for T {}
492
493pub(super) mod dsl {
494 #[cfg(doc)]
495 use super::frame_clause::{FrameBoundDsl, FrameClauseDsl};
496 use super::*;
497
498 /// Return type of [`WindowExpressionMethods::over`]
499 pub type Over<Fn> = <Fn as OverDsl>::Output;
500
501 /// Return type of [`WindowExpressionMethods::window_filter`]
502 pub type WindowFilter<Fn, P> = <Fn as FilterDsl<crate::dsl::AsExprOf<P, Bool>>>::Output;
503
504 /// Return type of [`WindowExpressionMethods::partition_by`]
505 pub type PartitionBy<Fn, E> = <Fn as PartitionByDsl<E>>::Output;
506
507 /// Return type of [`WindowExpressionMethods::window_order`]
508 pub type WindowOrder<Fn, E> = <Fn as OrderWindowDsl<E>>::Output;
509
510 /// Return type of [`WindowExpressionMethods::frame_by`]
511 pub type FrameBy<Fn, E> = <Fn as FrameDsl<E>>::Output;
512
513 /// Return type of [`AggregateExpressionMethods::aggregate_distinct`]
514 pub type AggregateDistinct<Fn> = <Fn as DistinctDsl>::Output;
515
516 /// Return type of [`AggregateExpressionMethods::aggregate_all`]
517 pub type AggregateAll<Fn> = <Fn as AllDsl>::Output;
518
519 /// Return type of [`AggregateExpressionMethods::aggregate_filter`]
520 pub type AggregateFilter<Fn, P> = <Fn as FilterDsl<crate::dsl::AsExprOf<P, Bool>>>::Output;
521
522 /// Return type of [`AggregateExpressionMethods::aggregate_order`]
523 pub type AggregateOrder<Fn, O> = <Fn as OrderAggregateDsl<O>>::Output;
524
525 /// Return type of [`FrameClauseDsl::frame_start_with`]
526 pub type FrameStartWith<S, T> = self::frame_clause::StartFrame<S, T>;
527
528 /// Return type of [`FrameClauseDsl::frame_start_with_exclusion`]
529 pub type FrameStartWithExclusion<S, T, E> = self::frame_clause::StartFrame<S, T, E>;
530
531 /// Return type of [`FrameClauseDsl::frame_between`]
532 pub type FrameBetween<S, E1, E2> = self::frame_clause::BetweenFrame<S, E1, E2>;
533
534 /// Return type of [`FrameClauseDsl::frame_between_with_exclusion`]
535 pub type FrameBetweenWithExclusion<S, E1, E2, E> =
536 self::frame_clause::BetweenFrame<S, E1, E2, E>;
537
538 /// Return type of [`FrameBoundDsl::preceding`]
539 pub type Preceding<I> = self::frame_clause::OffsetPreceding<I>;
540
541 /// Return type of [`FrameBoundDsl::following`]
542 pub type Following<I> = self::frame_clause::OffsetFollowing<I>;
543}