Skip to main content

diesel_derives/
attrs.rs

1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16    BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE, BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE,
17    MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE, SELECT_EXPRESSION_TYPE_NOTE,
18    SERIALIZE_AS_NOTE, SQL_TYPE_NOTE, SQLITE_TYPE_NOTE, TABLE_NAME_NOTE,
19    TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE, parse_eq, parse_paren,
20    unknown_attribute,
21};
22
23use crate::util::{CHECK_FOR_BACKEND_NOTE, parse_paren_list};
24
25pub trait MySpanned {
26    fn span(&self) -> Span;
27}
28
29#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for AttributeSpanWrapper<T>
    {
    #[inline]
    fn clone(&self) -> AttributeSpanWrapper<T> {
        AttributeSpanWrapper {
            item: ::core::clone::Clone::clone(&self.item),
            attribute_span: ::core::clone::Clone::clone(&self.attribute_span),
            ident_span: ::core::clone::Clone::clone(&self.ident_span),
        }
    }
}Clone)]
30pub struct AttributeSpanWrapper<T> {
31    pub item: T,
32    pub attribute_span: Span,
33    pub ident_span: Span,
34}
35
36pub enum FieldAttr {
37    Embed(Ident),
38    SkipInsertion(Ident),
39    SkipUpdate(Ident),
40
41    ColumnName(Ident, SqlIdentifier),
42    SqlType(Ident, TypePath),
43    TreatNoneAsDefaultValue(Ident, LitBool),
44    TreatNoneAsNull(Ident, LitBool),
45
46    SerializeAs(Ident, TypePath),
47    DeserializeAs(Ident, TypePath),
48    SelectExpression(Ident, Expr),
49    SelectExpressionType(Ident, Type),
50}
51
52#[derive(#[automatically_derived]
impl ::core::clone::Clone for SqlIdentifier {
    #[inline]
    fn clone(&self) -> SqlIdentifier {
        SqlIdentifier {
            field_name: ::core::clone::Clone::clone(&self.field_name),
            span: ::core::clone::Clone::clone(&self.span),
        }
    }
}Clone)]
53pub struct SqlIdentifier {
54    field_name: String,
55    span: Span,
56}
57
58impl SqlIdentifier {
59    pub fn span(&self) -> Span {
60        self.span
61    }
62
63    pub fn to_ident(&self) -> Result<Ident> {
64        match syn::parse_str::<Ident>(&::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("r#{0}", self.field_name))
    })format!("r#{}", self.field_name)) {
65            Ok(mut ident) => {
66                ident.set_span(self.span);
67                Ok(ident)
68            }
69            Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
70                self.span(),
71                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel does not support column names with whitespaces yet",
                self.field_name))
    })format!(
72                    "expected valid identifier, found `{0}`. \
73                 Diesel does not support column names with whitespaces yet",
74                    self.field_name
75                ),
76            )),
77            Err(_e) => Err(syn::Error::new(
78                self.span(),
79                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel automatically renames invalid identifiers, perhaps you meant to write `{0}_`?",
                self.field_name))
    })format!(
80                    "expected valid identifier, found `{0}`. \
81                 Diesel automatically renames invalid identifiers, \
82                 perhaps you meant to write `{0}_`?",
83                    self.field_name
84                ),
85            )),
86        }
87    }
88}
89
90impl ToTokens for SqlIdentifier {
91    fn to_tokens(&self, tokens: &mut TokenStream) {
92        if self.field_name.starts_with("r#") {
93            Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
94        } else {
95            Ident::new(&self.field_name, self.span).to_tokens(tokens)
96        }
97    }
98}
99
100impl Display for SqlIdentifier {
101    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102        let mut start = 0;
103        if self.field_name.starts_with("r#") {
104            start = 2;
105        }
106        f.write_str(&self.field_name[start..])
107    }
108}
109
110impl PartialEq<Ident> for SqlIdentifier {
111    fn eq(&self, other: &Ident) -> bool {
112        *other == self.field_name
113    }
114}
115
116impl From<&'_ Ident> for SqlIdentifier {
117    fn from(ident: &'_ Ident) -> Self {
118        use syn::ext::IdentExt;
119        let ident = ident.unraw();
120        Self {
121            span: ident.span(),
122            field_name: ident.to_string(),
123        }
124    }
125}
126
127impl Parse for SqlIdentifier {
128    fn parse(input: ParseStream) -> Result<Self> {
129        let fork = input.fork();
130
131        if let Ok(ident) = fork.parse::<Ident>() {
132            input.advance_to(&fork);
133            Ok((&ident).into())
134        } else {
135            let name = input.parse::<LitStr>()?;
136            Ok(Self {
137                field_name: name.value(),
138                span: name.span(),
139            })
140        }
141    }
142}
143
144impl Parse for FieldAttr {
145    fn parse(input: ParseStream) -> Result<Self> {
146        let name: Ident = input.parse()?;
147        let name_str = name.to_string();
148
149        match &*name_str {
150            "embed" => Ok(FieldAttr::Embed(name)),
151            "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
152            "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
153
154            "column_name" => Ok(FieldAttr::ColumnName(
155                name,
156                parse_eq(input, COLUMN_NAME_NOTE)?,
157            )),
158            "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
159            "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
160                name,
161                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
162            )),
163            "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
164                name,
165                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
166            )),
167            "serialize_as" => Ok(FieldAttr::SerializeAs(
168                name,
169                parse_eq(input, SERIALIZE_AS_NOTE)?,
170            )),
171            "deserialize_as" => Ok(FieldAttr::DeserializeAs(
172                name,
173                parse_eq(input, DESERIALIZE_AS_NOTE)?,
174            )),
175            "select_expression" => Ok(FieldAttr::SelectExpression(
176                name,
177                parse_eq(input, SELECT_EXPRESSION_NOTE)?,
178            )),
179            "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
180                name,
181                parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
182            )),
183            _ => Err(unknown_attribute(
184                &name,
185                &[
186                    "embed",
187                    "skip_insertion",
188                    "column_name",
189                    "sql_type",
190                    "treat_none_as_default_value",
191                    "treat_none_as_null",
192                    "serialize_as",
193                    "deserialize_as",
194                    "select_expression",
195                    "select_expression_type",
196                ],
197            )),
198        }
199    }
200}
201
202impl MySpanned for FieldAttr {
203    fn span(&self) -> Span {
204        match self {
205            FieldAttr::Embed(ident)
206            | FieldAttr::SkipInsertion(ident)
207            | FieldAttr::SkipUpdate(ident)
208            | FieldAttr::ColumnName(ident, _)
209            | FieldAttr::SqlType(ident, _)
210            | FieldAttr::TreatNoneAsNull(ident, _)
211            | FieldAttr::TreatNoneAsDefaultValue(ident, _)
212            | FieldAttr::SerializeAs(ident, _)
213            | FieldAttr::DeserializeAs(ident, _)
214            | FieldAttr::SelectExpression(ident, _)
215            | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
216        }
217    }
218}
219
220#[allow(clippy::large_enum_variant)]
221pub enum StructAttr {
222    Aggregate(Ident),
223    NotSized(Ident),
224    ForeignDerive(Ident),
225
226    TableName(Ident, Path),
227    SqlType(Ident, TypePath),
228    TreatNoneAsDefaultValue(Ident, LitBool),
229    TreatNoneAsNull(Ident, LitBool),
230
231    BelongsTo(Ident, BelongsTo),
232    MysqlType(Ident, MysqlType),
233    SqliteType(Ident, SqliteType),
234    PostgresType(Ident, PostgresType),
235    PrimaryKey(Ident, Punctuated<Ident, Comma>),
236    CheckForBackend(Ident, CheckForBackend),
237    BaseQuery(Ident, Expr),
238    BaseQueryType(Ident, Type),
239}
240
241impl Parse for StructAttr {
242    fn parse(input: ParseStream) -> Result<Self> {
243        let name: Ident = input.parse()?;
244        let name_str = name.to_string();
245
246        match &*name_str {
247            "aggregate" => Ok(StructAttr::Aggregate(name)),
248            "not_sized" => Ok(StructAttr::NotSized(name)),
249            "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
250
251            "table_name" => Ok(StructAttr::TableName(
252                name,
253                parse_eq(input, TABLE_NAME_NOTE)?,
254            )),
255            "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
256            "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
257                name,
258                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
259            )),
260            "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
261                name,
262                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
263            )),
264
265            "belongs_to" => Ok(StructAttr::BelongsTo(
266                name,
267                parse_paren(input, BELONGS_TO_NOTE)?,
268            )),
269            "mysql_type" => Ok(StructAttr::MysqlType(
270                name,
271                parse_paren(input, MYSQL_TYPE_NOTE)?,
272            )),
273            "sqlite_type" => Ok(StructAttr::SqliteType(
274                name,
275                parse_paren(input, SQLITE_TYPE_NOTE)?,
276            )),
277            "postgres_type" => Ok(StructAttr::PostgresType(
278                name,
279                parse_paren(input, POSTGRES_TYPE_NOTE)?,
280            )),
281            "primary_key" => Ok(StructAttr::PrimaryKey(
282                name,
283                parse_paren_list(input, "key1, key2", ::syn::token::Commasyn::Token![,])?,
284            )),
285            "check_for_backend" => {
286                let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
287                    CheckForBackend::Disabled(
288                        parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
289                    )
290                } else {
291                    CheckForBackend::Backends(parse_paren_list(
292                        input,
293                        CHECK_FOR_BACKEND_NOTE,
294                        ::syn::token::Commasyn::Token![,],
295                    )?)
296                };
297                Ok(StructAttr::CheckForBackend(name, value))
298            }
299            "base_query" => Ok(StructAttr::BaseQuery(
300                name,
301                parse_eq(input, BASE_QUERY_NOTE)?,
302            )),
303            "base_query_type" => Ok(StructAttr::BaseQueryType(
304                name,
305                parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
306            )),
307            _ => Err(unknown_attribute(
308                &name,
309                &[
310                    "aggregate",
311                    "not_sized",
312                    "foreign_derive",
313                    "table_name",
314                    "sql_type",
315                    "treat_none_as_default_value",
316                    "treat_none_as_null",
317                    "belongs_to",
318                    "mysql_type",
319                    "sqlite_type",
320                    "postgres_type",
321                    "primary_key",
322                    "check_for_backend",
323                    "base_query",
324                    "base_query_type",
325                ],
326            )),
327        }
328    }
329}
330
331impl MySpanned for StructAttr {
332    fn span(&self) -> Span {
333        match self {
334            StructAttr::Aggregate(ident)
335            | StructAttr::NotSized(ident)
336            | StructAttr::ForeignDerive(ident)
337            | StructAttr::TableName(ident, _)
338            | StructAttr::SqlType(ident, _)
339            | StructAttr::TreatNoneAsDefaultValue(ident, _)
340            | StructAttr::TreatNoneAsNull(ident, _)
341            | StructAttr::BelongsTo(ident, _)
342            | StructAttr::MysqlType(ident, _)
343            | StructAttr::SqliteType(ident, _)
344            | StructAttr::PostgresType(ident, _)
345            | StructAttr::CheckForBackend(ident, _)
346            | StructAttr::BaseQuery(ident, _)
347            | StructAttr::BaseQueryType(ident, _)
348            | StructAttr::PrimaryKey(ident, _) => ident.span(),
349        }
350    }
351}
352
353pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
354where
355    T: Parse + ParseDeprecated + MySpanned,
356{
357    let mut out = Vec::new();
358    for attr in attrs {
359        if attr.meta.path().is_ident("diesel") {
360            let map = attr
361                .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
362                .into_iter()
363                .map(|a| AttributeSpanWrapper {
364                    ident_span: a.span(),
365                    item: a,
366                    attribute_span: attr.meta.span(),
367                });
368            out.extend(map);
369        } else if truecfg!(all(
370            not(feature = "without-deprecated"),
371            feature = "with-deprecated"
372        )) {
373            let path = attr.meta.path();
374            let ident = path.get_ident().map(|f| f.to_string());
375
376            if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
377            | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
378                ident.as_deref().unwrap_or_default()
379            {
380                let m = &attr.meta;
381                let ts = {
    let mut _s = ::quote::__private::TokenStream::new();
    ::quote::ToTokens::to_tokens(&m, &mut _s);
    _s
}quote::quote!(#m).into();
382                let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
383
384                if let Some(value) = value {
385                    out.push(AttributeSpanWrapper {
386                        ident_span: value.span(),
387                        item: value,
388                        attribute_span: attr.meta.span(),
389                    });
390                }
391            }
392        }
393    }
394    Ok(out)
395}
396
397struct DisabledCheckForBackend {
398    value: LitBool,
399}
400
401impl syn::parse::Parse for DisabledCheckForBackend {
402    fn parse(input: ParseStream) -> Result<Self> {
403        let ident = input.parse::<Ident>()?;
404        if ident != "disable" {
405            return Err(syn::Error::new(
406                ident.span(),
407                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected `disable`, but got `{0}`",
                ident))
    })format!("expected `disable`, but got `{ident}`"),
408            ));
409        }
410        let lit = parse_eq::<LitBool>(input, "")?;
411        if !lit.value {
412            return Err(syn::Error::new(
413                lit.span(),
414                "only `true` is accepted in this position. \
415                 If you want to enable these checks, just skip the attribute entirely",
416            ));
417        }
418        Ok(Self { value: lit })
419    }
420}