1use std::fmt::{Display, Formatter};
2
3use proc_macro2::{Span, TokenStream};
4use quote::ToTokens;
5use syn::parse::discouraged::Speculative;
6use syn::parse::{Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::token::Comma;
10use syn::{Attribute, Expr, Ident, LitBool, LitStr, Path, Type, TypePath};
11
12use crate::deprecated::ParseDeprecated;
13use crate::parsers::{BelongsTo, MysqlType, PostgresType, SqliteType};
14use crate::util::{
15 parse_eq, parse_paren, unknown_attribute, BELONGS_TO_NOTE, COLUMN_NAME_NOTE,
16 DESERIALIZE_AS_NOTE, MYSQL_TYPE_NOTE, POSTGRES_TYPE_NOTE, SELECT_EXPRESSION_NOTE,
17 SELECT_EXPRESSION_TYPE_NOTE, SERIALIZE_AS_NOTE, SQLITE_TYPE_NOTE, SQL_TYPE_NOTE,
18 TABLE_NAME_NOTE, TREAT_NONE_AS_DEFAULT_VALUE_NOTE, TREAT_NONE_AS_NULL_NOTE,
19};
20
21use crate::util::{parse_paren_list, CHECK_FOR_BACKEND_NOTE};
22
23pub trait MySpanned {
24 fn span(&self) -> Span;
25}
26
27pub struct AttributeSpanWrapper<T> {
28 pub item: T,
29 pub attribute_span: Span,
30 pub ident_span: Span,
31}
32
33pub enum FieldAttr {
34 Embed(Ident),
35 SkipInsertion(Ident),
36
37 ColumnName(Ident, SqlIdentifier),
38 SqlType(Ident, TypePath),
39 TreatNoneAsDefaultValue(Ident, LitBool),
40 TreatNoneAsNull(Ident, LitBool),
41
42 SerializeAs(Ident, TypePath),
43 DeserializeAs(Ident, TypePath),
44 SelectExpression(Ident, Expr),
45 SelectExpressionType(Ident, Type),
46}
47
48#[derive(Clone)]
49pub struct SqlIdentifier {
50 field_name: String,
51 span: Span,
52}
53
54impl SqlIdentifier {
55 pub fn span(&self) -> Span {
56 self.span
57 }
58
59 pub fn to_ident(&self) -> Result<Ident> {
60 match syn::parse_str::<Ident>(&format!("r#{}", self.field_name)) {
61 Ok(mut ident) => {
62 ident.set_span(self.span);
63 Ok(ident)
64 }
65 Err(_e) if self.field_name.contains(' ') => Err(syn::Error::new(
66 self.span(),
67 format!(
68 "Expected valid identifier, found `{0}`. \
69 Diesel does not support column names with whitespaces yet",
70 self.field_name
71 ),
72 )),
73 Err(_e) => Err(syn::Error::new(
74 self.span(),
75 format!(
76 "Expected valid identifier, found `{0}`. \
77 Diesel automatically renames invalid identifiers, \
78 perhaps you meant to write `{0}_`?",
79 self.field_name
80 ),
81 )),
82 }
83 }
84}
85
86impl ToTokens for SqlIdentifier {
87 fn to_tokens(&self, tokens: &mut TokenStream) {
88 if self.field_name.starts_with("r#") {
89 Ident::new_raw(&self.field_name[2..], self.span).to_tokens(tokens)
90 } else {
91 Ident::new(&self.field_name, self.span).to_tokens(tokens)
92 }
93 }
94}
95
96impl Display for SqlIdentifier {
97 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
98 let mut start = 0;
99 if self.field_name.starts_with("r#") {
100 start = 2;
101 }
102 f.write_str(&self.field_name[start..])
103 }
104}
105
106impl PartialEq<Ident> for SqlIdentifier {
107 fn eq(&self, other: &Ident) -> bool {
108 *other == self.field_name
109 }
110}
111
112impl From<&'_ Ident> for SqlIdentifier {
113 fn from(ident: &'_ Ident) -> Self {
114 use syn::ext::IdentExt;
115 let ident = ident.unraw();
116 Self {
117 span: ident.span(),
118 field_name: ident.to_string(),
119 }
120 }
121}
122
123impl Parse for SqlIdentifier {
124 fn parse(input: ParseStream) -> Result<Self> {
125 let fork = input.fork();
126
127 if let Ok(ident) = fork.parse::<Ident>() {
128 input.advance_to(&fork);
129 Ok((&ident).into())
130 } else {
131 let name = input.parse::<LitStr>()?;
132 Ok(Self {
133 field_name: name.value(),
134 span: name.span(),
135 })
136 }
137 }
138}
139
140impl Parse for FieldAttr {
141 fn parse(input: ParseStream) -> Result<Self> {
142 let name: Ident = input.parse()?;
143 let name_str = name.to_string();
144
145 match &*name_str {
146 "embed" => Ok(FieldAttr::Embed(name)),
147 "skip_insertion" => Ok(FieldAttr::SkipInsertion(name)),
148
149 "column_name" => Ok(FieldAttr::ColumnName(
150 name,
151 parse_eq(input, COLUMN_NAME_NOTE)?,
152 )),
153 "sql_type" => Ok(FieldAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
154 "treat_none_as_default_value" => Ok(FieldAttr::TreatNoneAsDefaultValue(
155 name,
156 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
157 )),
158 "treat_none_as_null" => Ok(FieldAttr::TreatNoneAsNull(
159 name,
160 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
161 )),
162 "serialize_as" => Ok(FieldAttr::SerializeAs(
163 name,
164 parse_eq(input, SERIALIZE_AS_NOTE)?,
165 )),
166 "deserialize_as" => Ok(FieldAttr::DeserializeAs(
167 name,
168 parse_eq(input, DESERIALIZE_AS_NOTE)?,
169 )),
170 "select_expression" => Ok(FieldAttr::SelectExpression(
171 name,
172 parse_eq(input, SELECT_EXPRESSION_NOTE)?,
173 )),
174 "select_expression_type" => Ok(FieldAttr::SelectExpressionType(
175 name,
176 parse_eq(input, SELECT_EXPRESSION_TYPE_NOTE)?,
177 )),
178 _ => Err(unknown_attribute(
179 &name,
180 &[
181 "embed",
182 "skip_insertion",
183 "column_name",
184 "sql_type",
185 "treat_none_as_default_value",
186 "treat_none_as_null",
187 "serialize_as",
188 "deserialize_as",
189 "select_expression",
190 "select_expression_type",
191 ],
192 )),
193 }
194 }
195}
196
197impl MySpanned for FieldAttr {
198 fn span(&self) -> Span {
199 match self {
200 FieldAttr::Embed(ident)
201 | FieldAttr::SkipInsertion(ident)
202 | FieldAttr::ColumnName(ident, _)
203 | FieldAttr::SqlType(ident, _)
204 | FieldAttr::TreatNoneAsNull(ident, _)
205 | FieldAttr::TreatNoneAsDefaultValue(ident, _)
206 | FieldAttr::SerializeAs(ident, _)
207 | FieldAttr::DeserializeAs(ident, _)
208 | FieldAttr::SelectExpression(ident, _)
209 | FieldAttr::SelectExpressionType(ident, _) => ident.span(),
210 }
211 }
212}
213
214#[allow(clippy::large_enum_variant)]
215pub enum StructAttr {
216 Aggregate(Ident),
217 NotSized(Ident),
218 ForeignDerive(Ident),
219
220 TableName(Ident, Path),
221 SqlType(Ident, TypePath),
222 TreatNoneAsDefaultValue(Ident, LitBool),
223 TreatNoneAsNull(Ident, LitBool),
224
225 BelongsTo(Ident, BelongsTo),
226 MysqlType(Ident, MysqlType),
227 SqliteType(Ident, SqliteType),
228 PostgresType(Ident, PostgresType),
229 PrimaryKey(Ident, Punctuated<Ident, Comma>),
230 CheckForBackend(Ident, syn::punctuated::Punctuated<TypePath, syn::Token![,]>),
231}
232
233impl Parse for StructAttr {
234 fn parse(input: ParseStream) -> Result<Self> {
235 let name: Ident = input.parse()?;
236 let name_str = name.to_string();
237
238 match &*name_str {
239 "aggregate" => Ok(StructAttr::Aggregate(name)),
240 "not_sized" => Ok(StructAttr::NotSized(name)),
241 "foreign_derive" => Ok(StructAttr::ForeignDerive(name)),
242
243 "table_name" => Ok(StructAttr::TableName(
244 name,
245 parse_eq(input, TABLE_NAME_NOTE)?,
246 )),
247 "sql_type" => Ok(StructAttr::SqlType(name, parse_eq(input, SQL_TYPE_NOTE)?)),
248 "treat_none_as_default_value" => Ok(StructAttr::TreatNoneAsDefaultValue(
249 name,
250 parse_eq(input, TREAT_NONE_AS_DEFAULT_VALUE_NOTE)?,
251 )),
252 "treat_none_as_null" => Ok(StructAttr::TreatNoneAsNull(
253 name,
254 parse_eq(input, TREAT_NONE_AS_NULL_NOTE)?,
255 )),
256
257 "belongs_to" => Ok(StructAttr::BelongsTo(
258 name,
259 parse_paren(input, BELONGS_TO_NOTE)?,
260 )),
261 "mysql_type" => Ok(StructAttr::MysqlType(
262 name,
263 parse_paren(input, MYSQL_TYPE_NOTE)?,
264 )),
265 "sqlite_type" => Ok(StructAttr::SqliteType(
266 name,
267 parse_paren(input, SQLITE_TYPE_NOTE)?,
268 )),
269 "postgres_type" => Ok(StructAttr::PostgresType(
270 name,
271 parse_paren(input, POSTGRES_TYPE_NOTE)?,
272 )),
273 "primary_key" => Ok(StructAttr::PrimaryKey(
274 name,
275 parse_paren_list(input, "key1, key2", syn::Token![,])?,
276 )),
277 "check_for_backend" => Ok(StructAttr::CheckForBackend(
278 name,
279 parse_paren_list(input, CHECK_FOR_BACKEND_NOTE, syn::Token![,])?,
280 )),
281
282 _ => Err(unknown_attribute(
283 &name,
284 &[
285 "aggregate",
286 "not_sized",
287 "foreign_derive",
288 "table_name",
289 "sql_type",
290 "treat_none_as_default_value",
291 "treat_none_as_null",
292 "belongs_to",
293 "mysql_type",
294 "sqlite_type",
295 "postgres_type",
296 "primary_key",
297 "check_for_backend",
298 ],
299 )),
300 }
301 }
302}
303
304impl MySpanned for StructAttr {
305 fn span(&self) -> Span {
306 match self {
307 StructAttr::Aggregate(ident)
308 | StructAttr::NotSized(ident)
309 | StructAttr::ForeignDerive(ident)
310 | StructAttr::TableName(ident, _)
311 | StructAttr::SqlType(ident, _)
312 | StructAttr::TreatNoneAsDefaultValue(ident, _)
313 | StructAttr::TreatNoneAsNull(ident, _)
314 | StructAttr::BelongsTo(ident, _)
315 | StructAttr::MysqlType(ident, _)
316 | StructAttr::SqliteType(ident, _)
317 | StructAttr::PostgresType(ident, _)
318 | StructAttr::CheckForBackend(ident, _)
319 | StructAttr::PrimaryKey(ident, _) => ident.span(),
320 }
321 }
322}
323
324pub fn parse_attributes<T>(attrs: &[Attribute]) -> Result<Vec<AttributeSpanWrapper<T>>>
325where
326 T: Parse + ParseDeprecated + MySpanned,
327{
328 let mut out = Vec::new();
329 for attr in attrs {
330 if attr.meta.path().is_ident("diesel") {
331 let map = attr
332 .parse_args_with(Punctuated::<T, Comma>::parse_terminated)?
333 .into_iter()
334 .map(|a| AttributeSpanWrapper {
335 ident_span: a.span(),
336 item: a,
337 attribute_span: attr.meta.span(),
338 });
339 out.extend(map);
340 } else if cfg!(all(
341 not(feature = "without-deprecated"),
342 feature = "with-deprecated"
343 )) {
344 let path = attr.meta.path();
345 let ident = path.get_ident().map(|f| f.to_string());
346
347 if let "sql_type" | "column_name" | "table_name" | "changeset_options" | "primary_key"
348 | "belongs_to" | "sqlite_type" | "mysql_type" | "postgres" =
349 ident.as_deref().unwrap_or_default()
350 {
351 let m = &attr.meta;
352 let ts = quote::quote!(#m).into();
353 let value = syn::parse::Parser::parse(T::parse_deprecated, ts)?;
354
355 if let Some(value) = value {
356 out.push(AttributeSpanWrapper {
357 ident_span: value.span(),
358 item: value,
359 attribute_span: attr.meta.span(),
360 });
361 }
362 }
363 }
364 }
365 Ok(out)
366}