diesel_derives/
insertable.rs

1use crate::attrs::AttributeSpanWrapper;
2use crate::field::Field;
3use crate::model::Model;
4use crate::util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod};
5use proc_macro2::TokenStream;
6use quote::quote;
7use quote::quote_spanned;
8use syn::parse_quote;
9use syn::spanned::Spanned as _;
10use syn::{DeriveInput, Expr, Path, Result, Type};
11
12pub fn derive(item: DeriveInput) -> Result<TokenStream> {
13    let model = Model::from_item(&item, false, true)?;
14
15    let tokens = model
16        .table_names()
17        .iter()
18        .map(|table_name| derive_into_single_table(&item, &model, table_name))
19        .collect::<Result<Vec<_>>>()?;
20
21    Ok(wrap_in_dummy_mod(quote! {
22        use diesel::insertable::Insertable;
23        use diesel::internal::derives::insertable::UndecoratedInsertRecord;
24        use diesel::prelude::*;
25
26        #(#tokens)*
27    }))
28}
29
30fn derive_into_single_table(
31    item: &DeriveInput,
32    model: &Model,
33    table_name: &Path,
34) -> Result<TokenStream> {
35    let treat_none_as_default_value = model.treat_none_as_default_value();
36    let struct_name = &item.ident;
37
38    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
39
40    let mut generate_borrowed_insert = true;
41
42    let mut direct_field_ty = Vec::with_capacity(model.fields().len());
43    let mut direct_field_assign = Vec::with_capacity(model.fields().len());
44    let mut ref_field_ty = Vec::with_capacity(model.fields().len());
45    let mut ref_field_assign = Vec::with_capacity(model.fields().len());
46
47    for field in model.fields() {
48        // skip this field while generating the insertion
49        if field.skip_insertion() {
50            continue;
51        }
52        // Use field-level attr. with fallback to the struct-level one.
53        let treat_none_as_default_value = match &field.treat_none_as_default_value {
54            Some(attr) => {
55                if let Some(embed) = &field.embed {
56                    return Err(syn::Error::new(
57                        embed.attribute_span,
58                        "`embed` and `treat_none_as_default_value` are mutually exclusive",
59                    ));
60                }
61
62                if !is_option_ty(&field.ty) {
63                    return Err(syn::Error::new(
64                        field.ty.span(),
65                        "expected `treat_none_as_default_value` field to be of type `Option<_>`",
66                    ));
67                }
68
69                attr.item
70            }
71            None => treat_none_as_default_value,
72        };
73
74        match (field.serialize_as.as_ref(), field.embed()) {
75            (None, true) => {
76                direct_field_ty.push(field_ty_embed(field, None));
77                direct_field_assign.push(field_expr_embed(field, None));
78                ref_field_ty.push(field_ty_embed(field, Some(quote!(&'insert))));
79                ref_field_assign.push(field_expr_embed(field, Some(quote!(&))));
80            }
81            (None, false) => {
82                direct_field_ty.push(field_ty(
83                    field,
84                    table_name,
85                    None,
86                    treat_none_as_default_value,
87                )?);
88                direct_field_assign.push(field_expr(
89                    field,
90                    table_name,
91                    None,
92                    treat_none_as_default_value,
93                )?);
94                ref_field_ty.push(field_ty(
95                    field,
96                    table_name,
97                    Some(quote!(&'insert)),
98                    treat_none_as_default_value,
99                )?);
100                ref_field_assign.push(field_expr(
101                    field,
102                    table_name,
103                    Some(quote!(&)),
104                    treat_none_as_default_value,
105                )?);
106            }
107            (Some(AttributeSpanWrapper { item: ty, .. }), false) => {
108                direct_field_ty.push(field_ty_serialize_as(
109                    field,
110                    table_name,
111                    ty,
112                    treat_none_as_default_value,
113                )?);
114                direct_field_assign.push(field_expr_serialize_as(
115                    field,
116                    table_name,
117                    ty,
118                    treat_none_as_default_value,
119                )?);
120
121                generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs
122            }
123            (Some(AttributeSpanWrapper { attribute_span, .. }), true) => {
124                return Err(syn::Error::new(
125                    *attribute_span,
126                    "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`",
127                ));
128            }
129        }
130    }
131
132    let insert_owned = quote! {
133        impl #impl_generics Insertable<#table_name::table> for #struct_name #ty_generics
134            #where_clause
135        {
136            type Values = <(#(#direct_field_ty,)*) as Insertable<#table_name::table>>::Values;
137
138            fn values(self) -> <(#(#direct_field_ty,)*) as Insertable<#table_name::table>>::Values {
139                (#(#direct_field_assign,)*).values()
140            }
141        }
142    };
143
144    let insert_borrowed = if generate_borrowed_insert {
145        let mut impl_generics = item.generics.clone();
146        impl_generics.params.push(parse_quote!('insert));
147        let (impl_generics, ..) = impl_generics.split_for_impl();
148
149        quote! {
150            impl #impl_generics Insertable<#table_name::table>
151                for &'insert #struct_name #ty_generics
152            #where_clause
153            {
154                type Values = <(#(#ref_field_ty,)*) as Insertable<#table_name::table>>::Values;
155
156                fn values(self) -> <(#(#ref_field_ty,)*) as Insertable<#table_name::table>>::Values {
157                    (#(#ref_field_assign,)*).values()
158                }
159            }
160        }
161    } else {
162        quote! {}
163    };
164
165    Ok(quote! {
166        #[allow(unused_qualifications)]
167        #insert_owned
168
169        #[allow(unused_qualifications)]
170        #insert_borrowed
171
172        impl #impl_generics UndecoratedInsertRecord<#table_name::table>
173                for #struct_name #ty_generics
174            #where_clause
175        {
176        }
177    })
178}
179
180fn field_ty_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
181    let field_ty = &field.ty;
182    let span = field.span;
183    quote_spanned!(span=> #lifetime #field_ty)
184}
185
186fn field_expr_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
187    let field_name = &field.name;
188    quote!(#lifetime self.#field_name)
189}
190
191fn field_ty_serialize_as(
192    field: &Field,
193    table_name: &Path,
194    ty: &Type,
195    treat_none_as_default_value: bool,
196) -> Result<TokenStream> {
197    let column_name = field.column_name()?.to_ident()?;
198    let span = field.span;
199    if treat_none_as_default_value {
200        let inner_ty = inner_of_option_ty(ty);
201
202        Ok(quote_spanned! {span=>
203            std::option::Option<diesel::dsl::Eq<
204                #table_name::#column_name,
205                #inner_ty,
206            >>
207        })
208    } else {
209        Ok(quote_spanned! {span=>
210            diesel::dsl::Eq<
211                #table_name::#column_name,
212                #ty,
213            >
214        })
215    }
216}
217
218fn field_expr_serialize_as(
219    field: &Field,
220    table_name: &Path,
221    ty: &Type,
222    treat_none_as_default_value: bool,
223) -> Result<TokenStream> {
224    let field_name = &field.name;
225    let column_name = field.column_name()?.to_ident()?;
226    let column = quote!(#table_name::#column_name);
227    if treat_none_as_default_value {
228        if is_option_ty(ty) {
229            Ok(quote!(::std::convert::Into::<#ty>::into(self.#field_name).map(|v| #column.eq(v))))
230        } else {
231            Ok(
232                quote!(std::option::Option::Some(#column.eq(::std::convert::Into::<#ty>::into(self.#field_name)))),
233            )
234        }
235    } else {
236        Ok(quote!(#column.eq(::std::convert::Into::<#ty>::into(self.#field_name))))
237    }
238}
239
240fn field_ty(
241    field: &Field,
242    table_name: &Path,
243    lifetime: Option<TokenStream>,
244    treat_none_as_default_value: bool,
245) -> Result<TokenStream> {
246    let column_name = field.column_name()?.to_ident()?;
247    let span = field.span;
248    if treat_none_as_default_value {
249        let inner_ty = inner_of_option_ty(&field.ty);
250
251        Ok(quote_spanned! {span=>
252            std::option::Option<diesel::dsl::Eq<
253                #table_name::#column_name,
254                #lifetime #inner_ty,
255            >>
256        })
257    } else {
258        let inner_ty = &field.ty;
259
260        Ok(quote_spanned! {span=>
261            diesel::dsl::Eq<
262                #table_name::#column_name,
263                #lifetime #inner_ty,
264            >
265        })
266    }
267}
268
269fn field_expr(
270    field: &Field,
271    table_name: &Path,
272    lifetime: Option<TokenStream>,
273    treat_none_as_default_value: bool,
274) -> Result<TokenStream> {
275    let field_name = &field.name;
276    let column_name = field.column_name()?.to_ident()?;
277
278    let column: Expr = parse_quote!(#table_name::#column_name);
279    if treat_none_as_default_value {
280        if is_option_ty(&field.ty) {
281            if lifetime.is_some() {
282                Ok(quote!(self.#field_name.as_ref().map(|x| #column.eq(x))))
283            } else {
284                Ok(quote!(self.#field_name.map(|x| #column.eq(x))))
285            }
286        } else {
287            Ok(quote!(std::option::Option::Some(#column.eq(#lifetime self.#field_name))))
288        }
289    } else {
290        Ok(quote!(#column.eq(#lifetime self.#field_name)))
291    }
292}