1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::{
15 parse_eq, parse_paren, unknown_attribute, BELONGS_TO_NOTE, COLUMN_NAME_NOTE,
16 DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE,
17 SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE, SQL_TYPE_NOTE,
18 TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
19};
20
21use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
22
23pub trait MySpanned {
24 fn span(&self) -> Span;
25}
26
27pub struct AttributeSpanWrapper<T> {
28 pub item: T,
29 pub attribute_span: Span,
30 pub ident_span: Span,
31}
32
33pub enum FieldAttr {
34 Embed(Ident),
35 SkipInsertion(Ident),
36 SkipUpdate(Ident),
37
38 ColumnName(Ident, SqlIdentifier),
39 SqlType(Ident, TypePath),
40 TreatNoneAsDefaultValue(Ident, LitBool),
41 TreatNoneAsNull(Ident, LitBool),
42
43 SerializeAs(Ident, TypePath),
44 DeserializeAs(Ident, TypePath),
45 SelectExpression(Ident, Expr),
46 SelectExpressionType(Ident, Type),
47}
48
49#[derive(Clone)]
50pub struct SqlIdentifier {
51 field_name: String,
52 span: Span,
53}
54
55impl SqlIdentifier {
56 pub fn span(&self) -> Span {
57 self.span
58 }
59
60 pub fn to_ident(&self) -> Result<Ident> {
61 match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
62 Ok(mut ident) => {
63 ident.set_span(self.span);
64 Ok(ident)
65 }
66 Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
67 self.span(),
68 format!(
69 "Expected valid identifier, found `{0}`. \
70 Diesel does not support column names with whitespaces yet",
71 self.field_name
72 ),
73 )),
74 Err(_e) => Err(syn::Error::new(
75 self.span(),
76 format!(
77 "Expected valid identifier, found `{0}`. \
78 Diesel automatically renames invalid identifiers, \
79 perhaps you meant to write `{0}_`?",
80 self.field_name
81 ),
82 )),
83 }
84 }
85}
86
87impl ToTokens for SqlIdentifier {
88 fn to_tokens(&self, tokens: &mut TokenStream) {
89 if self.field_name.starts_with("r#") {
90 Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
91 } else {
92 Ident::new(&self.field_name, self.span).to_tokens(tokens)
93 }
94 }
95}
96
97impl Display for SqlIdentifier {
98 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99 let mut start = 0;
100 if self.field_name.starts_with("r#") {
101 start = 2;
102 }
103 f.write_str(&self.field_name[start..])
104 }
105}
106
107impl PartialEq<Ident> for SqlIdentifier {
108 fn eq(&self, other: &Ident) -> bool {
109 *other == self.field_name
110 }
111}
112
113impl From<&'_ Ident> for SqlIdentifier {
114 fn from(ident: &'_ Ident) -> Self {
115 use syn::ext::IdentExt;
116 let ident = ident.unraw();
117 Self {
118 span: ident.span(),
119 field_name: ident.to_string(),
120 }
121 }
122}
123
124impl Parse for SqlIdentifier {
125 fn parse(input: ParseStream) -> Result<Self> {
126 let fork = input.fork();
127
128 if let Ok(ident) = fork.parse::<Ident>() {
129 input.advance_to(&fork);
130 Ok((&ident).into())
131 } else {
132 let name = input.parse::<LitStr>()?;
133 Ok(Self {
134 field_name: name.value(),
135 span: name.span(),
136 })
137 }
138 }
139}
140
141impl Parse for FieldAttr {
142 fn parse(input: ParseStream) -> Result<Self> {
143 let name: Ident = input.parse()?;
144 let name_str = name.to_string();
145
146 match &*name_str {
147 "embed" => Ok(FieldAttr::Embed(name)),
148 "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
149 "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
150
151 "column_name" => Ok(FieldAttr::ColumnName(
152 name,
153 parse_eq(input, COLUMN_NAME_NOTE)?,
154 )),
155 "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
156 "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
157 name,
158 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
159 )),
160 "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
161 name,
162 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
163 )),
164 "serialize_as" => Ok(FieldAttr::SerializeAs(
165 name,
166 parse_eq(input, SERIALIZE_AS_NOTE)?,
167 )),
168 "deserialize_as" => Ok(FieldAttr::DeserializeAs(
169 name,
170 parse_eq(input, DESERIALIZE_AS_NOTE)?,
171 )),
172 "select_expression" => Ok(FieldAttr::SelectExpression(
173 name,
174 parse_eq(input, SELECT_EXPRESSION_NOTE)?,
175 )),
176 "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
177 name,
178 parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
179 )),
180 _ => Err(unknown_attribute(
181 &name,
182 &[
183 "embed",
184 "skip_insertion",
185 "column_name",
186 "sql_type",
187 "treat_none_as_default_value",
188 "treat_none_as_null",
189 "serialize_as",
190 "deserialize_as",
191 "select_expression",
192 "select_expression_type",
193 ],
194 )),
195 }
196 }
197}
198
199impl MySpanned for FieldAttr {
200 fn span(&self) -> Span {
201 match self {
202 FieldAttr::Embed(ident)
203 | FieldAttr::SkipInsertion(ident)
204 | FieldAttr::SkipUpdate(ident)
205 | FieldAttr::ColumnName(ident, _)
206 | FieldAttr::SqlType(ident, _)
207 | FieldAttr::TreatNoneAsNull(ident, _)
208 | FieldAttr::TreatNoneAsDefaultValue(ident, _)
209 | FieldAttr::SerializeAs(ident, _)
210 | FieldAttr::DeserializeAs(ident, _)
211 | FieldAttr::SelectExpression(ident, _)
212 | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
213 }
214 }
215}
216
217#[allow(clippy::large_enum_variant)]
218pub enum StructAttr {
219 Aggregate(Ident),
220 NotSized(Ident),
221 ForeignDerive(Ident),
222
223 TableName(Ident, Path),
224 SqlType(Ident, TypePath),
225 TreatNoneAsDefaultValue(Ident, LitBool),
226 TreatNoneAsNull(Ident, LitBool),
227
228 BelongsTo(Ident, BelongsTo),
229 MysqlType(Ident, MysqlType),
230 SqliteType(Ident, SqliteType),
231 PostgresType(Ident, PostgresType),
232 PrimaryKey(Ident, Punctuated<Ident, Comma>),
233 CheckForBackend(Ident, syn::punctuated::Punctuated<TypePath, syn::Token![,]>),
234}
235
236impl Parse for StructAttr {
237 fn parse(input: ParseStream) -> Result<Self> {
238 let name: Ident = input.parse()?;
239 let name_str = name.to_string();
240
241 match &*name_str {
242 "aggregate" => Ok(StructAttr::Aggregate(name)),
243 "not_sized" => Ok(StructAttr::NotSized(name)),
244 "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
245
246 "table_name" => Ok(StructAttr::TableName(
247 name,
248 parse_eq(input, TABLE_NAME_NOTE)?,
249 )),
250 "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
251 "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
252 name,
253 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
254 )),
255 "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
256 name,
257 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
258 )),
259
260 "belongs_to" => Ok(StructAttr::BelongsTo(
261 name,
262 parse_paren(input, BELONGS_TO_NOTE)?,
263 )),
264 "mysql_type" => Ok(StructAttr::MysqlType(
265 name,
266 parse_paren(input, MYSQL_TYPE_NOTE)?,
267 )),
268 "sqlite_type" => Ok(StructAttr::SqliteType(
269 name,
270 parse_paren(input, SQLITE_TYPE_NOTE)?,
271 )),
272 "postgres_type" => Ok(StructAttr::PostgresType(
273 name,
274 parse_paren(input, POSTGRES_TYPE_NOTE)?,
275 )),
276 "primary_key" => Ok(StructAttr::PrimaryKey(
277 name,
278 parse_paren_list(input, "key1, key2", syn::Token![,])?,
279 )),
280 "check_for_backend" => Ok(StructAttr::CheckForBackend(
281 name,
282 parse_paren_list(input, CHECK_FOR_BACKEND_NOTE, syn::Token![,])?,
283 )),
284
285 _ => Err(unknown_attribute(
286 &name,
287 &[
288 "aggregate",
289 "not_sized",
290 "foreign_derive",
291 "table_name",
292 "sql_type",
293 "treat_none_as_default_value",
294 "treat_none_as_null",
295 "belongs_to",
296 "mysql_type",
297 "sqlite_type",
298 "postgres_type",
299 "primary_key",
300 "check_for_backend",
301 ],
302 )),
303 }
304 }
305}
306
307impl MySpanned for StructAttr {
308 fn span(&self) -> Span {
309 match self {
310 StructAttr::Aggregate(ident)
311 | StructAttr::NotSized(ident)
312 | StructAttr::ForeignDerive(ident)
313 | StructAttr::TableName(ident, _)
314 | StructAttr::SqlType(ident, _)
315 | StructAttr::TreatNoneAsDefaultValue(ident, _)
316 | StructAttr::TreatNoneAsNull(ident, _)
317 | StructAttr::BelongsTo(ident, _)
318 | StructAttr::MysqlType(ident, _)
319 | StructAttr::SqliteType(ident, _)
320 | StructAttr::PostgresType(ident, _)
321 | StructAttr::CheckForBackend(ident, _)
322 | StructAttr::PrimaryKey(ident, _) => ident.span(),
323 }
324 }
325}
326
327pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
328where
329 T: Parse + ParseDeprecated + MySpanned,
330{
331 let mut out = Vec::new();
332 for attr in attrs {
333 if attr.meta.path().is_ident("diesel") {
334 let map = attr
335 .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
336 .into_iter()
337 .map(|a| AttributeSpanWrapper {
338 ident_span: a.span(),
339 item: a,
340 attribute_span: attr.meta.span(),
341 });
342 out.extend(map);
343 } else if cfg!(all(
344 not(feature = "without-deprecated"),
345 feature = "with-deprecated"
346 )) {
347 let path = attr.meta.path();
348 let ident = path.get_ident().map(|f| f.to_string());
349
350 if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
351 | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
352 ident.as_deref().unwrap_or_default()
353 {
354 let m = &attr.meta;
355 let ts = quote::quote!(#m).into();
356 let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
357
358 if let Some(value) = value {
359 out.push(AttributeSpanWrapper {
360 ident_span: value.span(),
361 item: value,
362 attribute_span: attr.meta.span(),
363 });
364 }
365 }
366 }
367 }
368 Ok(out)
369}