diesel_derives/
table.rs

1use diesel_table_macro_syntax::{ColumnDef, TableDecl};
2use proc_macro2::{Span, TokenStream};
3use syn::parse_quote;
4use syn::Ident;
5
6const DEFAULT_PRIMARY_KEY_NAME: &str = "id";
7
8pub(crate) fn expand(input: TableDecl) -> TokenStream {
9    if input.column_defs.len() > super::diesel_for_each_tuple::MAX_TUPLE_SIZE as usize {
10        let txt = if input.column_defs.len() > 128 {
11            "you reached the end. \nhelp: diesel does not support tables with \
12             more than 128 columns.\nhelp: consider using less columns."
13        } else if input.column_defs.len() > 64 {
14            "table contains more than 64 columns. \n consider enabling the \
15             `128-column-tables` feature to enable diesels support for \
16             tables with more than 64 columns."
17        } else if input.column_defs.len() > 32 {
18            "table contains more than 32 columns. \nhelp: consider enabling the \
19             `64-column-tables` feature to enable diesels support for \
20             tables with more than 32 columns."
21        } else {
22            "table contains more than 16 columns. \nhelp: consider enabling the \
23             `32-column-tables` feature to enable diesels support for \
24             tables with more than 16 columns."
25        };
26        return quote::quote! {
27            compile_error!(#txt);
28        };
29    }
30
31    let meta = &input.meta;
32    let table_name = &input.table_name;
33    let imports = if input.use_statements.is_empty() {
34        vec![parse_quote!(
35            use diesel::sql_types::*;
36        )]
37    } else {
38        input.use_statements.clone()
39    };
40    let column_names = input
41        .column_defs
42        .iter()
43        .map(|c| &c.column_name)
44        .collect::<Vec<_>>();
45    let column_names = &column_names;
46    let primary_key: TokenStream = match input.primary_keys.as_ref() {
47        None if column_names.contains(&&syn::Ident::new(
48            DEFAULT_PRIMARY_KEY_NAME,
49            proc_macro2::Span::mixed_site(),
50        )) =>
51        {
52            let id = syn::Ident::new(DEFAULT_PRIMARY_KEY_NAME, proc_macro2::Span::mixed_site());
53            parse_quote! {
54                #id
55            }
56        }
57        None => {
58            let mut message = format!(
59                "neither an explicit primary key found nor does an `id` column exist.\n\
60                 consider explicitly defining a primary key. \n\
61                 for example for specifying `{key}` as primary key:\n\n\
62                 table! {{\n
63                     {table}({key}){{\n",
64                key = column_names[0],
65                table = input.table_name,
66            );
67            message += &format!("\t{table_name} ({}) {{\n", &column_names[0]);
68            for c in &input.column_defs {
69                let tpe = c
70                    .tpe
71                    .path
72                    .segments
73                    .iter()
74                    .map(|p| p.ident.to_string())
75                    .collect::<Vec<_>>()
76                    .join("::");
77                message += &format!("\t\t{} -> {tpe},\n", c.column_name);
78            }
79            message += "\t}\n}";
80
81            let span = Span::mixed_site().located_at(input.table_name.span());
82            return quote::quote_spanned! {span=>
83                compile_error!(#message);
84            };
85        }
86        Some(a) if a.keys.len() == 1 => {
87            let k = a.keys.first().unwrap();
88            parse_quote! {
89                #k
90            }
91        }
92        Some(a) => {
93            let keys = a.keys.iter();
94
95            parse_quote! {
96                (#(#keys,)*)
97            }
98        }
99    };
100
101    let column_defs = input.column_defs.iter().map(expand_column_def);
102    let column_ty = input.column_defs.iter().map(|c| &c.tpe);
103    let valid_grouping_for_table_columns = generate_valid_grouping_for_table_columns(&input);
104
105    let sql_name = &input.sql_name;
106    let static_query_fragment_impl_for_table = if let Some(schema) = input.schema {
107        let schema_name = schema.to_string();
108        quote::quote! {
109            impl diesel::internal::table_macro::StaticQueryFragment for table {
110                type Component = diesel::internal::table_macro::InfixNode<
111                        diesel::internal::table_macro::Identifier<'static>,
112                    diesel::internal::table_macro::Identifier<'static>,
113                    &'static str
114                        >;
115                const STATIC_COMPONENT: &'static Self::Component = &diesel::internal::table_macro::InfixNode::new(
116                    diesel::internal::table_macro::Identifier(#schema_name),
117                    diesel::internal::table_macro::Identifier(#sql_name),
118                    "."
119                );
120            }
121        }
122    } else {
123        quote::quote! {
124            impl diesel::internal::table_macro::StaticQueryFragment for table {
125                type Component = diesel::internal::table_macro::Identifier<'static>;
126                const STATIC_COMPONENT: &'static Self::Component = &diesel::internal::table_macro::Identifier(#sql_name);
127            }
128        }
129    };
130
131    let reexport_column_from_dsl = input.column_defs.iter().map(|c| {
132        let column_name = &c.column_name;
133        if c.column_name == *table_name {
134            let span = Span::mixed_site().located_at(c.column_name.span());
135            let message = format!(
136                "column `{column_name}` cannot be named the same as it's table.\n\
137                 you may use `#[sql_name = \"{column_name}\"]` to reference the table's \
138                 `{column_name}` column \n\
139                 docs available at: `https://docs.diesel.rs/master/diesel/macro.table.html`\n"
140            );
141            quote::quote_spanned! { span =>
142                compile_error!(#message);
143            }
144        } else {
145            quote::quote! {
146                pub use super::columns::#column_name;
147            }
148        }
149    });
150
151    let backend_specific_table_impls = if cfg!(feature = "postgres") {
152        Some(quote::quote! {
153            impl<S> diesel::JoinTo<diesel::query_builder::Only<S>> for table
154            where
155                diesel::query_builder::Only<S>: diesel::JoinTo<table>,
156            {
157                type FromClause = diesel::query_builder::Only<S>;
158                type OnClause = <diesel::query_builder::Only<S> as diesel::JoinTo<table>>::OnClause;
159
160                fn join_target(__diesel_internal_rhs: diesel::query_builder::Only<S>) -> (Self::FromClause, Self::OnClause) {
161                    let (_, __diesel_internal_on_clause) = diesel::query_builder::Only::<S>::join_target(table);
162                    (__diesel_internal_rhs, __diesel_internal_on_clause)
163                }
164            }
165
166            impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Only<table>>
167                for table
168            {
169                type Count = diesel::query_source::Once;
170            }
171
172            impl diesel::query_source::AppearsInFromClause<table>
173                for diesel::query_builder::Only<table>
174            {
175                type Count = diesel::query_source::Once;
176            }
177
178            impl<S, TSM> diesel::JoinTo<diesel::query_builder::Tablesample<S, TSM>> for table
179            where
180                diesel::query_builder::Tablesample<S, TSM>: diesel::JoinTo<table>,
181                TSM: diesel::internal::table_macro::TablesampleMethod
182            {
183                type FromClause = diesel::query_builder::Tablesample<S, TSM>;
184                type OnClause = <diesel::query_builder::Tablesample<S, TSM> as diesel::JoinTo<table>>::OnClause;
185
186                fn join_target(__diesel_internal_rhs: diesel::query_builder::Tablesample<S, TSM>) -> (Self::FromClause, Self::OnClause) {
187                    let (_, __diesel_internal_on_clause) = diesel::query_builder::Tablesample::<S, TSM>::join_target(table);
188                    (__diesel_internal_rhs, __diesel_internal_on_clause)
189                }
190            }
191
192            impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<table, TSM>>
193                for table
194                    where
195                TSM: diesel::internal::table_macro::TablesampleMethod
196            {
197                type Count = diesel::query_source::Once;
198            }
199
200            impl<TSM> diesel::query_source::AppearsInFromClause<table>
201                for diesel::query_builder::Tablesample<table, TSM>
202                    where
203                TSM: diesel::internal::table_macro::TablesampleMethod
204            {
205                type Count = diesel::query_source::Once;
206            }
207        })
208    } else {
209        None
210    };
211
212    let imports_for_column_module = imports.iter().map(fix_import_for_submodule);
213
214    quote::quote! {
215        #(#meta)*
216        #[allow(unused_imports, dead_code, unreachable_pub, unused_qualifications)]
217        pub mod #table_name {
218            use ::diesel;
219            pub use self::columns::*;
220            #(#imports)*
221
222            /// Re-exports all of the columns of this table, as well as the
223            /// table struct renamed to the module name. This is meant to be
224            /// glob imported for functions which only deal with one table.
225            pub mod dsl {
226                #(#reexport_column_from_dsl)*
227                pub use super::table as #table_name;
228            }
229
230            #[allow(non_upper_case_globals, dead_code)]
231            /// A tuple of all of the columns on this table
232            pub const all_columns: (#(#column_names,)*) = (#(#column_names,)*);
233
234            #[allow(non_camel_case_types)]
235            #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, Default)]
236            /// The actual table struct
237            ///
238            /// This is the type which provides the base methods of the query
239            /// builder, such as `.select` and `.filter`.
240            pub struct table;
241
242            impl table {
243                #[allow(dead_code)]
244                /// Represents `table_name.*`, which is sometimes necessary
245                /// for efficient count queries. It cannot be used in place of
246                /// `all_columns`
247                pub fn star(&self) -> star {
248                    star
249                }
250            }
251
252            /// The SQL type of all of the columns on this table
253            pub type SqlType = (#(#column_ty,)*);
254
255            /// Helper type for representing a boxed query from this table
256            pub type BoxedQuery<'a, DB, ST = SqlType> = diesel::internal::table_macro::BoxedSelectStatement<'a, ST, diesel::internal::table_macro::FromClause<table>, DB>;
257
258            impl diesel::QuerySource for table {
259                type FromClause = diesel::internal::table_macro::StaticQueryFragmentInstance<table>;
260                type DefaultSelection = <Self as diesel::Table>::AllColumns;
261
262                fn from_clause(&self) -> Self::FromClause {
263                    diesel::internal::table_macro::StaticQueryFragmentInstance::new()
264                }
265
266                fn default_selection(&self) -> Self::DefaultSelection {
267                    use diesel::Table;
268                    Self::all_columns()
269                }
270            }
271
272            impl<DB> diesel::query_builder::QueryFragment<DB> for table where
273                DB: diesel::backend::Backend,
274                <table as diesel::internal::table_macro::StaticQueryFragment>::Component: diesel::query_builder::QueryFragment<DB>
275            {
276                fn walk_ast<'b>(&'b self, __diesel_internal_pass: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()> {
277                    <table as diesel::internal::table_macro::StaticQueryFragment>::STATIC_COMPONENT.walk_ast(__diesel_internal_pass)
278                }
279            }
280
281            #static_query_fragment_impl_for_table
282
283            impl diesel::query_builder::AsQuery for table {
284                type SqlType = SqlType;
285                type Query = diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<Self>>;
286
287                fn as_query(self) -> Self::Query {
288                    diesel::internal::table_macro::SelectStatement::simple(self)
289                }
290            }
291
292            impl diesel::Table for table {
293                type PrimaryKey = #primary_key;
294                type AllColumns = (#(#column_names,)*);
295
296                fn primary_key(&self) -> Self::PrimaryKey {
297                    #primary_key
298                }
299
300                fn all_columns() -> Self::AllColumns {
301                    (#(#column_names,)*)
302                }
303            }
304
305            impl diesel::associations::HasTable for table {
306                type Table = Self;
307
308                fn table() -> Self::Table {
309                    table
310                }
311            }
312
313            impl diesel::query_builder::IntoUpdateTarget for table {
314                type WhereClause = <<Self as diesel::query_builder::AsQuery>::Query as diesel::query_builder::IntoUpdateTarget>::WhereClause;
315
316                fn into_update_target(self) -> diesel::query_builder::UpdateTarget<Self::Table, Self::WhereClause> {
317                    use diesel::query_builder::AsQuery;
318                    let q: diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<table>> = self.as_query();
319                    q.into_update_target()
320                }
321            }
322
323            impl diesel::query_source::AppearsInFromClause<table> for table {
324                type Count = diesel::query_source::Once;
325            }
326
327            // impl<S: AliasSource<Table=table>> AppearsInFromClause<table> for Alias<S>
328            impl<S> diesel::internal::table_macro::AliasAppearsInFromClause<S, table> for table
329            where S: diesel::query_source::AliasSource<Target=table>,
330            {
331                type Count = diesel::query_source::Never;
332            }
333
334            // impl<S1: AliasSource<Table=table>, S2: AliasSource<Table=table>> AppearsInFromClause<Alias<S1>> for Alias<S2>
335            // Those are specified by the `alias!` macro, but this impl will allow it to implement this trait even in downstream
336            // crates from the schema
337            impl<S1, S2> diesel::internal::table_macro::AliasAliasAppearsInFromClause<table, S2, S1> for table
338            where S1: diesel::query_source::AliasSource<Target=table>,
339                  S2: diesel::query_source::AliasSource<Target=table>,
340                  S1: diesel::internal::table_macro::AliasAliasAppearsInFromClauseSameTable<S2, table>,
341            {
342                type Count = <S1 as diesel::internal::table_macro::AliasAliasAppearsInFromClauseSameTable<S2, table>>::Count;
343            }
344
345            impl<S> diesel::query_source::AppearsInFromClause<diesel::query_source::Alias<S>> for table
346            where S: diesel::query_source::AliasSource,
347            {
348                type Count = diesel::query_source::Never;
349            }
350
351            impl<S, C> diesel::internal::table_macro::FieldAliasMapperAssociatedTypesDisjointnessTrick<table, S, C> for table
352            where
353                S: diesel::query_source::AliasSource<Target = table> + ::std::clone::Clone,
354                C: diesel::query_source::Column<Table = table>,
355            {
356                type Out = diesel::query_source::AliasedField<S, C>;
357
358                fn map(__diesel_internal_column: C, __diesel_internal_alias: &diesel::query_source::Alias<S>) -> Self::Out {
359                    __diesel_internal_alias.field(__diesel_internal_column)
360                }
361            }
362
363            impl diesel::query_source::AppearsInFromClause<table> for diesel::internal::table_macro::NoFromClause {
364                type Count = diesel::query_source::Never;
365            }
366
367            impl<Left, Right, Kind> diesel::JoinTo<diesel::internal::table_macro::Join<Left, Right, Kind>> for table where
368                diesel::internal::table_macro::Join<Left, Right, Kind>: diesel::JoinTo<table>,
369                Left: diesel::query_source::QuerySource,
370                Right: diesel::query_source::QuerySource,
371            {
372                type FromClause = diesel::internal::table_macro::Join<Left, Right, Kind>;
373                type OnClause = <diesel::internal::table_macro::Join<Left, Right, Kind> as diesel::JoinTo<table>>::OnClause;
374
375                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::Join<Left, Right, Kind>) -> (Self::FromClause, Self::OnClause) {
376                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::Join::join_target(table);
377                    (__diesel_internal_rhs, __diesel_internal_on_clause)
378                }
379            }
380
381            impl<Join, On> diesel::JoinTo<diesel::internal::table_macro::JoinOn<Join, On>> for table where
382                diesel::internal::table_macro::JoinOn<Join, On>: diesel::JoinTo<table>,
383            {
384                type FromClause = diesel::internal::table_macro::JoinOn<Join, On>;
385                type OnClause = <diesel::internal::table_macro::JoinOn<Join, On> as diesel::JoinTo<table>>::OnClause;
386
387                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::JoinOn<Join, On>) -> (Self::FromClause, Self::OnClause) {
388                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::JoinOn::join_target(table);
389                    (__diesel_internal_rhs, __diesel_internal_on_clause)
390                }
391            }
392
393            impl<F, S, D, W, O, L, Of, G> diesel::JoinTo<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>> for table where
394                diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>: diesel::JoinTo<table>,
395                F: diesel::query_source::QuerySource
396            {
397                type FromClause = diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>;
398                type OnClause = <diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G> as diesel::JoinTo<table>>::OnClause;
399
400                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<F>, S, D, W, O, L, Of, G>) -> (Self::FromClause, Self::OnClause) {
401                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::SelectStatement::join_target(table);
402                    (__diesel_internal_rhs, __diesel_internal_on_clause)
403                }
404            }
405
406            impl<'a, QS, ST, DB> diesel::JoinTo<diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>> for table where
407                diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>: diesel::JoinTo<table>,
408                QS: diesel::query_source::QuerySource,
409            {
410                type FromClause = diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>;
411                type OnClause = <diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB> as diesel::JoinTo<table>>::OnClause;
412                fn join_target(__diesel_internal_rhs: diesel::internal::table_macro::BoxedSelectStatement<'a, diesel::internal::table_macro::FromClause<QS>, ST, DB>) -> (Self::FromClause, Self::OnClause) {
413                    let (_, __diesel_internal_on_clause) = diesel::internal::table_macro::BoxedSelectStatement::join_target(table);
414                    (__diesel_internal_rhs, __diesel_internal_on_clause)
415                }
416            }
417
418            impl<S> diesel::JoinTo<diesel::query_source::Alias<S>> for table
419            where
420                diesel::query_source::Alias<S>: diesel::JoinTo<table>,
421            {
422                type FromClause = diesel::query_source::Alias<S>;
423                type OnClause = <diesel::query_source::Alias<S> as diesel::JoinTo<table>>::OnClause;
424
425                fn join_target(__diesel_internal_rhs: diesel::query_source::Alias<S>) -> (Self::FromClause, Self::OnClause) {
426                    let (_, __diesel_internal_on_clause) = diesel::query_source::Alias::<S>::join_target(table);
427                    (__diesel_internal_rhs, __diesel_internal_on_clause)
428                }
429            }
430
431            // This impl should be able to live in Diesel,
432            // but Rust tries to recurse for no reason
433            impl<T> diesel::insertable::Insertable<T> for table
434            where
435                <table as diesel::query_builder::AsQuery>::Query: diesel::insertable::Insertable<T>,
436            {
437                type Values = <<table as diesel::query_builder::AsQuery>::Query as diesel::insertable::Insertable<T>>::Values;
438
439                fn values(self) -> Self::Values {
440                    use diesel::query_builder::AsQuery;
441                    self.as_query().values()
442                }
443            }
444
445            impl<'a, T> diesel::insertable::Insertable<T> for &'a table
446            where
447                table: diesel::insertable::Insertable<T>,
448            {
449                type Values = <table as diesel::insertable::Insertable<T>>::Values;
450
451                fn values(self) -> Self::Values {
452                    (*self).values()
453                }
454            }
455
456            #backend_specific_table_impls
457
458            /// Contains all of the columns of this table
459            pub mod columns {
460                use ::diesel;
461                use super::table;
462                #(#imports_for_column_module)*
463
464                #[allow(non_camel_case_types, dead_code)]
465                #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
466                /// Represents `table_name.*`, which is sometimes needed for
467                /// efficient count queries. It cannot be used in place of
468                /// `all_columns`, and has a `SqlType` of `()` to prevent it
469                /// being used that way
470                pub struct star;
471
472                impl<__GB> diesel::expression::ValidGrouping<__GB> for star
473                where
474                    (#(#column_names,)*): diesel::expression::ValidGrouping<__GB>,
475                {
476                    type IsAggregate = <(#(#column_names,)*) as diesel::expression::ValidGrouping<__GB>>::IsAggregate;
477                }
478
479                impl diesel::Expression for star {
480                    type SqlType = diesel::expression::expression_types::NotSelectable;
481                }
482
483                impl<DB: diesel::backend::Backend> diesel::query_builder::QueryFragment<DB> for star where
484                    <table as diesel::QuerySource>::FromClause: diesel::query_builder::QueryFragment<DB>,
485                {
486                    #[allow(non_snake_case)]
487                    fn walk_ast<'b>(&'b self, mut __diesel_internal_out: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()>
488                    {
489                        use diesel::QuerySource;
490
491                        if !__diesel_internal_out.should_skip_from() {
492                            const FROM_CLAUSE: diesel::internal::table_macro::StaticQueryFragmentInstance<table> = diesel::internal::table_macro::StaticQueryFragmentInstance::new();
493
494                            FROM_CLAUSE.walk_ast(__diesel_internal_out.reborrow())?;
495                            __diesel_internal_out.push_sql(".");
496                        }
497                        __diesel_internal_out.push_sql("*");
498                        Ok(())
499                    }
500                }
501
502                impl diesel::SelectableExpression<table> for star {
503                }
504
505                impl diesel::AppearsOnTable<table> for star {
506                }
507
508                #(#column_defs)*
509
510                #(#valid_grouping_for_table_columns)*
511            }
512        }
513    }
514}
515
516fn generate_valid_grouping_for_table_columns(table: &TableDecl) -> Vec<TokenStream> {
517    let mut ret = Vec::with_capacity(table.column_defs.len() * table.column_defs.len());
518
519    let primary_key = if let Some(ref pk) = table.primary_keys {
520        if pk.keys.len() == 1 {
521            pk.keys.first().map(ToString::to_string)
522        } else {
523            None
524        }
525    } else {
526        Some(DEFAULT_PRIMARY_KEY_NAME.into())
527    };
528
529    for (id, right_col) in table.column_defs.iter().enumerate() {
530        for left_col in table.column_defs.iter().skip(id) {
531            let right_to_left = if Some(left_col.column_name.to_string()) == primary_key {
532                Ident::new("Yes", proc_macro2::Span::mixed_site())
533            } else {
534                Ident::new("No", proc_macro2::Span::mixed_site())
535            };
536
537            let left_to_right = if Some(right_col.column_name.to_string()) == primary_key {
538                Ident::new("Yes", proc_macro2::Span::mixed_site())
539            } else {
540                Ident::new("No", proc_macro2::Span::mixed_site())
541            };
542
543            let left_col = &left_col.column_name;
544            let right_col = &right_col.column_name;
545
546            if left_col != right_col {
547                ret.push(quote::quote! {
548                    impl diesel::expression::IsContainedInGroupBy<#right_col> for #left_col {
549                        type Output = diesel::expression::is_contained_in_group_by::#right_to_left;
550                    }
551
552                    impl diesel::expression::IsContainedInGroupBy<#left_col> for #right_col {
553                        type Output = diesel::expression::is_contained_in_group_by::#left_to_right;
554                    }
555                });
556            }
557        }
558    }
559    ret
560}
561
562fn fix_import_for_submodule(import: &syn::ItemUse) -> syn::ItemUse {
563    let mut ret = import.clone();
564
565    if let syn::UseTree::Path(ref mut path) = ret.tree {
566        // prepend another `super` to the any import
567        // that starts with `super` so that it now refers to the correct
568        // module
569        if path.ident == "super" {
570            let inner = path.clone();
571            path.tree = Box::new(syn::UseTree::Path(inner));
572        }
573    }
574
575    ret
576}
577
578fn is_numeric(ty: &syn::TypePath) -> bool {
579    const NUMERIC_TYPES: &[&str] = &[
580        "SmallInt",
581        "Int2",
582        "Smallint",
583        "SmallSerial",
584        "Integer",
585        "Int4",
586        "Serial",
587        "BigInt",
588        "Int8",
589        "Bigint",
590        "BigSerial",
591        "Decimal",
592        "Float",
593        "Float4",
594        "Float8",
595        "Double",
596        "Numeric",
597    ];
598
599    if let Some(last) = ty.path.segments.last() {
600        match &last.arguments {
601            syn::PathArguments::AngleBracketed(t)
602                if (last.ident == "Nullable" || last.ident == "Unsigned") && t.args.len() == 1 =>
603            {
604                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
605                    NUMERIC_TYPES.iter().any(|i| {
606                        t.path.segments.last().map(|s| s.ident.to_string())
607                            == Some(String::from(*i))
608                    })
609                } else {
610                    false
611                }
612            }
613            _ => NUMERIC_TYPES.iter().any(|i| last.ident == *i),
614        }
615    } else {
616        false
617    }
618}
619
620fn is_date_time(ty: &syn::TypePath) -> bool {
621    const DATE_TYPES: &[&str] = &["Time", "Date", "Timestamp", "Timestamptz"];
622    if let Some(last) = ty.path.segments.last() {
623        match &last.arguments {
624            syn::PathArguments::AngleBracketed(t)
625                if last.ident == "Nullable" && t.args.len() == 1 =>
626            {
627                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
628                    DATE_TYPES.iter().any(|i| {
629                        t.path.segments.last().map(|s| s.ident.to_string())
630                            == Some(String::from(*i))
631                    })
632                } else {
633                    false
634                }
635            }
636            _ => DATE_TYPES.iter().any(|i| last.ident == *i),
637        }
638    } else {
639        false
640    }
641}
642
643fn is_network(ty: &syn::TypePath) -> bool {
644    const NETWORK_TYPES: &[&str] = &["Cidr", "Inet"];
645
646    if let Some(last) = ty.path.segments.last() {
647        match &last.arguments {
648            syn::PathArguments::AngleBracketed(t)
649                if last.ident == "Nullable" && t.args.len() == 1 =>
650            {
651                if let Some(syn::GenericArgument::Type(syn::Type::Path(t))) = t.args.first() {
652                    NETWORK_TYPES.iter().any(|i| {
653                        t.path.segments.last().map(|s| s.ident.to_string())
654                            == Some(String::from(*i))
655                    })
656                } else {
657                    false
658                }
659            }
660            _ => NETWORK_TYPES.iter().any(|i| last.ident == *i),
661        }
662    } else {
663        false
664    }
665}
666
667fn generate_op_impl(op: &str, tpe: &syn::Ident) -> TokenStream {
668    let fn_name = syn::Ident::new(&op.to_lowercase(), tpe.span());
669    let op = syn::Ident::new(op, tpe.span());
670    quote::quote! {
671        impl<Rhs> ::std::ops::#op<Rhs> for #tpe
672        where
673            Rhs: diesel::expression::AsExpression<
674                <<#tpe as diesel::Expression>::SqlType as diesel::sql_types::ops::#op>::Rhs,
675            >,
676        {
677            type Output = diesel::internal::table_macro::ops::#op<Self, Rhs::Expression>;
678
679            fn #fn_name(self, __diesel_internal_rhs: Rhs) -> Self::Output {
680                diesel::internal::table_macro::ops::#op::new(self, __diesel_internal_rhs.as_expression())
681            }
682        }
683    }
684}
685
686fn expand_column_def(column_def: &ColumnDef) -> TokenStream {
687    // TODO get a better span here as soon as that's
688    // possible using stable rust
689    let span = Span::mixed_site().located_at(column_def.column_name.span());
690    let meta = &column_def.meta;
691    let column_name = &column_def.column_name;
692    let sql_name = &column_def.sql_name;
693    let sql_type = &column_def.tpe;
694
695    let backend_specific_column_impl = if cfg!(feature = "postgres") {
696        Some(quote::quote! {
697            impl diesel::query_source::AppearsInFromClause<diesel::query_builder::Only<super::table>>
698                for #column_name
699            {
700                type Count = diesel::query_source::Once;
701            }
702            impl diesel::SelectableExpression<diesel::query_builder::Only<super::table>> for #column_name {}
703
704            impl<TSM> diesel::query_source::AppearsInFromClause<diesel::query_builder::Tablesample<super::table, TSM>>
705                for #column_name
706                    where
707                TSM: diesel::internal::table_macro::TablesampleMethod
708            {
709                type Count = diesel::query_source::Once;
710            }
711            impl<TSM> diesel::SelectableExpression<diesel::query_builder::Tablesample<super::table, TSM>>
712                for #column_name where TSM: diesel::internal::table_macro::TablesampleMethod {}
713        })
714    } else {
715        None
716    };
717
718    let ops_impls = if is_numeric(&column_def.tpe) {
719        let add = generate_op_impl("Add", column_name);
720        let sub = generate_op_impl("Sub", column_name);
721        let div = generate_op_impl("Div", column_name);
722        let mul = generate_op_impl("Mul", column_name);
723        Some(quote::quote! {
724            #add
725            #sub
726            #div
727            #mul
728        })
729    } else if is_date_time(&column_def.tpe) || is_network(&column_def.tpe) {
730        let add = generate_op_impl("Add", column_name);
731        let sub = generate_op_impl("Sub", column_name);
732        Some(quote::quote! {
733            #add
734            #sub
735        })
736    } else {
737        None
738    };
739
740    let max_length = column_def.max_length.as_ref().map(|column_max_length| {
741        quote::quote! {
742            impl self::diesel::query_source::SizeRestrictedColumn for #column_name {
743                const MAX_LENGTH: usize = #column_max_length;
744            }
745        }
746    });
747
748    quote::quote_spanned! {span=>
749        #(#meta)*
750        #[allow(non_camel_case_types, dead_code)]
751        #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId, Default)]
752        pub struct #column_name;
753
754        impl diesel::expression::Expression for #column_name {
755            type SqlType = #sql_type;
756        }
757
758        impl<DB> diesel::query_builder::QueryFragment<DB> for #column_name where
759            DB: diesel::backend::Backend,
760            diesel::internal::table_macro::StaticQueryFragmentInstance<table>: diesel::query_builder::QueryFragment<DB>,
761        {
762            #[allow(non_snake_case)]
763            fn walk_ast<'b>(&'b self, mut __diesel_internal_out: diesel::query_builder::AstPass<'_, 'b, DB>) -> diesel::result::QueryResult<()>
764            {
765                if !__diesel_internal_out.should_skip_from() {
766                    const FROM_CLAUSE: diesel::internal::table_macro::StaticQueryFragmentInstance<table> = diesel::internal::table_macro::StaticQueryFragmentInstance::new();
767
768                    FROM_CLAUSE.walk_ast(__diesel_internal_out.reborrow())?;
769                    __diesel_internal_out.push_sql(".");
770                }
771                __diesel_internal_out.push_identifier(#sql_name)
772            }
773        }
774
775        impl diesel::SelectableExpression<super::table> for #column_name {
776        }
777
778        impl<QS> diesel::AppearsOnTable<QS> for #column_name where
779            QS: diesel::query_source::AppearsInFromClause<super::table, Count=diesel::query_source::Once>,
780        {
781        }
782
783        impl<Left, Right> diesel::SelectableExpression<
784                diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::LeftOuter>,
785            > for #column_name where
786            #column_name: diesel::AppearsOnTable<diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::LeftOuter>>,
787            Self: diesel::SelectableExpression<Left>,
788            // If our table is on the right side of this join, only
789            // `Nullable<Self>` can be selected
790            Right: diesel::query_source::AppearsInFromClause<super::table, Count=diesel::query_source::Never> + diesel::query_source::QuerySource,
791            Left: diesel::query_source::QuerySource
792        {
793        }
794
795        impl<Left, Right> diesel::SelectableExpression<
796                diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::Inner>,
797            > for #column_name where
798            #column_name: diesel::AppearsOnTable<diesel::internal::table_macro::Join<Left, Right, diesel::internal::table_macro::Inner>>,
799            Left: diesel::query_source::AppearsInFromClause<super::table> + diesel::query_source::QuerySource,
800            Right: diesel::query_source::AppearsInFromClause<super::table> + diesel::query_source::QuerySource,
801        (Left::Count, Right::Count): diesel::internal::table_macro::Pick<Left, Right>,
802            Self: diesel::SelectableExpression<
803                <(Left::Count, Right::Count) as diesel::internal::table_macro::Pick<Left, Right>>::Selection,
804            >,
805        {
806        }
807
808        // FIXME: Remove this when overlapping marker traits are stable
809        impl<Join, On> diesel::SelectableExpression<diesel::internal::table_macro::JoinOn<Join, On>> for #column_name where
810            #column_name: diesel::SelectableExpression<Join> + diesel::AppearsOnTable<diesel::internal::table_macro::JoinOn<Join, On>>,
811        {
812        }
813
814        // FIXME: Remove this when overlapping marker traits are stable
815        impl<From> diesel::SelectableExpression<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<From>>> for #column_name where
816            From: diesel::query_source::QuerySource,
817            #column_name: diesel::SelectableExpression<From> + diesel::AppearsOnTable<diesel::internal::table_macro::SelectStatement<diesel::internal::table_macro::FromClause<From>>>,
818        {
819        }
820
821        impl<__GB> diesel::expression::ValidGrouping<__GB> for #column_name
822        where __GB: diesel::expression::IsContainedInGroupBy<#column_name, Output = diesel::expression::is_contained_in_group_by::Yes>,
823        {
824            type IsAggregate = diesel::expression::is_aggregate::Yes;
825        }
826
827        impl diesel::expression::ValidGrouping<()> for #column_name {
828            type IsAggregate = diesel::expression::is_aggregate::No;
829        }
830
831        impl diesel::expression::IsContainedInGroupBy<#column_name> for #column_name {
832            type Output = diesel::expression::is_contained_in_group_by::Yes;
833        }
834
835        impl diesel::query_source::Column for #column_name {
836            type Table = super::table;
837
838            const NAME: &'static str = #sql_name;
839        }
840
841        impl<T> diesel::EqAll<T> for #column_name where
842            T: diesel::expression::AsExpression<#sql_type>,
843            diesel::dsl::Eq<#column_name, T::Expression>: diesel::Expression<SqlType=diesel::sql_types::Bool>,
844        {
845            type Output = diesel::dsl::Eq<Self, T::Expression>;
846
847            fn eq_all(self, __diesel_internal_rhs: T) -> Self::Output {
848                use diesel::expression_methods::ExpressionMethods;
849                self.eq(__diesel_internal_rhs)
850            }
851        }
852
853        #max_length
854
855        #ops_impls
856        #backend_specific_column_impl
857    }
858}