diesel_derives/
selectable.rs1use proc_macro2::{Span, TokenStream};
2use quote::{quote, quote_spanned};
3use std::borrow::Cow;
4use syn::spanned::Spanned;
5use syn::{parse_quote, DeriveInput, Result};
6
7use crate::field::Field;
8use crate::model::{CheckForBackend, Model};
9use crate::util::wrap_in_dummy_mod;
10
11pub fn derive(
12 item: DeriveInput,
13 check_for_backend: Option<CheckForBackend>,
14) -> Result<TokenStream> {
15 let model = Model::from_item(&item, false, false)?;
16
17 let (original_impl_generics, ty_generics, original_where_clause) =
18 item.generics.split_for_impl();
19
20 let mut generics = item.generics.clone();
21 generics
22 .params
23 .push(parse_quote!(__DB: diesel::backend::Backend));
24
25 for embed_field in model.fields().iter().filter(|f| f.embed()) {
26 let embed_ty = &embed_field.ty;
27 generics
28 .make_where_clause()
29 .predicates
30 .push(parse_quote!(#embed_ty: Selectable<__DB>));
31 }
32
33 let (impl_generics, _, where_clause) = generics.split_for_impl();
34
35 let struct_name = &item.ident;
36
37 let mut compile_errors: Vec<syn::Error> = Vec::new();
38 let field_select_expression_type_builders = model
39 .fields()
40 .iter()
41 .map(|f| field_select_expression_ty_builder(f, &model, &mut compile_errors))
42 .collect::<Result<Vec<_>>>()?;
43 let field_select_expression_types = field_select_expression_type_builders
44 .iter()
45 .map(|f| f.type_with_backend(&parse_quote!(__DB)))
46 .collect::<Vec<_>>();
47 let field_select_expressions = model
48 .fields()
49 .iter()
50 .map(|f| field_column_inst(f, &model))
51 .collect::<Result<Vec<_>>>()?;
52
53 let check_function = if let Some(backends) = model
54 .check_for_backend
55 .as_ref()
56 .or(check_for_backend.as_ref())
57 .and_then(|c| match c {
58 CheckForBackend::Backends(punctuated) => Some(punctuated),
59 CheckForBackend::Disabled(_lit_bool) => None,
60 }) {
61 let field_check_bound = model
62 .fields()
63 .iter()
64 .zip(&field_select_expression_type_builders)
65 .flat_map(|(f, ty_builder)| {
66 backends.iter().map(move |b| {
67 let span = Span::mixed_site().located_at(f.ty.span());
68 let field_ty = to_field_ty_bound(f.ty_for_deserialize())?;
69 let ty = ty_builder.type_with_backend(b);
70 Ok(syn::parse_quote_spanned! {span =>
71 #field_ty: diesel::deserialize::FromSqlRow<diesel::dsl::SqlTypeOf<#ty>, #b>
72 })
73 })
74 })
75 .collect::<Result<Vec<_>>>()?;
76 let where_clause = &mut original_where_clause.cloned();
77 let where_clause = where_clause.get_or_insert_with(|| parse_quote!(where));
78 for field_check in field_check_bound {
79 where_clause.predicates.push(field_check);
80 }
81 Some(quote::quote! {
82 fn _check_field_compatibility #original_impl_generics()
83 #where_clause
84 {}
85 })
86 } else {
87 None
88 };
89
90 let errors: TokenStream = compile_errors
91 .into_iter()
92 .map(|e| e.into_compile_error())
93 .collect();
94
95 Ok(wrap_in_dummy_mod(quote! {
96 use diesel::expression::Selectable;
97
98 impl #impl_generics Selectable<__DB>
99 for #struct_name #ty_generics
100 #where_clause
101 {
102 type SelectExpression = (#(#field_select_expression_types,)*);
103
104 fn construct_selection() -> Self::SelectExpression {
105 (#(#field_select_expressions,)*)
106 }
107 }
108
109 #check_function
110
111 #errors
112 }))
113}
114
115fn to_field_ty_bound(field_ty: &syn::Type) -> Result<TokenStream> {
116 match field_ty {
117 syn::Type::Reference(r) => {
118 use crate::quote::ToTokens;
119 Err(syn::Error::new(
123 field_ty.span(),
124 format!(
125 "references are not supported in `Queryable` types\n\
126 consider using `std::borrow::Cow<'{}, {}>` instead",
127 r.lifetime
128 .as_ref()
129 .expect("It's a struct field so it must have a named lifetime")
130 .ident,
131 r.elem.to_token_stream()
132 ),
133 ))
134 }
135 field_ty => Ok(quote::quote! {
136 #field_ty
137 }),
138 }
139}
140
141fn field_select_expression_ty_builder<'a>(
142 field: &'a Field,
143 model: &Model,
144 compile_errors: &mut Vec<syn::Error>,
145) -> Result<FieldSelectExpressionTyBuilder<'a>> {
146 if let Some(ref select_expression) = field.select_expression {
147 use dsl_auto_type::auto_type::expression_type_inference as type_inference;
148 let expr = &select_expression.item;
149 let (inferred_type, errors) = type_inference::infer_expression_type(
150 expr,
151 field.select_expression_type.as_ref().map(|t| &t.item),
152 &type_inference::InferrerSettings::builder()
153 .dsl_path(parse_quote!(diesel::dsl))
154 .function_types_case(crate::AUTO_TYPE_DEFAULT_FUNCTION_TYPE_CASE)
155 .method_types_case(crate::AUTO_TYPE_DEFAULT_METHOD_TYPE_CASE)
156 .build(),
157 );
158 compile_errors.extend(errors);
159 Ok(FieldSelectExpressionTyBuilder::Always(
160 quote::quote!(#inferred_type),
161 ))
162 } else if let Some(ref select_expression_type) = field.select_expression_type {
163 let ty = &select_expression_type.item;
164 Ok(FieldSelectExpressionTyBuilder::Always(quote!(#ty)))
165 } else if field.embed() {
166 Ok(FieldSelectExpressionTyBuilder::EmbedSelectable {
167 embed_ty: &field.ty,
168 })
169 } else {
170 let table_name = &model.table_names()[0];
171 let column_name = field.column_name()?.to_ident()?;
172 let span = Span::call_site();
173 Ok(FieldSelectExpressionTyBuilder::Always(
174 quote_spanned!(span=> #table_name::#column_name),
175 ))
176 }
177}
178
179enum FieldSelectExpressionTyBuilder<'a> {
180 Always(TokenStream),
181 EmbedSelectable { embed_ty: &'a syn::Type },
182}
183
184impl FieldSelectExpressionTyBuilder<'_> {
185 fn type_with_backend(&self, backend: &syn::TypePath) -> Cow<'_, TokenStream> {
186 match self {
187 FieldSelectExpressionTyBuilder::Always(ty) => Cow::Borrowed(ty),
188 FieldSelectExpressionTyBuilder::EmbedSelectable { embed_ty } => {
189 Cow::Owned(quote!(<#embed_ty as Selectable<#backend>>::SelectExpression))
190 }
191 }
192 }
193}
194
195fn field_column_inst(field: &Field, model: &Model) -> Result<TokenStream> {
196 if let Some(ref select_expression) = field.select_expression {
197 let expr = &select_expression.item;
198 Ok(quote!(#expr))
199 } else if field.embed() {
200 let embed_ty = &field.ty;
201 Ok(quote!(<#embed_ty as Selectable<__DB>>::construct_selection()))
202 } else {
203 let table_name = &model.table_names()[0];
204 let column_name = field.column_name()?.to_ident()?;
205 let span = Span::call_site();
206 Ok(quote_spanned!(span=> #table_name::#column_name))
207 }
208}