diesel_derives/
insertable.rs

1use crate::attrs::AttributeSpanWrapper;
2use crate::field::Field;
3use crate::model::Model;
4use crate::util::{inner_of_option_ty, is_option_ty, wrap_in_dummy_mod};
5use proc_macro2::TokenStream;
6use quote::quote;
7use quote::quote_spanned;
8use syn::parse_quote;
9use syn::spanned::Spanned as _;
10use syn::{DeriveInput, Expr, Path, Result, Type};
11
12pub fn derive(item: DeriveInput) -> Result<TokenStream> {
13    let model = Model::from_item(&item, false, true)?;
14
15    let tokens = model
16        .table_names()
17        .iter()
18        .map(|table_name| derive_into_single_table(&item, &model, table_name))
19        .collect::<Result<Vec<_>>>()?;
20
21    Ok(wrap_in_dummy_mod(quote! {
22        #(#tokens)*
23    }))
24}
25
26fn derive_into_single_table(
27    item: &DeriveInput,
28    model: &Model,
29    table_name: &Path,
30) -> Result<TokenStream> {
31    let treat_none_as_default_value = model.treat_none_as_default_value();
32    let struct_name = &item.ident;
33
34    let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();
35
36    let mut generate_borrowed_insert = true;
37
38    let mut direct_field_ty = Vec::with_capacity(model.fields().len());
39    let mut direct_field_assign = Vec::with_capacity(model.fields().len());
40    let mut ref_field_ty = Vec::with_capacity(model.fields().len());
41    let mut ref_field_assign = Vec::with_capacity(model.fields().len());
42
43    for field in model.fields() {
44        // skip this field while generating the insertion
45        if field.skip_insertion() {
46            continue;
47        }
48        // Use field-level attr. with fallback to the struct-level one.
49        let treat_none_as_default_value = match &field.treat_none_as_default_value {
50            Some(attr) => {
51                if let Some(embed) = &field.embed {
52                    return Err(syn::Error::new(
53                        embed.attribute_span,
54                        "`embed` and `treat_none_as_default_value` are mutually exclusive",
55                    ));
56                }
57
58                if !is_option_ty(&field.ty) {
59                    return Err(syn::Error::new(
60                        field.ty.span(),
61                        "expected `treat_none_as_default_value` field to be of type `Option<_>`",
62                    ));
63                }
64
65                attr.item
66            }
67            None => treat_none_as_default_value,
68        };
69
70        match (field.serialize_as.as_ref(), field.embed()) {
71            (None, true) => {
72                direct_field_ty.push(field_ty_embed(field, None));
73                direct_field_assign.push(field_expr_embed(field, None));
74                ref_field_ty.push(field_ty_embed(field, Some(quote!(&'insert))));
75                ref_field_assign.push(field_expr_embed(field, Some(quote!(&))));
76            }
77            (None, false) => {
78                direct_field_ty.push(field_ty(
79                    field,
80                    table_name,
81                    None,
82                    treat_none_as_default_value,
83                )?);
84                direct_field_assign.push(field_expr(
85                    field,
86                    table_name,
87                    None,
88                    treat_none_as_default_value,
89                )?);
90                ref_field_ty.push(field_ty(
91                    field,
92                    table_name,
93                    Some(quote!(&'insert)),
94                    treat_none_as_default_value,
95                )?);
96                ref_field_assign.push(field_expr(
97                    field,
98                    table_name,
99                    Some(quote!(&)),
100                    treat_none_as_default_value,
101                )?);
102            }
103            (Some(AttributeSpanWrapper { item: ty, .. }), false) => {
104                direct_field_ty.push(field_ty_serialize_as(
105                    field,
106                    table_name,
107                    ty,
108                    treat_none_as_default_value,
109                )?);
110                direct_field_assign.push(field_expr_serialize_as(
111                    field,
112                    table_name,
113                    ty,
114                    treat_none_as_default_value,
115                )?);
116
117                generate_borrowed_insert = false; // as soon as we hit one field with #[diesel(serialize_as)] there is no point in generating the impl of Insertable for borrowed structs
118            }
119            (Some(AttributeSpanWrapper { attribute_span, .. }), true) => {
120                return Err(syn::Error::new(
121                    *attribute_span,
122                    "`#[diesel(embed)]` cannot be combined with `#[diesel(serialize_as)]`",
123                ));
124            }
125        }
126    }
127
128    let insert_owned = quote! {
129        impl #impl_generics diesel::insertable::Insertable<#table_name::table> for #struct_name #ty_generics
130            #where_clause
131        {
132            type Values = <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
133
134            fn values(self) -> <(#(#direct_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
135                diesel::insertable::Insertable::<#table_name::table>::values((#(#direct_field_assign,)*))
136            }
137        }
138    };
139
140    let insert_borrowed = if generate_borrowed_insert {
141        let mut impl_generics = item.generics.clone();
142        impl_generics.params.push(parse_quote!('insert));
143        let (impl_generics, ..) = impl_generics.split_for_impl();
144
145        quote! {
146            impl #impl_generics diesel::insertable::Insertable<#table_name::table>
147                for &'insert #struct_name #ty_generics
148            #where_clause
149            {
150                type Values = <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values;
151
152                fn values(self) -> <(#(#ref_field_ty,)*) as diesel::insertable::Insertable<#table_name::table>>::Values {
153                    diesel::insertable::Insertable::<#table_name::table>::values((#(#ref_field_assign,)*))
154                }
155            }
156        }
157    } else {
158        quote! {}
159    };
160
161    Ok(quote! {
162        #insert_owned
163
164        #insert_borrowed
165
166        impl #impl_generics diesel::internal::derives::insertable::UndecoratedInsertRecord<#table_name::table>
167                for #struct_name #ty_generics
168            #where_clause
169        {
170        }
171    })
172}
173
174fn field_ty_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
175    let field_ty = &field.ty;
176    let span = field.span;
177    quote_spanned!(span=> #lifetime #field_ty)
178}
179
180fn field_expr_embed(field: &Field, lifetime: Option<TokenStream>) -> TokenStream {
181    let field_name = &field.name;
182    quote!(#lifetime self.#field_name)
183}
184
185fn field_ty_serialize_as(
186    field: &Field,
187    table_name: &Path,
188    ty: &Type,
189    treat_none_as_default_value: bool,
190) -> Result<TokenStream> {
191    let column_name = field.column_name()?.to_ident()?;
192    let span = field.span;
193    if treat_none_as_default_value {
194        let inner_ty = inner_of_option_ty(ty);
195
196        Ok(quote_spanned! {span=>
197            std::option::Option<diesel::dsl::Eq<
198                #table_name::#column_name,
199                #inner_ty,
200            >>
201        })
202    } else {
203        Ok(quote_spanned! {span=>
204            diesel::dsl::Eq<
205                #table_name::#column_name,
206                #ty,
207            >
208        })
209    }
210}
211
212fn field_expr_serialize_as(
213    field: &Field,
214    table_name: &Path,
215    ty: &Type,
216    treat_none_as_default_value: bool,
217) -> Result<TokenStream> {
218    let field_name = &field.name;
219    let column_name = field.column_name()?.to_ident()?;
220    let column = quote!(#table_name::#column_name);
221    if treat_none_as_default_value {
222        if is_option_ty(ty) {
223            Ok(
224                quote!(::std::convert::Into::<#ty>::into(self.#field_name).map(|v| diesel::ExpressionMethods::eq(#column, v))),
225            )
226        } else {
227            Ok(
228                quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name)))),
229            )
230        }
231    } else {
232        Ok(
233            quote!(diesel::ExpressionMethods::eq(#column, ::std::convert::Into::<#ty>::into(self.#field_name))),
234        )
235    }
236}
237
238fn field_ty(
239    field: &Field,
240    table_name: &Path,
241    lifetime: Option<TokenStream>,
242    treat_none_as_default_value: bool,
243) -> Result<TokenStream> {
244    let column_name = field.column_name()?.to_ident()?;
245    let span = field.span;
246    if treat_none_as_default_value {
247        let inner_ty = inner_of_option_ty(&field.ty);
248
249        Ok(quote_spanned! {span=>
250            std::option::Option<diesel::dsl::Eq<
251                #table_name::#column_name,
252                #lifetime #inner_ty,
253            >>
254        })
255    } else {
256        let inner_ty = &field.ty;
257
258        Ok(quote_spanned! {span=>
259            diesel::dsl::Eq<
260                #table_name::#column_name,
261                #lifetime #inner_ty,
262            >
263        })
264    }
265}
266
267fn field_expr(
268    field: &Field,
269    table_name: &Path,
270    lifetime: Option<TokenStream>,
271    treat_none_as_default_value: bool,
272) -> Result<TokenStream> {
273    let field_name = &field.name;
274    let column_name = field.column_name()?.to_ident()?;
275
276    let column: Expr = parse_quote!(#table_name::#column_name);
277    if treat_none_as_default_value {
278        if is_option_ty(&field.ty) {
279            if lifetime.is_some() {
280                Ok(
281                    quote!(self.#field_name.as_ref().map(|x| diesel::ExpressionMethods::eq(#column, x))),
282                )
283            } else {
284                Ok(quote!(self.#field_name.map(|x| diesel::ExpressionMethods::eq(#column, x))))
285            }
286        } else {
287            Ok(
288                quote!(std::option::Option::Some(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name))),
289            )
290        }
291    } else {
292        Ok(quote!(diesel::ExpressionMethods::eq(#column, #lifetime self.#field_name)))
293    }
294}