diesel_derives/
associations.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::fold::Fold;
4use syn::parse_quote;
5use syn::{DeriveInput, Ident, Lifetime, Result};
6
7use crate::model::Model;
8use crate::parsers::BelongsTo;
9use crate::util::{camel_to_snake, wrap_in_dummy_mod};
10
11pub fn derive(item: DeriveInput) -> Result<TokenStream> {
12    let model = Model::from_item(&item, false, false)?;
13
14    if model.belongs_to.is_empty() {
15        return Err(syn::Error::new(
16            proc_macro2::Span::call_site(),
17            "At least one `belongs_to` is needed for deriving `Associations` on a structure.",
18        ));
19    }
20
21    let tokens = model
22        .belongs_to
23        .iter()
24        .map(|assoc| derive_belongs_to(&item, &model, assoc))
25        .collect::<Result<Vec<_>>>()?;
26
27    Ok(wrap_in_dummy_mod(quote!(#(#tokens)*)))
28}
29
30fn derive_belongs_to(item: &DeriveInput, model: &Model, assoc: &BelongsTo) -> Result<TokenStream> {
31    let (_, ty_generics, _) = item.generics.split_for_impl();
32
33    let struct_name = &item.ident;
34    let table_name = &model.table_names()[0];
35
36    let foreign_key = &foreign_key(assoc);
37
38    let foreign_key_field = model.find_column(foreign_key)?;
39    let foreign_key_name = &foreign_key_field.name;
40    let foreign_key_ty = &foreign_key_field.ty;
41
42    let mut generics = item.generics.clone();
43
44    let parent_struct = ReplacePathLifetimes::new(|i, span| {
45        let letter = char::from(b'b' + i as u8);
46        let lifetime = Lifetime::new(&format!("'__{letter}"), span);
47        generics.params.push(parse_quote!(#lifetime));
48        lifetime
49    })
50    .fold_type_path(assoc.parent.clone());
51
52    generics.params.push(parse_quote!(__FK));
53    {
54        let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
55        where_clause
56            .predicates
57            .push(parse_quote!(__FK: std::hash::Hash + std::cmp::Eq));
58        where_clause.predicates.push(
59                parse_quote!(for<'__a> &'__a #foreign_key_ty: std::convert::Into<::std::option::Option<&'__a __FK>>),
60            );
61        where_clause.predicates.push(
62                parse_quote!(for<'__a> &'__a #parent_struct: diesel::associations::Identifiable<Id = &'__a __FK>),
63            );
64    }
65
66    let foreign_key_expr = quote!(std::convert::Into::into(&self.#foreign_key_name));
67    let foreign_key_ty = quote!(__FK);
68
69    let (impl_generics, _, where_clause) = generics.split_for_impl();
70
71    Ok(quote! {
72        impl #impl_generics diesel::associations::BelongsTo<#parent_struct>
73            for #struct_name #ty_generics
74        #where_clause
75        {
76            type ForeignKey = #foreign_key_ty;
77            type ForeignKeyColumn = #table_name::#foreign_key;
78
79            fn foreign_key(&self) -> std::option::Option<&Self::ForeignKey> {
80                #foreign_key_expr
81            }
82
83            fn foreign_key_column() -> Self::ForeignKeyColumn {
84                #table_name::#foreign_key
85            }
86        }
87
88        impl #impl_generics diesel::associations::BelongsTo<&'_ #parent_struct>
89            for #struct_name #ty_generics
90        #where_clause
91        {
92            type ForeignKey = #foreign_key_ty;
93            type ForeignKeyColumn = #table_name::#foreign_key;
94
95            fn foreign_key(&self) -> std::option::Option<&Self::ForeignKey> {
96                #foreign_key_expr
97            }
98
99            fn foreign_key_column() -> Self::ForeignKeyColumn {
100                #table_name::#foreign_key
101            }
102        }
103    })
104}
105
106fn foreign_key(assoc: &BelongsTo) -> Ident {
107    let ident = &assoc
108        .parent
109        .path
110        .segments
111        .last()
112        .expect("paths always have at least one segment")
113        .ident;
114
115    assoc
116        .foreign_key
117        .clone()
118        .unwrap_or_else(|| infer_foreign_key(ident))
119}
120
121fn infer_foreign_key(name: &Ident) -> Ident {
122    let snake_case = camel_to_snake(&name.to_string());
123    Ident::new(&format!("{snake_case}_id"), name.span())
124}
125
126struct ReplacePathLifetimes<F> {
127    count: usize,
128    f: F,
129}
130
131impl<F> ReplacePathLifetimes<F> {
132    fn new(f: F) -> Self {
133        Self { count: 0, f }
134    }
135}
136
137impl<F> Fold for ReplacePathLifetimes<F>
138where
139    F: FnMut(usize, Span) -> Lifetime,
140{
141    fn fold_lifetime(&mut self, mut lt: Lifetime) -> Lifetime {
142        if lt.ident == "_" {
143            lt = (self.f)(self.count, lt.span());
144            self.count += 1;
145        }
146        lt
147    }
148}