1use crate::expression::{Expression, ValidGrouping};
2use crate::pg::Pg;
3use crate::query_builder::{AsQuery, AstPass, FromClause, QueryFragment, QueryId, SelectStatement};
4use crate::query_source::QuerySource;
5use crate::result::QueryResult;
6use crate::sql_types::{Double, SmallInt};
7use crate::{JoinTo, SelectableExpression, Table};
8use std::marker::PhantomData;
9
10#[doc(hidden)]
11pub trait TablesampleMethod: Clone {
12 fn method_name_sql() -> &'static str;
13}
14
15#[derive(Clone, Copy, Debug)]
16pub struct BernoulliMethod;
18
19impl TablesampleMethod for BernoulliMethod {
20 fn method_name_sql() -> &'static str {
21 "BERNOULLI"
22 }
23}
24
25#[derive(Clone, Copy, Debug)]
26pub struct SystemMethod;
28
29impl TablesampleMethod for SystemMethod {
30 fn method_name_sql() -> &'static str {
31 "SYSTEM"
32 }
33}
34
35#[derive(Debug, Clone, Copy)]
37pub struct Tablesample<S, TSM>
38where
39 TSM: TablesampleMethod,
40{
41 source: S,
42 method: PhantomData<TSM>,
43 portion: i16,
44 seed: Option<f64>,
45}
46
47impl<S, TSM> Tablesample<S, TSM>
48where
49 TSM: TablesampleMethod,
50{
51 pub(crate) fn new(source: S, portion: i16) -> Tablesample<S, TSM> {
52 Tablesample {
53 source,
54 method: PhantomData,
55 portion,
56 seed: None,
57 }
58 }
59
60 pub fn with_seed(self, seed: f64) -> Tablesample<S, TSM> {
63 Tablesample {
64 source: self.source,
65 method: self.method,
66 portion: self.portion,
67 seed: Some(seed),
68 }
69 }
70}
71
72impl<S, TSM> QueryId for Tablesample<S, TSM>
73where
74 S: QueryId,
75 TSM: TablesampleMethod,
76{
77 type QueryId = ();
78 const HAS_STATIC_QUERY_ID: bool = false;
79}
80
81impl<S, TSM> QuerySource for Tablesample<S, TSM>
82where
83 S: Table + Clone,
84 TSM: TablesampleMethod,
85 <S as QuerySource>::DefaultSelection:
86 ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
87{
88 type FromClause = Self;
89 type DefaultSelection = <S as QuerySource>::DefaultSelection;
90
91 fn from_clause(&self) -> Self::FromClause {
92 self.clone()
93 }
94
95 fn default_selection(&self) -> Self::DefaultSelection {
96 self.source.default_selection()
97 }
98}
99
100impl<S, TSM> QueryFragment<Pg> for Tablesample<S, TSM>
101where
102 S: QueryFragment<Pg>,
103 TSM: TablesampleMethod,
104{
105 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
106 self.source.walk_ast(out.reborrow())?;
107 out.push_sql(" TABLESAMPLE ");
108 out.push_sql(TSM::method_name_sql());
109 out.push_sql("(");
110 out.push_bind_param::<SmallInt, _>(&self.portion)?;
111 out.push_sql(")");
112 if let Some(f) = &self.seed {
113 out.push_sql(" REPEATABLE(");
114 out.push_bind_param::<Double, _>(f)?;
115 out.push_sql(")");
116 }
117 Ok(())
118 }
119}
120
121impl<S, TSM> AsQuery for Tablesample<S, TSM>
122where
123 S: Table + Clone,
124 TSM: TablesampleMethod,
125 <S as QuerySource>::DefaultSelection:
126 ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
127{
128 type SqlType = <<Self as QuerySource>::DefaultSelection as Expression>::SqlType;
129 type Query = SelectStatement<FromClause<Self>>;
130 fn as_query(self) -> Self::Query {
131 SelectStatement::simple(self)
132 }
133}
134
135impl<S, T, TSM> JoinTo<T> for Tablesample<S, TSM>
136where
137 S: JoinTo<T>,
138 T: Table,
139 S: Table,
140 TSM: TablesampleMethod,
141{
142 type FromClause = <S as JoinTo<T>>::FromClause;
143 type OnClause = <S as JoinTo<T>>::OnClause;
144
145 fn join_target(rhs: T) -> (Self::FromClause, Self::OnClause) {
146 <S as JoinTo<T>>::join_target(rhs)
147 }
148}
149
150impl<S, TSM> Table for Tablesample<S, TSM>
151where
152 S: Table + Clone + AsQuery,
153 TSM: TablesampleMethod,
154
155 <S as Table>::PrimaryKey: SelectableExpression<Tablesample<S, TSM>>,
156 <S as Table>::AllColumns: SelectableExpression<Tablesample<S, TSM>>,
157 <S as QuerySource>::DefaultSelection:
158 ValidGrouping<()> + SelectableExpression<Tablesample<S, TSM>>,
159{
160 type PrimaryKey = <S as Table>::PrimaryKey;
161 type AllColumns = <S as Table>::AllColumns;
162
163 fn primary_key(&self) -> Self::PrimaryKey {
164 self.source.primary_key()
165 }
166
167 fn all_columns() -> Self::AllColumns {
168 S::all_columns()
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175 use crate::backend::Backend;
176 use crate::query_builder::QueryBuilder;
177 use diesel::dsl::*;
178 use diesel::*;
179
180 macro_rules! assert_sql {
181 ($query:expr, $sql:expr) => {
182 let mut query_builder = <Pg as Backend>::QueryBuilder::default();
183 $query.to_sql(&mut query_builder, &Pg).unwrap();
184 let sql = query_builder.finish();
185 assert_eq!(sql, $sql);
186 };
187 }
188
189 table! {
190 users {
191 id -> Integer,
192 name -> VarChar,
193 }
194 }
195
196 #[diesel_test_helper::test]
197 fn test_generated_tablesample_sql() {
198 assert_sql!(
199 users::table.tablesample_bernoulli(10),
200 "\"users\" TABLESAMPLE BERNOULLI($1)"
201 );
202
203 assert_sql!(
204 users::table.tablesample_system(10),
205 "\"users\" TABLESAMPLE SYSTEM($1)"
206 );
207
208 assert_sql!(
209 users::table.tablesample_system(10).with_seed(42.0),
210 "\"users\" TABLESAMPLE SYSTEM($1) REPEATABLE($2)"
211 );
212 }
213}