diesel_derives/
queryable_by_name.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{parse_quote, parse_quote_spanned, DeriveInput, Ident, LitStr, Result, Type};
4
5use crate::attrs::AttributeSpanWrapper;
6use crate::field::{Field, FieldName};
7use crate::model::Model;
8use crate::util::wrap_in_dummy_mod;
9
10pub fn derive(item: DeriveInput) -> Result<TokenStream> {
11    let model = Model::from_item(&item, false, false)?;
12
13    let struct_name = &item.ident;
14    let fields = &model.fields().iter().map(get_ident).collect::<Vec<_>>();
15    let field_names = model.fields().iter().map(|f| &f.name);
16
17    let initial_field_expr = model
18        .fields()
19        .iter()
20        .map(|f| {
21            let field_ty = &f.ty;
22
23            if f.embed() {
24                Ok(quote!(<#field_ty as QueryableByName<__DB>>::build(row)?))
25            } else {
26                let st = sql_type(f, &model)?;
27                let deserialize_ty = f.ty_for_deserialize();
28                let name = f.column_name()?;
29                let name = LitStr::new(&name.to_string(), name.span());
30                Ok(quote!(
31                   {
32                       let field = diesel::row::NamedRow::get::<#st, #deserialize_ty>(row, #name)?;
33                       <#deserialize_ty as Into<#field_ty>>::into(field)
34                   }
35                ))
36            }
37        })
38        .collect::<Result<Vec<_>>>()?;
39
40    let (_, ty_generics, ..) = item.generics.split_for_impl();
41    let mut generics = item.generics.clone();
42    generics
43        .params
44        .push(parse_quote!(__DB: diesel::backend::Backend));
45
46    for field in model.fields() {
47        let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
48        let span = field.span;
49        let field_ty = field.ty_for_deserialize();
50        if field.embed() {
51            where_clause
52                .predicates
53                .push(parse_quote_spanned!(span=> #field_ty: QueryableByName<__DB>));
54        } else {
55            let st = sql_type(field, &model)?;
56            where_clause.predicates.push(
57                parse_quote_spanned!(span=> #field_ty: diesel::deserialize::FromSql<#st, __DB>),
58            );
59        }
60    }
61    let model = &model;
62    let check_function = if let Some(ref backends) = model.check_for_backend {
63        let field_check_bound = model.fields().iter().filter(|f| !f.embed()).flat_map(|f| {
64            backends.iter().map(move |b| {
65                let field_ty = f.ty_for_deserialize();
66                let span = f.span;
67                let ty = sql_type(f, model).unwrap();
68                quote::quote_spanned! {span =>
69                    #field_ty: diesel::deserialize::FromSqlRow<#ty, #b>
70                }
71            })
72        });
73        Some(quote::quote! {
74            fn _check_field_compatibility()
75            where
76                #(#field_check_bound,)*
77            {}
78        })
79    } else {
80        None
81    };
82
83    let (impl_generics, _, where_clause) = generics.split_for_impl();
84
85    Ok(wrap_in_dummy_mod(quote! {
86        use diesel::deserialize::{self, QueryableByName};
87        use diesel::row::{NamedRow};
88        use diesel::sql_types::Untyped;
89
90        impl #impl_generics QueryableByName<__DB>
91            for #struct_name #ty_generics
92        #where_clause
93        {
94            fn build<'__a>(row: &impl NamedRow<'__a, __DB>) -> deserialize::Result<Self>
95            {
96                #(
97                    let mut #fields = #initial_field_expr;
98                )*
99                deserialize::Result::Ok(Self {
100                    #(
101                        #field_names: #fields,
102                    )*
103                })
104            }
105        }
106
107        #check_function
108    }))
109}
110
111fn get_ident(field: &Field) -> Ident {
112    match &field.name {
113        FieldName::Named(n) => n.clone(),
114        FieldName::Unnamed(i) => Ident::new(&format!("field_{}", i.index), i.span),
115    }
116}
117
118fn sql_type(field: &Field, model: &Model) -> Result<Type> {
119    let table_name = &model.table_names()[0];
120
121    match field.sql_type {
122        Some(AttributeSpanWrapper { item: ref st, .. }) => Ok(st.clone()),
123        None => {
124            let column_name = field.column_name()?.to_ident()?;
125            Ok(parse_quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>))
126        }
127    }
128}