diesel_derives/
queryable_by_name.rs1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::spanned::Spanned;
4use syn::{parse_quote, parse_quote_spanned, DeriveInput, Ident, LitStr, Result, Type};
5
6use crate::attrs::AttributeSpanWrapper;
7use crate::field::{Field, FieldName};
8use crate::model::{CheckForBackend, Model};
9use crate::util::wrap_in_dummy_mod;
10
11pub fn derive(item: DeriveInput) -> Result<TokenStream> {
12 let model = Model::from_item(&item, false, false)?;
13
14 let struct_name = &item.ident;
15 let fields = &model.fields().iter().map(get_ident).collect::<Vec<_>>();
16 let field_names = model.fields().iter().map(|f| &f.name);
17
18 let initial_field_expr = model
19 .fields()
20 .iter()
21 .map(|f| {
22 let field_ty = &f.ty;
23
24 if f.embed() {
25 Ok(quote!(<#field_ty as QueryableByName<__DB>>::build(row)?))
26 } else {
27 let st = sql_type(f, &model)?;
28 let deserialize_ty = f.ty_for_deserialize();
29 let name = f.column_name()?;
30 let name = LitStr::new(&name.to_string(), name.span());
31 Ok(quote!(
32 {
33 let field = diesel::row::NamedRow::get::<#st, #deserialize_ty>(row, #name)?;
34 <#deserialize_ty as std::convert::Into<#field_ty>>::into(field)
35 }
36 ))
37 }
38 })
39 .collect::<Result<Vec<_>>>()?;
40
41 let (_, ty_generics, ..) = item.generics.split_for_impl();
42 let mut generics = item.generics.clone();
43 generics
44 .params
45 .push(parse_quote!(__DB: diesel::backend::Backend));
46
47 for field in model.fields() {
48 let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
49 let span = Span::mixed_site().located_at(field.ty.span());
50 let field_ty = field.ty_for_deserialize();
51 if field.embed() {
52 where_clause.predicates.push(
53 parse_quote_spanned!(span=> #field_ty: diesel::deserialize::QueryableByName<__DB>),
54 );
55 } else {
56 let st = sql_type(field, &model)?;
57 where_clause.predicates.push(
58 parse_quote_spanned!(span=> #field_ty: diesel::deserialize::FromSql<#st, __DB>),
59 );
60 }
61 }
62 let model = &model;
63 let check_function = if let Some(ref backends) = model.check_for_backend {
64 let field_check_bound = model.fields().iter().filter(|f| !f.embed()).flat_map(|f| {
65 if let CheckForBackend::Backends(backends) = backends {
66 let iter = backends.iter().map(move |b| {
67 let field_ty = f.ty_for_deserialize();
68 let span = Span::mixed_site().located_at(f.ty.span());
69 let ty = sql_type(f, model).unwrap();
70 quote::quote_spanned! {span =>
71 #field_ty: diesel::deserialize::FromSqlRow<#ty, #b>
72 }
73 });
74 Box::new(iter) as Box<dyn Iterator<Item = _>>
75 } else {
76 Box::new(std::iter::empty())
77 }
78 });
79 Some(quote::quote! {
80 fn _check_field_compatibility()
81 where
82 #(#field_check_bound,)*
83 {}
84 })
85 } else {
86 None
87 };
88
89 let (impl_generics, _, where_clause) = generics.split_for_impl();
90
91 Ok(wrap_in_dummy_mod(quote! {
92
93 impl #impl_generics diesel::deserialize::QueryableByName<__DB>
94 for #struct_name #ty_generics
95 #where_clause
96 {
97 fn build<'__a>(row: &impl diesel::row::NamedRow<'__a, __DB>) -> diesel::deserialize::Result<Self>
98 {
99 #(
100 let mut #fields = #initial_field_expr;
101 )*
102 diesel::deserialize::Result::Ok(Self {
103 #(
104 #field_names: #fields,
105 )*
106 })
107 }
108 }
109
110 #check_function
111 }))
112}
113
114fn get_ident(field: &Field) -> Ident {
115 match &field.name {
116 FieldName::Named(n) => n.clone(),
117 FieldName::Unnamed(i) => Ident::new(&format!("field_{}", i.index), i.span),
118 }
119}
120
121fn sql_type(field: &Field, model: &Model) -> Result<Type> {
122 let table_name = &model.table_names()[0];
123
124 match field.sql_type {
125 Some(AttributeSpanWrapper { item: ref st, .. }) => Ok(st.clone()),
126 None => {
127 let column_name = field.column_name()?.to_ident()?;
128 Ok(parse_quote!(diesel::dsl::SqlTypeOf<#table_name::#column_name>))
129 }
130 }
131}