diesel_derives/
associations.rs
1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::fold::Fold;
4use syn::parse_quote;
5use syn::{DeriveInput, Ident, Lifetime, Result};
6
7use crate::model::Model;
8use crate::parsers::BelongsTo;
9use crate::util::{camel_to_snake, wrap_in_dummy_mod};
10
11pub fn derive(item: DeriveInput) -> Result<TokenStream> {
12 let model = Model::from_item(&item, false, false)?;
13
14 if model.belongs_to.is_empty() {
15 return Err(syn::Error::new(
16 proc_macro2::Span::call_site(),
17 "At least one `belongs_to` is needed for deriving `Associations` on a structure.",
18 ));
19 }
20
21 let tokens = model
22 .belongs_to
23 .iter()
24 .map(|assoc| derive_belongs_to(&item, &model, assoc))
25 .collect::<Result<Vec<_>>>()?;
26
27 Ok(wrap_in_dummy_mod(quote!(#(#tokens)*)))
28}
29
30fn derive_belongs_to(item: &DeriveInput, model: &Model, assoc: &BelongsTo) -> Result<TokenStream> {
31 let (_, ty_generics, _) = item.generics.split_for_impl();
32
33 let struct_name = &item.ident;
34 let table_name = &model.table_names()[0];
35
36 let foreign_key = &foreign_key(assoc);
37
38 let foreign_key_field = model.find_column(foreign_key)?;
39 let foreign_key_name = &foreign_key_field.name;
40 let foreign_key_ty = &foreign_key_field.ty;
41
42 let mut generics = item.generics.clone();
43
44 let parent_struct = ReplacePathLifetimes::new(|i, span| {
45 let letter = char::from(b'b' + i as u8);
46 let lifetime = Lifetime::new(&format!("'__{letter}"), span);
47 generics.params.push(parse_quote!(#lifetime));
48 lifetime
49 })
50 .fold_type_path(assoc.parent.clone());
51
52 generics.params.push(parse_quote!(__FK));
53 {
54 let where_clause = generics.where_clause.get_or_insert(parse_quote!(where));
55 where_clause
56 .predicates
57 .push(parse_quote!(__FK: std::hash::Hash + std::cmp::Eq));
58 where_clause.predicates.push(
59 parse_quote!(for<'__a> &'__a #foreign_key_ty: std::convert::Into<::std::option::Option<&'__a __FK>>),
60 );
61 where_clause.predicates.push(
62 parse_quote!(for<'__a> &'__a #parent_struct: diesel::associations::Identifiable<Id = &'__a __FK>),
63 );
64 }
65
66 let foreign_key_expr = quote!(std::convert::Into::into(&self.#foreign_key_name));
67 let foreign_key_ty = quote!(__FK);
68
69 let (impl_generics, _, where_clause) = generics.split_for_impl();
70
71 Ok(quote! {
72 impl #impl_generics diesel::associations::BelongsTo<#parent_struct>
73 for #struct_name #ty_generics
74 #where_clause
75 {
76 type ForeignKey = #foreign_key_ty;
77 type ForeignKeyColumn = #table_name::#foreign_key;
78
79 fn foreign_key(&self) -> std::option::Option<&Self::ForeignKey> {
80 #foreign_key_expr
81 }
82
83 fn foreign_key_column() -> Self::ForeignKeyColumn {
84 #table_name::#foreign_key
85 }
86 }
87
88 impl #impl_generics diesel::associations::BelongsTo<&'_ #parent_struct>
89 for #struct_name #ty_generics
90 #where_clause
91 {
92 type ForeignKey = #foreign_key_ty;
93 type ForeignKeyColumn = #table_name::#foreign_key;
94
95 fn foreign_key(&self) -> std::option::Option<&Self::ForeignKey> {
96 #foreign_key_expr
97 }
98
99 fn foreign_key_column() -> Self::ForeignKeyColumn {
100 #table_name::#foreign_key
101 }
102 }
103 })
104}
105
106fn foreign_key(assoc: &BelongsTo) -> Ident {
107 let ident = &assoc
108 .parent
109 .path
110 .segments
111 .last()
112 .expect("paths always have at least one segment")
113 .ident;
114
115 assoc
116 .foreign_key
117 .clone()
118 .unwrap_or_else(|| infer_foreign_key(ident))
119}
120
121fn infer_foreign_key(name: &Ident) -> Ident {
122 let snake_case = camel_to_snake(&name.to_string());
123 Ident::new(&format!("{snake_case}_id"), name.span())
124}
125
126struct ReplacePathLifetimes<F> {
127 count: usize,
128 f: F,
129}
130
131impl<F> ReplacePathLifetimes<F> {
132 fn new(f: F) -> Self {
133 Self { count: 0, f }
134 }
135}
136
137impl<F> Fold for ReplacePathLifetimes<F>
138where
139 F: FnMut(usize, Span) -> Lifetime,
140{
141 fn fold_lifetime(&mut self, mut lt: Lifetime) -> Lifetime {
142 if lt.ident == "_" {
143 lt = (self.f)(self.count, lt.span());
144 self.count += 1;
145 }
146 lt
147 }
148}