diesel_derives/
attrs.rs

1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::{
15    parse_eq, parse_paren, unknown_attribute, BELONGS_TO_NOTE, COLUMN_NAME_NOTE,
16    DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE,
17    SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE, SQL_TYPE_NOTE,
18    TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
19};
20
21use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
22
23pub trait MySpanned {
24    fn span(&self) -> Span;
25}
26
27pub struct AttributeSpanWrapper<T> {
28    pub item: T,
29    pub attribute_span: Span,
30    pub ident_span: Span,
31}
32
33pub enum FieldAttr {
34    Embed(Ident),
35    SkipInsertion(Ident),
36
37    ColumnName(Ident, SqlIdentifier),
38    SqlType(Ident, TypePath),
39    TreatNoneAsDefaultValue(Ident, LitBool),
40    TreatNoneAsNull(Ident, LitBool),
41
42    SerializeAs(Ident, TypePath),
43    DeserializeAs(Ident, TypePath),
44    SelectExpression(Ident, Expr),
45    SelectExpressionType(Ident, Type),
46}
47
48#[derive(Clone)]
49pub struct SqlIdentifier {
50    field_name: String,
51    span: Span,
52}
53
54impl SqlIdentifier {
55    pub fn span(&self) -> Span {
56        self.span
57    }
58
59    pub fn to_ident(&self) -> Result<Ident> {
60        match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
61            Ok(mut ident) => {
62                ident.set_span(self.span);
63                Ok(ident)
64            }
65            Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
66                self.span(),
67                format!(
68                    "Expected valid identifier, found `{0}`. \
69                 Diesel does not support column names with whitespaces yet",
70                    self.field_name
71                ),
72            )),
73            Err(_e) => Err(syn::Error::new(
74                self.span(),
75                format!(
76                    "Expected valid identifier, found `{0}`. \
77                 Diesel automatically renames invalid identifiers, \
78                 perhaps you meant to write `{0}_`?",
79                    self.field_name
80                ),
81            )),
82        }
83    }
84}
85
86impl ToTokens for SqlIdentifier {
87    fn to_tokens(&self, tokens: &mut TokenStream) {
88        if self.field_name.starts_with("r#") {
89            Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
90        } else {
91            Ident::new(&self.field_name, self.span).to_tokens(tokens)
92        }
93    }
94}
95
96impl Display for SqlIdentifier {
97    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98        let mut start = 0;
99        if self.field_name.starts_with("r#") {
100            start = 2;
101        }
102        f.write_str(&self.field_name[start..])
103    }
104}
105
106impl PartialEq<Ident> for SqlIdentifier {
107    fn eq(&self, other: &Ident) -> bool {
108        *other == self.field_name
109    }
110}
111
112impl From<&'_ Ident> for SqlIdentifier {
113    fn from(ident: &'_ Ident) -> Self {
114        use syn::ext::IdentExt;
115        let ident = ident.unraw();
116        Self {
117            span: ident.span(),
118            field_name: ident.to_string(),
119        }
120    }
121}
122
123impl Parse for SqlIdentifier {
124    fn parse(input: ParseStream) -> Result<Self> {
125        let fork = input.fork();
126
127        if let Ok(ident) = fork.parse::<Ident>() {
128            input.advance_to(&fork);
129            Ok((&ident).into())
130        } else {
131            let name = input.parse::<LitStr>()?;
132            Ok(Self {
133                field_name: name.value(),
134                span: name.span(),
135            })
136        }
137    }
138}
139
140impl Parse for FieldAttr {
141    fn parse(input: ParseStream) -> Result<Self> {
142        let name: Ident = input.parse()?;
143        let name_str = name.to_string();
144
145        match &*name_str {
146            "embed" => Ok(FieldAttr::Embed(name)),
147            "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
148
149            "column_name" => Ok(FieldAttr::ColumnName(
150                name,
151                parse_eq(input, COLUMN_NAME_NOTE)?,
152            )),
153            "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
154            "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
155                name,
156                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
157            )),
158            "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
159                name,
160                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
161            )),
162            "serialize_as" => Ok(FieldAttr::SerializeAs(
163                name,
164                parse_eq(input, SERIALIZE_AS_NOTE)?,
165            )),
166            "deserialize_as" => Ok(FieldAttr::DeserializeAs(
167                name,
168                parse_eq(input, DESERIALIZE_AS_NOTE)?,
169            )),
170            "select_expression" => Ok(FieldAttr::SelectExpression(
171                name,
172                parse_eq(input, SELECT_EXPRESSION_NOTE)?,
173            )),
174            "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
175                name,
176                parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
177            )),
178            _ => Err(unknown_attribute(
179                &name,
180                &[
181                    "embed",
182                    "skip_insertion",
183                    "column_name",
184                    "sql_type",
185                    "treat_none_as_default_value",
186                    "treat_none_as_null",
187                    "serialize_as",
188                    "deserialize_as",
189                    "select_expression",
190                    "select_expression_type",
191                ],
192            )),
193        }
194    }
195}
196
197impl MySpanned for FieldAttr {
198    fn span(&self) -> Span {
199        match self {
200            FieldAttr::Embed(ident)
201            | FieldAttr::SkipInsertion(ident)
202            | FieldAttr::ColumnName(ident, _)
203            | FieldAttr::SqlType(ident, _)
204            | FieldAttr::TreatNoneAsNull(ident, _)
205            | FieldAttr::TreatNoneAsDefaultValue(ident, _)
206            | FieldAttr::SerializeAs(ident, _)
207            | FieldAttr::DeserializeAs(ident, _)
208            | FieldAttr::SelectExpression(ident, _)
209            | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
210        }
211    }
212}
213
214#[allow(clippy::large_enum_variant)]
215pub enum StructAttr {
216    Aggregate(Ident),
217    NotSized(Ident),
218    ForeignDerive(Ident),
219
220    TableName(Ident, Path),
221    SqlType(Ident, TypePath),
222    TreatNoneAsDefaultValue(Ident, LitBool),
223    TreatNoneAsNull(Ident, LitBool),
224
225    BelongsTo(Ident, BelongsTo),
226    MysqlType(Ident, MysqlType),
227    SqliteType(Ident, SqliteType),
228    PostgresType(Ident, PostgresType),
229    PrimaryKey(Ident, Punctuated<Ident, Comma>),
230    CheckForBackend(Ident, syn::punctuated::Punctuated<TypePath, syn::Token![,]>),
231}
232
233impl Parse for StructAttr {
234    fn parse(input: ParseStream) -> Result<Self> {
235        let name: Ident = input.parse()?;
236        let name_str = name.to_string();
237
238        match &*name_str {
239            "aggregate" => Ok(StructAttr::Aggregate(name)),
240            "not_sized" => Ok(StructAttr::NotSized(name)),
241            "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
242
243            "table_name" => Ok(StructAttr::TableName(
244                name,
245                parse_eq(input, TABLE_NAME_NOTE)?,
246            )),
247            "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
248            "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
249                name,
250                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
251            )),
252            "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
253                name,
254                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
255            )),
256
257            "belongs_to" => Ok(StructAttr::BelongsTo(
258                name,
259                parse_paren(input, BELONGS_TO_NOTE)?,
260            )),
261            "mysql_type" => Ok(StructAttr::MysqlType(
262                name,
263                parse_paren(input, MYSQL_TYPE_NOTE)?,
264            )),
265            "sqlite_type" => Ok(StructAttr::SqliteType(
266                name,
267                parse_paren(input, SQLITE_TYPE_NOTE)?,
268            )),
269            "postgres_type" => Ok(StructAttr::PostgresType(
270                name,
271                parse_paren(input, POSTGRES_TYPE_NOTE)?,
272            )),
273            "primary_key" => Ok(StructAttr::PrimaryKey(
274                name,
275                parse_paren_list(input, "key1, key2", syn::Token![,])?,
276            )),
277            "check_for_backend" => Ok(StructAttr::CheckForBackend(
278                name,
279                parse_paren_list(input, CHECK_FOR_BACKEND_NOTE, syn::Token![,])?,
280            )),
281
282            _ => Err(unknown_attribute(
283                &name,
284                &[
285                    "aggregate",
286                    "not_sized",
287                    "foreign_derive",
288                    "table_name",
289                    "sql_type",
290                    "treat_none_as_default_value",
291                    "treat_none_as_null",
292                    "belongs_to",
293                    "mysql_type",
294                    "sqlite_type",
295                    "postgres_type",
296                    "primary_key",
297                    "check_for_backend",
298                ],
299            )),
300        }
301    }
302}
303
304impl MySpanned for StructAttr {
305    fn span(&self) -> Span {
306        match self {
307            StructAttr::Aggregate(ident)
308            | StructAttr::NotSized(ident)
309            | StructAttr::ForeignDerive(ident)
310            | StructAttr::TableName(ident, _)
311            | StructAttr::SqlType(ident, _)
312            | StructAttr::TreatNoneAsDefaultValue(ident, _)
313            | StructAttr::TreatNoneAsNull(ident, _)
314            | StructAttr::BelongsTo(ident, _)
315            | StructAttr::MysqlType(ident, _)
316            | StructAttr::SqliteType(ident, _)
317            | StructAttr::PostgresType(ident, _)
318            | StructAttr::CheckForBackend(ident, _)
319            | StructAttr::PrimaryKey(ident, _) => ident.span(),
320        }
321    }
322}
323
324pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
325where
326    T: Parse + ParseDeprecated + MySpanned,
327{
328    let mut out = Vec::new();
329    for attr in attrs {
330        if attr.meta.path().is_ident("diesel") {
331            let map = attr
332                .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
333                .into_iter()
334                .map(|a| AttributeSpanWrapper {
335                    ident_span: a.span(),
336                    item: a,
337                    attribute_span: attr.meta.span(),
338                });
339            out.extend(map);
340        } else if cfg!(all(
341            not(feature = "without-deprecated"),
342            feature = "with-deprecated"
343        )) {
344            let path = attr.meta.path();
345            let ident = path.get_ident().map(|f| f.to_string());
346
347            if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
348            | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
349                ident.as_deref().unwrap_or_default()
350            {
351                let m = &attr.meta;
352                let ts = quote::quote!(#m).into();
353                let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
354
355                if let Some(value) = value {
356                    out.push(AttributeSpanWrapper {
357                        ident_span: value.span(),
358                        item: value,
359                        attribute_span: attr.meta.span(),
360                    });
361                }
362            }
363        }
364    }
365    Ok(out)
366}