1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4use syn::parse::{Parse, ParseStream, Result};
5use syn::punctuated::Punctuated;
6use syn::{
7 parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
8 PathArguments, Token, Type,
9};
10
11pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool) -> TokenStream {
12 let SqlFunctionDecl {
13 mut attributes,
14 fn_token,
15 fn_name,
16 mut generics,
17 args,
18 return_type,
19 } = input;
20
21 let sql_name = attributes
22 .iter()
23 .find(|attr| attr.meta.path().is_ident("sql_name"))
24 .and_then(|attr| {
25 if let Meta::NameValue(MetaNameValue {
26 value:
27 syn::Expr::Lit(syn::ExprLit {
28 lit: syn::Lit::Str(ref lit),
29 ..
30 }),
31 ..
32 }) = attr.meta
33 {
34 Some(lit.value())
35 } else {
36 None
37 }
38 })
39 .unwrap_or_else(|| fn_name.to_string());
40
41 let is_aggregate = attributes
42 .iter()
43 .any(|attr| attr.meta.path().is_ident("aggregate"));
44
45 attributes.retain(|attr| {
46 !attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
47 });
48
49 let args = &args;
50 let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
51 .iter()
52 .map(|StrictFnArg { name, ty, .. }| (name, ty))
53 .unzip();
54 let arg_struct_assign = args.iter().map(
55 |StrictFnArg {
56 name, colon_token, ..
57 }| {
58 let name2 = name.clone();
59 quote!(#name #colon_token #name2.as_expression())
60 },
61 );
62
63 let type_args = &generics
64 .type_params()
65 .map(|type_param| type_param.ident.clone())
66 .collect::<Vec<_>>();
67
68 for StrictFnArg { name, .. } in args {
69 generics.params.push(parse_quote!(#name));
70 }
71
72 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
73 let where_clause = where_clause
76 .map(|w| quote!(#w))
77 .unwrap_or_else(|| quote!(where));
78
79 let mut generics_with_internal = generics.clone();
80 generics_with_internal
81 .params
82 .push(parse_quote!(__DieselInternal));
83 let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
84
85 let sql_type;
86 let numeric_derive;
87
88 if arg_name.is_empty() {
89 sql_type = None;
90 numeric_derive = None;
92 } else {
93 sql_type = Some(quote!((#(#arg_name),*): Expression,));
94 numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
95 }
96
97 let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
98
99 let args_iter = args.iter();
100 let mut tokens = quote! {
101 use diesel::{self, QueryResult};
102 use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
103 use diesel::query_builder::{QueryFragment, AstPass};
104 use diesel::sql_types::*;
105 use super::*;
106
107 #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
108 #numeric_derive
109 pub struct #fn_name #ty_generics {
110 #(pub(in super) #args_iter,)*
111 #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
112 }
113
114 #[doc = #helper_type_doc]
115 pub type HelperType #ty_generics = #fn_name <
116 #(#type_args,)*
117 #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
118 >;
119
120 impl #impl_generics Expression for #fn_name #ty_generics
121 #where_clause
122 #sql_type
123 {
124 type SqlType = #return_type;
125 }
126
127 impl #impl_generics_internal SelectableExpression<__DieselInternal>
129 for #fn_name #ty_generics
130 #where_clause
131 #(#arg_name: SelectableExpression<__DieselInternal>,)*
132 Self: AppearsOnTable<__DieselInternal>,
133 {
134 }
135
136 impl #impl_generics_internal AppearsOnTable<__DieselInternal>
138 for #fn_name #ty_generics
139 #where_clause
140 #(#arg_name: AppearsOnTable<__DieselInternal>,)*
141 Self: Expression,
142 {
143 }
144
145 impl #impl_generics_internal QueryFragment<__DieselInternal>
147 for #fn_name #ty_generics
148 where
149 __DieselInternal: diesel::backend::Backend,
150 #(#arg_name: QueryFragment<__DieselInternal>,)*
151 {
152 #[allow(unused_assignments)]
153 fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
154 out.push_sql(concat!(#sql_name, "("));
155 let mut needs_comma = false;
157 #(
158 if !self.#arg_name.is_noop(out.backend())? {
159 if needs_comma {
160 out.push_sql(", ");
161 }
162 self.#arg_name.walk_ast(out.reborrow())?;
163 needs_comma = true;
164 }
165 )*
166 out.push_sql(")");
167 Ok(())
168 }
169 }
170 };
171
172 let is_supported_on_sqlite = cfg!(feature = "sqlite")
173 && type_args.is_empty()
174 && is_sqlite_type(&return_type)
175 && arg_type.iter().all(|a| is_sqlite_type(a));
176
177 if is_aggregate {
178 tokens = quote! {
179 #tokens
180
181 impl #impl_generics_internal ValidGrouping<__DieselInternal>
182 for #fn_name #ty_generics
183 {
184 type IsAggregate = diesel::expression::is_aggregate::Yes;
185 }
186 };
187 if is_supported_on_sqlite {
188 tokens = quote! {
189 #tokens
190
191 use diesel::sqlite::{Sqlite, SqliteConnection};
192 use diesel::serialize::ToSql;
193 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
194 use diesel::sqlite::SqliteAggregateFunction;
195 use diesel::sql_types::IntoNullable;
196 };
197
198 match arg_name.len() {
199 x if x > 1 => {
200 tokens = quote! {
201 #tokens
202
203 #[allow(dead_code)]
204 pub fn register_impl<A, #(#arg_name,)*>(
210 conn: &mut SqliteConnection
211 ) -> QueryResult<()>
212 where
213 A: SqliteAggregateFunction<(#(#arg_name,)*)>
214 + Send
215 + 'static
216 + ::std::panic::UnwindSafe
217 + ::std::panic::RefUnwindSafe,
218 A::Output: ToSql<#return_type, Sqlite>,
219 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
220 StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
221 ::std::panic::UnwindSafe,
222 {
223 conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
224 }
225 };
226 }
227 1 => {
228 let arg_name = arg_name[0];
229 let arg_type = arg_type[0];
230
231 tokens = quote! {
232 #tokens
233
234 #[allow(dead_code)]
235 pub fn register_impl<A, #arg_name>(
241 conn: &mut SqliteConnection
242 ) -> QueryResult<()>
243 where
244 A: SqliteAggregateFunction<#arg_name>
245 + Send
246 + 'static
247 + std::panic::UnwindSafe
248 + std::panic::RefUnwindSafe,
249 A::Output: ToSql<#return_type, Sqlite>,
250 #arg_name: FromSqlRow<#arg_type, Sqlite> +
251 StaticallySizedRow<#arg_type, Sqlite> +
252 ::std::panic::UnwindSafe,
253 {
254 conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
255 }
256 };
257 }
258 _ => (),
259 }
260 }
261 } else {
262 tokens = quote! {
263 #tokens
264
265 #[derive(ValidGrouping)]
266 pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
267
268 impl #impl_generics_internal ValidGrouping<__DieselInternal>
269 for #fn_name #ty_generics
270 where
271 __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
272 {
273 type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
274 }
275 };
276
277 if is_supported_on_sqlite && !arg_name.is_empty() {
278 tokens = quote! {
279 #tokens
280
281 use diesel::sqlite::{Sqlite, SqliteConnection};
282 use diesel::serialize::ToSql;
283 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
284
285 #[allow(dead_code)]
286 pub fn register_impl<F, Ret, #(#arg_name,)*>(
294 conn: &mut SqliteConnection,
295 f: F,
296 ) -> QueryResult<()>
297 where
298 F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
299 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
300 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
301 Ret: ToSql<#return_type, Sqlite>,
302 {
303 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
304 #sql_name,
305 true,
306 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
307 )
308 }
309
310 #[allow(dead_code)]
311 pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
320 conn: &mut SqliteConnection,
321 mut f: F,
322 ) -> QueryResult<()>
323 where
324 F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
325 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
326 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
327 Ret: ToSql<#return_type, Sqlite>,
328 {
329 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
330 #sql_name,
331 false,
332 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
333 )
334 }
335 };
336 }
337
338 if is_supported_on_sqlite && arg_name.is_empty() {
339 tokens = quote! {
340 #tokens
341
342 use diesel::sqlite::{Sqlite, SqliteConnection};
343 use diesel::serialize::ToSql;
344
345 #[allow(dead_code)]
346 pub fn register_impl<F, Ret>(
354 conn: &SqliteConnection,
355 f: F,
356 ) -> QueryResult<()>
357 where
358 F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
359 Ret: ToSql<#return_type, Sqlite>,
360 {
361 conn.register_noarg_sql_function::<#return_type, _, _>(
362 #sql_name,
363 true,
364 f,
365 )
366 }
367
368 #[allow(dead_code)]
369 pub fn register_nondeterministic_impl<F, Ret>(
378 conn: &SqliteConnection,
379 mut f: F,
380 ) -> QueryResult<()>
381 where
382 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
383 Ret: ToSql<#return_type, Sqlite>,
384 {
385 conn.register_noarg_sql_function::<#return_type, _, _>(
386 #sql_name,
387 false,
388 f,
389 )
390 }
391 };
392 }
393 }
394
395 let args_iter = args.iter();
396
397 let (outside_of_module_helper_type, return_type_path, internals_module_name) =
398 if legacy_helper_type_and_module {
399 (None, quote! { #fn_name::HelperType }, fn_name.clone())
400 } else {
401 let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
402 (
403 Some(quote! {
404 #[allow(non_camel_case_types, non_snake_case)]
405 #[doc = #helper_type_doc]
406 pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
407 #(#type_args,)*
408 #(<#arg_name as ::diesel::expression::AsExpression<#arg_type>>::Expression,)*
409 >;
410 }),
411 quote! { #fn_name },
412 internals_module_name,
413 )
414 };
415
416 quote! {
417 #(#attributes)*
418 #[allow(non_camel_case_types)]
419 pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
420 -> #return_type_path #ty_generics
421 #where_clause
422 #(#arg_name: ::diesel::expression::AsExpression<#arg_type>,)*
423 {
424 #internals_module_name::#fn_name {
425 #(#arg_struct_assign,)*
426 #(#type_args: ::std::marker::PhantomData,)*
427 }
428 }
429
430 #outside_of_module_helper_type
431
432 #[doc(hidden)]
433 #[allow(non_camel_case_types, non_snake_case, unused_imports)]
434 pub(crate) mod #internals_module_name {
435 #tokens
436 }
437 }
438}
439
440pub(crate) struct SqlFunctionDecl {
441 attributes: Vec<Attribute>,
442 fn_token: Token![fn],
443 fn_name: Ident,
444 generics: Generics,
445 args: Punctuated<StrictFnArg, Token![,]>,
446 return_type: Type,
447}
448
449impl Parse for SqlFunctionDecl {
450 fn parse(input: ParseStream) -> Result<Self> {
451 let attributes = Attribute::parse_outer(input)?;
452 let fn_token: Token![fn] = input.parse()?;
453 let fn_name = Ident::parse(input)?;
454 let generics = Generics::parse(input)?;
455 let args;
456 let _paren = parenthesized!(args in input);
457 let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
458 let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
459 Type::parse(input)?
460 } else {
461 parse_quote!(diesel::expression::expression_types::NotSelectable)
462 };
463 let _semi = Option::<Token![;]>::parse(input)?;
464
465 Ok(Self {
466 attributes,
467 fn_token,
468 fn_name,
469 generics,
470 args,
471 return_type,
472 })
473 }
474}
475
476struct StrictFnArg {
478 name: Ident,
479 colon_token: Token![:],
480 ty: Type,
481}
482
483impl Parse for StrictFnArg {
484 fn parse(input: ParseStream) -> Result<Self> {
485 let name = input.parse()?;
486 let colon_token = input.parse()?;
487 let ty = input.parse()?;
488 Ok(Self {
489 name,
490 colon_token,
491 ty,
492 })
493 }
494}
495
496impl ToTokens for StrictFnArg {
497 fn to_tokens(&self, tokens: &mut TokenStream) {
498 self.name.to_tokens(tokens);
499 self.colon_token.to_tokens(tokens);
500 self.name.to_tokens(tokens);
501 }
502}
503
504fn is_sqlite_type(ty: &Type) -> bool {
505 let last_segment = if let Type::Path(tp) = ty {
506 if let Some(segment) = tp.path.segments.last() {
507 segment
508 } else {
509 return false;
510 }
511 } else {
512 return false;
513 };
514
515 let ident = last_segment.ident.to_string();
516 if ident == "Nullable" {
517 if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
518 if let Some(GenericArgument::Type(ty)) = ab.args.first() {
519 return is_sqlite_type(ty);
520 }
521 }
522 return false;
523 }
524
525 [
526 "BigInt",
527 "Binary",
528 "Bool",
529 "Date",
530 "Double",
531 "Float",
532 "Integer",
533 "Numeric",
534 "SmallInt",
535 "Text",
536 "Time",
537 "Timestamp",
538 ]
539 .contains(&ident.as_str())
540}