diesel_derives/
selectable.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use std::borrow::Cow;
4use syn::spanned::Spanned;
5use syn::{parse_quote, DeriveInput, Result};
6
7use crate::field::Field;
8use crate::model::Model;
9use crate::util::wrap_in_dummy_mod;
10
11pub fn derive(item: DeriveInput) -> Result<TokenStream> {
12    let model = Model::from_item(&item, false, false)?;
13
14    let (original_impl_generics, ty_generics, original_where_clause) =
15        item.generics.split_for_impl();
16
17    let mut generics = item.generics.clone();
18    generics
19        .params
20        .push(parse_quote!(__DB: diesel::backend::Backend));
21
22    for embed_field in model.fields().iter().filter(|f| f.embed()) {
23        let embed_ty = &embed_field.ty;
24        generics
25            .make_where_clause()
26            .predicates
27            .push(parse_quote!(#embed_ty: Selectable<__DB>));
28    }
29
30    let (impl_generics, _, where_clause) = generics.split_for_impl();
31
32    let struct_name = &item.ident;
33
34    let mut compile_errors: Vec<syn::Error> = Vec::new();
35    let field_select_expression_type_builders = model
36        .fields()
37        .iter()
38        .map(|f| field_select_expression_ty_builder(f, &model, &mut compile_errors))
39        .collect::<Result<Vec<_>>>()?;
40    let field_select_expression_types = field_select_expression_type_builders
41        .iter()
42        .map(|f| f.type_with_backend(&parse_quote!(__DB)))
43        .collect::<Vec<_>>();
44    let field_select_expressions = model
45        .fields()
46        .iter()
47        .map(|f| field_column_inst(f, &model))
48        .collect::<Result<Vec<_>>>()?;
49
50    let check_function = if let Some(ref backends) = model.check_for_backend {
51        let field_check_bound = model
52            .fields()
53            .iter()
54            .zip(&field_select_expression_type_builders)
55            .flat_map(|(f, ty_builder)| {
56                backends.iter().map(move |b| {
57                    let span = f.ty.span();
58                    let field_ty = to_field_ty_bound(f.ty_for_deserialize())?;
59                    let ty = ty_builder.type_with_backend(b);
60                    Ok(syn::parse_quote_spanned! {span =>
61                        #field_ty: diesel::deserialize::FromSqlRow<diesel::dsl::SqlTypeOf<#ty>, #b>
62                    })
63                })
64            })
65            .collect::<Result<Vec<_>>>()?;
66        let where_clause = &mut original_where_clause.cloned();
67        let where_clause = where_clause.get_or_insert_with(|| parse_quote!(where));
68        for field_check in field_check_bound {
69            where_clause.predicates.push(field_check);
70        }
71        Some(quote::quote! {
72            fn _check_field_compatibility #original_impl_generics()
73                #where_clause
74            {}
75        })
76    } else {
77        None
78    };
79
80    let errors: TokenStream = compile_errors
81        .into_iter()
82        .map(|e| e.into_compile_error())
83        .collect();
84
85    Ok(wrap_in_dummy_mod(quote! {
86        use diesel::expression::Selectable;
87
88        impl #impl_generics Selectable<__DB>
89            for #struct_name #ty_generics
90        #where_clause
91        {
92            type SelectExpression = (#(#field_select_expression_types,)*);
93
94            fn construct_selection() -> Self::SelectExpression {
95                (#(#field_select_expressions,)*)
96            }
97        }
98
99        #check_function
100
101        #errors
102    }))
103}
104
105fn to_field_ty_bound(field_ty: &syn::Type) -> Result<TokenStream> {
106    match field_ty {
107        syn::Type::Reference(r) => {
108            use crate::quote::ToTokens;
109            // references are not supported for checking for now
110            //
111            // (How ever you can even have references in a `Queryable` struct anyway)
112            Err(syn::Error::new(
113                field_ty.span(),
114                format!(
115                    "References are not supported in `Queryable` types\n\
116                         Consider using `std::borrow::Cow<'{}, {}>` instead",
117                    r.lifetime
118                        .as_ref()
119                        .expect("It's a struct field so it must have a named lifetime")
120                        .ident,
121                    r.elem.to_token_stream()
122                ),
123            ))
124        }
125        field_ty => Ok(quote::quote! {
126            #field_ty
127        }),
128    }
129}
130
131fn field_select_expression_ty_builder<'a>(
132    field: &'a Field,
133    model: &Model,
134    compile_errors: &mut Vec<syn::Error>,
135) -> Result<FieldSelectExpressionTyBuilder<'a>> {
136    if let Some(ref select_expression) = field.select_expression {
137        use dsl_auto_type::auto_type::expression_type_inference as type_inference;
138        let expr = &select_expression.item;
139        let (inferred_type, errors) = type_inference::infer_expression_type(
140            expr,
141            field.select_expression_type.as_ref().map(|t| &t.item),
142            &type_inference::InferrerSettings::builder()
143                .dsl_path(parse_quote!(diesel::dsl))
144                .function_types_case(crate::AUTO_TYPE_DEFAULT_FUNCTION_TYPE_CASE)
145                .method_types_case(crate::AUTO_TYPE_DEFAULT_METHOD_TYPE_CASE)
146                .build(),
147        );
148        compile_errors.extend(errors);
149        Ok(FieldSelectExpressionTyBuilder::Always(
150            quote::quote!(#inferred_type),
151        ))
152    } else if let Some(ref select_expression_type) = field.select_expression_type {
153        let ty = &select_expression_type.item;
154        Ok(FieldSelectExpressionTyBuilder::Always(quote!(#ty)))
155    } else if field.embed() {
156        Ok(FieldSelectExpressionTyBuilder::EmbedSelectable {
157            embed_ty: &field.ty,
158        })
159    } else {
160        let table_name = &model.table_names()[0];
161        let column_name = field.column_name()?.to_ident()?;
162        Ok(FieldSelectExpressionTyBuilder::Always(
163            quote!(#table_name::#column_name),
164        ))
165    }
166}
167
168enum FieldSelectExpressionTyBuilder<'a> {
169    Always(TokenStream),
170    EmbedSelectable { embed_ty: &'a syn::Type },
171}
172
173impl FieldSelectExpressionTyBuilder<'_> {
174    fn type_with_backend(&self, backend: &syn::TypePath) -> Cow<'_, TokenStream> {
175        match self {
176            FieldSelectExpressionTyBuilder::Always(ty) => Cow::Borrowed(ty),
177            FieldSelectExpressionTyBuilder::EmbedSelectable { embed_ty } => {
178                Cow::Owned(quote!(<#embed_ty as Selectable<#backend>>::SelectExpression))
179            }
180        }
181    }
182}
183
184fn field_column_inst(field: &Field, model: &Model) -> Result<TokenStream> {
185    if let Some(ref select_expression) = field.select_expression {
186        let expr = &select_expression.item;
187        Ok(quote!(#expr))
188    } else if field.embed() {
189        let embed_ty = &field.ty;
190        Ok(quote!(<#embed_ty as Selectable<__DB>>::construct_selection()))
191    } else {
192        let table_name = &model.table_names()[0];
193        let column_name = field.column_name()?.to_ident()?;
194        Ok(quote!(#table_name::#column_name))
195    }
196}