1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::model::CheckForBackend;
14use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
15use crate::util::{
16 parse_eq, parse_paren, unknown_attribute, BASE_QUERY_NOTE, BASE_QUERY_TYPE_NOTE,
17 BELONGS_TO_NOTE, COLUMN_NAME_NOTE, DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE,
18 SELECT_EXPRESSION_NOTE, SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE,
19 SQL_TYPE_NOTE, TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
20};
21
22use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
23
24pub trait MySpanned {
25 fn span(&self) -> Span;
26}
27
28#[derive(Clone)]
29pub struct AttributeSpanWrapper<T> {
30 pub item: T,
31 pub attribute_span: Span,
32 pub ident_span: Span,
33}
34
35pub enum FieldAttr {
36 Embed(Ident),
37 SkipInsertion(Ident),
38 SkipUpdate(Ident),
39
40 ColumnName(Ident, SqlIdentifier),
41 SqlType(Ident, TypePath),
42 TreatNoneAsDefaultValue(Ident, LitBool),
43 TreatNoneAsNull(Ident, LitBool),
44
45 SerializeAs(Ident, TypePath),
46 DeserializeAs(Ident, TypePath),
47 SelectExpression(Ident, Expr),
48 SelectExpressionType(Ident, Type),
49}
50
51#[derive(Clone)]
52pub struct SqlIdentifier {
53 field_name: String,
54 span: Span,
55}
56
57impl SqlIdentifier {
58 pub fn span(&self) -> Span {
59 self.span
60 }
61
62 pub fn to_ident(&self) -> Result<Ident> {
63 match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
64 Ok(mut ident) => {
65 ident.set_span(self.span);
66 Ok(ident)
67 }
68 Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
69 self.span(),
70 format!(
71 "expected valid identifier, found `{0}`. \
72 Diesel does not support column names with whitespaces yet",
73 self.field_name
74 ),
75 )),
76 Err(_e) => Err(syn::Error::new(
77 self.span(),
78 format!(
79 "expected valid identifier, found `{0}`. \
80 Diesel automatically renames invalid identifiers, \
81 perhaps you meant to write `{0}_`?",
82 self.field_name
83 ),
84 )),
85 }
86 }
87}
88
89impl ToTokens for SqlIdentifier {
90 fn to_tokens(&self, tokens: &mut TokenStream) {
91 if self.field_name.starts_with("r#") {
92 Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
93 } else {
94 Ident::new(&self.field_name, self.span).to_tokens(tokens)
95 }
96 }
97}
98
99impl Display for SqlIdentifier {
100 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
101 let mut start = 0;
102 if self.field_name.starts_with("r#") {
103 start = 2;
104 }
105 f.write_str(&self.field_name[start..])
106 }
107}
108
109impl PartialEq<Ident> for SqlIdentifier {
110 fn eq(&self, other: &Ident) -> bool {
111 *other == self.field_name
112 }
113}
114
115impl From<&'_ Ident> for SqlIdentifier {
116 fn from(ident: &'_ Ident) -> Self {
117 use syn::ext::IdentExt;
118 let ident = ident.unraw();
119 Self {
120 span: ident.span(),
121 field_name: ident.to_string(),
122 }
123 }
124}
125
126impl Parse for SqlIdentifier {
127 fn parse(input: ParseStream) -> Result<Self> {
128 let fork = input.fork();
129
130 if let Ok(ident) = fork.parse::<Ident>() {
131 input.advance_to(&fork);
132 Ok((&ident).into())
133 } else {
134 let name = input.parse::<LitStr>()?;
135 Ok(Self {
136 field_name: name.value(),
137 span: name.span(),
138 })
139 }
140 }
141}
142
143impl Parse for FieldAttr {
144 fn parse(input: ParseStream) -> Result<Self> {
145 let name: Ident = input.parse()?;
146 let name_str = name.to_string();
147
148 match &*name_str {
149 "embed" => Ok(FieldAttr::Embed(name)),
150 "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
151 "skip_update" => Ok(FieldAttr::SkipUpdate(name)),
152
153 "column_name" => Ok(FieldAttr::ColumnName(
154 name,
155 parse_eq(input, COLUMN_NAME_NOTE)?,
156 )),
157 "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
158 "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
159 name,
160 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
161 )),
162 "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
163 name,
164 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
165 )),
166 "serialize_as" => Ok(FieldAttr::SerializeAs(
167 name,
168 parse_eq(input, SERIALIZE_AS_NOTE)?,
169 )),
170 "deserialize_as" => Ok(FieldAttr::DeserializeAs(
171 name,
172 parse_eq(input, DESERIALIZE_AS_NOTE)?,
173 )),
174 "select_expression" => Ok(FieldAttr::SelectExpression(
175 name,
176 parse_eq(input, SELECT_EXPRESSION_NOTE)?,
177 )),
178 "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
179 name,
180 parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
181 )),
182 _ => Err(unknown_attribute(
183 &name,
184 &[
185 "embed",
186 "skip_insertion",
187 "column_name",
188 "sql_type",
189 "treat_none_as_default_value",
190 "treat_none_as_null",
191 "serialize_as",
192 "deserialize_as",
193 "select_expression",
194 "select_expression_type",
195 ],
196 )),
197 }
198 }
199}
200
201impl MySpanned for FieldAttr {
202 fn span(&self) -> Span {
203 match self {
204 FieldAttr::Embed(ident)
205 | FieldAttr::SkipInsertion(ident)
206 | FieldAttr::SkipUpdate(ident)
207 | FieldAttr::ColumnName(ident, _)
208 | FieldAttr::SqlType(ident, _)
209 | FieldAttr::TreatNoneAsNull(ident, _)
210 | FieldAttr::TreatNoneAsDefaultValue(ident, _)
211 | FieldAttr::SerializeAs(ident, _)
212 | FieldAttr::DeserializeAs(ident, _)
213 | FieldAttr::SelectExpression(ident, _)
214 | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
215 }
216 }
217}
218
219#[allow(clippy::large_enum_variant)]
220pub enum StructAttr {
221 Aggregate(Ident),
222 NotSized(Ident),
223 ForeignDerive(Ident),
224
225 TableName(Ident, Path),
226 SqlType(Ident, TypePath),
227 TreatNoneAsDefaultValue(Ident, LitBool),
228 TreatNoneAsNull(Ident, LitBool),
229
230 BelongsTo(Ident, BelongsTo),
231 MysqlType(Ident, MysqlType),
232 SqliteType(Ident, SqliteType),
233 PostgresType(Ident, PostgresType),
234 PrimaryKey(Ident, Punctuated<Ident, Comma>),
235 CheckForBackend(Ident, CheckForBackend),
236 BaseQuery(Ident, Expr),
237 BaseQueryType(Ident, Type),
238}
239
240impl Parse for StructAttr {
241 fn parse(input: ParseStream) -> Result<Self> {
242 let name: Ident = input.parse()?;
243 let name_str = name.to_string();
244
245 match &*name_str {
246 "aggregate" => Ok(StructAttr::Aggregate(name)),
247 "not_sized" => Ok(StructAttr::NotSized(name)),
248 "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
249
250 "table_name" => Ok(StructAttr::TableName(
251 name,
252 parse_eq(input, TABLE_NAME_NOTE)?,
253 )),
254 "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
255 "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
256 name,
257 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
258 )),
259 "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
260 name,
261 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
262 )),
263
264 "belongs_to" => Ok(StructAttr::BelongsTo(
265 name,
266 parse_paren(input, BELONGS_TO_NOTE)?,
267 )),
268 "mysql_type" => Ok(StructAttr::MysqlType(
269 name,
270 parse_paren(input, MYSQL_TYPE_NOTE)?,
271 )),
272 "sqlite_type" => Ok(StructAttr::SqliteType(
273 name,
274 parse_paren(input, SQLITE_TYPE_NOTE)?,
275 )),
276 "postgres_type" => Ok(StructAttr::PostgresType(
277 name,
278 parse_paren(input, POSTGRES_TYPE_NOTE)?,
279 )),
280 "primary_key" => Ok(StructAttr::PrimaryKey(
281 name,
282 parse_paren_list(input, "key1, key2", syn::Token![,])?,
283 )),
284 "check_for_backend" => {
285 let value = if parse_paren::<DisabledCheckForBackend>(&input.fork(), "").is_ok() {
286 CheckForBackend::Disabled(
287 parse_paren::<DisabledCheckForBackend>(&input.fork(), "")?.value,
288 )
289 } else {
290 CheckForBackend::Backends(parse_paren_list(
291 input,
292 CHECK_FOR_BACKEND_NOTE,
293 syn::Token![,],
294 )?)
295 };
296 Ok(StructAttr::CheckForBackend(name, value))
297 }
298 "base_query" => Ok(StructAttr::BaseQuery(
299 name,
300 parse_eq(input, BASE_QUERY_NOTE)?,
301 )),
302 "base_query_type" => Ok(StructAttr::BaseQueryType(
303 name,
304 parse_eq(input, BASE_QUERY_TYPE_NOTE)?,
305 )),
306 _ => Err(unknown_attribute(
307 &name,
308 &[
309 "aggregate",
310 "not_sized",
311 "foreign_derive",
312 "table_name",
313 "sql_type",
314 "treat_none_as_default_value",
315 "treat_none_as_null",
316 "belongs_to",
317 "mysql_type",
318 "sqlite_type",
319 "postgres_type",
320 "primary_key",
321 "check_for_backend",
322 "base_query",
323 "base_query_type",
324 ],
325 )),
326 }
327 }
328}
329
330impl MySpanned for StructAttr {
331 fn span(&self) -> Span {
332 match self {
333 StructAttr::Aggregate(ident)
334 | StructAttr::NotSized(ident)
335 | StructAttr::ForeignDerive(ident)
336 | StructAttr::TableName(ident, _)
337 | StructAttr::SqlType(ident, _)
338 | StructAttr::TreatNoneAsDefaultValue(ident, _)
339 | StructAttr::TreatNoneAsNull(ident, _)
340 | StructAttr::BelongsTo(ident, _)
341 | StructAttr::MysqlType(ident, _)
342 | StructAttr::SqliteType(ident, _)
343 | StructAttr::PostgresType(ident, _)
344 | StructAttr::CheckForBackend(ident, _)
345 | StructAttr::BaseQuery(ident, _)
346 | StructAttr::BaseQueryType(ident, _)
347 | StructAttr::PrimaryKey(ident, _) => ident.span(),
348 }
349 }
350}
351
352pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
353where
354 T: Parse + ParseDeprecated + MySpanned,
355{
356 let mut out = Vec::new();
357 for attr in attrs {
358 if attr.meta.path().is_ident("diesel") {
359 let map = attr
360 .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
361 .into_iter()
362 .map(|a| AttributeSpanWrapper {
363 ident_span: a.span(),
364 item: a,
365 attribute_span: attr.meta.span(),
366 });
367 out.extend(map);
368 } else if cfg!(all(
369 not(feature = "without-deprecated"),
370 feature = "with-deprecated"
371 )) {
372 let path = attr.meta.path();
373 let ident = path.get_ident().map(|f| f.to_string());
374
375 if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
376 | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
377 ident.as_deref().unwrap_or_default()
378 {
379 let m = &attr.meta;
380 let ts = quote::quote!(#m).into();
381 let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
382
383 if let Some(value) = value {
384 out.push(AttributeSpanWrapper {
385 ident_span: value.span(),
386 item: value,
387 attribute_span: attr.meta.span(),
388 });
389 }
390 }
391 }
392 }
393 Ok(out)
394}
395
396struct DisabledCheckForBackend {
397 value: LitBool,
398}
399
400impl syn::parse::Parse for DisabledCheckForBackend {
401 fn parse(input: ParseStream) -> Result<Self> {
402 let ident = input.parse::<Ident>()?;
403 if ident != "disable" {
404 return Err(syn::Error::new(
405 ident.span(),
406 format!("expected `disable`, but got `{ident}`"),
407 ));
408 }
409 let lit = parse_eq::<LitBool>(input, "")?;
410 if !lit.value {
411 return Err(syn::Error::new(
412 lit.span(),
413 "only `true` is accepted in this position. \
414 If you want to enable these checks, just skip the attribute entirely",
415 ));
416 }
417 Ok(Self { value: lit })
418 }
419}