diesel_derives/
table.rs

1use diesel_table_macro_syntax::{ColumnDef, TableDecl};
2use proc_macro2::TokenStream;
3use syn::parse_quote;
4use syn::Ident;
5
6const DEFAULT_PRIMARY_KEY_NAME: &str = "id";
7
8pub(crate) fn expand(input: TableDecl) -> TokenStream {
9    if input.column_defs.len() > super::diesel_for_each_tuple::MAX_TUPLE_SIZE as usize {
10        let txt = if input.column_defs.len() > 128 {
11            "You reached the end. Diesel does not support tables with \
12             more than 128 columns. Consider using less columns."
13        } else if input.column_defs.len() > 64 {
14            "Table contains more than 64 columns. Consider enabling the \
15             `128-column-tables` feature to enable diesels support for \
16             tables with more than 64 columns."
17        } else if input.column_defs.len() > 32 {
18            "Table contains more than 32 columns. Consider enabling the \
19             `64-column-tables` feature to enable diesels support for \
20             tables with more than 32 columns."
21        } else {
22            "Table contains more than 16 columns. Consider enabling the \
23             `32-column-tables` feature to enable diesels support for \
24             tables with more than 16 columns."
25        };
26        return quote::quote! {
27            compile_error!(#txt);
28        };
29    }
30
31    let meta = &input.meta;
32    let table_name = &input.table_name;
33    let imports = if input.use_statements.is_empty() {
34        vec![parse_quote!(
35            use diesel::sql_types::*;
36        )]
37    } else {
38        input.use_statements.clone()
39    };
40    let column_names = input
41        .column_defs
42        .iter()
43        .map(|c| &c.column_name)
44        .collect::<Vec<_>>();
45    let column_names = &column_names;
46    let primary_key: TokenStream = match input.primary_keys.as_ref() {
47        None if column_names.contains(&&syn::Ident::new(
48            DEFAULT_PRIMARY_KEY_NAME,
49            proc_macro2::Span::call_site(),
50        )) =>
51        {
52            let id = syn::Ident::new(DEFAULT_PRIMARY_KEY_NAME, proc_macro2::Span::call_site());
53            parse_quote! {
54                #id
55            }
56        }
57        None => {
58            let mut message = format!(
59                "Neither an explicit primary key found nor does an `id` column exist.\n\
60                 Consider explicitly defining a primary key. \n\
61                 For example for specifying `{}` as primary key:\n\n\
62                 table! {{\n",
63                column_names[0],
64            );
65            message += &format!("\t{table_name} ({}) {{\n", &column_names[0]);
66            for c in &input.column_defs {
67                let tpe = c
68                    .tpe
69                    .path
70                    .segments
71                    .iter()
72                    .map(|p| p.ident.to_string())
73                    .collect::<Vec<_>>()
74                    .join("::");
75                message += &format!("\t\t{} -> {tpe},\n", c.column_name);
76            }
77            message += "\t}\n}";
78
79            let span = input.table_name.span();
80            return quote::quote_spanned! {span=>
81                compile_error!(#message);
82            };
83        }
84        Some(a) if a.keys.len() == 1 => {
85            let k = a.keys.first().unwrap();
86            parse_quote! {
87                #k
88            }
89        }
90        Some(a) => {
91            let keys = a.keys.iter();
92
93            parse_quote! {
94                (#(#keys,)*)
95            }
96        }
97    };
98
99    let column_defs = input.column_defs.iter().map(expand_column_def);
100    let column_ty = input.column_defs.iter().map(|c| &c.tpe);
101    let valid_grouping_for_table_columns = generate_valid_grouping_for_table_columns(&input);
102
103    let sql_name = &input.sql_name;
104    let static_query_fragment_impl_for_table = if let Some(schema) = input.schema {
105        let schema_name = schema.to_string();
106        quote::quote! {
107            impl diesel::internal::table_macro::StaticQueryFragment for table {
108                type Component = diesel::internal::table_macro::InfixNode<
109                        diesel::internal::table_macro::Identifier<'static>,
110                    diesel::internal::table_macro::Identifier<'static>,
111                    &'static str
112                        >;
113                const STATIC_COMPONENT: &'static Self::Component = &diesel::internal::table_macro::InfixNode::new(
114                    diesel::internal::table_macro::Identifier(#schema_name),
115                    diesel::internal::table_macro::Identifier(#sql_name),
116                    "."
117                );
118            }
119        }
120    } else {
121        quote::quote! {
122            impl diesel::internal::table_macro::StaticQueryFragment for table {
123                type Component = diesel::internal::table_macro::Identifier<'static>;
124                const STATIC_COMPONENT: &'static Self::Component = &diesel::internal::table_macro::Identifier(#sql_name);
125            }
126        }
127    };
128
129    let reexport_column_from_dsl = input.column_defs.iter().map(|c| {
130        let column_name = &c.column_name;
131        if c.column_name == *table_name {
132            let span = c.column_name.span();
133            let message = format!(
134                "Column `{column_name}` cannot be named the same as it's table.\n\
135                 You may use `#[sql_name = \"{column_name}\"]` to reference the table's \
136                 `{column_name}` column \n\
137                 Docs available at: `https://docs.diesel.rs/master/diesel/macro.table.html`\n"
138            );
139            quote::quote_spanned! { span =>
140                compile_error!(#message);
141            }
142        } else {
143            quote::quote! {
144                pub use super::columns::#column_name;
145            }
146        }
147    });
148
149    let backend_specific_table_impls = if cfg!(feature = "postgres") {
150        Some(quote::quote! {
151            impl<S> diesel::JoinTo<diesel::query_builder::Only<S>> for table
152            where
153                diesel::query_builder::Only<S>: diesel::JoinTo<table>,
154            {
155                type FromClause = diesel::query_builder::Only<S>;
156                type OnClause = <diesel::query_builder::Only<S> as diesel::JoinTo<table>>::OnClause;
157
158                fn join_target(__diesel_internal_rhs: diesel::query_builder::Only<S>) -> (Self::FromClause, Self::OnClause) {
159                    let (_, __diesel_internal_on_clause) = diesel::query_builder::Only::<S>::join_target(table);
160                    (__diesel_internal_rhs, __diesel_internal_on_clause)
161                }
162            }
163
164            impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Only<table>>
165                for table
166            {
167                type Count = diesel::query_source::Once;
168            }
169
170            impl diesel::query_source::AppearsInFromClause<table>
171                for diesel::query_builder::Only<table>
172            {
173                type Count = diesel::query_source::Once;
174            }
175
176            impl<S, TSM> diesel::JoinTo<diesel::query_builder::Tablesample<S, TSM>> for table
177            where
178                diesel::query_builder::Tablesample<S, TSM>: diesel::JoinTo<table>,
179                TSM: diesel::internal::table_macro::TablesampleMethod
180            {
181                type FromClause = diesel::query_builder::Tablesample<S, TSM>;
182                type OnClause = <diesel::query_builder::Tablesample<S, TSM> as diesel::JoinTo<table>>::OnClause;
183
184                fn join_target(__diesel_internal_rhs: diesel::query_builder::Tablesample<S, TSM>) -> (Self::FromClause, Self::OnClause) {
185                    let (_, __diesel_internal_on_clause) = diesel::query_builder::Tablesample::<S, TSM>::join_target(table);
186                    (__diesel_internal_rhs, __diesel_internal_on_clause)
187                }
188            }
189
190            impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<table, TSM>>
191                for table
192                    where
193                TSM: diesel::internal::table_macro::TablesampleMethod
194            {
195                type Count = diesel::query_source::Once;
196            }
197
198            impl<TSM> diesel::query_source::AppearsInFromClause<table>
199                for diesel::query_builder::Tablesample<table, TSM>
200                    where
201                TSM: diesel::internal::table_macro::TablesampleMethod
202            {
203                type Count = diesel::query_source::Once;
204            }
205        })
206    } else {
207        None
208    };
209
210    let imports_for_column_module = imports.iter().map(fix_import_for_submodule);
211
212    quote::quote! {
213        #(#meta)*
214        #[allow(unused_imports, dead_code, unreachable_pub, unused_qualifications)]
215        pub mod #table_name {
216            use ::diesel;
217            pub use self::columns::*;
218            #(#imports)*
219
220            /// Re-exports all of the columns of this table, as well as the
221            /// table struct renamed to the module name. This is meant to be
222            /// glob imported for functions which only deal with one table.
223            pub mod dsl {
224                #(#reexport_column_from_dsl)*
225                pub use super::table as #table_name;
226            }
227
228            #[allow(non_upper_case_globals, dead_code)]
229            /// A tuple of all of the columns on this table
230            pub const all_columns: (#(#column_names,)*) = (#(#column_names,)*);
231
232            #[allow(non_camel_case_types)]
233            #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, Default)]
234            /// The actual table struct
235            ///
236            /// This is the type which provides the base methods of the query
237            /// builder, such as `.select` and `.filter`.
238            pub struct table;
239
240            impl table {
241                #[allow(dead_code)]
242                /// Represents `table_name.*`, which is sometimes necessary
243                /// for efficient count queries. It cannot be used in place of
244                /// `all_columns`
245                pub fn star(&self) -> star {
246                    star
247                }
248            }
249
250            /// The SQL type of all of the columns on this table
251            pub type SqlType = (#(#column_ty,)*);
252
253            /// Helper type for representing a boxed query from this table
254            pub type BoxedQuery<'a, DB, ST = SqlType> = diesel::internal::table_macro::BoxedSelectStatement<'a, ST, diesel::internal::table_macro::FromClause<table>, DB>;
255
256            impl diesel::QuerySource for table {
257                type FromClause = diesel::internal::table_macro::StaticQueryFragmentInstance<table>;
258                type DefaultSelection = <Self as diesel::Table>::AllColumns;
259
260                fn from_clause(&self) -> Self::FromClause {
261                    diesel::internal::table_macro::StaticQueryFragmentInstance::new()
262                }
263
264                fn default_selection(&self) -> Self::DefaultSelection {
265                    use diesel::Table;
266                    Self::all_columns()
267                }
268            }
269
270            impl<DB> diesel::query_builder::QueryFragment<DB> for table where
271                DB: diesel::backend::Backend,
272                <table as diesel::internal::table_macro::StaticQueryFragment>::Component: diesel::query_builder::QueryFragment<DB>
273            {
274                fn walk_ast<'b>(&'b self, __diesel_internal_pass: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()> {
275                    <table as diesel::internal::table_macro::StaticQueryFragment>::STATIC_COMPONENT.walk_ast(__diesel_internal_pass)
276                }
277            }
278
279            #static_query_fragment_impl_for_table
280
281            impl diesel::query_builder::AsQuery for table {
282                type SqlType = SqlType;
283                type Query = diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<Self>>;
284
285                fn as_query(self) -> Self::Query {
286                    diesel::internal::table_macro::SelectStatement::simple(self)
287                }
288            }
289
290            impl diesel::Table for table {
291                type PrimaryKey = #primary_key;
292                type AllColumns = (#(#column_names,)*);
293
294                fn primary_key(&self) -> Self::PrimaryKey {
295                    #primary_key
296                }
297
298                fn all_columns() -> Self::AllColumns {
299                    (#(#column_names,)*)
300                }
301            }
302
303            impl diesel::associations::HasTable for table {
304                type Table = Self;
305
306                fn table() -> Self::Table {
307                    table
308                }
309            }
310
311            impl diesel::query_builder::IntoUpdateTarget for table {
312                type WhereClause = <<Self as diesel::query_builder::AsQuery>::Query as diesel::query_builder::IntoUpdateTarget>::WhereClause;
313
314                fn into_update_target(self) -> diesel::query_builder::UpdateTarget<Self::Table, Self::WhereClause> {
315                    use diesel::query_builder::AsQuery;
316                    let q: diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<table>> = self.as_query();
317                    q.into_update_target()
318                }
319            }
320
321            impl diesel::query_source::AppearsInFromClause<table> for table {
322                type Count = diesel::query_source::Once;
323            }
324
325            // impl<S: AliasSource<Table=table>> AppearsInFromClause<table> for Alias<S>
326            impl<S> diesel::internal::table_macro::AliasAppearsInFromClause<S, table> for table
327            where S: diesel::query_source::AliasSource<Target=table>,
328            {
329                type Count = diesel::query_source::Never;
330            }
331
332            // impl<S1: AliasSource<Table=table>, S2: AliasSource<Table=table>> AppearsInFromClause<Alias<S1>> for Alias<S2>
333            // Those are specified by the `alias!` macro, but this impl will allow it to implement this trait even in downstream
334            // crates from the schema
335            impl<S1, S2> diesel::internal::table_macro::AliasAliasAppearsInFromClause<table, S2, S1> for table
336            where S1: diesel::query_source::AliasSource<Target=table>,
337                  S2: diesel::query_source::AliasSource<Target=table>,
338                  S1: diesel::internal::table_macro::AliasAliasAppearsInFromClauseSameTable<S2, table>,
339            {
340                type Count = <S1 as diesel::internal::table_macro::AliasAliasAppearsInFromClauseSameTable<S2, table>>::Count;
341            }
342
343            impl<S> diesel::query_source::AppearsInFromClause<diesel::query_source::Alias<S>> for table
344            where S: diesel::query_source::AliasSource,
345            {
346                type Count = diesel::query_source::Never;
347            }
348
349            impl<S, C> diesel::internal::table_macro::FieldAliasMapperAssociatedTypesDisjointnessTrick<table, S, C> for table
350            where
351                S: diesel::query_source::AliasSource<Target = table> + ::std::clone::Clone,
352                C: diesel::query_source::Column<Table = table>,
353            {
354                type Out = diesel::query_source::AliasedField<S, C>;
355
356                fn map(__diesel_internal_column: C, __diesel_internal_alias: &diesel::query_source::Alias<S>) -> Self::Out {
357                    __diesel_internal_alias.field(__diesel_internal_column)
358                }
359            }
360
361            impl diesel::query_source::AppearsInFromClause<table> for diesel::internal::table_macro::NoFromClause {
362                type Count = diesel::query_source::Never;
363            }
364
365            impl<Left, Right, Kind> diesel::JoinTo<diesel::internal::table_macro::Join<Left, Right, Kind>> for table where
366                diesel::internal::table_macro::Join<Left, Right, Kind>: diesel::JoinTo<table>,
367                Left: diesel::query_source::QuerySource,
368                Right: diesel::query_source::QuerySource,
369            {
370                type FromClause = diesel::internal::table_macro::Join<Left, Right, Kind>;
371                type OnClause = <diesel::internal::table_macro::Join<Left, Right, Kind> as diesel::JoinTo<table>>::OnClause;
372
373                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::Join<Left, Right, Kind>) -> (Self::FromClause, Self::OnClause) {
374                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::Join::join_target(table);
375                    (__diesel_internal_rhs, __diesel_internal_on_clause)
376                }
377            }
378
379            impl<Join, On> diesel::JoinTo<diesel::internal::table_macro::JoinOn<Join, On>> for table where
380                diesel::internal::table_macro::JoinOn<Join, On>: diesel::JoinTo<table>,
381            {
382                type FromClause = diesel::internal::table_macro::JoinOn<Join, On>;
383                type OnClause = <diesel::internal::table_macro::JoinOn<Join, On> as diesel::JoinTo<table>>::OnClause;
384
385                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::JoinOn<Join, On>) -> (Self::FromClause, Self::OnClause) {
386                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::JoinOn::join_target(table);
387                    (__diesel_internal_rhs, __diesel_internal_on_clause)
388                }
389            }
390
391            impl<F, S, D, W, O, L, Of, G> diesel::JoinTo<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>> for table where
392                diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>: diesel::JoinTo<table>,
393                F: diesel::query_source::QuerySource
394            {
395                type FromClause = diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>;
396                type OnClause = <diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G> as diesel::JoinTo<table>>::OnClause;
397
398                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>) -> (Self::FromClause, Self::OnClause) {
399                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::SelectStatement::join_target(table);
400                    (__diesel_internal_rhs, __diesel_internal_on_clause)
401                }
402            }
403
404            impl<'a, QS, ST, DB> diesel::JoinTo<diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>> for table where
405                diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>: diesel::JoinTo<table>,
406                QS: diesel::query_source::QuerySource,
407            {
408                type FromClause = diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>;
409                type OnClause = <diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB> as diesel::JoinTo<table>>::OnClause;
410                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>) -> (Self::FromClause, Self::OnClause) {
411                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::BoxedSelectStatement::join_target(table);
412                    (__diesel_internal_rhs, __diesel_internal_on_clause)
413                }
414            }
415
416            impl<S> diesel::JoinTo<diesel::query_source::Alias<S>> for table
417            where
418                diesel::query_source::Alias<S>: diesel::JoinTo<table>,
419            {
420                type FromClause = diesel::query_source::Alias<S>;
421                type OnClause = <diesel::query_source::Alias<S> as diesel::JoinTo<table>>::OnClause;
422
423                fn join_target(__diesel_internal_rhs: diesel::query_source::Alias<S>) -> (Self::FromClause, Self::OnClause) {
424                    let (_, __diesel_internal_on_clause) = diesel::query_source::Alias::<S>::join_target(table);
425                    (__diesel_internal_rhs, __diesel_internal_on_clause)
426                }
427            }
428
429            // This impl should be able to live in Diesel,
430            // but Rust tries to recurse for no reason
431            impl<T> diesel::insertable::Insertable<T> for table
432            where
433                <table as diesel::query_builder::AsQuery>::Query: diesel::insertable::Insertable<T>,
434            {
435                type Values = <<table as diesel::query_builder::AsQuery>::Query as diesel::insertable::Insertable<T>>::Values;
436
437                fn values(self) -> Self::Values {
438                    use diesel::query_builder::AsQuery;
439                    self.as_query().values()
440                }
441            }
442
443            impl<'a, T> diesel::insertable::Insertable<T> for &'a table
444            where
445                table: diesel::insertable::Insertable<T>,
446            {
447                type Values = <table as diesel::insertable::Insertable<T>>::Values;
448
449                fn values(self) -> Self::Values {
450                    (*self).values()
451                }
452            }
453
454            #backend_specific_table_impls
455
456            /// Contains all of the columns of this table
457            pub mod columns {
458                use ::diesel;
459                use super::table;
460                #(#imports_for_column_module)*
461
462                #[allow(non_camel_case_types, dead_code)]
463                #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
464                /// Represents `table_name.*`, which is sometimes needed for
465                /// efficient count queries. It cannot be used in place of
466                /// `all_columns`, and has a `SqlType` of `()` to prevent it
467                /// being used that way
468                pub struct star;
469
470                impl<__GB> diesel::expression::ValidGrouping<__GB> for star
471                where
472                    (#(#column_names,)*): diesel::expression::ValidGrouping<__GB>,
473                {
474                    type IsAggregate = <(#(#column_names,)*) as diesel::expression::ValidGrouping<__GB>>::IsAggregate;
475                }
476
477                impl diesel::Expression for star {
478                    type SqlType = diesel::expression::expression_types::NotSelectable;
479                }
480
481                impl<DB: diesel::backend::Backend> diesel::query_builder::QueryFragment<DB> for star where
482                    <table as diesel::QuerySource>::FromClause: diesel::query_builder::QueryFragment<DB>,
483                {
484                    #[allow(non_snake_case)]
485                    fn walk_ast<'b>(&'b self, mut __diesel_internal_out: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()>
486                    {
487                        use diesel::QuerySource;
488
489                        if !__diesel_internal_out.should_skip_from() {
490                            const FROM_CLAUSE: diesel::internal::table_macro::StaticQueryFragmentInstance<table> = diesel::internal::table_macro::StaticQueryFragmentInstance::new();
491
492                            FROM_CLAUSE.walk_ast(__diesel_internal_out.reborrow())?;
493                            __diesel_internal_out.push_sql(".");
494                        }
495                        __diesel_internal_out.push_sql("*");
496                        Ok(())
497                    }
498                }
499
500                impl diesel::SelectableExpression<table> for star {
501                }
502
503                impl diesel::AppearsOnTable<table> for star {
504                }
505
506                #(#column_defs)*
507
508                #(#valid_grouping_for_table_columns)*
509            }
510        }
511    }
512}
513
514fn generate_valid_grouping_for_table_columns(table: &TableDecl) -> Vec<TokenStream> {
515    let mut ret = Vec::with_capacity(table.column_defs.len() * table.column_defs.len());
516
517    let primary_key = if let Some(ref pk) = table.primary_keys {
518        if pk.keys.len() == 1 {
519            pk.keys.first().map(ToString::to_string)
520        } else {
521            None
522        }
523    } else {
524        Some(DEFAULT_PRIMARY_KEY_NAME.into())
525    };
526
527    for (id, right_col) in table.column_defs.iter().enumerate() {
528        for left_col in table.column_defs.iter().skip(id) {
529            let right_to_left = if Some(left_col.column_name.to_string()) == primary_key {
530                Ident::new("Yes", proc_macro2::Span::call_site())
531            } else {
532                Ident::new("No", proc_macro2::Span::call_site())
533            };
534
535            let left_to_right = if Some(right_col.column_name.to_string()) == primary_key {
536                Ident::new("Yes", proc_macro2::Span::call_site())
537            } else {
538                Ident::new("No", proc_macro2::Span::call_site())
539            };
540
541            let left_col = &left_col.column_name;
542            let right_col = &right_col.column_name;
543
544            if left_col != right_col {
545                ret.push(quote::quote! {
546                    impl diesel::expression::IsContainedInGroupBy<#right_col> for #left_col {
547                        type Output = diesel::expression::is_contained_in_group_by::#right_to_left;
548                    }
549
550                    impl diesel::expression::IsContainedInGroupBy<#left_col> for #right_col {
551                        type Output = diesel::expression::is_contained_in_group_by::#left_to_right;
552                    }
553                });
554            }
555        }
556    }
557    ret
558}
559
560fn fix_import_for_submodule(import: &syn::ItemUse) -> syn::ItemUse {
561    let mut ret = import.clone();
562
563    if let syn::UseTree::Path(ref mut path) = ret.tree {
564        // prepend another `super` to the any import
565        // that starts with `super` so that it now refers to the correct
566        // module
567        if path.ident == "super" {
568            let inner = path.clone();
569            path.tree = Box::new(syn::UseTree::Path(inner));
570        }
571    }
572
573    ret
574}
575
576fn is_numeric(ty: &syn::TypePath) -> bool {
577    const NUMERIC_TYPES: &[&str] = &[
578        "SmallInt",
579        "Int2",
580        "Smallint",
581        "SmallSerial",
582        "Integer",
583        "Int4",
584        "Serial",
585        "BigInt",
586        "Int8",
587        "Bigint",
588        "BigSerial",
589        "Decimal",
590        "Float",
591        "Float4",
592        "Float8",
593        "Double",
594        "Numeric",
595    ];
596
597    if let Some(last) = ty.path.segments.last() {
598        match &last.arguments {
599            syn::PathArguments::AngleBracketed(t)
600                if (last.ident == "Nullable" || last.ident == "Unsigned") && t.args.len() == 1 =>
601            {
602                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
603                    NUMERIC_TYPES.iter().any(|i| {
604                        t.path.segments.last().map(|s| s.ident.to_string())
605                            == Some(String::from(*i))
606                    })
607                } else {
608                    false
609                }
610            }
611            _ => NUMERIC_TYPES.iter().any(|i| last.ident == *i),
612        }
613    } else {
614        false
615    }
616}
617
618fn is_date_time(ty: &syn::TypePath) -> bool {
619    const DATE_TYPES: &[&str] = &["Time", "Date", "Timestamp", "Timestamptz"];
620    if let Some(last) = ty.path.segments.last() {
621        match &last.arguments {
622            syn::PathArguments::AngleBracketed(t)
623                if last.ident == "Nullable" && t.args.len() == 1 =>
624            {
625                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
626                    DATE_TYPES.iter().any(|i| {
627                        t.path.segments.last().map(|s| s.ident.to_string())
628                            == Some(String::from(*i))
629                    })
630                } else {
631                    false
632                }
633            }
634            _ => DATE_TYPES.iter().any(|i| last.ident == *i),
635        }
636    } else {
637        false
638    }
639}
640
641fn is_network(ty: &syn::TypePath) -> bool {
642    const NETWORK_TYPES: &[&str] = &["Cidr", "Inet"];
643
644    if let Some(last) = ty.path.segments.last() {
645        match &last.arguments {
646            syn::PathArguments::AngleBracketed(t)
647                if last.ident == "Nullable" && t.args.len() == 1 =>
648            {
649                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
650                    NETWORK_TYPES.iter().any(|i| {
651                        t.path.segments.last().map(|s| s.ident.to_string())
652                            == Some(String::from(*i))
653                    })
654                } else {
655                    false
656                }
657            }
658            _ => NETWORK_TYPES.iter().any(|i| last.ident == *i),
659        }
660    } else {
661        false
662    }
663}
664
665fn generate_op_impl(op: &str, tpe: &syn::Ident) -> TokenStream {
666    let fn_name = syn::Ident::new(&op.to_lowercase(), tpe.span());
667    let op = syn::Ident::new(op, tpe.span());
668    quote::quote! {
669        impl<Rhs> ::std::ops::#op<Rhs> for #tpe
670        where
671            Rhs: diesel::expression::AsExpression<
672                <<#tpe as diesel::Expression>::SqlType as diesel::sql_types::ops::#op>::Rhs,
673            >,
674        {
675            type Output = diesel::internal::table_macro::ops::#op<Self, Rhs::Expression>;
676
677            fn #fn_name(self, __diesel_internal_rhs: Rhs) -> Self::Output {
678                diesel::internal::table_macro::ops::#op::new(self, __diesel_internal_rhs.as_expression())
679            }
680        }
681    }
682}
683
684fn expand_column_def(column_def: &ColumnDef) -> TokenStream {
685    // TODO get a better span here as soon as that's
686    // possible using stable rust
687    let span = column_def.column_name.span();
688    let meta = &column_def.meta;
689    let column_name = &column_def.column_name;
690    let sql_name = &column_def.sql_name;
691    let sql_type = &column_def.tpe;
692
693    let backend_specific_column_impl = if cfg!(feature = "postgres") {
694        Some(quote::quote! {
695            impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Only<super::table>>
696                for #column_name
697            {
698                type Count = diesel::query_source::Once;
699            }
700            impl diesel::SelectableExpression<diesel::query_builder::Only<super::table>> for #column_name {}
701
702            impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<super::table, TSM>>
703                for #column_name
704                    where
705                TSM: diesel::internal::table_macro::TablesampleMethod
706            {
707                type Count = diesel::query_source::Once;
708            }
709            impl<TSM> diesel::SelectableExpression<diesel::query_builder::Tablesample<super::table, TSM>>
710                for #column_name where TSM: diesel::internal::table_macro::TablesampleMethod {}
711        })
712    } else {
713        None
714    };
715
716    let ops_impls = if is_numeric(&column_def.tpe) {
717        let add = generate_op_impl("Add", column_name);
718        let sub = generate_op_impl("Sub", column_name);
719        let div = generate_op_impl("Div", column_name);
720        let mul = generate_op_impl("Mul", column_name);
721        Some(quote::quote! {
722            #add
723            #sub
724            #div
725            #mul
726        })
727    } else if is_date_time(&column_def.tpe) || is_network(&column_def.tpe) {
728        let add = generate_op_impl("Add", column_name);
729        let sub = generate_op_impl("Sub", column_name);
730        Some(quote::quote! {
731            #add
732            #sub
733        })
734    } else {
735        None
736    };
737
738    let max_length = column_def.max_length.as_ref().map(|column_max_length| {
739        quote::quote! {
740            impl self::diesel::query_source::SizeRestrictedColumn for #column_name {
741                const MAX_LENGTH: usize = #column_max_length;
742            }
743        }
744    });
745
746    quote::quote_spanned! {span=>
747        #(#meta)*
748        #[allow(non_camel_case_types, dead_code)]
749        #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, Default)]
750        pub struct #column_name;
751
752        impl diesel::expression::Expression for #column_name {
753            type SqlType = #sql_type;
754        }
755
756        impl<DB> diesel::query_builder::QueryFragment<DB> for #column_name where
757            DB: diesel::backend::Backend,
758            diesel::internal::table_macro::StaticQueryFragmentInstance<table>: diesel::query_builder::QueryFragment<DB>,
759        {
760            #[allow(non_snake_case)]
761            fn walk_ast<'b>(&'b self, mut __diesel_internal_out: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()>
762            {
763                if !__diesel_internal_out.should_skip_from() {
764                    const FROM_CLAUSE: diesel::internal::table_macro::StaticQueryFragmentInstance<table> = diesel::internal::table_macro::StaticQueryFragmentInstance::new();
765
766                    FROM_CLAUSE.walk_ast(__diesel_internal_out.reborrow())?;
767                    __diesel_internal_out.push_sql(".");
768                }
769                __diesel_internal_out.push_identifier(#sql_name)
770            }
771        }
772
773        impl diesel::SelectableExpression<super::table> for #column_name {
774        }
775
776        impl<QS> diesel::AppearsOnTable<QS> for #column_name where
777            QS: diesel::query_source::AppearsInFromClause<super::table, Count=diesel::query_source::Once>,
778        {
779        }
780
781        impl<Left, Right> diesel::SelectableExpression<
782                diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::LeftOuter>,
783            > for #column_name where
784            #column_name: diesel::AppearsOnTable<diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::LeftOuter>>,
785            Self: diesel::SelectableExpression<Left>,
786            // If our table is on the right side of this join, only
787            // `Nullable<Self>` can be selected
788            Right: diesel::query_source::AppearsInFromClause<super::table, Count=diesel::query_source::Never> + diesel::query_source::QuerySource,
789            Left: diesel::query_source::QuerySource
790        {
791        }
792
793        impl<Left, Right> diesel::SelectableExpression<
794                diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::Inner>,
795            > for #column_name where
796            #column_name: diesel::AppearsOnTable<diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::Inner>>,
797            Left: diesel::query_source::AppearsInFromClause<super::table> + diesel::query_source::QuerySource,
798            Right: diesel::query_source::AppearsInFromClause<super::table> + diesel::query_source::QuerySource,
799        (Left::Count, Right::Count): diesel::internal::table_macro::Pick<Left, Right>,
800            Self: diesel::SelectableExpression<
801                <(Left::Count, Right::Count) as diesel::internal::table_macro::Pick<Left, Right>>::Selection,
802            >,
803        {
804        }
805
806        // FIXME: Remove this when overlapping marker traits are stable
807        impl<Join, On> diesel::SelectableExpression<diesel::internal::table_macro::JoinOn<Join, On>> for #column_name where
808            #column_name: diesel::SelectableExpression<Join> + diesel::AppearsOnTable<diesel::internal::table_macro::JoinOn<Join, On>>,
809        {
810        }
811
812        // FIXME: Remove this when overlapping marker traits are stable
813        impl<From> diesel::SelectableExpression<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<From>>> for #column_name where
814            From: diesel::query_source::QuerySource,
815            #column_name: diesel::SelectableExpression<From> + diesel::AppearsOnTable<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<From>>>,
816        {
817        }
818
819        impl<__GB> diesel::expression::ValidGrouping<__GB> for #column_name
820        where __GB: diesel::expression::IsContainedInGroupBy<#column_name, Output = diesel::expression::is_contained_in_group_by::Yes>,
821        {
822            type IsAggregate = diesel::expression::is_aggregate::Yes;
823        }
824
825        impl diesel::expression::ValidGrouping<()> for #column_name {
826            type IsAggregate = diesel::expression::is_aggregate::No;
827        }
828
829        impl diesel::expression::IsContainedInGroupBy<#column_name> for #column_name {
830            type Output = diesel::expression::is_contained_in_group_by::Yes;
831        }
832
833        impl diesel::query_source::Column for #column_name {
834            type Table = super::table;
835
836            const NAME: &'static str = #sql_name;
837        }
838
839        impl<T> diesel::EqAll<T> for #column_name where
840            T: diesel::expression::AsExpression<#sql_type>,
841            diesel::dsl::Eq<#column_name, T::Expression>: diesel::Expression<SqlType=diesel::sql_types::Bool>,
842        {
843            type Output = diesel::dsl::Eq<Self, T::Expression>;
844
845            fn eq_all(self, __diesel_internal_rhs: T) -> Self::Output {
846                use diesel::expression_methods::ExpressionMethods;
847                self.eq(__diesel_internal_rhs)
848            }
849        }
850
851        #max_length
852
853        #ops_impls
854        #backend_specific_column_impl
855    }
856}