Skip to main content

diesel_derives/
model.rs

1use proc_macro2::Span;
2use std::slice::from_ref;
3use syn::Result;
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{
7    Data, DataStruct, DeriveInput, Field as SynField, Fields, FieldsNamed, FieldsUnnamed, Ident,
8    LitBool, Path, Type,
9};
10
11use crate::attrs::{StructAttr, parse_attributes};
12use crate::field::Field;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::camel_to_snake;
15
16pub struct Model {
17    name: Path,
18    table_names: Vec<Path>,
19    pub primary_key_names: Vec<Ident>,
20    treat_none_as_default_value: Option<LitBool>,
21    treat_none_as_null: Option<LitBool>,
22    pub belongs_to: Vec<BelongsTo>,
23    pub sql_types: Vec<Type>,
24    pub aggregate: bool,
25    pub not_sized: bool,
26    pub foreign_derive: bool,
27    pub enum_type: bool,
28    pub mysql_type: Option<MysqlType>,
29    pub sqlite_type: Option<SqliteType>,
30    pub postgres_type: Option<PostgresType>,
31    pub check_for_backend: Option<CheckForBackend>,
32    pub base_query: Option<syn::Expr>,
33    pub base_query_type: Option<syn::Type>,
34    fields: Vec<Field>,
35}
36
37pub enum CheckForBackend {
38    Backends(syn::punctuated::Punctuated<syn::TypePath, ::syn::token::Commasyn::Token![,]>),
39    Disabled(LitBool),
40}
41
42impl Model {
43    pub fn from_item(
44        item: &DeriveInput,
45        allow_unit_structs: bool,
46        allow_multiple_table: bool,
47    ) -> Result<Self> {
48        let DeriveInput {
49            data, ident, attrs, ..
50        } = item;
51
52        let fields = match *data {
53            Data::Struct(DataStruct {
54                fields: Fields::Named(FieldsNamed { ref named, .. }),
55                ..
56            }) => Some(named),
57            Data::Struct(DataStruct {
58                fields: Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }),
59                ..
60            }) => Some(unnamed),
61            _ if !allow_unit_structs => {
62                return Err(syn::Error::new(
63                    proc_macro2::Span::mixed_site(),
64                    "this derive can only be used on non-unit structs",
65                ));
66            }
67            _ => None,
68        };
69
70        let mut table_names = ::alloc::vec::Vec::new()vec![];
71        let mut primary_key_names = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [Ident::new("id", Span::mixed_site())]))vec![Ident::new("id", Span::mixed_site())];
72        let mut treat_none_as_default_value = None;
73        let mut treat_none_as_null = None;
74        let mut belongs_to = ::alloc::vec::Vec::new()vec![];
75        let mut sql_types = ::alloc::vec::Vec::new()vec![];
76        let mut aggregate = false;
77        let mut not_sized = false;
78        let mut enum_type = false;
79        let mut foreign_derive = false;
80        let mut mysql_type = None;
81        let mut sqlite_type = None;
82        let mut postgres_type = None;
83        let mut check_for_backend = None;
84        let mut base_query = None;
85        let mut base_query_type = None;
86
87        for attr in parse_attributes(attrs)? {
88            match attr.item {
89                StructAttr::SqlType(_, value) => sql_types.push(Type::Path(value)),
90                StructAttr::TableName(ident, value) => {
91                    if !allow_multiple_table && !table_names.is_empty() {
92                        return Err(syn::Error::new(
93                            ident.span(),
94                            "expected a single table name attribute\n\
95                             note: remove this attribute",
96                        ));
97                    }
98                    table_names.push(value)
99                }
100                StructAttr::PrimaryKey(_, keys) => {
101                    primary_key_names = keys.into_iter().collect();
102                }
103                StructAttr::TreatNoneAsDefaultValue(_, val) => {
104                    treat_none_as_default_value = Some(val)
105                }
106                StructAttr::TreatNoneAsNull(_, val) => treat_none_as_null = Some(val),
107                StructAttr::BelongsTo(_, val) => belongs_to.push(val),
108                StructAttr::Aggregate(_) => aggregate = true,
109                StructAttr::NotSized(_) => not_sized = true,
110                StructAttr::ForeignDerive(_) => foreign_derive = true,
111                StructAttr::EnumType(_) => enum_type = true,
112                StructAttr::MysqlType(_, val) => mysql_type = Some(val),
113                StructAttr::SqliteType(_, val) => sqlite_type = Some(val),
114                StructAttr::PostgresType(_, val) => postgres_type = Some(val),
115                StructAttr::CheckForBackend(_, b) => {
116                    check_for_backend = Some(b);
117                }
118                StructAttr::BaseQuery(_, e) => base_query = Some(e),
119                StructAttr::BaseQueryType(_, t) => base_query_type = Some(t),
120                StructAttr::RenameAll(_, _) => { /*ignore here as only relevant for enums*/ }
121            }
122        }
123
124        let name = Ident::new(&infer_table_name(&ident.to_string()), ident.span()).into();
125
126        Ok(Self {
127            name,
128            table_names,
129            primary_key_names,
130            treat_none_as_default_value,
131            treat_none_as_null,
132            belongs_to,
133            sql_types,
134            aggregate,
135            not_sized,
136            foreign_derive,
137            mysql_type,
138            sqlite_type,
139            postgres_type,
140            fields: fields_from_item_data(fields)?,
141            check_for_backend,
142            base_query,
143            base_query_type,
144            enum_type,
145        })
146    }
147
148    pub fn table_names(&self) -> &[Path] {
149        match self.table_names.len() {
150            0 => from_ref(&self.name),
151            _ => &self.table_names,
152        }
153    }
154
155    pub fn fields(&self) -> &[Field] {
156        &self.fields
157    }
158
159    pub fn find_column(&self, column_name: &Ident) -> Result<&Field> {
160        self.fields()
161            .iter()
162            .find(|f| {
163                f.column_name()
164                    .map(|c| c == *column_name)
165                    .unwrap_or_default()
166            })
167            .ok_or_else(|| {
168                syn::Error::new(
169                    column_name.span(),
170                    ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("no field with column name `{0}`",
                column_name))
    })format!("no field with column name `{column_name}`"),
171                )
172            })
173    }
174
175    pub fn treat_none_as_default_value(&self) -> bool {
176        self.treat_none_as_default_value
177            .as_ref()
178            .map(|v| v.value())
179            .unwrap_or(true)
180    }
181
182    pub fn treat_none_as_null(&self) -> bool {
183        self.treat_none_as_null
184            .as_ref()
185            .map(|v| v.value())
186            .unwrap_or(false)
187    }
188}
189
190fn fields_from_item_data(fields: Option<&Punctuated<SynField, Comma>>) -> Result<Vec<Field>> {
191    fields
192        .map(|fields| {
193            fields
194                .iter()
195                .enumerate()
196                .map(|(i, f)| Field::from_struct_field(f, i))
197                .collect::<Result<Vec<_>>>()
198        })
199        .unwrap_or_else(|| Ok(Vec::new()))
200}
201
202pub fn infer_table_name(name: &str) -> String {
203    let mut result = camel_to_snake(name);
204    result.push('s');
205    result
206}
207
208#[test]
209fn infer_table_name_pluralizes_and_downcases() {
210    assert_eq!("foos", &infer_table_name("Foo"));
211    assert_eq!("bars", &infer_table_name("Bar"));
212}
213
214#[test]
215fn infer_table_name_properly_handles_underscores() {
216    assert_eq!("foo_bars", &infer_table_name("FooBar"));
217    assert_eq!("foo_bar_bazs", &infer_table_name("FooBarBaz"));
218}