diesel_derives/
attrs.rs

1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::{
15    parse_eq, parse_paren, unknown_attribute, BELONGS_TO_NOTE, COLUMN_NAME_NOTE,
16    DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE,
17    SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE, SQL_TYPE_NOTE,
18    TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
19};
20
21use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
22
23pub trait MySpanned {
24    fn span(&self) -> Span;
25}
26
27pub struct AttributeSpanWrapper<T> {
28    pub item: T,
29    pub attribute_span: Span,
30    pub ident_span: Span,
31}
32
33pub enum FieldAttr {
34    Embed(Ident),
35    SkipInsertion(Ident),
36    SkipUpdate(Ident),
37
38    ColumnName(Ident, SqlIdentifier),
39    SqlType(Ident, TypePath),
40    TreatNoneAsDefaultValue(Ident, LitBool),
41    TreatNoneAsNull(Ident, LitBool),
42
43    SerializeAs(Ident, TypePath),
44    DeserializeAs(Ident, TypePath),
45    SelectExpression(Ident, Expr),
46    SelectExpressionType(Ident, Type),
47}
48
49#[derive(Clone)]
50pub struct SqlIdentifier {
51    field_name: String,
52    span: Span,
53}
54
55impl SqlIdentifier {
56    pub fn span(&self) -> Span {
57        self.span
58    }
59
60    pub fn to_ident(&self) -> Result<Ident> {
61        match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
62            Ok(mut ident) => {
63                ident.set_span(self.span);
64                Ok(ident)
65            }
66            Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
67                self.span(),
68                format!(
69                    "Expected valid identifier, found `{0}`. \
70                 Diesel does not support column names with whitespaces yet",
71                    self.field_name
72                ),
73            )),
74            Err(_e) => Err(syn::Error::new(
75                self.span(),
76                format!(
77                    "Expected valid identifier, found `{0}`. \
78                 Diesel automatically renames invalid identifiers, \
79                 perhaps you meant to write `{0}_`?",
80                    self.field_name
81                ),
82            )),
83        }
84    }
85}
86
87impl ToTokens for SqlIdentifier {
88    fn to_tokens(&self, tokens: &mut TokenStream) {
89        if self.field_name.starts_with("r#") {
90            Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
91        } else {
92            Ident::new(&self.field_name, self.span).to_tokens(tokens)
93        }
94    }
95}
96
97impl Display for SqlIdentifier {
98    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99        let mut start = 0;
100        if self.field_name.starts_with("r#") {
101            start = 2;
102        }
103        f.write_str(&self.field_name[start..])
104    }
105}
106
107impl PartialEq<Ident> for SqlIdentifier {
108    fn eq(&self, other: &Ident) -> bool {
109        *other == self.field_name
110    }
111}
112
113impl From<&'_ Ident> for SqlIdentifier {
114    fn from(ident: &'_ Ident) -> Self {
115        use syn::ext::IdentExt;
116        let ident = ident.unraw();
117        Self {
118            span: ident.span(),
119            field_name: ident.to_string(),
120        }
121    }
122}
123
124impl Parse for SqlIdentifier {
125    fn parse(input: ParseStream) -> Result<Self> {
126        let fork = input.fork();
127
128        if let Ok(ident) = fork.parse::<Ident>() {
129            input.advance_to(&fork);
130            Ok((&ident).into())
131        } else {
132            let name = input.parse::<LitStr>()?;
133            Ok(Self {
134                field_name: name.value(),
135                span: name.span(),
136            })
137        }
138    }
139}
140
141impl Parse for FieldAttr {
142    fn parse(input: ParseStream) -> Result<Self> {
143        let name: Ident = input.parse()?;
144        let name_str = name.to_string();
145
146        match &*name_str {
147            "embed" => Ok(FieldAttr::Embed(name)),
148            "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
149            "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
150
151            "column_name" => Ok(FieldAttr::ColumnName(
152                name,
153                parse_eq(input, COLUMN_NAME_NOTE)?,
154            )),
155            "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
156            "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
157                name,
158                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
159            )),
160            "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
161                name,
162                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
163            )),
164            "serialize_as" => Ok(FieldAttr::SerializeAs(
165                name,
166                parse_eq(input, SERIALIZE_AS_NOTE)?,
167            )),
168            "deserialize_as" => Ok(FieldAttr::DeserializeAs(
169                name,
170                parse_eq(input, DESERIALIZE_AS_NOTE)?,
171            )),
172            "select_expression" => Ok(FieldAttr::SelectExpression(
173                name,
174                parse_eq(input, SELECT_EXPRESSION_NOTE)?,
175            )),
176            "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
177                name,
178                parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
179            )),
180            _ => Err(unknown_attribute(
181                &name,
182                &[
183                    "embed",
184                    "skip_insertion",
185                    "column_name",
186                    "sql_type",
187                    "treat_none_as_default_value",
188                    "treat_none_as_null",
189                    "serialize_as",
190                    "deserialize_as",
191                    "select_expression",
192                    "select_expression_type",
193                ],
194            )),
195        }
196    }
197}
198
199impl MySpanned for FieldAttr {
200    fn span(&self) -> Span {
201        match self {
202            FieldAttr::Embed(ident)
203            | FieldAttr::SkipInsertion(ident)
204            | FieldAttr::SkipUpdate(ident)
205            | FieldAttr::ColumnName(ident, _)
206            | FieldAttr::SqlType(ident, _)
207            | FieldAttr::TreatNoneAsNull(ident, _)
208            | FieldAttr::TreatNoneAsDefaultValue(ident, _)
209            | FieldAttr::SerializeAs(ident, _)
210            | FieldAttr::DeserializeAs(ident, _)
211            | FieldAttr::SelectExpression(ident, _)
212            | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
213        }
214    }
215}
216
217#[allow(clippy::large_enum_variant)]
218pub enum StructAttr {
219    Aggregate(Ident),
220    NotSized(Ident),
221    ForeignDerive(Ident),
222
223    TableName(Ident, Path),
224    SqlType(Ident, TypePath),
225    TreatNoneAsDefaultValue(Ident, LitBool),
226    TreatNoneAsNull(Ident, LitBool),
227
228    BelongsTo(Ident, BelongsTo),
229    MysqlType(Ident, MysqlType),
230    SqliteType(Ident, SqliteType),
231    PostgresType(Ident, PostgresType),
232    PrimaryKey(Ident, Punctuated<Ident, Comma>),
233    CheckForBackend(Ident, syn::punctuated::Punctuated<TypePath, syn::Token![,]>),
234}
235
236impl Parse for StructAttr {
237    fn parse(input: ParseStream) -> Result<Self> {
238        let name: Ident = input.parse()?;
239        let name_str = name.to_string();
240
241        match &*name_str {
242            "aggregate" => Ok(StructAttr::Aggregate(name)),
243            "not_sized" => Ok(StructAttr::NotSized(name)),
244            "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
245
246            "table_name" => Ok(StructAttr::TableName(
247                name,
248                parse_eq(input, TABLE_NAME_NOTE)?,
249            )),
250            "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
251            "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
252                name,
253                parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
254            )),
255            "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
256                name,
257                parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
258            )),
259
260            "belongs_to" => Ok(StructAttr::BelongsTo(
261                name,
262                parse_paren(input, BELONGS_TO_NOTE)?,
263            )),
264            "mysql_type" => Ok(StructAttr::MysqlType(
265                name,
266                parse_paren(input, MYSQL_TYPE_NOTE)?,
267            )),
268            "sqlite_type" => Ok(StructAttr::SqliteType(
269                name,
270                parse_paren(input, SQLITE_TYPE_NOTE)?,
271            )),
272            "postgres_type" => Ok(StructAttr::PostgresType(
273                name,
274                parse_paren(input, POSTGRES_TYPE_NOTE)?,
275            )),
276            "primary_key" => Ok(StructAttr::PrimaryKey(
277                name,
278                parse_paren_list(input, "key1, key2", syn::Token![,])?,
279            )),
280            "check_for_backend" => Ok(StructAttr::CheckForBackend(
281                name,
282                parse_paren_list(input, CHECK_FOR_BACKEND_NOTE, syn::Token![,])?,
283            )),
284
285            _ => Err(unknown_attribute(
286                &name,
287                &[
288                    "aggregate",
289                    "not_sized",
290                    "foreign_derive",
291                    "table_name",
292                    "sql_type",
293                    "treat_none_as_default_value",
294                    "treat_none_as_null",
295                    "belongs_to",
296                    "mysql_type",
297                    "sqlite_type",
298                    "postgres_type",
299                    "primary_key",
300                    "check_for_backend",
301                ],
302            )),
303        }
304    }
305}
306
307impl MySpanned for StructAttr {
308    fn span(&self) -> Span {
309        match self {
310            StructAttr::Aggregate(ident)
311            | StructAttr::NotSized(ident)
312            | StructAttr::ForeignDerive(ident)
313            | StructAttr::TableName(ident, _)
314            | StructAttr::SqlType(ident, _)
315            | StructAttr::TreatNoneAsDefaultValue(ident, _)
316            | StructAttr::TreatNoneAsNull(ident, _)
317            | StructAttr::BelongsTo(ident, _)
318            | StructAttr::MysqlType(ident, _)
319            | StructAttr::SqliteType(ident, _)
320            | StructAttr::PostgresType(ident, _)
321            | StructAttr::CheckForBackend(ident, _)
322            | StructAttr::PrimaryKey(ident, _) => ident.span(),
323        }
324    }
325}
326
327pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
328where
329    T: Parse + ParseDeprecated + MySpanned,
330{
331    let mut out = Vec::new();
332    for attr in attrs {
333        if attr.meta.path().is_ident("diesel") {
334            let map = attr
335                .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
336                .into_iter()
337                .map(|a| AttributeSpanWrapper {
338                    ident_span: a.span(),
339                    item: a,
340                    attribute_span: attr.meta.span(),
341                });
342            out.extend(map);
343        } else if cfg!(all(
344            not(feature = "without-deprecated"),
345            feature = "with-deprecated"
346        )) {
347            let path = attr.meta.path();
348            let ident = path.get_ident().map(|f| f.to_string());
349
350            if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
351            | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
352                ident.as_deref().unwrap_or_default()
353            {
354                let m = &attr.meta;
355                let ts = quote::quote!(#m).into();
356                let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
357
358                if let Some(value) = value {
359                    out.push(AttributeSpanWrapper {
360                        ident_span: value.span(),
361                        item: value,
362                        attribute_span: attr.meta.span(),
363                    });
364                }
365            }
366        }
367    }
368    Ok(out)
369}