diesel_derives/
model.rs

1use proc_macro2::Span;
2use std::slice::from_ref;
3use syn::punctuated::Punctuated;
4use syn::token::Comma;
5use syn::Result;
6use syn::{
7    Data, DataStruct, DeriveInput, Field as SynField, Fields, FieldsNamed, FieldsUnnamed, Ident,
8    LitBool, Path, Type,
9};
10
11use crate::attrs::{parse_attributes, StructAttr};
12use crate::field::Field;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::camel_to_snake;
15
16pub struct Model {
17    name: Path,
18    table_names: Vec<Path>,
19    pub primary_key_names: Vec<Ident>,
20    treat_none_as_default_value: Option<LitBool>,
21    treat_none_as_null: Option<LitBool>,
22    pub belongs_to: Vec<BelongsTo>,
23    pub sql_types: Vec<Type>,
24    pub aggregate: bool,
25    pub not_sized: bool,
26    pub foreign_derive: bool,
27    pub mysql_type: Option<MysqlType>,
28    pub sqlite_type: Option<SqliteType>,
29    pub postgres_type: Option<PostgresType>,
30    pub check_for_backend: Option<syn::punctuated::Punctuated<syn::TypePath, syn::Token![,]>>,
31    fields: Vec<Field>,
32}
33
34impl Model {
35    pub fn from_item(
36        item: &DeriveInput,
37        allow_unit_structs: bool,
38        allow_multiple_table: bool,
39    ) -> Result<Self> {
40        let DeriveInput {
41            data, ident, attrs, ..
42        } = item;
43
44        let fields = match *data {
45            Data::Struct(DataStruct {
46                fields: Fields::Named(FieldsNamed { ref named, .. }),
47                ..
48            }) => Some(named),
49            Data::Struct(DataStruct {
50                fields: Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }),
51                ..
52            }) => Some(unnamed),
53            _ if !allow_unit_structs => {
54                return Err(syn::Error::new(
55                    proc_macro2::Span::call_site(),
56                    "This derive can only be used on non-unit structs",
57                ));
58            }
59            _ => None,
60        };
61
62        let mut table_names = vec![];
63        let mut primary_key_names = vec![Ident::new("id", Span::call_site())];
64        let mut treat_none_as_default_value = None;
65        let mut treat_none_as_null = None;
66        let mut belongs_to = vec![];
67        let mut sql_types = vec![];
68        let mut aggregate = false;
69        let mut not_sized = false;
70        let mut foreign_derive = false;
71        let mut mysql_type = None;
72        let mut sqlite_type = None;
73        let mut postgres_type = None;
74        let mut check_for_backend = None;
75
76        for attr in parse_attributes(attrs)? {
77            match attr.item {
78                StructAttr::SqlType(_, value) => sql_types.push(Type::Path(value)),
79                StructAttr::TableName(ident, value) => {
80                    if !allow_multiple_table && !table_names.is_empty() {
81                        return Err(syn::Error::new(
82                            ident.span(),
83                            "expected a single table name attribute\n\
84                             note: remove this attribute",
85                        ));
86                    }
87                    table_names.push(value)
88                }
89                StructAttr::PrimaryKey(_, keys) => {
90                    primary_key_names = keys.into_iter().collect();
91                }
92                StructAttr::TreatNoneAsDefaultValue(_, val) => {
93                    treat_none_as_default_value = Some(val)
94                }
95                StructAttr::TreatNoneAsNull(_, val) => treat_none_as_null = Some(val),
96                StructAttr::BelongsTo(_, val) => belongs_to.push(val),
97                StructAttr::Aggregate(_) => aggregate = true,
98                StructAttr::NotSized(_) => not_sized = true,
99                StructAttr::ForeignDerive(_) => foreign_derive = true,
100                StructAttr::MysqlType(_, val) => mysql_type = Some(val),
101                StructAttr::SqliteType(_, val) => sqlite_type = Some(val),
102                StructAttr::PostgresType(_, val) => postgres_type = Some(val),
103                StructAttr::CheckForBackend(_, b) => {
104                    check_for_backend = Some(b);
105                }
106            }
107        }
108
109        let name = Ident::new(&infer_table_name(&ident.to_string()), ident.span()).into();
110
111        Ok(Self {
112            name,
113            table_names,
114            primary_key_names,
115            treat_none_as_default_value,
116            treat_none_as_null,
117            belongs_to,
118            sql_types,
119            aggregate,
120            not_sized,
121            foreign_derive,
122            mysql_type,
123            sqlite_type,
124            postgres_type,
125            fields: fields_from_item_data(fields)?,
126            check_for_backend,
127        })
128    }
129
130    pub fn table_names(&self) -> &[Path] {
131        match self.table_names.len() {
132            0 => from_ref(&self.name),
133            _ => &self.table_names,
134        }
135    }
136
137    pub fn fields(&self) -> &[Field] {
138        &self.fields
139    }
140
141    pub fn find_column(&self, column_name: &Ident) -> Result<&Field> {
142        self.fields()
143            .iter()
144            .find(|f| {
145                f.column_name()
146                    .map(|c| c == *column_name)
147                    .unwrap_or_default()
148            })
149            .ok_or_else(|| {
150                syn::Error::new(
151                    column_name.span(),
152                    format!("No field with column name {column_name}"),
153                )
154            })
155    }
156
157    pub fn treat_none_as_default_value(&self) -> bool {
158        self.treat_none_as_default_value
159            .as_ref()
160            .map(|v| v.value())
161            .unwrap_or(true)
162    }
163
164    pub fn treat_none_as_null(&self) -> bool {
165        self.treat_none_as_null
166            .as_ref()
167            .map(|v| v.value())
168            .unwrap_or(false)
169    }
170}
171
172fn fields_from_item_data(fields: Option<&Punctuated<SynField, Comma>>) -> Result<Vec<Field>> {
173    fields
174        .map(|fields| {
175            fields
176                .iter()
177                .enumerate()
178                .map(|(i, f)| Field::from_struct_field(f, i))
179                .collect::<Result<Vec<_>>>()
180        })
181        .unwrap_or_else(|| Ok(Vec::new()))
182}
183
184pub fn infer_table_name(name: &str) -> String {
185    let mut result = camel_to_snake(name);
186    result.push('s');
187    result
188}
189
190#[test]
191fn infer_table_name_pluralizes_and_downcases() {
192    assert_eq!("foos", &infer_table_name("Foo"));
193    assert_eq!("bars", &infer_table_name("Bar"));
194}
195
196#[test]
197fn infer_table_name_properly_handles_underscores() {
198    assert_eq!("foo_bars", &infer_table_name("FooBar"));
199    assert_eq!("foo_bar_bazs", &infer_table_name("FooBarBaz"));
200}