diesel_derives/
selectable.rs

1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use std::borrow::Cow;
4use syn::spanned::Spanned;
5use syn::{parse_quote, DeriveInput, Result};
6
7use crate::field::Field;
8use crate::model::{CheckForBackend, Model};
9use crate::util::wrap_in_dummy_mod;
10
11pub fn derive(
12    item: DeriveInput,
13    check_for_backend: Option<CheckForBackend>,
14) -> Result<TokenStream> {
15    let model = Model::from_item(&item, false, false)?;
16
17    let (original_impl_generics, ty_generics, original_where_clause) =
18        item.generics.split_for_impl();
19
20    let mut generics = item.generics.clone();
21    generics
22        .params
23        .push(parse_quote!(__DB: diesel::backend::Backend));
24
25    for embed_field in model.fields().iter().filter(|f| f.embed()) {
26        let embed_ty = &embed_field.ty;
27        generics
28            .make_where_clause()
29            .predicates
30            .push(parse_quote!(#embed_ty: Selectable<__DB>));
31    }
32
33    let (impl_generics, _, where_clause) = generics.split_for_impl();
34
35    let struct_name = &item.ident;
36
37    let mut compile_errors: Vec<syn::Error> = Vec::new();
38    let field_select_expression_type_builders = model
39        .fields()
40        .iter()
41        .map(|f| field_select_expression_ty_builder(f, &model, &mut compile_errors))
42        .collect::<Result<Vec<_>>>()?;
43    let field_select_expression_types = field_select_expression_type_builders
44        .iter()
45        .map(|f| f.type_with_backend(&parse_quote!(__DB)))
46        .collect::<Vec<_>>();
47    let field_select_expressions = model
48        .fields()
49        .iter()
50        .map(|f| field_column_inst(f, &model))
51        .collect::<Result<Vec<_>>>()?;
52
53    let check_function = if let Some(backends) = model
54        .check_for_backend
55        .as_ref()
56        .or(check_for_backend.as_ref())
57        .and_then(|c| match c {
58            CheckForBackend::Backends(punctuated) => Some(punctuated),
59            CheckForBackend::Disabled(_lit_bool) => None,
60        }) {
61        let field_check_bound = model
62            .fields()
63            .iter()
64            .zip(&field_select_expression_type_builders)
65            .flat_map(|(f, ty_builder)| {
66                backends.iter().map(move |b| {
67                    let span = Span::mixed_site().located_at(f.ty.span());
68                    let field_ty = to_field_ty_bound(f.ty_for_deserialize())?;
69                    let ty = ty_builder.type_with_backend(b);
70                    Ok(syn::parse_quote_spanned! {span =>
71                        #field_ty: diesel::deserialize::FromSqlRow<diesel::dsl::SqlTypeOf<#ty>, #b>
72                    })
73                })
74            })
75            .collect::<Result<Vec<_>>>()?;
76        let where_clause = &mut original_where_clause.cloned();
77        let where_clause = where_clause.get_or_insert_with(|| parse_quote!(where));
78        for field_check in field_check_bound {
79            where_clause.predicates.push(field_check);
80        }
81        Some(quote::quote! {
82            fn _check_field_compatibility #original_impl_generics()
83                #where_clause
84            {}
85        })
86    } else {
87        None
88    };
89
90    let errors: TokenStream = compile_errors
91        .into_iter()
92        .map(|e| e.into_compile_error())
93        .collect();
94
95    Ok(wrap_in_dummy_mod(quote! {
96        use diesel::expression::Selectable;
97
98        impl #impl_generics Selectable<__DB>
99            for #struct_name #ty_generics
100        #where_clause
101        {
102            type SelectExpression = (#(#field_select_expression_types,)*);
103
104            fn construct_selection() -> Self::SelectExpression {
105                (#(#field_select_expressions,)*)
106            }
107        }
108
109        #check_function
110
111        #errors
112    }))
113}
114
115fn to_field_ty_bound(field_ty: &syn::Type) -> Result<TokenStream> {
116    match field_ty {
117        syn::Type::Reference(r) => {
118            use crate::quote::ToTokens;
119            // references are not supported for checking for now
120            //
121            // (How ever you can even have references in a `Queryable` struct anyway)
122            Err(syn::Error::new(
123                field_ty.span(),
124                format!(
125                    "references are not supported in `Queryable` types\n\
126                         consider using `std::borrow::Cow<'{}, {}>` instead",
127                    r.lifetime
128                        .as_ref()
129                        .expect("It's a struct field so it must have a named lifetime")
130                        .ident,
131                    r.elem.to_token_stream()
132                ),
133            ))
134        }
135        field_ty => Ok(quote::quote! {
136            #field_ty
137        }),
138    }
139}
140
141fn field_select_expression_ty_builder<'a>(
142    field: &'a Field,
143    model: &Model,
144    compile_errors: &mut Vec<syn::Error>,
145) -> Result<FieldSelectExpressionTyBuilder<'a>> {
146    if let Some(ref select_expression) = field.select_expression {
147        use dsl_auto_type::auto_type::expression_type_inference as type_inference;
148        let expr = &select_expression.item;
149        let (inferred_type, errors) = type_inference::infer_expression_type(
150            expr,
151            field.select_expression_type.as_ref().map(|t| &t.item),
152            &type_inference::InferrerSettings::builder()
153                .dsl_path(parse_quote!(diesel::dsl))
154                .function_types_case(crate::AUTO_TYPE_DEFAULT_FUNCTION_TYPE_CASE)
155                .method_types_case(crate::AUTO_TYPE_DEFAULT_METHOD_TYPE_CASE)
156                .build(),
157        );
158        compile_errors.extend(errors);
159        Ok(FieldSelectExpressionTyBuilder::Always(
160            quote::quote!(#inferred_type),
161        ))
162    } else if let Some(ref select_expression_type) = field.select_expression_type {
163        let ty = &select_expression_type.item;
164        Ok(FieldSelectExpressionTyBuilder::Always(quote!(#ty)))
165    } else if field.embed() {
166        Ok(FieldSelectExpressionTyBuilder::EmbedSelectable {
167            embed_ty: &field.ty,
168        })
169    } else {
170        let table_name = &model.table_names()[0];
171        let column_name = field.column_name()?.to_ident()?;
172        let span = Span::call_site();
173        Ok(FieldSelectExpressionTyBuilder::Always(
174            quote_spanned!(span=> #table_name::#column_name),
175        ))
176    }
177}
178
179enum FieldSelectExpressionTyBuilder<'a> {
180    Always(TokenStream),
181    EmbedSelectable { embed_ty: &'a syn::Type },
182}
183
184impl FieldSelectExpressionTyBuilder<'_> {
185    fn type_with_backend(&self, backend: &syn::TypePath) -> Cow<'_, TokenStream> {
186        match self {
187            FieldSelectExpressionTyBuilder::Always(ty) => Cow::Borrowed(ty),
188            FieldSelectExpressionTyBuilder::EmbedSelectable { embed_ty } => {
189                Cow::Owned(quote!(<#embed_ty as Selectable<#backend>>::SelectExpression))
190            }
191        }
192    }
193}
194
195fn field_column_inst(field: &Field, model: &Model) -> Result<TokenStream> {
196    if let Some(ref select_expression) = field.select_expression {
197        let expr = &select_expression.item;
198        Ok(quote!(#expr))
199    } else if field.embed() {
200        let embed_ty = &field.ty;
201        Ok(quote!(<#embed_ty as Selectable<__DB>>::construct_selection()))
202    } else {
203        let table_name = &model.table_names()[0];
204        let column_name = field.column_name()?.to_ident()?;
205        let span = Span::call_site();
206        Ok(quote_spanned!(span=> #table_name::#column_name))
207    }
208}