1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16 BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE, BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE,
17 MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE, SELECT_EXPRESSION_TYPE_NOTE,
18 SERIALIZE_AS_NOTE, SQL_TYPE_NOTE, SQLITE_TYPE_NOTE, TABLE_NAME_NOTE,
19 TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE, parse_eq, parse_paren,
20 unknown_attribute,
21};
22
23use crate::util::{CHECK_FOR_BACKEND_NOTE, parse_paren_list};
24
25pub trait MySpanned {
26 fn span(&self) -> Span;
27}
28
29#[derive(#[automatically_derived]
impl<T: ::core::clone::Clone> ::core::clone::Clone for AttributeSpanWrapper<T>
{
#[inline]
fn clone(&self) -> AttributeSpanWrapper<T> {
AttributeSpanWrapper {
item: ::core::clone::Clone::clone(&self.item),
attribute_span: ::core::clone::Clone::clone(&self.attribute_span),
ident_span: ::core::clone::Clone::clone(&self.ident_span),
}
}
}Clone)]
30pub struct AttributeSpanWrapper<T> {
31 pub item: T,
32 pub attribute_span: Span,
33 pub ident_span: Span,
34}
35
36pub enum FieldAttr {
37 Embed(Ident),
38 SkipInsertion(Ident),
39 SkipUpdate(Ident),
40
41 ColumnName(Ident, SqlIdentifier),
42 SqlType(Ident, TypePath),
43 TreatNoneAsDefaultValue(Ident, LitBool),
44 TreatNoneAsNull(Ident, LitBool),
45
46 SerializeAs(Ident, TypePath),
47 DeserializeAs(Ident, TypePath),
48 SelectExpression(Ident, Expr),
49 SelectExpressionType(Ident, Type),
50}
51
52#[derive(#[automatically_derived]
impl ::core::clone::Clone for SqlIdentifier {
#[inline]
fn clone(&self) -> SqlIdentifier {
SqlIdentifier {
field_name: ::core::clone::Clone::clone(&self.field_name),
span: ::core::clone::Clone::clone(&self.span),
}
}
}Clone)]
53pub struct SqlIdentifier {
54 field_name: String,
55 span: Span,
56}
57
58impl SqlIdentifier {
59 pub fn span(&self) -> Span {
60 self.span
61 }
62
63 pub fn to_ident(&self) -> Result<Ident> {
64 match syn::parse_str::<Ident>(&::alloc::__export::must_use({
::alloc::fmt::format(format_args!("r#{0}", self.field_name))
})format!("r#{}", self.field_name)) {
65 Ok(mut ident) => {
66 ident.set_span(self.span);
67 Ok(ident)
68 }
69 Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
70 self.span(),
71 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel does not support column names with whitespaces yet",
self.field_name))
})format!(
72 "expected valid identifier, found `{0}`. \
73 Diesel does not support column names with whitespaces yet",
74 self.field_name
75 ),
76 )),
77 Err(_e) => Err(syn::Error::new(
78 self.span(),
79 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected valid identifier, found `{0}`. Diesel automatically renames invalid identifiers, perhaps you meant to write `{0}_`?",
self.field_name))
})format!(
80 "expected valid identifier, found `{0}`. \
81 Diesel automatically renames invalid identifiers, \
82 perhaps you meant to write `{0}_`?",
83 self.field_name
84 ),
85 )),
86 }
87 }
88}
89
90impl ToTokens for SqlIdentifier {
91 fn to_tokens(&self, tokens: &mut TokenStream) {
92 if self.field_name.starts_with("r#") {
93 Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
94 } else {
95 Ident::new(&self.field_name, self.span).to_tokens(tokens)
96 }
97 }
98}
99
100impl Display for SqlIdentifier {
101 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
102 let mut start = 0;
103 if self.field_name.starts_with("r#") {
104 start = 2;
105 }
106 f.write_str(&self.field_name[start..])
107 }
108}
109
110impl PartialEq<Ident> for SqlIdentifier {
111 fn eq(&self, other: &Ident) -> bool {
112 *other == self.field_name
113 }
114}
115
116impl From<&'_ Ident> for SqlIdentifier {
117 fn from(ident: &'_ Ident) -> Self {
118 use syn::ext::IdentExt;
119 let ident = ident.unraw();
120 Self {
121 span: ident.span(),
122 field_name: ident.to_string(),
123 }
124 }
125}
126
127impl Parse for SqlIdentifier {
128 fn parse(input: ParseStream) -> Result<Self> {
129 let fork = input.fork();
130
131 if let Ok(ident) = fork.parse::<Ident>() {
132 input.advance_to(&fork);
133 Ok((&ident).into())
134 } else {
135 let name = input.parse::<LitStr>()?;
136 Ok(Self {
137 field_name: name.value(),
138 span: name.span(),
139 })
140 }
141 }
142}
143
144impl Parse for FieldAttr {
145 fn parse(input: ParseStream) -> Result<Self> {
146 let name: Ident = input.parse()?;
147 let name_str = name.to_string();
148
149 match &*name_str {
150 "embed" => Ok(FieldAttr::Embed(name)),
151 "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
152 "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
153
154 "column_name" => Ok(FieldAttr::ColumnName(
155 name,
156 parse_eq(input, COLUMN_NAME_NOTE)?,
157 )),
158 "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
159 "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
160 name,
161 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
162 )),
163 "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
164 name,
165 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
166 )),
167 "serialize_as" => Ok(FieldAttr::SerializeAs(
168 name,
169 parse_eq(input, SERIALIZE_AS_NOTE)?,
170 )),
171 "deserialize_as" => Ok(FieldAttr::DeserializeAs(
172 name,
173 parse_eq(input, DESERIALIZE_AS_NOTE)?,
174 )),
175 "select_expression" => Ok(FieldAttr::SelectExpression(
176 name,
177 parse_eq(input, SELECT_EXPRESSION_NOTE)?,
178 )),
179 "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
180 name,
181 parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
182 )),
183 _ => Err(unknown_attribute(
184 &name,
185 &[
186 "embed",
187 "skip_insertion",
188 "column_name",
189 "sql_type",
190 "treat_none_as_default_value",
191 "treat_none_as_null",
192 "serialize_as",
193 "deserialize_as",
194 "select_expression",
195 "select_expression_type",
196 ],
197 )),
198 }
199 }
200}
201
202impl MySpanned for FieldAttr {
203 fn span(&self) -> Span {
204 match self {
205 FieldAttr::Embed(ident)
206 | FieldAttr::SkipInsertion(ident)
207 | FieldAttr::SkipUpdate(ident)
208 | FieldAttr::ColumnName(ident, _)
209 | FieldAttr::SqlType(ident, _)
210 | FieldAttr::TreatNoneAsNull(ident, _)
211 | FieldAttr::TreatNoneAsDefaultValue(ident, _)
212 | FieldAttr::SerializeAs(ident, _)
213 | FieldAttr::DeserializeAs(ident, _)
214 | FieldAttr::SelectExpression(ident, _)
215 | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
216 }
217 }
218}
219
220#[allow(clippy::large_enum_variant)]
221pub enum StructAttr {
222 Aggregate(Ident),
223 NotSized(Ident),
224 ForeignDerive(Ident),
225
226 TableName(Ident, Path),
227 SqlType(Ident, TypePath),
228 TreatNoneAsDefaultValue(Ident, LitBool),
229 TreatNoneAsNull(Ident, LitBool),
230
231 BelongsTo(Ident, BelongsTo),
232 MysqlType(Ident, MysqlType),
233 SqliteType(Ident, SqliteType),
234 PostgresType(Ident, PostgresType),
235 PrimaryKey(Ident, Punctuated<Ident, Comma>),
236 CheckForBackend(Ident, CheckForBackend),
237 BaseQuery(Ident, Expr),
238 BaseQueryType(Ident, Type),
239}
240
241impl Parse for StructAttr {
242 fn parse(input: ParseStream) -> Result<Self> {
243 let name: Ident = input.parse()?;
244 let name_str = name.to_string();
245
246 match &*name_str {
247 "aggregate" => Ok(StructAttr::Aggregate(name)),
248 "not_sized" => Ok(StructAttr::NotSized(name)),
249 "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
250
251 "table_name" => Ok(StructAttr::TableName(
252 name,
253 parse_eq(input, TABLE_NAME_NOTE)?,
254 )),
255 "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
256 "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
257 name,
258 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
259 )),
260 "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
261 name,
262 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
263 )),
264
265 "belongs_to" => Ok(StructAttr::BelongsTo(
266 name,
267 parse_paren(input, BELONGS_TO_NOTE)?,
268 )),
269 "mysql_type" => Ok(StructAttr::MysqlType(
270 name,
271 parse_paren(input, MYSQL_TYPE_NOTE)?,
272 )),
273 "sqlite_type" => Ok(StructAttr::SqliteType(
274 name,
275 parse_paren(input, SQLITE_TYPE_NOTE)?,
276 )),
277 "postgres_type" => Ok(StructAttr::PostgresType(
278 name,
279 parse_paren(input, POSTGRES_TYPE_NOTE)?,
280 )),
281 "primary_key" => Ok(StructAttr::PrimaryKey(
282 name,
283 parse_paren_list(input, "key1, key2", ::syn::token::Commasyn::Token![,])?,
284 )),
285 "check_for_backend" => {
286 let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
287 CheckForBackend::Disabled(
288 parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
289 )
290 } else {
291 CheckForBackend::Backends(parse_paren_list(
292 input,
293 CHECK_FOR_BACKEND_NOTE,
294 ::syn::token::Commasyn::Token![,],
295 )?)
296 };
297 Ok(StructAttr::CheckForBackend(name, value))
298 }
299 "base_query" => Ok(StructAttr::BaseQuery(
300 name,
301 parse_eq(input, BASE_QUERY_NOTE)?,
302 )),
303 "base_query_type" => Ok(StructAttr::BaseQueryType(
304 name,
305 parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
306 )),
307 _ => Err(unknown_attribute(
308 &name,
309 &[
310 "aggregate",
311 "not_sized",
312 "foreign_derive",
313 "table_name",
314 "sql_type",
315 "treat_none_as_default_value",
316 "treat_none_as_null",
317 "belongs_to",
318 "mysql_type",
319 "sqlite_type",
320 "postgres_type",
321 "primary_key",
322 "check_for_backend",
323 "base_query",
324 "base_query_type",
325 ],
326 )),
327 }
328 }
329}
330
331impl MySpanned for StructAttr {
332 fn span(&self) -> Span {
333 match self {
334 StructAttr::Aggregate(ident)
335 | StructAttr::NotSized(ident)
336 | StructAttr::ForeignDerive(ident)
337 | StructAttr::TableName(ident, _)
338 | StructAttr::SqlType(ident, _)
339 | StructAttr::TreatNoneAsDefaultValue(ident, _)
340 | StructAttr::TreatNoneAsNull(ident, _)
341 | StructAttr::BelongsTo(ident, _)
342 | StructAttr::MysqlType(ident, _)
343 | StructAttr::SqliteType(ident, _)
344 | StructAttr::PostgresType(ident, _)
345 | StructAttr::CheckForBackend(ident, _)
346 | StructAttr::BaseQuery(ident, _)
347 | StructAttr::BaseQueryType(ident, _)
348 | StructAttr::PrimaryKey(ident, _) => ident.span(),
349 }
350 }
351}
352
353pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
354where
355 T: Parse + ParseDeprecated + MySpanned,
356{
357 let mut out = Vec::new();
358 for attr in attrs {
359 if attr.meta.path().is_ident("diesel") {
360 let map = attr
361 .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
362 .into_iter()
363 .map(|a| AttributeSpanWrapper {
364 ident_span: a.span(),
365 item: a,
366 attribute_span: attr.meta.span(),
367 });
368 out.extend(map);
369 } else if truecfg!(all(
370 not(feature = "without-deprecated"),
371 feature = "with-deprecated"
372 )) {
373 let path = attr.meta.path();
374 let ident = path.get_ident().map(|f| f.to_string());
375
376 if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
377 | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
378 ident.as_deref().unwrap_or_default()
379 {
380 let m = &attr.meta;
381 let ts = {
let mut _s = ::quote::__private::TokenStream::new();
::quote::ToTokens::to_tokens(&m, &mut _s);
_s
}quote::quote!(#m).into();
382 let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
383
384 if let Some(value) = value {
385 out.push(AttributeSpanWrapper {
386 ident_span: value.span(),
387 item: value,
388 attribute_span: attr.meta.span(),
389 });
390 }
391 }
392 }
393 }
394 Ok(out)
395}
396
397struct DisabledCheckForBackend {
398 value: LitBool,
399}
400
401impl syn::parse::Parse for DisabledCheckForBackend {
402 fn parse(input: ParseStream) -> Result<Self> {
403 let ident = input.parse::<Ident>()?;
404 if ident != "disable" {
405 return Err(syn::Error::new(
406 ident.span(),
407 ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("expected `disable`, but got `{0}`",
ident))
})format!("expected `disable`, but got `{ident}`"),
408 ));
409 }
410 let lit = parse_eq::<LitBool>(input, "")?;
411 if !lit.value {
412 return Err(syn::Error::new(
413 lit.span(),
414 "only `true` is accepted in this position. \
415 If you want to enable these checks, just skip the attribute entirely",
416 ));
417 }
418 Ok(Self { value: lit })
419 }
420}