Skip to main content

diesel_derives/
attrs.rs

1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16    BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE, BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE,
17    MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, RENAME_ALL_NOTE, RENAME_NOTE, SELECT_EXPRESSION_NOTE,
18    SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQL_TYPE_NOTE, SQLITE_TYPE_NOTE,
19    TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE, parse_eq,
20    parse_eq_type, parse_paren, unknown_attribute,
21};
22
23use crate::util::{CHECK_FOR_BACKEND_NOTE, parse_paren_list};
24
25pub trait MySpanned {
26    fn span(&self) -> Span;
27}
28
29#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for AttributeSpanWrapper<T>
    {
    #[inline]
    fn clone(&self) -> AttributeSpanWrapper<T> {
        AttributeSpanWrapper {
            item: ::core::clone::Clone::clone(&self.item),
            attribute_span: ::core::clone::Clone::clone(&self.attribute_span),
            ident_span: ::core::clone::Clone::clone(&self.ident_span),
        }
    }
}Clone)]
30pub struct AttributeSpanWrapper<T> {
31    pub item: T,
32    pub attribute_span: Span,
33    pub ident_span: Span,
34}
35
36pub enum FieldAttr {
37    Embed(Ident),
38    SkipInsertion(Ident),
39    SkipUpdate(Ident),
40
41    ColumnName(Ident, SqlIdentifier),
42    SqlType(Ident, TypePath),
43    TreatNoneAsDefaultValue(Ident, LitBool),
44    TreatNoneAsNull(Ident, LitBool),
45
46    SerializeAs(Ident, Type),
47    DeserializeAs(Ident, Type),
48    SelectExpression(Ident, Expr),
49    SelectExpressionType(Ident, Type),
50    Rename(Ident, LitStr),
51}
52
53#[derive(#[automatically_derived]
impl ::core::clone::Clone for SqlIdentifier {
    #[inline]
    fn clone(&self) -> SqlIdentifier {
        SqlIdentifier {
            field_name: ::core::clone::Clone::clone(&self.field_name),
            span: ::core::clone::Clone::clone(&self.span),
        }
    }
}Clone)]
54pub struct SqlIdentifier {
55    field_name: String,
56    span: Span,
57}
58
59impl SqlIdentifier {
60    pub fn span(&self) -> Span {
61        self.span
62    }
63
64    pub fn to_ident(&self) -> Result<Ident> {
65        match syn::parse_str::<Ident>(&::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("r#{0}", self.field_name))
    })format!("r#{}", self.field_name)) {
66            Ok(mut ident) => {
67                ident.set_span(self.span);
68                Ok(ident)
69            }
70            Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
71                self.span(),
72                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel does not support column names with whitespaces yet",
                self.field_name))
    })format!(
73                    "expected valid identifier, found `{0}`. \
74                 Diesel does not support column names with whitespaces yet",
75                    self.field_name
76                ),
77            )),
78            Err(_e) => Err(syn::Error::new(
79                self.span(),
80                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel automatically renames invalid identifiers, perhaps you meant to write `{0}_`?",
                self.field_name))
    })format!(
81                    "expected valid identifier, found `{0}`. \
82                 Diesel automatically renames invalid identifiers, \
83                 perhaps you meant to write `{0}_`?",
84                    self.field_name
85                ),
86            )),
87        }
88    }
89}
90
91impl ToTokens for SqlIdentifier {
92    fn to_tokens(&self, tokens: &mut TokenStream) {
93        if self.field_name.starts_with("r#") {
94            Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
95        } else {
96            Ident::new(&self.field_name, self.span).to_tokens(tokens)
97        }
98    }
99}
100
101impl Display for SqlIdentifier {
102    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103        let mut start = 0;
104        if self.field_name.starts_with("r#") {
105            start = 2;
106        }
107        f.write_str(&self.field_name[start..])
108    }
109}
110
111impl PartialEq<Ident> for SqlIdentifier {
112    fn eq(&self, other: &Ident) -> bool {
113        *other == self.field_name
114    }
115}
116
117impl From<&'_ Ident> for SqlIdentifier {
118    fn from(ident: &'_ Ident) -> Self {
119        use syn::ext::IdentExt;
120        let ident = ident.unraw();
121        Self {
122            span: ident.span(),
123            field_name: ident.to_string(),
124        }
125    }
126}
127
128impl Parse for SqlIdentifier {
129    fn parse(input: ParseStream) -> Result<Self> {
130        let fork = input.fork();
131
132        if let Ok(ident) = fork.parse::<Ident>() {
133            input.advance_to(&fork);
134            Ok((&ident).into())
135        } else {
136            let name = input.parse::<LitStr>()?;
137            Ok(Self {
138                field_name: name.value(),
139                span: name.span(),
140            })
141        }
142    }
143}
144
145impl Parse for FieldAttr {
146    fn parse(input: ParseStream) -> Result<Self> {
147        let name: Ident = input.parse()?;
148        let name_str = name.to_string();
149
150        match &*name_str {
151            "embed" => Ok(FieldAttr::Embed(name)),
152            "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
153            "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
154
155            "column_name" => Ok(FieldAttr::ColumnName(
156                name,
157                parse_eq(input, COLUMN_NAME_NOTE)?,
158            )),
159            "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
160            "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
161                name,
162                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
163            )),
164            "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
165                name,
166                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
167            )),
168            "serialize_as" => Ok(FieldAttr::SerializeAs(
169                name,
170                parse_eq_type(input, SERIALIZE_AS_NOTE)?,
171            )),
172            "deserialize_as" => Ok(FieldAttr::DeserializeAs(
173                name,
174                parse_eq_type(input, DESERIALIZE_AS_NOTE)?,
175            )),
176            "select_expression" => Ok(FieldAttr::SelectExpression(
177                name,
178                parse_eq(input, SELECT_EXPRESSION_NOTE)?,
179            )),
180            "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
181                name,
182                parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
183            )),
184            "rename" => Ok(FieldAttr::Rename(name, parse_eq(input, RENAME_NOTE)?)),
185            _ => Err(unknown_attribute(
186                &name,
187                &[
188                    "embed",
189                    "skip_insertion",
190                    "column_name",
191                    "sql_type",
192                    "treat_none_as_default_value",
193                    "treat_none_as_null",
194                    "serialize_as",
195                    "deserialize_as",
196                    "select_expression",
197                    "select_expression_type",
198                    "rename",
199                ],
200            )),
201        }
202    }
203}
204
205impl MySpanned for FieldAttr {
206    fn span(&self) -> Span {
207        match self {
208            FieldAttr::Embed(ident)
209            | FieldAttr::SkipInsertion(ident)
210            | FieldAttr::SkipUpdate(ident)
211            | FieldAttr::ColumnName(ident, _)
212            | FieldAttr::SqlType(ident, _)
213            | FieldAttr::TreatNoneAsNull(ident, _)
214            | FieldAttr::TreatNoneAsDefaultValue(ident, _)
215            | FieldAttr::SerializeAs(ident, _)
216            | FieldAttr::DeserializeAs(ident, _)
217            | FieldAttr::SelectExpression(ident, _)
218            | FieldAttr::SelectExpressionType(ident, _)
219            | FieldAttr::Rename(ident, _) => ident.span(),
220        }
221    }
222}
223
224#[allow(clippy::large_enum_variant)]
225pub enum StructAttr {
226    Aggregate(Ident),
227    NotSized(Ident),
228    ForeignDerive(Ident),
229    EnumType(Ident),
230
231    TableName(Ident, Path),
232    SqlType(Ident, TypePath),
233    TreatNoneAsDefaultValue(Ident, LitBool),
234    TreatNoneAsNull(Ident, LitBool),
235
236    BelongsTo(Ident, BelongsTo),
237    MysqlType(Ident, MysqlType),
238    SqliteType(Ident, SqliteType),
239    PostgresType(Ident, PostgresType),
240    PrimaryKey(Ident, Punctuated<Ident, Comma>),
241    CheckForBackend(Ident, CheckForBackend),
242    BaseQuery(Ident, Expr),
243    BaseQueryType(Ident, Type),
244    RenameAll(Ident, RenameVariants),
245}
246
247impl Parse for StructAttr {
248    fn parse(input: ParseStream) -> Result<Self> {
249        let name: Ident = input.parse()?;
250        let name_str = name.to_string();
251
252        match &*name_str {
253            "aggregate" => Ok(StructAttr::Aggregate(name)),
254            "not_sized" => Ok(StructAttr::NotSized(name)),
255            "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
256            "enum_type" => Ok(StructAttr::EnumType(name)),
257
258            "table_name" => Ok(StructAttr::TableName(
259                name,
260                parse_eq(input, TABLE_NAME_NOTE)?,
261            )),
262            "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
263            "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
264                name,
265                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
266            )),
267            "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
268                name,
269                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
270            )),
271
272            "belongs_to" => Ok(StructAttr::BelongsTo(
273                name,
274                parse_paren(input, BELONGS_TO_NOTE)?,
275            )),
276            "mysql_type" => Ok(StructAttr::MysqlType(
277                name,
278                parse_paren(input, MYSQL_TYPE_NOTE)?,
279            )),
280            "sqlite_type" => Ok(StructAttr::SqliteType(
281                name,
282                parse_paren(input, SQLITE_TYPE_NOTE)?,
283            )),
284            "postgres_type" => Ok(StructAttr::PostgresType(
285                name,
286                parse_paren(input, POSTGRES_TYPE_NOTE)?,
287            )),
288            "primary_key" => Ok(StructAttr::PrimaryKey(
289                name,
290                parse_paren_list(input, "key1, key2", ::syn::token::Commasyn::Token![,])?,
291            )),
292            "check_for_backend" => {
293                let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
294                    CheckForBackend::Disabled(
295                        parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
296                    )
297                } else {
298                    CheckForBackend::Backends(parse_paren_list(
299                        input,
300                        CHECK_FOR_BACKEND_NOTE,
301                        ::syn::token::Commasyn::Token![,],
302                    )?)
303                };
304                Ok(StructAttr::CheckForBackend(name, value))
305            }
306            "base_query" => Ok(StructAttr::BaseQuery(
307                name,
308                parse_eq(input, BASE_QUERY_NOTE)?,
309            )),
310            "base_query_type" => Ok(StructAttr::BaseQueryType(
311                name,
312                parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
313            )),
314            "rename_all" => Ok(StructAttr::RenameAll(
315                name,
316                parse_eq(input, RENAME_ALL_NOTE)?,
317            )),
318            _ => Err(unknown_attribute(
319                &name,
320                &[
321                    "aggregate",
322                    "not_sized",
323                    "foreign_derive",
324                    "table_name",
325                    "sql_type",
326                    "treat_none_as_default_value",
327                    "treat_none_as_null",
328                    "belongs_to",
329                    "mysql_type",
330                    "sqlite_type",
331                    "postgres_type",
332                    "primary_key",
333                    "check_for_backend",
334                    "base_query",
335                    "base_query_type",
336                    "enum_type",
337                    "rename_all",
338                ],
339            )),
340        }
341    }
342}
343
344impl MySpanned for StructAttr {
345    fn span(&self) -> Span {
346        match self {
347            StructAttr::Aggregate(ident)
348            | StructAttr::NotSized(ident)
349            | StructAttr::ForeignDerive(ident)
350            | StructAttr::EnumType(ident)
351            | StructAttr::TableName(ident, _)
352            | StructAttr::SqlType(ident, _)
353            | StructAttr::TreatNoneAsDefaultValue(ident, _)
354            | StructAttr::TreatNoneAsNull(ident, _)
355            | StructAttr::BelongsTo(ident, _)
356            | StructAttr::MysqlType(ident, _)
357            | StructAttr::SqliteType(ident, _)
358            | StructAttr::PostgresType(ident, _)
359            | StructAttr::CheckForBackend(ident, _)
360            | StructAttr::BaseQuery(ident, _)
361            | StructAttr::BaseQueryType(ident, _)
362            | StructAttr::PrimaryKey(ident, _)
363            | StructAttr::RenameAll(ident, _) => ident.span(),
364        }
365    }
366}
367
368pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
369where
370    T: Parse + ParseDeprecated + MySpanned,
371{
372    let mut out = Vec::new();
373    for attr in attrs {
374        if attr.meta.path().is_ident("diesel") {
375            let map = attr
376                .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
377                .into_iter()
378                .map(|a| AttributeSpanWrapper {
379                    ident_span: a.span(),
380                    item: a,
381                    attribute_span: attr.meta.span(),
382                });
383            out.extend(map);
384        } else if truecfg!(all(
385            not(feature = "without-deprecated"),
386            feature = "with-deprecated"
387        )) {
388            let path = attr.meta.path();
389            let ident = path.get_ident().map(|f| f.to_string());
390
391            if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
392            | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
393                ident.as_deref().unwrap_or_default()
394            {
395                let m = &attr.meta;
396                let ts = {
    let mut _s = ::quote::__private::TokenStream::new();
    ::quote::ToTokens::to_tokens(&m, &mut _s);
    _s
}quote::quote!(#m).into();
397                let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
398
399                if let Some(value) = value {
400                    out.push(AttributeSpanWrapper {
401                        ident_span: value.span(),
402                        item: value,
403                        attribute_span: attr.meta.span(),
404                    });
405                }
406            }
407        }
408    }
409    Ok(out)
410}
411
412struct DisabledCheckForBackend {
413    value: LitBool,
414}
415
416impl syn::parse::Parse for DisabledCheckForBackend {
417    fn parse(input: ParseStream) -> Result<Self> {
418        let ident = input.parse::<Ident>()?;
419        if ident != "disable" {
420            return Err(syn::Error::new(
421                ident.span(),
422                ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("expected `disable`, but got `{0}`",
                ident))
    })format!("expected `disable`, but got `{ident}`"),
423            ));
424        }
425        let lit = parse_eq::<LitBool>(input, "")?;
426        if !lit.value {
427            return Err(syn::Error::new(
428                lit.span(),
429                "only `true` is accepted in this position. \
430                 If you want to enable these checks, just skip the attribute entirely",
431            ));
432        }
433        Ok(Self { value: lit })
434    }
435}
436
437#[derive(#[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::fmt::Debug for RenameVariants {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f,
            match self {
                RenameVariants::LowerCase => "LowerCase",
                RenameVariants::UpperCase => "UpperCase",
                RenameVariants::PascalCase => "PascalCase",
                RenameVariants::CamelCase => "CamelCase",
                RenameVariants::SnakeCase => "SnakeCase",
                RenameVariants::ScreamingSnakeCase => "ScreamingSnakeCase",
                RenameVariants::KebabCase => "KebabCase",
                RenameVariants::ScreamingKebabCase => "ScreamingKebabCase",
            })
    }
}Debug, #[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::clone::Clone for RenameVariants {
    #[inline]
    fn clone(&self) -> RenameVariants { *self }
}Clone, #[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::marker::Copy for RenameVariants { }Copy)]
438#[allow(clippy::enum_variant_names)]
439pub enum RenameVariants {
440    LowerCase,
441    UpperCase,
442    PascalCase,
443    CamelCase,
444    SnakeCase,
445    ScreamingSnakeCase,
446    KebabCase,
447    ScreamingKebabCase,
448}
449
450impl syn::parse::Parse for RenameVariants {
451    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
452        let lit = input.parse::<syn::LitStr>()?;
453        let v = lit.value();
454        let v = match v.as_str() {
455            "lowercase" => Self::LowerCase,
456            "UPPERCASE" => Self::UpperCase,
457            "PascalCase" => Self::PascalCase,
458            "camelCase" => Self::CamelCase,
459            "snake_case" => Self::SnakeCase,
460            "SCREAMING_SNAKE_CASE" => Self::ScreamingSnakeCase,
461            "kebab-case" => Self::KebabCase,
462            "SCREAMING-KEBAB-CASE" => Self::ScreamingKebabCase,
463            s => {
464                return Err(syn::Error::new(
465                    lit.span(),
466                    ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!("got invalid case identifier: `{0}`\nonly: `lowercase`, `UPPERCASE`, `PascalCase`, `camelCase`, `snake_case`, `SCREAMING_SNAKE_CASE`, `kebab-case` and `SCREAMING-KEBAB-CASE` are supported",
                s))
    })format!(
467                        "got invalid case identifier: `{s}`\n\
468                         only: `lowercase`, `UPPERCASE`, `PascalCase`, `camelCase`, \
469                         `snake_case`, `SCREAMING_SNAKE_CASE`, `kebab-case` \
470                         and `SCREAMING-KEBAB-CASE` are supported"
471                    ),
472                ));
473            }
474        };
475        Ok(v)
476    }
477}
478
479impl RenameVariants {
480    pub fn apply_case_to_enum_variant(&self, input: String) -> String {
481        match self {
482            // Rust enum variants are already pascal case
483            Self::PascalCase => input,
484            Self::LowerCase => input.to_ascii_lowercase(),
485            Self::UpperCase => input.to_ascii_uppercase(),
486            Self::CamelCase => input[..1].to_ascii_lowercase() + &input[1..],
487            Self::SnakeCase => {
488                let mut snake = String::new();
489                for (i, ch) in input.char_indices() {
490                    if i > 0 && ch.is_uppercase() {
491                        snake.push('_');
492                    }
493                    snake.push(ch.to_ascii_lowercase());
494                }
495                snake
496            }
497            Self::ScreamingSnakeCase => Self::SnakeCase
498                .apply_case_to_enum_variant(input)
499                .to_ascii_uppercase(),
500            Self::KebabCase => Self::SnakeCase
501                .apply_case_to_enum_variant(input)
502                .replace('_', "-"),
503            Self::ScreamingKebabCase => Self::ScreamingSnakeCase
504                .apply_case_to_enum_variant(input)
505                .replace('_', "-"),
506        }
507    }
508}