diesel_derives/
queryable_by_name.rs
1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::{parse_quote, parse_quote_spanned, DeriveInput, Ident, LitStr, Result, Type};
4
5use crate::attrs::AttributeSpanWrapper;
6use crate::field::{Field, FieldName};
7use crate::model::Model;
8use crate::util::wrap_in_dummy_mod;
9
10pub fn derive(item: DeriveInput) -> Result<TokenStream> {
11 let model = Model::from_item(&item, false, false)?;
12
13 let struct_name = &item.ident;
14 let fields = &model.fields().iter().map(get_ident).collect::<Vec<_>>();
15 let field_names = model.fields().iter().map(|f| &f.name);
16
17 let initial_field_expr = model
18 .fields()
19 .iter()
20 .map(|f| {
21 let field_ty = &f.ty;
22
23 if f.embed() {
24 Ok(quote!(<#field_ty as QueryableByName<__DB>>::build(row)?))
25 } else {
26 let st = sql_type(f, &model)?;
27 let deserialize_ty = f.ty_for_deserialize();
28 let name = f.column_name()?;
29 let name = LitStr::new(&name.to_string(), name.span());
30 Ok(quote!(
31 {
32 let field = diesel::row::NamedRow::get::<#st, #deserialize_ty>(row, #name)?;
33 <#deserialize_ty as std::convert::Into<#field_ty>>::into(field)
34 }
35 ))
36 }
37 })
38 .collect::<Result<Vec<_>>>()?;
39
40 let (_, ty_generics, ..) = item.generics.split_for_impl();
41 let mut generics = item.generics.clone();
42 generics
43 .params
44 .push(parse_quote!(__DB: diesel::backend::Backend));
45
46 for field in model.fields() {
47 let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
48 let span = field.span;
49 let field_ty = field.ty_for_deserialize();
50 if field.embed() {
51 where_clause.predicates.push(
52 parse_quote_spanned!(span=> #field_ty: diesel::deserialize::QueryableByName<__DB>),
53 );
54 } else {
55 let st = sql_type(field, &model)?;
56 where_clause.predicates.push(
57 parse_quote_spanned!(span=> #field_ty: diesel::deserialize::FromSql<#st, __DB>),
58 );
59 }
60 }
61 let model = &model;
62 let check_function = if let Some(ref backends) = model.check_for_backend {
63 let field_check_bound = model.fields().iter().filter(|f| !f.embed()).flat_map(|f| {
64 backends.iter().map(move |b| {
65 let field_ty = f.ty_for_deserialize();
66 let span = f.span;
67 let ty = sql_type(f, model).unwrap();
68 quote::quote_spanned! {span =>
69 #field_ty: diesel::deserialize::FromSqlRow<#ty, #b>
70 }
71 })
72 });
73 Some(quote::quote! {
74 fn _check_field_compatibility()
75 where
76 #(#field_check_bound,)*
77 {}
78 })
79 } else {
80 None
81 };
82
83 let (impl_generics, _, where_clause) = generics.split_for_impl();
84
85 Ok(wrap_in_dummy_mod(quote! {
86
87 impl #impl_generics diesel::deserialize::QueryableByName<__DB>
88 for #struct_name #ty_generics
89 #where_clause
90 {
91 fn build<'__a>(row: &impl diesel::row::NamedRow<'__a, __DB>) -> diesel::deserialize::Result<Self>
92 {
93 #(
94 let mut #fields = #initial_field_expr;
95 )*
96 diesel::deserialize::Result::Ok(Self {
97 #(
98 #field_names: #fields,
99 )*
100 })
101 }
102 }
103
104 #check_function
105 }))
106}
107
108fn get_ident(field: &Field) -> Ident {
109 match &field.name {
110 FieldName::Named(n) => n.clone(),
111 FieldName::Unnamed(i) => Ident::new(&format!("field_{}", i.index), i.span),
112 }
113}
114
115fn sql_type(field: &Field, model: &Model) -> Result<Type> {
116 let table_name = &model.table_names()[0];
117
118 match field.sql_type {
119 Some(AttributeSpanWrapper { item: ref st, .. }) => Ok(st.clone()),
120 None => {
121 let column_name = field.column_name()?.to_ident()?;
122 Ok(parse_quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>))
123 }
124 }
125}