diesel_derives/
model.rs

1use proc_macro2::Span;
2use std::slice::from_ref;
3use syn::punctuated::Punctuated;
4use syn::token::Comma;
5use syn::Result;
6use syn::{
7    Data, DataStruct, DeriveInput, Field as SynField, Fields, FieldsNamed, FieldsUnnamed, Ident,
8    LitBool, Path, Type,
9};
10
11use crate::attrs::{parse_attributes, StructAttr};
12use crate::field::Field;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::camel_to_snake;
15
16pub struct Model {
17    name: Path,
18    table_names: Vec<Path>,
19    pub primary_key_names: Vec<Ident>,
20    treat_none_as_default_value: Option<LitBool>,
21    treat_none_as_null: Option<LitBool>,
22    pub belongs_to: Vec<BelongsTo>,
23    pub sql_types: Vec<Type>,
24    pub aggregate: bool,
25    pub not_sized: bool,
26    pub foreign_derive: bool,
27    pub mysql_type: Option<MysqlType>,
28    pub sqlite_type: Option<SqliteType>,
29    pub postgres_type: Option<PostgresType>,
30    pub check_for_backend: Option<CheckForBackend>,
31    pub base_query: Option<syn::Expr>,
32    pub base_query_type: Option<syn::Type>,
33    fields: Vec<Field>,
34}
35
36pub enum CheckForBackend {
37    Backends(syn::punctuated::Punctuated<syn::TypePath, syn::Token![,]>),
38    Disabled(LitBool),
39}
40
41impl Model {
42    pub fn from_item(
43        item: &DeriveInput,
44        allow_unit_structs: bool,
45        allow_multiple_table: bool,
46    ) -> Result<Self> {
47        let DeriveInput {
48            data, ident, attrs, ..
49        } = item;
50
51        let fields = match *data {
52            Data::Struct(DataStruct {
53                fields: Fields::Named(FieldsNamed { ref named, .. }),
54                ..
55            }) => Some(named),
56            Data::Struct(DataStruct {
57                fields: Fields::Unnamed(FieldsUnnamed { ref unnamed, .. }),
58                ..
59            }) => Some(unnamed),
60            _ if !allow_unit_structs => {
61                return Err(syn::Error::new(
62                    proc_macro2::Span::mixed_site(),
63                    "this derive can only be used on non-unit structs",
64                ));
65            }
66            _ => None,
67        };
68
69        let mut table_names = vec![];
70        let mut primary_key_names = vec![Ident::new("id", Span::mixed_site())];
71        let mut treat_none_as_default_value = None;
72        let mut treat_none_as_null = None;
73        let mut belongs_to = vec![];
74        let mut sql_types = vec![];
75        let mut aggregate = false;
76        let mut not_sized = false;
77        let mut foreign_derive = false;
78        let mut mysql_type = None;
79        let mut sqlite_type = None;
80        let mut postgres_type = None;
81        let mut check_for_backend = None;
82        let mut base_query = None;
83        let mut base_query_type = None;
84
85        for attr in parse_attributes(attrs)? {
86            match attr.item {
87                StructAttr::SqlType(_, value) => sql_types.push(Type::Path(value)),
88                StructAttr::TableName(ident, value) => {
89                    if !allow_multiple_table && !table_names.is_empty() {
90                        return Err(syn::Error::new(
91                            ident.span(),
92                            "expected a single table name attribute\n\
93                             note: remove this attribute",
94                        ));
95                    }
96                    table_names.push(value)
97                }
98                StructAttr::PrimaryKey(_, keys) => {
99                    primary_key_names = keys.into_iter().collect();
100                }
101                StructAttr::TreatNoneAsDefaultValue(_, val) => {
102                    treat_none_as_default_value = Some(val)
103                }
104                StructAttr::TreatNoneAsNull(_, val) => treat_none_as_null = Some(val),
105                StructAttr::BelongsTo(_, val) => belongs_to.push(val),
106                StructAttr::Aggregate(_) => aggregate = true,
107                StructAttr::NotSized(_) => not_sized = true,
108                StructAttr::ForeignDerive(_) => foreign_derive = true,
109                StructAttr::MysqlType(_, val) => mysql_type = Some(val),
110                StructAttr::SqliteType(_, val) => sqlite_type = Some(val),
111                StructAttr::PostgresType(_, val) => postgres_type = Some(val),
112                StructAttr::CheckForBackend(_, b) => {
113                    check_for_backend = Some(b);
114                }
115                StructAttr::BaseQuery(_, e) => base_query = Some(e),
116                StructAttr::BaseQueryType(_, t) => base_query_type = Some(t),
117            }
118        }
119
120        let name = Ident::new(&infer_table_name(&ident.to_string()), ident.span()).into();
121
122        Ok(Self {
123            name,
124            table_names,
125            primary_key_names,
126            treat_none_as_default_value,
127            treat_none_as_null,
128            belongs_to,
129            sql_types,
130            aggregate,
131            not_sized,
132            foreign_derive,
133            mysql_type,
134            sqlite_type,
135            postgres_type,
136            fields: fields_from_item_data(fields)?,
137            check_for_backend,
138            base_query,
139            base_query_type,
140        })
141    }
142
143    pub fn table_names(&self) -> &[Path] {
144        match self.table_names.len() {
145            0 => from_ref(&self.name),
146            _ => &self.table_names,
147        }
148    }
149
150    pub fn fields(&self) -> &[Field] {
151        &self.fields
152    }
153
154    pub fn find_column(&self, column_name: &Ident) -> Result<&Field> {
155        self.fields()
156            .iter()
157            .find(|f| {
158                f.column_name()
159                    .map(|c| c == *column_name)
160                    .unwrap_or_default()
161            })
162            .ok_or_else(|| {
163                syn::Error::new(
164                    column_name.span(),
165                    format!("no field with column name `{column_name}`"),
166                )
167            })
168    }
169
170    pub fn treat_none_as_default_value(&self) -> bool {
171        self.treat_none_as_default_value
172            .as_ref()
173            .map(|v| v.value())
174            .unwrap_or(true)
175    }
176
177    pub fn treat_none_as_null(&self) -> bool {
178        self.treat_none_as_null
179            .as_ref()
180            .map(|v| v.value())
181            .unwrap_or(false)
182    }
183}
184
185fn fields_from_item_data(fields: Option<&Punctuated<SynField, Comma>>) -> Result<Vec<Field>> {
186    fields
187        .map(|fields| {
188            fields
189                .iter()
190                .enumerate()
191                .map(|(i, f)| Field::from_struct_field(f, i))
192                .collect::<Result<Vec<_>>>()
193        })
194        .unwrap_or_else(|| Ok(Vec::new()))
195}
196
197pub fn infer_table_name(name: &str) -> String {
198    let mut result = camel_to_snake(name);
199    result.push('s');
200    result
201}
202
203#[test]
204fn infer_table_name_pluralizes_and_downcases() {
205    assert_eq!("foos", &infer_table_name("Foo"));
206    assert_eq!("bars", &infer_table_name("Bar"));
207}
208
209#[test]
210fn infer_table_name_properly_handles_underscores() {
211    assert_eq!("foo_bars", &infer_table_name("FooBar"));
212    assert_eq!("foo_bar_bazs", &infer_table_name("FooBarBaz"));
213}