1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16 BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE, BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE,
17 MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, RENAME_ALL_NOTE, RENAME_NOTE, SELECT_EXPRESSION_NOTE,
18 SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQL_TYPE_NOTE, SQLITE_TYPE_NOTE,
19 TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE, parse_eq,
20 parse_eq_type, parse_paren, unknown_attribute,
21};
22
23use crate::util::{CHECK_FOR_BACKEND_NOTE, parse_paren_list};
24
25pub trait MySpanned {
26 fn span(&self) -> Span;
27}
28
29#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for AttributeSpanWrapper<T>
{
#[inline]
fn clone(&self) -> AttributeSpanWrapper<T> {
AttributeSpanWrapper {
item: ::core::clone::Clone::clone(&self.item),
attribute_span: ::core::clone::Clone::clone(&self.attribute_span),
ident_span: ::core::clone::Clone::clone(&self.ident_span),
}
}
}Clone)]
30pub struct AttributeSpanWrapper<T> {
31 pub item: T,
32 pub attribute_span: Span,
33 pub ident_span: Span,
34}
35
36pub enum FieldAttr {
37 Embed(Ident),
38 SkipInsertion(Ident),
39 SkipUpdate(Ident),
40
41 ColumnName(Ident, SqlIdentifier),
42 SqlType(Ident, TypePath),
43 TreatNoneAsDefaultValue(Ident, LitBool),
44 TreatNoneAsNull(Ident, LitBool),
45
46 SerializeAs(Ident, Type),
47 DeserializeAs(Ident, Type),
48 SelectExpression(Ident, Expr),
49 SelectExpressionType(Ident, Type),
50 Rename(Ident, LitStr),
51}
52
53#[derive(#[automatically_derived]
impl ::core::clone::Clone for SqlIdentifier {
#[inline]
fn clone(&self) -> SqlIdentifier {
SqlIdentifier {
field_name: ::core::clone::Clone::clone(&self.field_name),
span: ::core::clone::Clone::clone(&self.span),
}
}
}Clone)]
54pub struct SqlIdentifier {
55 field_name: String,
56 span: Span,
57}
58
59impl SqlIdentifier {
60 pub fn span(&self) -> Span {
61 self.span
62 }
63
64 pub fn to_ident(&self) -> Result<Ident> {
65 match syn::parse_str::<Ident>(&::alloc::__export::must_use({
::alloc::fmt::format(format_args!("r#{0}", self.field_name))
})format!("r#{}", self.field_name)) {
66 Ok(mut ident) => {
67 ident.set_span(self.span);
68 Ok(ident)
69 }
70 Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
71 self.span(),
72 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel does not support column names with whitespaces yet",
self.field_name))
})format!(
73 "expected valid identifier, found `{0}`. \
74 Diesel does not support column names with whitespaces yet",
75 self.field_name
76 ),
77 )),
78 Err(_e) => Err(syn::Error::new(
79 self.span(),
80 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel automatically renames invalid identifiers, perhaps you meant to write `{0}_`?",
self.field_name))
})format!(
81 "expected valid identifier, found `{0}`. \
82 Diesel automatically renames invalid identifiers, \
83 perhaps you meant to write `{0}_`?",
84 self.field_name
85 ),
86 )),
87 }
88 }
89}
90
91impl ToTokens for SqlIdentifier {
92 fn to_tokens(&self, tokens: &mut TokenStream) {
93 if self.field_name.starts_with("r#") {
94 Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
95 } else {
96 Ident::new(&self.field_name, self.span).to_tokens(tokens)
97 }
98 }
99}
100
101impl Display for SqlIdentifier {
102 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
103 let mut start = 0;
104 if self.field_name.starts_with("r#") {
105 start = 2;
106 }
107 f.write_str(&self.field_name[start..])
108 }
109}
110
111impl PartialEq<Ident> for SqlIdentifier {
112 fn eq(&self, other: &Ident) -> bool {
113 *other == self.field_name
114 }
115}
116
117impl From<&'_ Ident> for SqlIdentifier {
118 fn from(ident: &'_ Ident) -> Self {
119 use syn::ext::IdentExt;
120 let ident = ident.unraw();
121 Self {
122 span: ident.span(),
123 field_name: ident.to_string(),
124 }
125 }
126}
127
128impl Parse for SqlIdentifier {
129 fn parse(input: ParseStream) -> Result<Self> {
130 let fork = input.fork();
131
132 if let Ok(ident) = fork.parse::<Ident>() {
133 input.advance_to(&fork);
134 Ok((&ident).into())
135 } else {
136 let name = input.parse::<LitStr>()?;
137 Ok(Self {
138 field_name: name.value(),
139 span: name.span(),
140 })
141 }
142 }
143}
144
145impl Parse for FieldAttr {
146 fn parse(input: ParseStream) -> Result<Self> {
147 let name: Ident = input.parse()?;
148 let name_str = name.to_string();
149
150 match &*name_str {
151 "embed" => Ok(FieldAttr::Embed(name)),
152 "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
153 "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
154
155 "column_name" => Ok(FieldAttr::ColumnName(
156 name,
157 parse_eq(input, COLUMN_NAME_NOTE)?,
158 )),
159 "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
160 "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
161 name,
162 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
163 )),
164 "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
165 name,
166 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
167 )),
168 "serialize_as" => Ok(FieldAttr::SerializeAs(
169 name,
170 parse_eq_type(input, SERIALIZE_AS_NOTE)?,
171 )),
172 "deserialize_as" => Ok(FieldAttr::DeserializeAs(
173 name,
174 parse_eq_type(input, DESERIALIZE_AS_NOTE)?,
175 )),
176 "select_expression" => Ok(FieldAttr::SelectExpression(
177 name,
178 parse_eq(input, SELECT_EXPRESSION_NOTE)?,
179 )),
180 "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
181 name,
182 parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
183 )),
184 "rename" => Ok(FieldAttr::Rename(name, parse_eq(input, RENAME_NOTE)?)),
185 _ => Err(unknown_attribute(
186 &name,
187 &[
188 "embed",
189 "skip_insertion",
190 "column_name",
191 "sql_type",
192 "treat_none_as_default_value",
193 "treat_none_as_null",
194 "serialize_as",
195 "deserialize_as",
196 "select_expression",
197 "select_expression_type",
198 "rename",
199 ],
200 )),
201 }
202 }
203}
204
205impl MySpanned for FieldAttr {
206 fn span(&self) -> Span {
207 match self {
208 FieldAttr::Embed(ident)
209 | FieldAttr::SkipInsertion(ident)
210 | FieldAttr::SkipUpdate(ident)
211 | FieldAttr::ColumnName(ident, _)
212 | FieldAttr::SqlType(ident, _)
213 | FieldAttr::TreatNoneAsNull(ident, _)
214 | FieldAttr::TreatNoneAsDefaultValue(ident, _)
215 | FieldAttr::SerializeAs(ident, _)
216 | FieldAttr::DeserializeAs(ident, _)
217 | FieldAttr::SelectExpression(ident, _)
218 | FieldAttr::SelectExpressionType(ident, _)
219 | FieldAttr::Rename(ident, _) => ident.span(),
220 }
221 }
222}
223
224#[allow(clippy::large_enum_variant)]
225pub enum StructAttr {
226 Aggregate(Ident),
227 NotSized(Ident),
228 ForeignDerive(Ident),
229 EnumType(Ident),
230
231 TableName(Ident, Path),
232 SqlType(Ident, TypePath),
233 TreatNoneAsDefaultValue(Ident, LitBool),
234 TreatNoneAsNull(Ident, LitBool),
235
236 BelongsTo(Ident, BelongsTo),
237 MysqlType(Ident, MysqlType),
238 SqliteType(Ident, SqliteType),
239 PostgresType(Ident, PostgresType),
240 PrimaryKey(Ident, Punctuated<Ident, Comma>),
241 CheckForBackend(Ident, CheckForBackend),
242 BaseQuery(Ident, Expr),
243 BaseQueryType(Ident, Type),
244 RenameAll(Ident, RenameVariants),
245}
246
247impl Parse for StructAttr {
248 fn parse(input: ParseStream) -> Result<Self> {
249 let name: Ident = input.parse()?;
250 let name_str = name.to_string();
251
252 match &*name_str {
253 "aggregate" => Ok(StructAttr::Aggregate(name)),
254 "not_sized" => Ok(StructAttr::NotSized(name)),
255 "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
256 "enum_type" => Ok(StructAttr::EnumType(name)),
257
258 "table_name" => Ok(StructAttr::TableName(
259 name,
260 parse_eq(input, TABLE_NAME_NOTE)?,
261 )),
262 "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
263 "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
264 name,
265 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
266 )),
267 "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
268 name,
269 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
270 )),
271
272 "belongs_to" => Ok(StructAttr::BelongsTo(
273 name,
274 parse_paren(input, BELONGS_TO_NOTE)?,
275 )),
276 "mysql_type" => Ok(StructAttr::MysqlType(
277 name,
278 parse_paren(input, MYSQL_TYPE_NOTE)?,
279 )),
280 "sqlite_type" => Ok(StructAttr::SqliteType(
281 name,
282 parse_paren(input, SQLITE_TYPE_NOTE)?,
283 )),
284 "postgres_type" => Ok(StructAttr::PostgresType(
285 name,
286 parse_paren(input, POSTGRES_TYPE_NOTE)?,
287 )),
288 "primary_key" => Ok(StructAttr::PrimaryKey(
289 name,
290 parse_paren_list(input, "key1, key2", ::syn::token::Commasyn::Token![,])?,
291 )),
292 "check_for_backend" => {
293 let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
294 CheckForBackend::Disabled(
295 parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
296 )
297 } else {
298 CheckForBackend::Backends(parse_paren_list(
299 input,
300 CHECK_FOR_BACKEND_NOTE,
301 ::syn::token::Commasyn::Token![,],
302 )?)
303 };
304 Ok(StructAttr::CheckForBackend(name, value))
305 }
306 "base_query" => Ok(StructAttr::BaseQuery(
307 name,
308 parse_eq(input, BASE_QUERY_NOTE)?,
309 )),
310 "base_query_type" => Ok(StructAttr::BaseQueryType(
311 name,
312 parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
313 )),
314 "rename_all" => Ok(StructAttr::RenameAll(
315 name,
316 parse_eq(input, RENAME_ALL_NOTE)?,
317 )),
318 _ => Err(unknown_attribute(
319 &name,
320 &[
321 "aggregate",
322 "not_sized",
323 "foreign_derive",
324 "table_name",
325 "sql_type",
326 "treat_none_as_default_value",
327 "treat_none_as_null",
328 "belongs_to",
329 "mysql_type",
330 "sqlite_type",
331 "postgres_type",
332 "primary_key",
333 "check_for_backend",
334 "base_query",
335 "base_query_type",
336 "enum_type",
337 "rename_all",
338 ],
339 )),
340 }
341 }
342}
343
344impl MySpanned for StructAttr {
345 fn span(&self) -> Span {
346 match self {
347 StructAttr::Aggregate(ident)
348 | StructAttr::NotSized(ident)
349 | StructAttr::ForeignDerive(ident)
350 | StructAttr::EnumType(ident)
351 | StructAttr::TableName(ident, _)
352 | StructAttr::SqlType(ident, _)
353 | StructAttr::TreatNoneAsDefaultValue(ident, _)
354 | StructAttr::TreatNoneAsNull(ident, _)
355 | StructAttr::BelongsTo(ident, _)
356 | StructAttr::MysqlType(ident, _)
357 | StructAttr::SqliteType(ident, _)
358 | StructAttr::PostgresType(ident, _)
359 | StructAttr::CheckForBackend(ident, _)
360 | StructAttr::BaseQuery(ident, _)
361 | StructAttr::BaseQueryType(ident, _)
362 | StructAttr::PrimaryKey(ident, _)
363 | StructAttr::RenameAll(ident, _) => ident.span(),
364 }
365 }
366}
367
368pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
369where
370 T: Parse + ParseDeprecated + MySpanned,
371{
372 let mut out = Vec::new();
373 for attr in attrs {
374 if attr.meta.path().is_ident("diesel") {
375 let map = attr
376 .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
377 .into_iter()
378 .map(|a| AttributeSpanWrapper {
379 ident_span: a.span(),
380 item: a,
381 attribute_span: attr.meta.span(),
382 });
383 out.extend(map);
384 } else if truecfg!(all(
385 not(feature = "without-deprecated"),
386 feature = "with-deprecated"
387 )) {
388 let path = attr.meta.path();
389 let ident = path.get_ident().map(|f| f.to_string());
390
391 if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
392 | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
393 ident.as_deref().unwrap_or_default()
394 {
395 let m = &attr.meta;
396 let ts = {
let mut _s = ::quote::__private::TokenStream::new();
::quote::ToTokens::to_tokens(&m, &mut _s);
_s
}quote::quote!(#m).into();
397 let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
398
399 if let Some(value) = value {
400 out.push(AttributeSpanWrapper {
401 ident_span: value.span(),
402 item: value,
403 attribute_span: attr.meta.span(),
404 });
405 }
406 }
407 }
408 }
409 Ok(out)
410}
411
412struct DisabledCheckForBackend {
413 value: LitBool,
414}
415
416impl syn::parse::Parse for DisabledCheckForBackend {
417 fn parse(input: ParseStream) -> Result<Self> {
418 let ident = input.parse::<Ident>()?;
419 if ident != "disable" {
420 return Err(syn::Error::new(
421 ident.span(),
422 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected `disable`, but got `{0}`",
ident))
})format!("expected `disable`, but got `{ident}`"),
423 ));
424 }
425 let lit = parse_eq::<LitBool>(input, "")?;
426 if !lit.value {
427 return Err(syn::Error::new(
428 lit.span(),
429 "only `true` is accepted in this position. \
430 If you want to enable these checks, just skip the attribute entirely",
431 ));
432 }
433 Ok(Self { value: lit })
434 }
435}
436
437#[derive(#[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::fmt::Debug for RenameVariants {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::write_str(f,
match self {
RenameVariants::LowerCase => "LowerCase",
RenameVariants::UpperCase => "UpperCase",
RenameVariants::PascalCase => "PascalCase",
RenameVariants::CamelCase => "CamelCase",
RenameVariants::SnakeCase => "SnakeCase",
RenameVariants::ScreamingSnakeCase => "ScreamingSnakeCase",
RenameVariants::KebabCase => "KebabCase",
RenameVariants::ScreamingKebabCase => "ScreamingKebabCase",
})
}
}Debug, #[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::clone::Clone for RenameVariants {
#[inline]
fn clone(&self) -> RenameVariants { *self }
}Clone, #[automatically_derived]
#[allow(clippy::enum_variant_names)]
impl ::core::marker::Copy for RenameVariants { }Copy)]
438#[allow(clippy::enum_variant_names)]
439pub enum RenameVariants {
440 LowerCase,
441 UpperCase,
442 PascalCase,
443 CamelCase,
444 SnakeCase,
445 ScreamingSnakeCase,
446 KebabCase,
447 ScreamingKebabCase,
448}
449
450impl syn::parse::Parse for RenameVariants {
451 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
452 let lit = input.parse::<syn::LitStr>()?;
453 let v = lit.value();
454 let v = match v.as_str() {
455 "lowercase" => Self::LowerCase,
456 "UPPERCASE" => Self::UpperCase,
457 "PascalCase" => Self::PascalCase,
458 "camelCase" => Self::CamelCase,
459 "snake_case" => Self::SnakeCase,
460 "SCREAMING_SNAKE_CASE" => Self::ScreamingSnakeCase,
461 "kebab-case" => Self::KebabCase,
462 "SCREAMING-KEBAB-CASE" => Self::ScreamingKebabCase,
463 s => {
464 return Err(syn::Error::new(
465 lit.span(),
466 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("got invalid case identifier: `{0}`\nonly: `lowercase`, `UPPERCASE`, `PascalCase`, `camelCase`, `snake_case`, `SCREAMING_SNAKE_CASE`, `kebab-case` and `SCREAMING-KEBAB-CASE` are supported",
s))
})format!(
467 "got invalid case identifier: `{s}`\n\
468 only: `lowercase`, `UPPERCASE`, `PascalCase`, `camelCase`, \
469 `snake_case`, `SCREAMING_SNAKE_CASE`, `kebab-case` \
470 and `SCREAMING-KEBAB-CASE` are supported"
471 ),
472 ));
473 }
474 };
475 Ok(v)
476 }
477}
478
479impl RenameVariants {
480 pub fn apply_case_to_enum_variant(&self, input: String) -> String {
481 match self {
482 Self::PascalCase => input,
484 Self::LowerCase => input.to_ascii_lowercase(),
485 Self::UpperCase => input.to_ascii_uppercase(),
486 Self::CamelCase => input[..1].to_ascii_lowercase() + &input[1..],
487 Self::SnakeCase => {
488 let mut snake = String::new();
489 for (i, ch) in input.char_indices() {
490 if i > 0 && ch.is_uppercase() {
491 snake.push('_');
492 }
493 snake.push(ch.to_ascii_lowercase());
494 }
495 snake
496 }
497 Self::ScreamingSnakeCase => Self::SnakeCase
498 .apply_case_to_enum_variant(input)
499 .to_ascii_uppercase(),
500 Self::KebabCase => Self::SnakeCase
501 .apply_case_to_enum_variant(input)
502 .replace('_', "-"),
503 Self::ScreamingKebabCase => Self::ScreamingSnakeCase
504 .apply_case_to_enum_variant(input)
505 .replace('_', "-"),
506 }
507 }
508}