diesel_derives/
attrs.rs

1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16    parse_eq, parse_paren, unknown_attribute, BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE,
17    BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE,
18    SELECT_EXPRESSION_NOTE, SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE,
19    SQL_TYPE_NOTE, TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
20};
21
22use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
23
24pub trait MySpanned {
25    fn span(&self) -> Span;
26}
27
28#[derive(Clone)]
29pub struct AttributeSpanWrapper<T> {
30    pub item: T,
31    pub attribute_span: Span,
32    pub ident_span: Span,
33}
34
35pub enum FieldAttr {
36    Embed(Ident),
37    SkipInsertion(Ident),
38    SkipUpdate(Ident),
39
40    ColumnName(Ident, SqlIdentifier),
41    SqlType(Ident, TypePath),
42    TreatNoneAsDefaultValue(Ident, LitBool),
43    TreatNoneAsNull(Ident, LitBool),
44
45    SerializeAs(Ident, TypePath),
46    DeserializeAs(Ident, TypePath),
47    SelectExpression(Ident, Expr),
48    SelectExpressionType(Ident, Type),
49}
50
51#[derive(Clone)]
52pub struct SqlIdentifier {
53    field_name: String,
54    span: Span,
55}
56
57impl SqlIdentifier {
58    pub fn span(&self) -> Span {
59        self.span
60    }
61
62    pub fn to_ident(&self) -> Result<Ident> {
63        match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
64            Ok(mut ident) => {
65                ident.set_span(self.span);
66                Ok(ident)
67            }
68            Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
69                self.span(),
70                format!(
71                    "expected valid identifier, found `{0}`. \
72                 Diesel does not support column names with whitespaces yet",
73                    self.field_name
74                ),
75            )),
76            Err(_e) => Err(syn::Error::new(
77                self.span(),
78                format!(
79                    "expected valid identifier, found `{0}`. \
80                 Diesel automatically renames invalid identifiers, \
81                 perhaps you meant to write `{0}_`?",
82                    self.field_name
83                ),
84            )),
85        }
86    }
87}
88
89impl ToTokens for SqlIdentifier {
90    fn to_tokens(&self, tokens: &mut TokenStream) {
91        if self.field_name.starts_with("r#") {
92            Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
93        } else {
94            Ident::new(&self.field_name, self.span).to_tokens(tokens)
95        }
96    }
97}
98
99impl Display for SqlIdentifier {
100    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101        let mut start = 0;
102        if self.field_name.starts_with("r#") {
103            start = 2;
104        }
105        f.write_str(&self.field_name[start..])
106    }
107}
108
109impl PartialEq<Ident> for SqlIdentifier {
110    fn eq(&self, other: &Ident) -> bool {
111        *other == self.field_name
112    }
113}
114
115impl From<&'_ Ident> for SqlIdentifier {
116    fn from(ident: &'_ Ident) -> Self {
117        use syn::ext::IdentExt;
118        let ident = ident.unraw();
119        Self {
120            span: ident.span(),
121            field_name: ident.to_string(),
122        }
123    }
124}
125
126impl Parse for SqlIdentifier {
127    fn parse(input: ParseStream) -> Result<Self> {
128        let fork = input.fork();
129
130        if let Ok(ident) = fork.parse::<Ident>() {
131            input.advance_to(&fork);
132            Ok((&ident).into())
133        } else {
134            let name = input.parse::<LitStr>()?;
135            Ok(Self {
136                field_name: name.value(),
137                span: name.span(),
138            })
139        }
140    }
141}
142
143impl Parse for FieldAttr {
144    fn parse(input: ParseStream) -> Result<Self> {
145        let name: Ident = input.parse()?;
146        let name_str = name.to_string();
147
148        match &*name_str {
149            "embed" => Ok(FieldAttr::Embed(name)),
150            "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
151            "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
152
153            "column_name" => Ok(FieldAttr::ColumnName(
154                name,
155                parse_eq(input, COLUMN_NAME_NOTE)?,
156            )),
157            "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
158            "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
159                name,
160                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
161            )),
162            "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
163                name,
164                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
165            )),
166            "serialize_as" => Ok(FieldAttr::SerializeAs(
167                name,
168                parse_eq(input, SERIALIZE_AS_NOTE)?,
169            )),
170            "deserialize_as" => Ok(FieldAttr::DeserializeAs(
171                name,
172                parse_eq(input, DESERIALIZE_AS_NOTE)?,
173            )),
174            "select_expression" => Ok(FieldAttr::SelectExpression(
175                name,
176                parse_eq(input, SELECT_EXPRESSION_NOTE)?,
177            )),
178            "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
179                name,
180                parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
181            )),
182            _ => Err(unknown_attribute(
183                &name,
184                &[
185                    "embed",
186                    "skip_insertion",
187                    "column_name",
188                    "sql_type",
189                    "treat_none_as_default_value",
190                    "treat_none_as_null",
191                    "serialize_as",
192                    "deserialize_as",
193                    "select_expression",
194                    "select_expression_type",
195                ],
196            )),
197        }
198    }
199}
200
201impl MySpanned for FieldAttr {
202    fn span(&self) -> Span {
203        match self {
204            FieldAttr::Embed(ident)
205            | FieldAttr::SkipInsertion(ident)
206            | FieldAttr::SkipUpdate(ident)
207            | FieldAttr::ColumnName(ident, _)
208            | FieldAttr::SqlType(ident, _)
209            | FieldAttr::TreatNoneAsNull(ident, _)
210            | FieldAttr::TreatNoneAsDefaultValue(ident, _)
211            | FieldAttr::SerializeAs(ident, _)
212            | FieldAttr::DeserializeAs(ident, _)
213            | FieldAttr::SelectExpression(ident, _)
214            | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
215        }
216    }
217}
218
219#[allow(clippy::large_enum_variant)]
220pub enum StructAttr {
221    Aggregate(Ident),
222    NotSized(Ident),
223    ForeignDerive(Ident),
224
225    TableName(Ident, Path),
226    SqlType(Ident, TypePath),
227    TreatNoneAsDefaultValue(Ident, LitBool),
228    TreatNoneAsNull(Ident, LitBool),
229
230    BelongsTo(Ident, BelongsTo),
231    MysqlType(Ident, MysqlType),
232    SqliteType(Ident, SqliteType),
233    PostgresType(Ident, PostgresType),
234    PrimaryKey(Ident, Punctuated<Ident, Comma>),
235    CheckForBackend(Ident, CheckForBackend),
236    BaseQuery(Ident, Expr),
237    BaseQueryType(Ident, Type),
238}
239
240impl Parse for StructAttr {
241    fn parse(input: ParseStream) -> Result<Self> {
242        let name: Ident = input.parse()?;
243        let name_str = name.to_string();
244
245        match &*name_str {
246            "aggregate" => Ok(StructAttr::Aggregate(name)),
247            "not_sized" => Ok(StructAttr::NotSized(name)),
248            "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
249
250            "table_name" => Ok(StructAttr::TableName(
251                name,
252                parse_eq(input, TABLE_NAME_NOTE)?,
253            )),
254            "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
255            "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
256                name,
257                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
258            )),
259            "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
260                name,
261                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
262            )),
263
264            "belongs_to" => Ok(StructAttr::BelongsTo(
265                name,
266                parse_paren(input, BELONGS_TO_NOTE)?,
267            )),
268            "mysql_type" => Ok(StructAttr::MysqlType(
269                name,
270                parse_paren(input, MYSQL_TYPE_NOTE)?,
271            )),
272            "sqlite_type" => Ok(StructAttr::SqliteType(
273                name,
274                parse_paren(input, SQLITE_TYPE_NOTE)?,
275            )),
276            "postgres_type" => Ok(StructAttr::PostgresType(
277                name,
278                parse_paren(input, POSTGRES_TYPE_NOTE)?,
279            )),
280            "primary_key" => Ok(StructAttr::PrimaryKey(
281                name,
282                parse_paren_list(input, "key1, key2", syn::Token![,])?,
283            )),
284            "check_for_backend" => {
285                let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
286                    CheckForBackend::Disabled(
287                        parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
288                    )
289                } else {
290                    CheckForBackend::Backends(parse_paren_list(
291                        input,
292                        CHECK_FOR_BACKEND_NOTE,
293                        syn::Token![,],
294                    )?)
295                };
296                Ok(StructAttr::CheckForBackend(name, value))
297            }
298            "base_query" => Ok(StructAttr::BaseQuery(
299                name,
300                parse_eq(input, BASE_QUERY_NOTE)?,
301            )),
302            "base_query_type" => Ok(StructAttr::BaseQueryType(
303                name,
304                parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
305            )),
306            _ => Err(unknown_attribute(
307                &name,
308                &[
309                    "aggregate",
310                    "not_sized",
311                    "foreign_derive",
312                    "table_name",
313                    "sql_type",
314                    "treat_none_as_default_value",
315                    "treat_none_as_null",
316                    "belongs_to",
317                    "mysql_type",
318                    "sqlite_type",
319                    "postgres_type",
320                    "primary_key",
321                    "check_for_backend",
322                    "base_query",
323                    "base_query_type",
324                ],
325            )),
326        }
327    }
328}
329
330impl MySpanned for StructAttr {
331    fn span(&self) -> Span {
332        match self {
333            StructAttr::Aggregate(ident)
334            | StructAttr::NotSized(ident)
335            | StructAttr::ForeignDerive(ident)
336            | StructAttr::TableName(ident, _)
337            | StructAttr::SqlType(ident, _)
338            | StructAttr::TreatNoneAsDefaultValue(ident, _)
339            | StructAttr::TreatNoneAsNull(ident, _)
340            | StructAttr::BelongsTo(ident, _)
341            | StructAttr::MysqlType(ident, _)
342            | StructAttr::SqliteType(ident, _)
343            | StructAttr::PostgresType(ident, _)
344            | StructAttr::CheckForBackend(ident, _)
345            | StructAttr::BaseQuery(ident, _)
346            | StructAttr::BaseQueryType(ident, _)
347            | StructAttr::PrimaryKey(ident, _) => ident.span(),
348        }
349    }
350}
351
352pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
353where
354    T: Parse + ParseDeprecated + MySpanned,
355{
356    let mut out = Vec::new();
357    for attr in attrs {
358        if attr.meta.path().is_ident("diesel") {
359            let map = attr
360                .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
361                .into_iter()
362                .map(|a| AttributeSpanWrapper {
363                    ident_span: a.span(),
364                    item: a,
365                    attribute_span: attr.meta.span(),
366                });
367            out.extend(map);
368        } else if cfg!(all(
369            not(feature = "without-deprecated"),
370            feature = "with-deprecated"
371        )) {
372            let path = attr.meta.path();
373            let ident = path.get_ident().map(|f| f.to_string());
374
375            if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
376            | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
377                ident.as_deref().unwrap_or_default()
378            {
379                let m = &attr.meta;
380                let ts = quote::quote!(#m).into();
381                let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
382
383                if let Some(value) = value {
384                    out.push(AttributeSpanWrapper {
385                        ident_span: value.span(),
386                        item: value,
387                        attribute_span: attr.meta.span(),
388                    });
389                }
390            }
391        }
392    }
393    Ok(out)
394}
395
396struct DisabledCheckForBackend {
397    value: LitBool,
398}
399
400impl syn::parse::Parse for DisabledCheckForBackend {
401    fn parse(input: ParseStream) -> Result<Self> {
402        let ident = input.parse::<Ident>()?;
403        if ident != "disable" {
404            return Err(syn::Error::new(
405                ident.span(),
406                format!("expected `disable`, but got `{ident}`"),
407            ));
408        }
409        let lit = parse_eq::<LitBool>(input, "")?;
410        if !lit.value {
411            return Err(syn::Error::new(
412                lit.span(),
413                "only `true` is accepted in this position. \
414                 If you want to enable these checks, just skip the attribute entirely",
415            ));
416        }
417        Ok(Self { value: lit })
418    }
419}