1use proc_macro2::Span;
2use proc_macro2::TokenStream;
3use quote::format_ident;
4use quote::quote;
5use quote::ToTokens;
6use quote::TokenStreamExt;
7use std::iter;
8use syn::parse::{Parse, ParseStream, Result};
9use syn::punctuated::Pair;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{
13 parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, ImplGenerics, LitStr,
14 PathArguments, Token, Type, TypeGenerics,
15};
16use syn::{GenericParam, Meta};
17use syn::{LitBool, Path};
18use syn::{LitInt, MetaNameValue};
19
20use crate::attrs::{AttributeSpanWrapper, MySpanned};
21use crate::util::parse_eq;
22
23const VARIADIC_VARIANTS_DEFAULT: usize = 2;
24const VARIADIC_ARG_COUNT_ENV: Option<&str> = option_env!("DIESEL_VARIADIC_FUNCTION_ARGS");
25
26pub(crate) fn expand(
27 input: Vec<SqlFunctionDecl>,
28 legacy_helper_type_and_module: bool,
29 generate_return_type_helpers: bool,
30) -> TokenStream {
31 let mut result = TokenStream::new();
32 let mut return_type_helper_module_paths = vec![];
33
34 for decl in input {
35 let expanded = expand_one(
36 decl,
37 legacy_helper_type_and_module,
38 generate_return_type_helpers,
39 );
40 let expanded = match expanded {
41 Err(err) => err.into_compile_error(),
42 Ok(expanded) => {
43 if let Some(return_type_helper_module_path) =
44 expanded.return_type_helper_module_path
45 {
46 return_type_helper_module_paths.push(return_type_helper_module_path);
47 }
48
49 expanded.tokens
50 }
51 };
52
53 result.append_all(expanded.into_iter());
54 }
55
56 if !generate_return_type_helpers {
57 return result;
58 }
59
60 quote! {
61 #result
62
63 #[allow(unused_imports)]
64 #[doc(hidden)]
65 mod return_type_helpers {
66 #(
67 #[doc(inline)]
68 pub use super:: #return_type_helper_module_paths ::*;
69 )*
70 }
71 }
72}
73
74struct ExpandedSqlFunction {
75 tokens: TokenStream,
76 return_type_helper_module_path: Option<Path>,
77}
78
79fn expand_one(
80 mut input: SqlFunctionDecl,
81 legacy_helper_type_and_module: bool,
82 generate_return_type_helpers: bool,
83) -> syn::Result<ExpandedSqlFunction> {
84 let attributes = &mut input.attributes;
85
86 let variadic_argument_count = attributes.iter().find_map(|attr| {
87 if let SqlFunctionAttribute::Variadic(_, c) = &attr.item {
88 Some((c.base10_parse(), c.span()))
89 } else {
90 None
91 }
92 });
93
94 let Some((variadic_argument_count, variadic_span)) = variadic_argument_count else {
95 let sql_name = parse_sql_name_attr(&mut input);
96
97 return expand_nonvariadic(
98 input,
99 sql_name,
100 legacy_helper_type_and_module,
101 generate_return_type_helpers,
102 );
103 };
104
105 let variadic_argument_count = variadic_argument_count?;
106
107 let variadic_variants = VARIADIC_ARG_COUNT_ENV
108 .and_then(|arg_count| arg_count.parse::<usize>().ok())
109 .unwrap_or(VARIADIC_VARIANTS_DEFAULT);
110
111 let mut result = TokenStream::new();
112 let mut helper_type_modules = vec![];
113 for variant_no in 0..=variadic_variants {
114 let expanded = expand_variadic(
115 input.clone(),
116 legacy_helper_type_and_module,
117 generate_return_type_helpers,
118 variadic_argument_count,
119 variant_no,
120 variadic_span,
121 )?;
122
123 if let Some(return_type_helper_module_path) = expanded.return_type_helper_module_path {
124 helper_type_modules.push(return_type_helper_module_path);
125 }
126
127 result.append_all(expanded.tokens.into_iter());
128 }
129
130 if generate_return_type_helpers {
131 let return_types_module_name = Ident::new(
132 &format!("__{}_return_types", input.fn_name),
133 input.fn_name.span(),
134 );
135 let result = quote! {
136 #result
137
138 #[allow(unused_imports)]
139 #[doc(inline)]
140 mod #return_types_module_name {
141 #(
142 #[doc(inline)]
143 pub use super:: #helper_type_modules ::*;
144 )*
145 }
146 };
147
148 let return_type_helper_module_path = Some(parse_quote! {
149 #return_types_module_name
150 });
151
152 Ok(ExpandedSqlFunction {
153 tokens: result,
154 return_type_helper_module_path,
155 })
156 } else {
157 Ok(ExpandedSqlFunction {
158 tokens: result,
159 return_type_helper_module_path: None,
160 })
161 }
162}
163
164fn expand_variadic(
165 mut input: SqlFunctionDecl,
166 legacy_helper_type_and_module: bool,
167 generate_return_type_helpers: bool,
168 variadic_argument_count: usize,
169 variant_no: usize,
170 variadic_span: Span,
171) -> syn::Result<ExpandedSqlFunction> {
172 add_variadic_doc_comments(&mut input.attributes, &input.fn_name.to_string());
173
174 let sql_name = parse_sql_name_attr(&mut input);
175
176 input.fn_name = format_ident!("{}_{}", input.fn_name, variant_no);
177
178 let nonvariadic_args_count = input
179 .args
180 .len()
181 .checked_sub(variadic_argument_count)
182 .ok_or_else(|| {
183 syn::Error::new(
184 variadic_span,
185 "invalid variadic argument count: not enough function arguments",
186 )
187 })?;
188
189 let mut variadic_generic_indexes = vec![];
190 let mut arguments_with_generic_types = vec![];
191 for (arg_idx, arg) in input.args.iter().skip(nonvariadic_args_count).enumerate() {
192 let Type::Path(ty_path) = arg.ty.clone() else {
194 continue;
195 };
196 let Ok(ty_ident) = ty_path.path.require_ident() else {
197 continue;
198 };
199
200 let idx = input.generics.params.iter().position(|param| match param {
201 GenericParam::Type(type_param) => type_param.ident == *ty_ident,
202 _ => false,
203 });
204
205 if let Some(idx) = idx {
206 variadic_generic_indexes.push(idx);
207 arguments_with_generic_types.push(arg_idx);
208 }
209 }
210
211 let mut args: Vec<_> = input.args.into_pairs().collect();
212 let variadic_args = args.split_off(nonvariadic_args_count);
213 let nonvariadic_args = args;
214
215 let variadic_args: Vec<_> = iter::repeat_n(variadic_args, variant_no)
216 .enumerate()
217 .flat_map(|(arg_group_idx, arg_group)| {
218 let mut resulting_args = vec![];
219
220 for (arg_idx, arg) in arg_group.into_iter().enumerate() {
221 let mut arg = arg.into_value();
222
223 arg.name = format_ident!("{}_{}", arg.name, arg_group_idx + 1);
224
225 if arguments_with_generic_types.contains(&arg_idx) {
226 let Type::Path(mut ty_path) = arg.ty.clone() else {
227 unreachable!("This argument should have path type as checked earlier")
228 };
229 let Ok(ident) = ty_path.path.require_ident() else {
230 unreachable!("This argument should have ident type as checked earlier")
231 };
232
233 ty_path.path.segments[0].ident =
234 format_ident!("{}{}", ident, arg_group_idx + 1);
235 arg.ty = Type::Path(ty_path);
236 }
237
238 let pair = Pair::new(arg, Some(Token])));
239 resulting_args.push(pair);
240 }
241
242 resulting_args
243 })
244 .collect();
245
246 input.args = nonvariadic_args.into_iter().chain(variadic_args).collect();
247
248 let generics: Vec<_> = input.generics.params.into_pairs().collect();
249 input.generics.params = if variant_no == 0 {
250 generics
251 .into_iter()
252 .enumerate()
253 .filter_map(|(generic_idx, generic)| {
254 (!variadic_generic_indexes.contains(&generic_idx)).then_some(generic)
255 })
256 .collect()
257 } else {
258 iter::repeat_n(generics, variant_no)
259 .enumerate()
260 .flat_map(|(generic_group_idx, generic_group)| {
261 let mut resulting_generics = vec![];
262
263 for (generic_idx, generic) in generic_group.into_iter().enumerate() {
264 if !variadic_generic_indexes.contains(&generic_idx) {
265 if generic_group_idx == 0 {
266 resulting_generics.push(generic);
267 }
268
269 continue;
270 }
271
272 let mut generic = generic.into_value();
273
274 if let GenericParam::Type(type_param) = &mut generic {
275 type_param.ident =
276 format_ident!("{}{}", type_param.ident, generic_group_idx + 1);
277 } else {
278 unreachable!("This generic should be a type param as checked earlier")
279 }
280
281 let pair = Pair::new(generic, Some(Token])));
282 resulting_generics.push(pair);
283 }
284
285 resulting_generics
286 })
287 .collect()
288 };
289
290 expand_nonvariadic(
291 input,
292 sql_name,
293 legacy_helper_type_and_module,
294 generate_return_type_helpers,
295 )
296}
297
298fn add_variadic_doc_comments(
299 attributes: &mut Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
300 fn_name: &str,
301) {
302 let mut doc_comments_end = attributes.len()
303 - attributes
304 .iter()
305 .rev()
306 .position(|attr| matches!(&attr.item, SqlFunctionAttribute::Other(syn::Attribute{ meta: Meta::NameValue(MetaNameValue { path, .. }), ..}) if path.is_ident("doc")))
307 .unwrap_or(attributes.len());
308
309 let fn_family = format!("`{fn_name}_0`, `{fn_name}_1`, ... `{fn_name}_n`");
310
311 let doc_comments: Vec<Attribute> = parse_quote! {
312 #[doc = #fn_family]
319 #[doc(alias = #fn_name)]
335 };
336
337 for new_attribute in doc_comments {
338 attributes.insert(
339 doc_comments_end,
340 AttributeSpanWrapper {
341 item: SqlFunctionAttribute::Other(new_attribute),
342 attribute_span: Span::mixed_site(),
343 ident_span: Span::mixed_site(),
344 },
345 );
346 doc_comments_end += 1;
347 }
348}
349
350fn parse_sql_name_attr(input: &mut SqlFunctionDecl) -> String {
351 let result = input
352 .attributes
353 .iter()
354 .find_map(|attr| match attr.item {
355 SqlFunctionAttribute::SqlName(_, ref value) => Some(value.value()),
356 _ => None,
357 })
358 .unwrap_or_else(|| input.fn_name.to_string());
359
360 result
361}
362
363fn expand_nonvariadic(
364 input: SqlFunctionDecl,
365 sql_name: String,
366 legacy_helper_type_and_module: bool,
367 generate_return_type_helpers: bool,
368) -> syn::Result<ExpandedSqlFunction> {
369 let SqlFunctionDecl {
370 attributes,
371 fn_token,
372 fn_name,
373 mut generics,
374 args,
375 return_type,
376 } = input;
377
378 let is_aggregate = attributes
379 .iter()
380 .any(|attr| matches!(attr.item, SqlFunctionAttribute::Aggregate(..)));
381
382 let can_be_called_directly = !function_cannot_be_called_directly(&attributes);
383
384 let skip_return_type_helper = attributes
385 .iter()
386 .any(|attr| matches!(attr.item, SqlFunctionAttribute::SkipReturnTypeHelper(..)));
387
388 let window_attrs = attributes
389 .iter()
390 .filter(|a| matches!(a.item, SqlFunctionAttribute::Window { .. }))
391 .cloned()
392 .collect::<Vec<_>>();
393
394 let restrictions = attributes
395 .iter()
396 .find_map(|a| match a.item {
397 SqlFunctionAttribute::Restriction(ref r) => Some(r.clone()),
398 _ => None,
399 })
400 .unwrap_or_default();
401
402 let attributes = attributes
403 .into_iter()
404 .filter_map(|a| match a.item {
405 SqlFunctionAttribute::Other(a) => Some(a),
406 _ => None,
407 })
408 .collect::<Vec<_>>();
409
410 let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
411 .iter()
412 .map(|StrictFnArg { name, ty, .. }| (name, ty))
413 .unzip();
414 let arg_struct_assign = args.iter().map(
415 |StrictFnArg {
416 name, colon_token, ..
417 }| {
418 let name2 = name.clone();
419 quote!(#name #colon_token #name2.as_expression())
420 },
421 );
422
423 let type_args = &generics
424 .type_params()
425 .map(|type_param| type_param.ident.clone())
426 .collect::<Vec<_>>();
427
428 for StrictFnArg { name, .. } in &args {
429 generics.params.push(parse_quote!(#name));
430 }
431
432 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
433 let where_clause = where_clause
436 .map(|w| quote!(#w))
437 .unwrap_or_else(|| quote!(where));
438
439 let mut generics_with_internal = generics.clone();
440 generics_with_internal
441 .params
442 .push(parse_quote!(__DieselInternal));
443 let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
444
445 let sql_type;
446 let numeric_derive;
447
448 if arg_name.is_empty() {
449 sql_type = None;
450 numeric_derive = None;
452 } else {
453 sql_type = Some(quote!((#(#arg_name),*): Expression,));
454 numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
455 }
456
457 let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
458 let query_fragment_impl =
459 can_be_called_directly.then_some(restrictions.generate_all_queryfragment_impls(
460 generics.clone(),
461 &ty_generics,
462 arg_name,
463 &fn_name,
464 ));
465
466 let args_iter = args.iter();
467 let mut tokens = quote! {
468 use diesel::{self, QueryResult};
469 use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
470 use diesel::query_builder::{QueryFragment, AstPass};
471 use diesel::sql_types::*;
472 use diesel::internal::sql_functions::*;
473 use super::*;
474
475 #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
476 #numeric_derive
477 pub struct #fn_name #ty_generics {
478 #(pub(in super) #args_iter,)*
479 #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
480 }
481
482 #[doc = #helper_type_doc]
483 pub type HelperType #ty_generics = #fn_name <
484 #(#type_args,)*
485 #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
486 >;
487
488 impl #impl_generics Expression for #fn_name #ty_generics
489 #where_clause
490 #sql_type
491 {
492 type SqlType = #return_type;
493 }
494
495 impl #impl_generics_internal SelectableExpression<__DieselInternal>
497 for #fn_name #ty_generics
498 #where_clause
499 #(#arg_name: SelectableExpression<__DieselInternal>,)*
500 Self: AppearsOnTable<__DieselInternal>,
501 {
502 }
503
504 impl #impl_generics_internal AppearsOnTable<__DieselInternal>
506 for #fn_name #ty_generics
507 #where_clause
508 #(#arg_name: AppearsOnTable<__DieselInternal>,)*
509 Self: Expression,
510 {
511 }
512
513 impl #impl_generics_internal FunctionFragment<__DieselInternal>
514 for #fn_name #ty_generics
515 where
516 __DieselInternal: diesel::backend::Backend,
517 #(#arg_name: QueryFragment<__DieselInternal>,)*
518 {
519 const FUNCTION_NAME: &'static str = #sql_name;
520
521 #[allow(unused_assignments)]
522 fn walk_arguments<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> {
523 let mut needs_comma = false;
525 #(
526 if !self.#arg_name.is_noop(out.backend())? {
527 if needs_comma {
528 out.push_sql(", ");
529 }
530 self.#arg_name.walk_ast(out.reborrow())?;
531 needs_comma = true;
532 }
533 )*
534 Ok(())
535 }
536 }
537
538 #query_fragment_impl
539 };
540
541 let is_supported_on_sqlite = cfg!(feature = "sqlite")
542 && type_args.is_empty()
543 && is_sqlite_type(&return_type)
544 && arg_type.iter().all(|a| is_sqlite_type(a));
545
546 for window in &window_attrs {
547 tokens.extend(generate_window_function_tokens(
548 window,
549 generics.clone(),
550 &ty_generics,
551 &fn_name,
552 ));
553 }
554 if !window_attrs.is_empty() {
555 tokens.extend(quote::quote! {
556 impl #impl_generics IsWindowFunction for #fn_name #ty_generics {
557 type ArgTypes = (#(#arg_name,)*);
558 }
559 });
560 }
561
562 if is_aggregate {
563 tokens = generate_tokens_for_aggregate_functions(
564 tokens,
565 &impl_generics_internal,
566 &impl_generics,
567 &fn_name,
568 &ty_generics,
569 arg_name,
570 arg_type,
571 is_supported_on_sqlite,
572 !window_attrs.is_empty(),
573 &return_type,
574 &sql_name,
575 );
576 } else if window_attrs.is_empty() {
577 tokens = generate_tokens_for_non_aggregate_functions(
578 tokens,
579 &impl_generics_internal,
580 &fn_name,
581 &ty_generics,
582 arg_name,
583 arg_type,
584 is_supported_on_sqlite,
585 &return_type,
586 &sql_name,
587 );
588 }
589
590 let args_iter = args.iter();
591
592 let (outside_of_module_helper_type, return_type_path, internals_module_name) =
593 if legacy_helper_type_and_module {
594 (None, quote! { #fn_name::HelperType }, fn_name.clone())
595 } else {
596 let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
597 (
598 Some(quote! {
599 #[allow(non_camel_case_types, non_snake_case)]
600 #[doc = #helper_type_doc]
601 pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
602 #(#type_args,)*
603 #(<#arg_name as diesel::expression::AsExpression<#arg_type>>::Expression,)*
604 >;
605 }),
606 quote! { #fn_name },
607 internals_module_name,
608 )
609 };
610
611 let (return_type_helper_module, return_type_helper_module_path) =
612 if !generate_return_type_helpers || skip_return_type_helper {
613 (None, None)
614 } else {
615 let auto_derived_types = type_args
616 .iter()
617 .map(|type_arg| {
618 for arg in &args {
619 let Type::Path(path) = &arg.ty else {
620 continue;
621 };
622
623 let Some(path_ident) = path.path.get_ident() else {
624 continue;
625 };
626
627 if path_ident == type_arg {
628 return Ok(arg.name.clone());
629 }
630 }
631
632 Err(syn::Error::new(
633 type_arg.span(),
634 "cannot find argument corresponding to the generic",
635 ))
636 })
637 .collect::<Result<Vec<_>>>()?;
638
639 let arg_names_iter: Vec<_> = args.iter().map(|arg| arg.name.clone()).collect();
640
641 let return_type_module_name =
642 Ident::new(&format!("__{fn_name}_return_type"), fn_name.span());
643
644 let doc =
645 format!("Return type of the [`{fn_name}()`](fn@super::{fn_name}) SQL function.");
646 let return_type_helper_module = quote! {
647 #[allow(non_camel_case_types, non_snake_case, unused_imports)]
648 #[doc(inline)]
649 mod #return_type_module_name {
650 #[doc = #doc]
651 pub type #fn_name<
652 #(#arg_names_iter,)*
653 > = super::#fn_name<
654 #( <#auto_derived_types as diesel::expression::Expression>::SqlType, )*
655 #(#arg_names_iter,)*
656 >;
657 }
658 };
659
660 let module_path = parse_quote!(
661 #return_type_module_name
662 );
663
664 (Some(return_type_helper_module), Some(module_path))
665 };
666
667 let tokens = quote! {
668 #(#attributes)*
669 #[allow(non_camel_case_types)]
670 pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
671 -> #return_type_path #ty_generics
672 #where_clause
673 #(#arg_name: diesel::expression::AsExpression<#arg_type>,)*
674 {
675 #internals_module_name::#fn_name {
676 #(#arg_struct_assign,)*
677 #(#type_args: ::std::marker::PhantomData,)*
678 }
679 }
680
681 #outside_of_module_helper_type
682
683 #return_type_helper_module
684
685 #[doc(hidden)]
686 #[allow(non_camel_case_types, non_snake_case, unused_imports)]
687 pub(crate) mod #internals_module_name {
688 #tokens
689 }
690 };
691
692 Ok(ExpandedSqlFunction {
693 tokens,
694 return_type_helper_module_path,
695 })
696}
697
698fn generate_window_function_tokens(
699 window: &AttributeSpanWrapper<SqlFunctionAttribute>,
700 generics: Generics,
701 ty_generics: &TypeGenerics<'_>,
702 fn_name: &Ident,
703) -> TokenStream {
704 let SqlFunctionAttribute::Window {
705 restrictions,
706 require_order,
707 ..
708 } = &window.item
709 else {
710 unreachable!("We filtered for window attributes above")
711 };
712 restrictions.generate_all_window_fragment_impls(
713 generics,
714 ty_generics,
715 fn_name,
716 require_order.unwrap_or_default(),
717 )
718}
719
720#[allow(clippy::too_many_arguments)]
721fn generate_tokens_for_non_aggregate_functions(
722 mut tokens: TokenStream,
723 impl_generics_internal: &syn::ImplGenerics<'_>,
724 fn_name: &syn::Ident,
725 ty_generics: &syn::TypeGenerics<'_>,
726 arg_name: &[&syn::Ident],
727 arg_type: &[&syn::Type],
728 is_supported_on_sqlite: bool,
729 return_type: &syn::Type,
730 sql_name: &str,
731) -> TokenStream {
732 tokens = quote! {
733 #tokens
734
735 #[derive(ValidGrouping)]
736 pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
737
738 impl #impl_generics_internal ValidGrouping<__DieselInternal>
739 for #fn_name #ty_generics
740 where
741 __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
742 {
743 type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
744 }
745 };
746
747 if is_supported_on_sqlite && !arg_name.is_empty() {
748 tokens = quote! {
749 #tokens
750
751 use diesel::sqlite::{Sqlite, SqliteConnection};
752 use diesel::serialize::ToSql;
753 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
754
755 #[allow(dead_code)]
756 pub fn register_impl<F, Ret, #(#arg_name,)*>(
764 conn: &mut SqliteConnection,
765 f: F,
766 ) -> QueryResult<()>
767 where
768 F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
769 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
770 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
771 Ret: ToSql<#return_type, Sqlite>,
772 {
773 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
774 #sql_name,
775 true,
776 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
777 )
778 }
779
780 #[allow(dead_code)]
781 pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
790 conn: &mut SqliteConnection,
791 mut f: F,
792 ) -> QueryResult<()>
793 where
794 F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
795 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
796 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
797 Ret: ToSql<#return_type, Sqlite>,
798 {
799 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
800 #sql_name,
801 false,
802 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
803 )
804 }
805 };
806 }
807
808 if is_supported_on_sqlite && arg_name.is_empty() {
809 tokens = quote! {
810 #tokens
811
812 use diesel::sqlite::{Sqlite, SqliteConnection};
813 use diesel::serialize::ToSql;
814
815 #[allow(dead_code)]
816 pub fn register_impl<F, Ret>(
824 conn: &SqliteConnection,
825 f: F,
826 ) -> QueryResult<()>
827 where
828 F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
829 Ret: ToSql<#return_type, Sqlite>,
830 {
831 conn.register_noarg_sql_function::<#return_type, _, _>(
832 #sql_name,
833 true,
834 f,
835 )
836 }
837
838 #[allow(dead_code)]
839 pub fn register_nondeterministic_impl<F, Ret>(
848 conn: &SqliteConnection,
849 mut f: F,
850 ) -> QueryResult<()>
851 where
852 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
853 Ret: ToSql<#return_type, Sqlite>,
854 {
855 conn.register_noarg_sql_function::<#return_type, _, _>(
856 #sql_name,
857 false,
858 f,
859 )
860 }
861 };
862 }
863 tokens
864}
865
866#[allow(clippy::too_many_arguments)]
867fn generate_tokens_for_aggregate_functions(
868 mut tokens: TokenStream,
869 impl_generics_internal: &syn::ImplGenerics<'_>,
870 impl_generics: &syn::ImplGenerics<'_>,
871 fn_name: &syn::Ident,
872 ty_generics: &syn::TypeGenerics<'_>,
873 arg_name: &[&syn::Ident],
874 arg_type: &[&syn::Type],
875 is_supported_on_sqlite: bool,
876 is_window: bool,
877 return_type: &syn::Type,
878 sql_name: &str,
879) -> TokenStream {
880 tokens = quote! {
881 #tokens
882
883 impl #impl_generics_internal ValidGrouping<__DieselInternal>
884 for #fn_name #ty_generics
885 {
886 type IsAggregate = diesel::expression::is_aggregate::Yes;
887 }
888
889 impl #impl_generics IsAggregateFunction for #fn_name #ty_generics {}
890 };
891 if is_supported_on_sqlite && !is_window {
893 tokens = quote! {
894 #tokens
895
896 use diesel::sqlite::{Sqlite, SqliteConnection};
897 use diesel::serialize::ToSql;
898 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
899 use diesel::sqlite::SqliteAggregateFunction;
900 use diesel::sql_types::IntoNullable;
901 };
902
903 match arg_name.len() {
904 x if x > 1 => {
905 tokens = quote! {
906 #tokens
907
908 #[allow(dead_code)]
909 pub fn register_impl<A, #(#arg_name,)*>(
915 conn: &mut SqliteConnection
916 ) -> QueryResult<()>
917 where
918 A: SqliteAggregateFunction<(#(#arg_name,)*)>
919 + Send
920 + 'static
921 + ::std::panic::UnwindSafe
922 + ::std::panic::RefUnwindSafe,
923 A::Output: ToSql<#return_type, Sqlite>,
924 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
925 StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
926 ::std::panic::UnwindSafe,
927 {
928 conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
929 }
930 };
931 }
932 1 => {
933 let arg_name = arg_name[0];
934 let arg_type = arg_type[0];
935
936 tokens = quote! {
937 #tokens
938
939 #[allow(dead_code)]
940 pub fn register_impl<A, #arg_name>(
946 conn: &mut SqliteConnection
947 ) -> QueryResult<()>
948 where
949 A: SqliteAggregateFunction<#arg_name>
950 + Send
951 + 'static
952 + std::panic::UnwindSafe
953 + std::panic::RefUnwindSafe,
954 A::Output: ToSql<#return_type, Sqlite>,
955 #arg_name: FromSqlRow<#arg_type, Sqlite> +
956 StaticallySizedRow<#arg_type, Sqlite> +
957 ::std::panic::UnwindSafe,
958 {
959 conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
960 }
961 };
962 }
963 _ => (),
964 }
965 }
966 tokens
967}
968
969fn function_cannot_be_called_directly(
970 attributes: &[AttributeSpanWrapper<SqlFunctionAttribute>],
971) -> bool {
972 let mut has_aggregate = false;
973 let mut has_window = false;
974 for attr in attributes {
975 has_aggregate = has_aggregate || matches!(attr.item, SqlFunctionAttribute::Aggregate(..));
976 has_window = has_window || matches!(attr.item, SqlFunctionAttribute::Window { .. });
977 }
978 has_window && !has_aggregate
979}
980
981pub(crate) struct ExternSqlBlock {
982 pub(crate) function_decls: Vec<SqlFunctionDecl>,
983}
984
985impl Parse for ExternSqlBlock {
986 fn parse(input: ParseStream) -> Result<Self> {
987 let mut error = None::<syn::Error>;
988
989 let mut combine_error = |e: syn::Error| {
990 error = Some(
991 error
992 .take()
993 .map(|mut o| {
994 o.combine(e.clone());
995 o
996 })
997 .unwrap_or(e),
998 )
999 };
1000
1001 let block = syn::ItemForeignMod::parse(input)?;
1002 if block.abi.name.as_ref().map(|n| n.value()) != Some("SQL".into()) {
1003 return Err(syn::Error::new(block.abi.span(), "expect `SQL` as ABI"));
1004 }
1005 if block.unsafety.is_some() {
1006 return Err(syn::Error::new(
1007 block.unsafety.unwrap().span(),
1008 "expect `SQL` function blocks to be safe",
1009 ));
1010 }
1011
1012 let parsed_block_attrs = parse_attributes(&mut combine_error, block.attrs);
1013
1014 let item_count = block.items.len();
1015 let function_decls_input = block
1016 .items
1017 .into_iter()
1018 .map(|i| syn::parse2::<SqlFunctionDecl>(quote! { #i }));
1019
1020 let mut function_decls = Vec::with_capacity(item_count);
1021 for decl in function_decls_input {
1022 match decl {
1023 Ok(mut decl) => {
1024 decl.attributes = merge_attributes(&parsed_block_attrs, decl.attributes);
1025 function_decls.push(decl)
1026 }
1027 Err(e) => {
1028 error = Some(
1029 error
1030 .take()
1031 .map(|mut o| {
1032 o.combine(e.clone());
1033 o
1034 })
1035 .unwrap_or(e),
1036 );
1037 }
1038 }
1039 }
1040
1041 error
1042 .map(Err)
1043 .unwrap_or(Ok(ExternSqlBlock { function_decls }))
1044 }
1045}
1046
1047fn merge_attributes(
1048 parsed_block_attrs: &[AttributeSpanWrapper<SqlFunctionAttribute>],
1049 mut attributes: Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
1050) -> Vec<AttributeSpanWrapper<SqlFunctionAttribute>> {
1051 for attr in parsed_block_attrs {
1052 if attributes.iter().all(|a| match (&a.item, &attr.item) {
1053 (SqlFunctionAttribute::Aggregate(_), SqlFunctionAttribute::Aggregate(_)) => todo!(),
1054 (SqlFunctionAttribute::Window { .. }, SqlFunctionAttribute::Window { .. })
1055 | (SqlFunctionAttribute::SqlName(_, _), SqlFunctionAttribute::SqlName(_, _))
1056 | (SqlFunctionAttribute::Restriction(_), SqlFunctionAttribute::Restriction(_))
1057 | (SqlFunctionAttribute::Variadic(_, _), SqlFunctionAttribute::Variadic(_, _))
1058 | (
1059 SqlFunctionAttribute::SkipReturnTypeHelper(_),
1060 SqlFunctionAttribute::SkipReturnTypeHelper(_),
1061 ) => false,
1062 _ => true,
1063 }) {
1064 attributes.push(attr.clone());
1065 }
1066 }
1067 attributes
1068}
1069
1070#[derive(Clone)]
1071pub(crate) struct SqlFunctionDecl {
1072 attributes: Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
1073 fn_token: Token![fn],
1074 fn_name: Ident,
1075 generics: Generics,
1076 args: Punctuated<StrictFnArg, Token![,]>,
1077 return_type: Type,
1078}
1079
1080impl Parse for SqlFunctionDecl {
1081 fn parse(input: ParseStream) -> Result<Self> {
1082 let mut error = None::<syn::Error>;
1083 let mut combine_error = |e: syn::Error| {
1084 error = Some(
1085 error
1086 .take()
1087 .map(|mut o| {
1088 o.combine(e.clone());
1089 o
1090 })
1091 .unwrap_or(e),
1092 )
1093 };
1094
1095 let attributes = Attribute::parse_outer(input).unwrap_or_else(|e| {
1096 combine_error(e);
1097 Vec::new()
1098 });
1099 let attributes_collected = parse_attributes(&mut combine_error, attributes);
1100
1101 let fn_token: Token![fn] = input.parse().unwrap_or_else(|e| {
1102 combine_error(e);
1103 Default::default()
1104 });
1105 let fn_name = Ident::parse(input).unwrap_or_else(|e| {
1106 combine_error(e);
1107 Ident::new("dummy", Span::call_site())
1108 });
1109 let generics = Generics::parse(input).unwrap_or_else(|e| {
1110 combine_error(e);
1111 Generics {
1112 lt_token: None,
1113 params: Punctuated::new(),
1114 gt_token: None,
1115 where_clause: None,
1116 }
1117 });
1118 let args;
1119 let _paren = parenthesized!(args in input);
1120 let args = args
1121 .parse_terminated(StrictFnArg::parse, Token![,])
1122 .unwrap_or_else(|e| {
1123 combine_error(e);
1124 Punctuated::new()
1125 });
1126 let rarrow = Option::<Token![->]>::parse(input).unwrap_or_else(|e| {
1127 combine_error(e);
1128 None
1129 });
1130 let return_type = if rarrow.is_some() {
1131 Type::parse(input).unwrap_or_else(|e| {
1132 combine_error(e);
1133 Type::Never(syn::TypeNever {
1134 bang_token: Default::default(),
1135 })
1136 })
1137 } else {
1138 parse_quote!(diesel::expression::expression_types::NotSelectable)
1139 };
1140 let _semi = Option::<Token![;]>::parse(input).unwrap_or_else(|e| {
1141 combine_error(e);
1142 None
1143 });
1144
1145 error.map(Err).unwrap_or(Ok(Self {
1146 attributes: attributes_collected,
1147 fn_token,
1148 fn_name,
1149 generics,
1150 args,
1151 return_type,
1152 }))
1153 }
1154}
1155
1156fn parse_attribute(
1157 attr: syn::Attribute,
1158) -> Result<Option<AttributeSpanWrapper<SqlFunctionAttribute>>> {
1159 match &attr.meta {
1160 syn::Meta::NameValue(syn::MetaNameValue {
1161 path,
1162 value:
1163 syn::Expr::Lit(syn::ExprLit {
1164 lit: syn::Lit::Str(sql_name),
1165 ..
1166 }),
1167 ..
1168 }) if path.is_ident("sql_name") => Ok(Some(AttributeSpanWrapper {
1169 attribute_span: attr.span(),
1170 ident_span: sql_name.span(),
1171 item: SqlFunctionAttribute::SqlName(path.require_ident()?.clone(), sql_name.clone()),
1172 })),
1173 syn::Meta::Path(path) if path.is_ident("aggregate") => Ok(Some(AttributeSpanWrapper {
1174 attribute_span: attr.span(),
1175 ident_span: path.span(),
1176 item: SqlFunctionAttribute::Aggregate(
1177 path.require_ident()
1178 .map_err(|e| {
1179 syn::Error::new(
1180 e.span(),
1181 format!("{e}, the correct format is `#[aggregate]`"),
1182 )
1183 })?
1184 .clone(),
1185 ),
1186 })),
1187 syn::Meta::Path(path) if path.is_ident("skip_return_type_helper") => {
1188 Ok(Some(AttributeSpanWrapper {
1189 ident_span: attr.span(),
1190 attribute_span: path.span(),
1191 item: SqlFunctionAttribute::SkipReturnTypeHelper(
1192 path.require_ident()
1193 .map_err(|e| {
1194 syn::Error::new(
1195 e.span(),
1196 format!("{e}, the correct format is `#[skip_return_type_helper]`"),
1197 )
1198 })?
1199 .clone(),
1200 ),
1201 }))
1202 }
1203 syn::Meta::Path(path) if path.is_ident("window") => Ok(Some(AttributeSpanWrapper {
1204 attribute_span: attr.span(),
1205 ident_span: path.span(),
1206 item: SqlFunctionAttribute::Window {
1207 ident: path
1208 .require_ident()
1209 .map_err(|e| {
1210 syn::Error::new(e.span(), format!("{e}, the correct format is `#[window]`"))
1211 })?
1212 .clone(),
1213 restrictions: BackendRestriction::None,
1214 require_order: None,
1215 },
1216 })),
1217 syn::Meta::List(syn::MetaList {
1218 path,
1219 delimiter: syn::MacroDelimiter::Paren(_),
1220 tokens,
1221 }) if path.is_ident("variadic") => {
1222 let count: syn::LitInt = syn::parse2(tokens.clone()).map_err(|e| {
1223 syn::Error::new(
1224 e.span(),
1225 format!("{e}, the correct format is `#[variadic(3)]`"),
1226 )
1227 })?;
1228 Ok(Some(AttributeSpanWrapper {
1229 item: SqlFunctionAttribute::Variadic(
1230 path.require_ident()
1231 .map_err(|e| {
1232 syn::Error::new(
1233 e.span(),
1234 format!("{e}, the correct format is `#[variadic(3)]`"),
1235 )
1236 })?
1237 .clone(),
1238 count.clone(),
1239 ),
1240 attribute_span: attr.span(),
1241 ident_span: path.require_ident()?.span(),
1242 }))
1243 }
1244 syn::Meta::NameValue(_) | syn::Meta::Path(_) => Ok(Some(AttributeSpanWrapper {
1245 attribute_span: attr.span(),
1246 ident_span: attr.span(),
1247 item: SqlFunctionAttribute::Other(attr),
1248 })),
1249 syn::Meta::List(_) => {
1250 let name = attr.meta.path().require_ident()?;
1251 let attribute_span = attr.meta.span();
1252 attr.clone()
1253 .parse_args_with(|input: &syn::parse::ParseBuffer| {
1254 SqlFunctionAttribute::parse_attr(
1255 name.clone(),
1256 input,
1257 attr.clone(),
1258 attribute_span,
1259 )
1260 })
1261 }
1262 }
1263}
1264
1265fn parse_attributes(
1266 combine_error: &mut impl FnMut(syn::Error),
1267 attributes: Vec<Attribute>,
1268) -> Vec<AttributeSpanWrapper<SqlFunctionAttribute>> {
1269 let attribute_count = attributes.len();
1270
1271 let attributes = attributes
1272 .into_iter()
1273 .filter_map(|attr| parse_attribute(attr).transpose());
1274
1275 let mut attributes_collected = Vec::with_capacity(attribute_count);
1276 for attr in attributes {
1277 match attr {
1278 Ok(attr) => attributes_collected.push(attr),
1279 Err(e) => {
1280 combine_error(e);
1281 }
1282 }
1283 }
1284 attributes_collected
1285}
1286
1287#[derive(Clone)]
1289struct StrictFnArg {
1290 name: Ident,
1291 colon_token: Token![:],
1292 ty: Type,
1293}
1294
1295impl Parse for StrictFnArg {
1296 fn parse(input: ParseStream) -> Result<Self> {
1297 let name = input.parse()?;
1298 let colon_token = input.parse()?;
1299 let ty = input.parse()?;
1300 Ok(Self {
1301 name,
1302 colon_token,
1303 ty,
1304 })
1305 }
1306}
1307
1308impl ToTokens for StrictFnArg {
1309 fn to_tokens(&self, tokens: &mut TokenStream) {
1310 self.name.to_tokens(tokens);
1311 self.colon_token.to_tokens(tokens);
1312 self.name.to_tokens(tokens);
1313 }
1314}
1315
1316fn is_sqlite_type(ty: &Type) -> bool {
1317 let last_segment = if let Type::Path(tp) = ty {
1318 if let Some(segment) = tp.path.segments.last() {
1319 segment
1320 } else {
1321 return false;
1322 }
1323 } else {
1324 return false;
1325 };
1326
1327 let ident = last_segment.ident.to_string();
1328 if ident == "Nullable" {
1329 if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
1330 if let Some(GenericArgument::Type(ty)) = ab.args.first() {
1331 return is_sqlite_type(ty);
1332 }
1333 }
1334 return false;
1335 }
1336
1337 [
1338 "BigInt",
1339 "Binary",
1340 "Bool",
1341 "Date",
1342 "Double",
1343 "Float",
1344 "Integer",
1345 "Numeric",
1346 "SmallInt",
1347 "Text",
1348 "Time",
1349 "Timestamp",
1350 ]
1351 .contains(&ident.as_str())
1352}
1353
1354#[derive(Default, Clone, Debug)]
1355enum BackendRestriction {
1356 #[default]
1357 None,
1358 SqlDialect(syn::Ident, syn::Ident, syn::Path),
1359 BackendBound(
1360 syn::Ident,
1361 syn::punctuated::Punctuated<syn::TypeParamBound, syn::Token![+]>,
1362 ),
1363 Backends(
1364 syn::Ident,
1365 syn::punctuated::Punctuated<syn::Path, syn::Token![,]>,
1366 ),
1367}
1368
1369impl BackendRestriction {
1370 fn parse_from(input: &syn::parse::ParseBuffer<'_>) -> Result<Self> {
1371 if input.is_empty() {
1372 return Ok(Self::None);
1373 }
1374 Self::parse(input)
1375 }
1376
1377 fn parse_backends(
1378 input: &syn::parse::ParseBuffer<'_>,
1379 name: Ident,
1380 ) -> Result<BackendRestriction> {
1381 let backends = Punctuated::parse_terminated(input)?;
1382 Ok(Self::Backends(name, backends))
1383 }
1384
1385 fn parse_sql_dialect(
1386 content: &syn::parse::ParseBuffer<'_>,
1387 name: Ident,
1388 ) -> Result<BackendRestriction> {
1389 let dialect = content.parse()?;
1390 let _del: syn::Token![,] = content.parse()?;
1391 let dialect_variant = content.parse()?;
1392
1393 Ok(Self::SqlDialect(name, dialect, dialect_variant))
1394 }
1395
1396 fn parse_backend_bounds(
1397 input: &syn::parse::ParseBuffer<'_>,
1398 name: Ident,
1399 ) -> Result<BackendRestriction> {
1400 let restrictions = Punctuated::parse_terminated(input)?;
1401 Ok(Self::BackendBound(name, restrictions))
1402 }
1403
1404 fn generate_all_window_fragment_impls(
1405 &self,
1406 mut generics: Generics,
1407 ty_generics: &TypeGenerics<'_>,
1408 fn_name: &syn::Ident,
1409 require_order: bool,
1410 ) -> TokenStream {
1411 generics.params.push(parse_quote!(__P));
1412 generics.params.push(parse_quote!(__O));
1413 generics.params.push(parse_quote!(__F));
1414 let order = if require_order {
1415 quote::quote! {
1416 diesel::internal::sql_functions::Order<__O, true>
1417 }
1418 } else {
1419 quote::quote! {__O}
1420 };
1421 match *self {
1422 BackendRestriction::None => {
1423 generics.params.push(parse_quote!(__DieselInternal));
1424 let (impl_generics, _, _) = generics.split_for_impl();
1425 Self::generate_window_fragment_impl(
1426 parse_quote!(__DieselInternal),
1427 Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)),
1428 &impl_generics,
1429 ty_generics,
1430 fn_name,
1431 None,
1432 &order,
1433 )
1434 }
1435 BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => {
1436 generics.params.push(parse_quote!(__DieselInternal));
1437 let (impl_generics, _, _) = generics.split_for_impl();
1438 let mut out = quote::quote! {
1439 impl #impl_generics WindowFunctionFragment<#fn_name #ty_generics, __DieselInternal>
1440 for OverClause<__P, #order, __F>
1441 where
1442 Self: WindowFunctionFragment<#fn_name #ty_generics, __DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>,
1443 __DieselInternal: diesel::backend::Backend,
1444 {
1445 }
1446
1447 };
1448 let specific_impl = Self::generate_window_fragment_impl(
1449 parse_quote!(__DieselInternal),
1450 Some(
1451 parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,),
1452 ),
1453 &impl_generics,
1454 ty_generics,
1455 fn_name,
1456 Some(dialect_type),
1457 &order,
1458 );
1459 out.extend(specific_impl);
1460 out
1461 }
1462 BackendRestriction::BackendBound(_, ref restriction) => {
1463 generics.params.push(parse_quote!(__DieselInternal));
1464 let (impl_generics, _, _) = generics.split_for_impl();
1465 Self::generate_window_fragment_impl(
1466 parse_quote!(__DieselInternal),
1467 Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)),
1468 &impl_generics,
1469 ty_generics,
1470 fn_name,
1471 None,
1472 &order,
1473 )
1474 }
1475 BackendRestriction::Backends(_, ref backends) => {
1476 let (impl_generics, _, _) = generics.split_for_impl();
1477 let backends = backends.iter().map(|b| {
1478 Self::generate_window_fragment_impl(
1479 quote! {#b},
1480 None,
1481 &impl_generics,
1482 ty_generics,
1483 fn_name,
1484 None,
1485 &order,
1486 )
1487 });
1488
1489 parse_quote!(#(#backends)*)
1490 }
1491 }
1492 }
1493
1494 fn generate_window_fragment_impl(
1495 backend: TokenStream,
1496 backend_bound: Option<proc_macro2::TokenStream>,
1497 impl_generics: &ImplGenerics<'_>,
1498 ty_generics: &TypeGenerics<'_>,
1499 fn_name: &syn::Ident,
1500 dialect: Option<&syn::Path>,
1501 order: &TokenStream,
1502 ) -> TokenStream {
1503 quote::quote! {
1504 impl #impl_generics WindowFunctionFragment<#fn_name #ty_generics, #backend, #dialect> for OverClause<__P, #order, __F>
1505 where #backend_bound
1506 {
1507
1508 }
1509 }
1510 }
1511
1512 fn generate_all_queryfragment_impls(
1513 &self,
1514 mut generics: Generics,
1515 ty_generics: &TypeGenerics<'_>,
1516 arg_name: &[&syn::Ident],
1517 fn_name: &syn::Ident,
1518 ) -> proc_macro2::TokenStream {
1519 match *self {
1520 BackendRestriction::None => {
1521 generics.params.push(parse_quote!(__DieselInternal));
1522 let (impl_generics, _, _) = generics.split_for_impl();
1523 Self::generate_queryfragment_impl(
1524 parse_quote!(__DieselInternal),
1525 Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)),
1526 &impl_generics,
1527 ty_generics,
1528 arg_name,
1529 fn_name,
1530 None,
1531 )
1532 }
1533 BackendRestriction::BackendBound(_, ref restriction) => {
1534 generics.params.push(parse_quote!(__DieselInternal));
1535 let (impl_generics, _, _) = generics.split_for_impl();
1536 Self::generate_queryfragment_impl(
1537 parse_quote!(__DieselInternal),
1538 Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)),
1539 &impl_generics,
1540 ty_generics,
1541 arg_name,
1542 fn_name,
1543 None,
1544 )
1545 }
1546 BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => {
1547 generics.params.push(parse_quote!(__DieselInternal));
1548 let (impl_generics, _, _) = generics.split_for_impl();
1549 let specific_impl = Self::generate_queryfragment_impl(
1550 parse_quote!(__DieselInternal),
1551 Some(
1552 parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,),
1553 ),
1554 &impl_generics,
1555 ty_generics,
1556 arg_name,
1557 fn_name,
1558 Some(dialect_type),
1559 );
1560 quote::quote! {
1561 impl #impl_generics QueryFragment<__DieselInternal>
1562 for #fn_name #ty_generics
1563 where
1564 Self: QueryFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>,
1565 __DieselInternal: diesel::backend::Backend,
1566 {
1567 fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> {
1568 <Self as QueryFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>>::walk_ast(self, out)
1569 }
1570
1571 }
1572
1573 #specific_impl
1574 }
1575 }
1576 BackendRestriction::Backends(_, ref backends) => {
1577 let (impl_generics, _, _) = generics.split_for_impl();
1578 let backends = backends.iter().map(|b| {
1579 Self::generate_queryfragment_impl(
1580 quote! {#b},
1581 None,
1582 &impl_generics,
1583 ty_generics,
1584 arg_name,
1585 fn_name,
1586 None,
1587 )
1588 });
1589
1590 parse_quote!(#(#backends)*)
1591 }
1592 }
1593 }
1594
1595 fn generate_queryfragment_impl(
1596 backend: proc_macro2::TokenStream,
1597 backend_bound: Option<proc_macro2::TokenStream>,
1598 impl_generics: &ImplGenerics<'_>,
1599 ty_generics: &TypeGenerics<'_>,
1600 arg_name: &[&syn::Ident],
1601 fn_name: &syn::Ident,
1602 dialect: Option<&syn::Path>,
1603 ) -> proc_macro2::TokenStream {
1604 quote::quote! {
1605 impl #impl_generics QueryFragment<#backend, #dialect>
1606 for #fn_name #ty_generics
1607 where
1608 #backend_bound
1609 #(#arg_name: QueryFragment<#backend>,)*
1610 {
1611 fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, #backend>) -> QueryResult<()>{
1612 out.push_sql(<Self as FunctionFragment<#backend>>::FUNCTION_NAME);
1613 out.push_sql("(");
1614 self.walk_arguments(out.reborrow())?;
1615 out.push_sql(")");
1616 Ok(())
1617 }
1618 }
1619 }
1620 }
1621}
1622
1623impl Parse for BackendRestriction {
1624 fn parse(input: ParseStream) -> Result<Self> {
1625 let name: syn::Ident = input.parse()?;
1626 let name_str = name.to_string();
1627 let content;
1628 parenthesized!(content in input);
1629 match &*name_str {
1630 "backends" => Self::parse_backends(&content, name),
1631 "dialect" => Self::parse_sql_dialect(&content, name),
1632 "backend_bounds" => Self::parse_backend_bounds(&content, name),
1633 _ => Err(syn::Error::new(
1634 name.span(),
1635 format!("unexpected option `{name_str}`"),
1636 )),
1637 }
1638 }
1639}
1640
1641#[derive(Debug, Clone)]
1642enum SqlFunctionAttribute {
1643 Aggregate(Ident),
1644 Window {
1645 ident: Ident,
1646 restrictions: BackendRestriction,
1647 require_order: Option<bool>,
1648 },
1649 SqlName(Ident, LitStr),
1650 Restriction(BackendRestriction),
1651 Variadic(Ident, LitInt),
1652 SkipReturnTypeHelper(Ident),
1653 Other(Attribute),
1654}
1655
1656impl MySpanned for SqlFunctionAttribute {
1657 fn span(&self) -> proc_macro2::Span {
1658 match self {
1659 SqlFunctionAttribute::Restriction(BackendRestriction::Backends(ref ident, ..))
1660 | SqlFunctionAttribute::Restriction(BackendRestriction::SqlDialect(ref ident, ..))
1661 | SqlFunctionAttribute::Restriction(BackendRestriction::BackendBound(ref ident, ..))
1662 | SqlFunctionAttribute::Aggregate(ref ident, ..)
1663 | SqlFunctionAttribute::Window { ref ident, .. }
1664 | SqlFunctionAttribute::Variadic(ref ident, ..)
1665 | SqlFunctionAttribute::SkipReturnTypeHelper(ref ident)
1666 | SqlFunctionAttribute::SqlName(ref ident, ..) => ident.span(),
1667 SqlFunctionAttribute::Restriction(BackendRestriction::None) => {
1668 unreachable!("We do not construct that")
1669 }
1670 SqlFunctionAttribute::Other(ref attribute) => attribute.span(),
1671 }
1672 }
1673}
1674
1675fn parse_require_order(input: &syn::parse::ParseBuffer<'_>) -> Result<bool> {
1676 let ident = input.parse::<Ident>()?;
1677 if ident == "require_order" {
1678 let _ = input.parse::<Token![=]>()?;
1679 let value = input.parse::<LitBool>()?;
1680 Ok(value.value)
1681 } else {
1682 Err(syn::Error::new(
1683 ident.span(),
1684 format!("Expected `require_order` but got `{ident}`"),
1685 ))
1686 }
1687}
1688
1689impl SqlFunctionAttribute {
1690 fn parse_attr(
1691 name: Ident,
1692 input: &syn::parse::ParseBuffer<'_>,
1693 attr: Attribute,
1694 attribute_span: proc_macro2::Span,
1695 ) -> Result<Option<AttributeSpanWrapper<Self>>> {
1696 if name == "cfg_attr" {
1699 let ident = input.parse::<Ident>()?;
1700 if ident != "feature" {
1701 return Err(syn::Error::new(
1702 ident.span(),
1703 format!(
1704 "only single feature `cfg_attr` attributes are supported. \
1705 Got `{ident}` but expected `feature = \"foo\"`"
1706 ),
1707 ));
1708 }
1709 let _ = input.parse::<Token![=]>()?;
1710 let feature = input.parse::<LitStr>()?;
1711 let feature_value = feature.value();
1712 let _ = input.parse::<Token![,]>()?;
1713 let ignore = match feature_value.as_str() {
1714 "postgres_backend" => !cfg!(feature = "postgres"),
1715 "sqlite" => !cfg!(feature = "sqlite"),
1716 "mysql_backend" => !cfg!(feature = "mysql"),
1717 feature => {
1718 return Err(syn::Error::new(
1719 feature.span(),
1720 format!(
1721 "only `mysql_backend`, `postgres_backend` and `sqlite` \
1722 are supported features, but got `{feature}`"
1723 ),
1724 ));
1725 }
1726 };
1727 let name = input.parse::<Ident>()?;
1728 let inner;
1729 let _paren = parenthesized!(inner in input);
1730 let ret = SqlFunctionAttribute::parse_attr(name, &inner, attr, attribute_span)?;
1731 if ignore {
1732 Ok(None)
1733 } else {
1734 Ok(ret)
1735 }
1736 } else {
1737 let name_str = name.to_string();
1738 let parsed_attr = match &*name_str {
1739 "window" => {
1740 let restrictions = if BackendRestriction::parse_from(&input.fork()).is_ok() {
1741 BackendRestriction::parse_from(input).map(Ok).ok()
1742 } else {
1743 None
1744 };
1745 if input.fork().parse::<Token![,]>().is_ok() {
1746 let _ = input.parse::<Token![,]>()?;
1747 }
1748 let require_order = if parse_require_order(&input.fork()).is_ok() {
1749 Some(parse_require_order(input)?)
1750 } else {
1751 None
1752 };
1753 if input.fork().parse::<Token![,]>().is_ok() {
1754 let _ = input.parse::<Token![,]>()?;
1755 }
1756 let restrictions =
1757 restrictions.unwrap_or_else(|| BackendRestriction::parse_from(input))?;
1758 Self::Window {
1759 ident: name,
1760 restrictions,
1761 require_order,
1762 }
1763 }
1764 "sql_name" => {
1765 parse_eq(input, "sql_name = \"SUM\"").map(|v| Self::SqlName(name, v))?
1766 }
1767 "backends" => {
1768 BackendRestriction::parse_backends(input, name).map(Self::Restriction)?
1769 }
1770 "dialect" => {
1771 BackendRestriction::parse_sql_dialect(input, name).map(Self::Restriction)?
1772 }
1773 "backend_bounds" => {
1774 BackendRestriction::parse_backend_bounds(input, name).map(Self::Restriction)?
1775 }
1776 "variadic" => Self::Variadic(name, input.parse()?),
1777 _ => {
1778 let _ = input.step(|cursor| {
1780 let mut rest = *cursor;
1781 while let Some((_, next)) = rest.token_tree() {
1782 rest = next;
1783 }
1784 Ok(((), rest))
1785 });
1786 SqlFunctionAttribute::Other(attr)
1787 }
1788 };
1789 Ok(Some(AttributeSpanWrapper {
1790 ident_span: parsed_attr.span(),
1791 item: parsed_attr,
1792 attribute_span,
1793 }))
1794 }
1795 }
1796}
1797
1798#[derive(Default)]
1799pub(crate) struct DeclareSqlFunctionArgs {
1800 pub(crate) generate_return_type_helpers: bool,
1801}
1802
1803impl DeclareSqlFunctionArgs {
1804 pub(crate) fn parse_from_macro_input(input: TokenStream) -> syn::Result<Self> {
1805 if input.is_empty() {
1806 return Ok(Self::default());
1807 }
1808 let input_span = input.span();
1809 let parsed: syn::MetaNameValue = syn::parse2(input).map_err(|e| {
1810 let span = e.span();
1811 syn::Error::new(
1812 span,
1813 format!("{e}, the correct format is `generate_return_type_helpers = true/false`"),
1814 )
1815 })?;
1816 match parsed {
1817 syn::MetaNameValue {
1818 path,
1819 value:
1820 syn::Expr::Lit(syn::ExprLit {
1821 lit: syn::Lit::Bool(b),
1822 ..
1823 }),
1824 ..
1825 } if path.is_ident("generate_return_type_helpers") => Ok(Self {
1826 generate_return_type_helpers: b.value,
1827 }),
1828 _ => Err(syn::Error::new(input_span, "Invalid config")),
1829 }
1830 }
1831}