diesel_derives/
sql_function.rs

1use proc_macro2::Span;
2use proc_macro2::TokenStream;
3use quote::format_ident;
4use quote::quote;
5use quote::ToTokens;
6use quote::TokenStreamExt;
7use std::iter;
8use syn::parse::{Parse, ParseStream, Result};
9use syn::punctuated::Pair;
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{
13    parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, ImplGenerics, LitStr,
14    PathArguments, Token, Type, TypeGenerics,
15};
16use syn::{GenericParam, Meta};
17use syn::{LitBool, Path};
18use syn::{LitInt, MetaNameValue};
19
20use crate::attrs::{AttributeSpanWrapper, MySpanned};
21use crate::util::parse_eq;
22
23const VARIADIC_VARIANTS_DEFAULT: usize = 2;
24const VARIADIC_ARG_COUNT_ENV: Option<&str> = option_env!("DIESEL_VARIADIC_FUNCTION_ARGS");
25
26pub(crate) fn expand(
27    input: Vec<SqlFunctionDecl>,
28    legacy_helper_type_and_module: bool,
29    generate_return_type_helpers: bool,
30) -> TokenStream {
31    let mut result = TokenStream::new();
32    let mut return_type_helper_module_paths = vec![];
33
34    for decl in input {
35        let expanded = expand_one(
36            decl,
37            legacy_helper_type_and_module,
38            generate_return_type_helpers,
39        );
40        let expanded = match expanded {
41            Err(err) => err.into_compile_error(),
42            Ok(expanded) => {
43                if let Some(return_type_helper_module_path) =
44                    expanded.return_type_helper_module_path
45                {
46                    return_type_helper_module_paths.push(return_type_helper_module_path);
47                }
48
49                expanded.tokens
50            }
51        };
52
53        result.append_all(expanded.into_iter());
54    }
55
56    if !generate_return_type_helpers {
57        return result;
58    }
59
60    quote! {
61        #result
62
63        #[allow(unused_imports)]
64        #[doc(hidden)]
65        mod return_type_helpers {
66            #(
67                #[doc(inline)]
68                pub use super:: #return_type_helper_module_paths ::*;
69            )*
70        }
71    }
72}
73
74struct ExpandedSqlFunction {
75    tokens: TokenStream,
76    return_type_helper_module_path: Option<Path>,
77}
78
79fn expand_one(
80    mut input: SqlFunctionDecl,
81    legacy_helper_type_and_module: bool,
82    generate_return_type_helpers: bool,
83) -> syn::Result<ExpandedSqlFunction> {
84    let attributes = &mut input.attributes;
85
86    let variadic_argument_count = attributes.iter().find_map(|attr| {
87        if let SqlFunctionAttribute::Variadic(_, c) = &attr.item {
88            Some((c.base10_parse(), c.span()))
89        } else {
90            None
91        }
92    });
93
94    let Some((variadic_argument_count, variadic_span)) = variadic_argument_count else {
95        let sql_name = parse_sql_name_attr(&mut input);
96
97        return expand_nonvariadic(
98            input,
99            sql_name,
100            legacy_helper_type_and_module,
101            generate_return_type_helpers,
102        );
103    };
104
105    let variadic_argument_count = variadic_argument_count?;
106
107    let variadic_variants = VARIADIC_ARG_COUNT_ENV
108        .and_then(|arg_count| arg_count.parse::<usize>().ok())
109        .unwrap_or(VARIADIC_VARIANTS_DEFAULT);
110
111    let mut result = TokenStream::new();
112    let mut helper_type_modules = vec![];
113    for variant_no in 0..=variadic_variants {
114        let expanded = expand_variadic(
115            input.clone(),
116            legacy_helper_type_and_module,
117            generate_return_type_helpers,
118            variadic_argument_count,
119            variant_no,
120            variadic_span,
121        )?;
122
123        if let Some(return_type_helper_module_path) = expanded.return_type_helper_module_path {
124            helper_type_modules.push(return_type_helper_module_path);
125        }
126
127        result.append_all(expanded.tokens.into_iter());
128    }
129
130    if generate_return_type_helpers {
131        let return_types_module_name = Ident::new(
132            &format!("__{}_return_types", input.fn_name),
133            input.fn_name.span(),
134        );
135        let result = quote! {
136            #result
137
138            #[allow(unused_imports)]
139            #[doc(inline)]
140            mod #return_types_module_name {
141                #(
142                    #[doc(inline)]
143                    pub use super:: #helper_type_modules ::*;
144                )*
145            }
146        };
147
148        let return_type_helper_module_path = Some(parse_quote! {
149            #return_types_module_name
150        });
151
152        Ok(ExpandedSqlFunction {
153            tokens: result,
154            return_type_helper_module_path,
155        })
156    } else {
157        Ok(ExpandedSqlFunction {
158            tokens: result,
159            return_type_helper_module_path: None,
160        })
161    }
162}
163
164fn expand_variadic(
165    mut input: SqlFunctionDecl,
166    legacy_helper_type_and_module: bool,
167    generate_return_type_helpers: bool,
168    variadic_argument_count: usize,
169    variant_no: usize,
170    variadic_span: Span,
171) -> syn::Result<ExpandedSqlFunction> {
172    add_variadic_doc_comments(&mut input.attributes, &input.fn_name.to_string());
173
174    let sql_name = parse_sql_name_attr(&mut input);
175
176    input.fn_name = format_ident!("{}_{}", input.fn_name, variant_no);
177
178    let nonvariadic_args_count = input
179        .args
180        .len()
181        .checked_sub(variadic_argument_count)
182        .ok_or_else(|| {
183            syn::Error::new(
184                variadic_span,
185                "invalid variadic argument count: not enough function arguments",
186            )
187        })?;
188
189    let mut variadic_generic_indexes = vec![];
190    let mut arguments_with_generic_types = vec![];
191    for (arg_idx, arg) in input.args.iter().skip(nonvariadic_args_count).enumerate() {
192        // If argument is of type that definitely cannot be a generic then we skip it.
193        let Type::Path(ty_path) = arg.ty.clone() else {
194            continue;
195        };
196        let Ok(ty_ident) = ty_path.path.require_ident() else {
197            continue;
198        };
199
200        let idx = input.generics.params.iter().position(|param| match param {
201            GenericParam::Type(type_param) => type_param.ident == *ty_ident,
202            _ => false,
203        });
204
205        if let Some(idx) = idx {
206            variadic_generic_indexes.push(idx);
207            arguments_with_generic_types.push(arg_idx);
208        }
209    }
210
211    let mut args: Vec<_> = input.args.into_pairs().collect();
212    let variadic_args = args.split_off(nonvariadic_args_count);
213    let nonvariadic_args = args;
214
215    let variadic_args: Vec<_> = iter::repeat_n(variadic_args, variant_no)
216        .enumerate()
217        .flat_map(|(arg_group_idx, arg_group)| {
218            let mut resulting_args = vec![];
219
220            for (arg_idx, arg) in arg_group.into_iter().enumerate() {
221                let mut arg = arg.into_value();
222
223                arg.name = format_ident!("{}_{}", arg.name, arg_group_idx + 1);
224
225                if arguments_with_generic_types.contains(&arg_idx) {
226                    let Type::Path(mut ty_path) = arg.ty.clone() else {
227                        unreachable!("This argument should have path type as checked earlier")
228                    };
229                    let Ok(ident) = ty_path.path.require_ident() else {
230                        unreachable!("This argument should have ident type as checked earlier")
231                    };
232
233                    ty_path.path.segments[0].ident =
234                        format_ident!("{}{}", ident, arg_group_idx + 1);
235                    arg.ty = Type::Path(ty_path);
236                }
237
238                let pair = Pair::new(arg, Some(Token![,]([Span::call_site()])));
239                resulting_args.push(pair);
240            }
241
242            resulting_args
243        })
244        .collect();
245
246    input.args = nonvariadic_args.into_iter().chain(variadic_args).collect();
247
248    let generics: Vec<_> = input.generics.params.into_pairs().collect();
249    input.generics.params = if variant_no == 0 {
250        generics
251            .into_iter()
252            .enumerate()
253            .filter_map(|(generic_idx, generic)| {
254                (!variadic_generic_indexes.contains(&generic_idx)).then_some(generic)
255            })
256            .collect()
257    } else {
258        iter::repeat_n(generics, variant_no)
259            .enumerate()
260            .flat_map(|(generic_group_idx, generic_group)| {
261                let mut resulting_generics = vec![];
262
263                for (generic_idx, generic) in generic_group.into_iter().enumerate() {
264                    if !variadic_generic_indexes.contains(&generic_idx) {
265                        if generic_group_idx == 0 {
266                            resulting_generics.push(generic);
267                        }
268
269                        continue;
270                    }
271
272                    let mut generic = generic.into_value();
273
274                    if let GenericParam::Type(type_param) = &mut generic {
275                        type_param.ident =
276                            format_ident!("{}{}", type_param.ident, generic_group_idx + 1);
277                    } else {
278                        unreachable!("This generic should be a type param as checked earlier")
279                    }
280
281                    let pair = Pair::new(generic, Some(Token![,]([Span::call_site()])));
282                    resulting_generics.push(pair);
283                }
284
285                resulting_generics
286            })
287            .collect()
288    };
289
290    expand_nonvariadic(
291        input,
292        sql_name,
293        legacy_helper_type_and_module,
294        generate_return_type_helpers,
295    )
296}
297
298fn add_variadic_doc_comments(
299    attributes: &mut Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
300    fn_name: &str,
301) {
302    let mut doc_comments_end = attributes.len()
303        - attributes
304            .iter()
305            .rev()
306            .position(|attr| matches!(&attr.item, SqlFunctionAttribute::Other(syn::Attribute{ meta: Meta::NameValue(MetaNameValue { path, .. }), ..}) if path.is_ident("doc")))
307            .unwrap_or(attributes.len());
308
309    let fn_family = format!("`{fn_name}_0`, `{fn_name}_1`, ... `{fn_name}_n`");
310
311    let doc_comments: Vec<Attribute> = parse_quote! {
312        ///
313        /// # Variadic functions
314        ///
315        /// This function is variadic in SQL, so there's a family of functions
316        /// on a diesel side:
317        ///
318        #[doc = #fn_family]
319        ///
320        /// Here, the postfix number indicates repetitions of variadic arguments.
321        /// To use this function, the appropriate version with the correct
322        /// argument count must be selected.
323        ///
324        /// ## Controlling the generation of variadic function variants
325        ///
326        /// By default, only variants with 0, 1, and 2 repetitions of variadic
327        /// arguments are generated. To generate more variants, set the
328        /// `DIESEL_VARIADIC_FUNCTION_ARGS` environment variable to the desired
329        /// number of variants.
330        ///
331        /// For a greater convenience this environment variable can also be set
332        /// in a `.cargo/config.toml` file as described in the
333        /// [cargo documentation](https://doc.rust-lang.org/cargo/reference/config.html#env).
334        #[doc(alias = #fn_name)]
335    };
336
337    for new_attribute in doc_comments {
338        attributes.insert(
339            doc_comments_end,
340            AttributeSpanWrapper {
341                item: SqlFunctionAttribute::Other(new_attribute),
342                attribute_span: Span::mixed_site(),
343                ident_span: Span::mixed_site(),
344            },
345        );
346        doc_comments_end += 1;
347    }
348}
349
350fn parse_sql_name_attr(input: &mut SqlFunctionDecl) -> String {
351    let result = input
352        .attributes
353        .iter()
354        .find_map(|attr| match attr.item {
355            SqlFunctionAttribute::SqlName(_, ref value) => Some(value.value()),
356            _ => None,
357        })
358        .unwrap_or_else(|| input.fn_name.to_string());
359
360    result
361}
362
363fn expand_nonvariadic(
364    input: SqlFunctionDecl,
365    sql_name: String,
366    legacy_helper_type_and_module: bool,
367    generate_return_type_helpers: bool,
368) -> syn::Result<ExpandedSqlFunction> {
369    let SqlFunctionDecl {
370        attributes,
371        fn_token,
372        fn_name,
373        mut generics,
374        args,
375        return_type,
376    } = input;
377
378    let is_aggregate = attributes
379        .iter()
380        .any(|attr| matches!(attr.item, SqlFunctionAttribute::Aggregate(..)));
381
382    let can_be_called_directly = !function_cannot_be_called_directly(&attributes);
383
384    let skip_return_type_helper = attributes
385        .iter()
386        .any(|attr| matches!(attr.item, SqlFunctionAttribute::SkipReturnTypeHelper(..)));
387
388    let window_attrs = attributes
389        .iter()
390        .filter(|a| matches!(a.item, SqlFunctionAttribute::Window { .. }))
391        .cloned()
392        .collect::<Vec<_>>();
393
394    let restrictions = attributes
395        .iter()
396        .find_map(|a| match a.item {
397            SqlFunctionAttribute::Restriction(ref r) => Some(r.clone()),
398            _ => None,
399        })
400        .unwrap_or_default();
401
402    let attributes = attributes
403        .into_iter()
404        .filter_map(|a| match a.item {
405            SqlFunctionAttribute::Other(a) => Some(a),
406            _ => None,
407        })
408        .collect::<Vec<_>>();
409
410    let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
411        .iter()
412        .map(|StrictFnArg { name, ty, .. }| (name, ty))
413        .unzip();
414    let arg_struct_assign = args.iter().map(
415        |StrictFnArg {
416             name, colon_token, ..
417         }| {
418            let name2 = name.clone();
419            quote!(#name #colon_token #name2.as_expression())
420        },
421    );
422
423    let type_args = &generics
424        .type_params()
425        .map(|type_param| type_param.ident.clone())
426        .collect::<Vec<_>>();
427
428    for StrictFnArg { name, .. } in &args {
429        generics.params.push(parse_quote!(#name));
430    }
431
432    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
433    // Even if we force an empty where clause, it still won't print the where
434    // token with no bounds.
435    let where_clause = where_clause
436        .map(|w| quote!(#w))
437        .unwrap_or_else(|| quote!(where));
438
439    let mut generics_with_internal = generics.clone();
440    generics_with_internal
441        .params
442        .push(parse_quote!(__DieselInternal));
443    let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
444
445    let sql_type;
446    let numeric_derive;
447
448    if arg_name.is_empty() {
449        sql_type = None;
450        // FIXME: We can always derive once trivial bounds are stable
451        numeric_derive = None;
452    } else {
453        sql_type = Some(quote!((#(#arg_name),*): Expression,));
454        numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
455    }
456
457    let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
458    let query_fragment_impl =
459        can_be_called_directly.then_some(restrictions.generate_all_queryfragment_impls(
460            generics.clone(),
461            &ty_generics,
462            arg_name,
463            &fn_name,
464        ));
465
466    let args_iter = args.iter();
467    let mut tokens = quote! {
468        use diesel::{self, QueryResult};
469        use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
470        use diesel::query_builder::{QueryFragment, AstPass};
471        use diesel::sql_types::*;
472        use diesel::internal::sql_functions::*;
473        use super::*;
474
475        #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
476        #numeric_derive
477        pub struct #fn_name #ty_generics {
478            #(pub(in super) #args_iter,)*
479            #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
480        }
481
482        #[doc = #helper_type_doc]
483        pub type HelperType #ty_generics = #fn_name <
484            #(#type_args,)*
485            #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
486        >;
487
488        impl #impl_generics Expression for #fn_name #ty_generics
489        #where_clause
490            #sql_type
491        {
492            type SqlType = #return_type;
493        }
494
495        // __DieselInternal is what we call QS normally
496        impl #impl_generics_internal SelectableExpression<__DieselInternal>
497            for #fn_name #ty_generics
498        #where_clause
499            #(#arg_name: SelectableExpression<__DieselInternal>,)*
500            Self: AppearsOnTable<__DieselInternal>,
501        {
502        }
503
504        // __DieselInternal is what we call QS normally
505        impl #impl_generics_internal AppearsOnTable<__DieselInternal>
506            for #fn_name #ty_generics
507        #where_clause
508            #(#arg_name: AppearsOnTable<__DieselInternal>,)*
509            Self: Expression,
510        {
511        }
512
513        impl #impl_generics_internal FunctionFragment<__DieselInternal>
514            for #fn_name #ty_generics
515        where
516            __DieselInternal: diesel::backend::Backend,
517            #(#arg_name: QueryFragment<__DieselInternal>,)*
518        {
519            const FUNCTION_NAME: &'static str = #sql_name;
520
521            #[allow(unused_assignments)]
522            fn walk_arguments<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> {
523                // we unroll the arguments manually here, to prevent borrow check issues
524                let mut needs_comma = false;
525                #(
526                    if !self.#arg_name.is_noop(out.backend())? {
527                        if needs_comma {
528                            out.push_sql(", ");
529                        }
530                        self.#arg_name.walk_ast(out.reborrow())?;
531                        needs_comma = true;
532                    }
533                )*
534                Ok(())
535            }
536        }
537
538        #query_fragment_impl
539    };
540
541    let is_supported_on_sqlite = cfg!(feature = "sqlite")
542        && type_args.is_empty()
543        && is_sqlite_type(&return_type)
544        && arg_type.iter().all(|a| is_sqlite_type(a));
545
546    for window in &window_attrs {
547        tokens.extend(generate_window_function_tokens(
548            window,
549            generics.clone(),
550            &ty_generics,
551            &fn_name,
552        ));
553    }
554    if !window_attrs.is_empty() {
555        tokens.extend(quote::quote! {
556            impl #impl_generics IsWindowFunction for #fn_name #ty_generics {
557                type ArgTypes = (#(#arg_name,)*);
558            }
559        });
560    }
561
562    if is_aggregate {
563        tokens = generate_tokens_for_aggregate_functions(
564            tokens,
565            &impl_generics_internal,
566            &impl_generics,
567            &fn_name,
568            &ty_generics,
569            arg_name,
570            arg_type,
571            is_supported_on_sqlite,
572            !window_attrs.is_empty(),
573            &return_type,
574            &sql_name,
575        );
576    } else if window_attrs.is_empty() {
577        tokens = generate_tokens_for_non_aggregate_functions(
578            tokens,
579            &impl_generics_internal,
580            &fn_name,
581            &ty_generics,
582            arg_name,
583            arg_type,
584            is_supported_on_sqlite,
585            &return_type,
586            &sql_name,
587        );
588    }
589
590    let args_iter = args.iter();
591
592    let (outside_of_module_helper_type, return_type_path, internals_module_name) =
593        if legacy_helper_type_and_module {
594            (None, quote! { #fn_name::HelperType }, fn_name.clone())
595        } else {
596            let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
597            (
598                Some(quote! {
599                    #[allow(non_camel_case_types, non_snake_case)]
600                    #[doc = #helper_type_doc]
601                    pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
602                        #(#type_args,)*
603                        #(<#arg_name as diesel::expression::AsExpression<#arg_type>>::Expression,)*
604                    >;
605                }),
606                quote! { #fn_name },
607                internals_module_name,
608            )
609        };
610
611    let (return_type_helper_module, return_type_helper_module_path) =
612        if !generate_return_type_helpers || skip_return_type_helper {
613            (None, None)
614        } else {
615            let auto_derived_types = type_args
616                .iter()
617                .map(|type_arg| {
618                    for arg in &args {
619                        let Type::Path(path) = &arg.ty else {
620                            continue;
621                        };
622
623                        let Some(path_ident) = path.path.get_ident() else {
624                            continue;
625                        };
626
627                        if path_ident == type_arg {
628                            return Ok(arg.name.clone());
629                        }
630                    }
631
632                    Err(syn::Error::new(
633                        type_arg.span(),
634                        "cannot find argument corresponding to the generic",
635                    ))
636                })
637                .collect::<Result<Vec<_>>>()?;
638
639            let arg_names_iter: Vec<_> = args.iter().map(|arg| arg.name.clone()).collect();
640
641            let return_type_module_name =
642                Ident::new(&format!("__{fn_name}_return_type"), fn_name.span());
643
644            let doc =
645                format!("Return type of the [`{fn_name}()`](fn@super::{fn_name}) SQL function.");
646            let return_type_helper_module = quote! {
647                #[allow(non_camel_case_types, non_snake_case, unused_imports)]
648                #[doc(inline)]
649                mod #return_type_module_name {
650                    #[doc = #doc]
651                    pub type #fn_name<
652                        #(#arg_names_iter,)*
653                    > = super::#fn_name<
654                        #( <#auto_derived_types as diesel::expression::Expression>::SqlType, )*
655                        #(#arg_names_iter,)*
656                    >;
657                }
658            };
659
660            let module_path = parse_quote!(
661                #return_type_module_name
662            );
663
664            (Some(return_type_helper_module), Some(module_path))
665        };
666
667    let tokens = quote! {
668        #(#attributes)*
669        #[allow(non_camel_case_types)]
670        pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
671            -> #return_type_path #ty_generics
672        #where_clause
673            #(#arg_name: diesel::expression::AsExpression<#arg_type>,)*
674        {
675            #internals_module_name::#fn_name {
676                #(#arg_struct_assign,)*
677                #(#type_args: ::std::marker::PhantomData,)*
678            }
679        }
680
681        #outside_of_module_helper_type
682
683        #return_type_helper_module
684
685        #[doc(hidden)]
686        #[allow(non_camel_case_types, non_snake_case, unused_imports)]
687        pub(crate) mod #internals_module_name {
688            #tokens
689        }
690    };
691
692    Ok(ExpandedSqlFunction {
693        tokens,
694        return_type_helper_module_path,
695    })
696}
697
698fn generate_window_function_tokens(
699    window: &AttributeSpanWrapper<SqlFunctionAttribute>,
700    generics: Generics,
701    ty_generics: &TypeGenerics<'_>,
702    fn_name: &Ident,
703) -> TokenStream {
704    let SqlFunctionAttribute::Window {
705        restrictions,
706        require_order,
707        ..
708    } = &window.item
709    else {
710        unreachable!("We filtered for window attributes above")
711    };
712    restrictions.generate_all_window_fragment_impls(
713        generics,
714        ty_generics,
715        fn_name,
716        require_order.unwrap_or_default(),
717    )
718}
719
720#[allow(clippy::too_many_arguments)]
721fn generate_tokens_for_non_aggregate_functions(
722    mut tokens: TokenStream,
723    impl_generics_internal: &syn::ImplGenerics<'_>,
724    fn_name: &syn::Ident,
725    ty_generics: &syn::TypeGenerics<'_>,
726    arg_name: &[&syn::Ident],
727    arg_type: &[&syn::Type],
728    is_supported_on_sqlite: bool,
729    return_type: &syn::Type,
730    sql_name: &str,
731) -> TokenStream {
732    tokens = quote! {
733        #tokens
734
735        #[derive(ValidGrouping)]
736        pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
737
738        impl #impl_generics_internal ValidGrouping<__DieselInternal>
739            for #fn_name #ty_generics
740        where
741            __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
742        {
743            type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
744        }
745    };
746
747    if is_supported_on_sqlite && !arg_name.is_empty() {
748        tokens = quote! {
749            #tokens
750
751            use diesel::sqlite::{Sqlite, SqliteConnection};
752            use diesel::serialize::ToSql;
753            use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
754
755            #[allow(dead_code)]
756            /// Registers an implementation for this function on the given connection
757            ///
758            /// This function must be called for every `SqliteConnection` before
759            /// this SQL function can be used on SQLite. The implementation must be
760            /// deterministic (returns the same result given the same arguments). If
761            /// the function is nondeterministic, call
762            /// `register_nondeterministic_impl` instead.
763            pub fn register_impl<F, Ret, #(#arg_name,)*>(
764                conn: &mut SqliteConnection,
765                f: F,
766            ) -> QueryResult<()>
767            where
768                F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
769                (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
770                    StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
771                Ret: ToSql<#return_type, Sqlite>,
772            {
773                conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
774                    #sql_name,
775                    true,
776                    move |(#(#arg_name,)*)| f(#(#arg_name,)*),
777                )
778            }
779
780            #[allow(dead_code)]
781            /// Registers an implementation for this function on the given connection
782            ///
783            /// This function must be called for every `SqliteConnection` before
784            /// this SQL function can be used on SQLite.
785            /// `register_nondeterministic_impl` should only be used if your
786            /// function can return different results with the same arguments (e.g.
787            /// `random`). If your function is deterministic, you should call
788            /// `register_impl` instead.
789            pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
790                conn: &mut SqliteConnection,
791                mut f: F,
792            ) -> QueryResult<()>
793            where
794                F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
795                (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
796                    StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
797                Ret: ToSql<#return_type, Sqlite>,
798            {
799                conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
800                    #sql_name,
801                    false,
802                    move |(#(#arg_name,)*)| f(#(#arg_name,)*),
803                )
804            }
805        };
806    }
807
808    if is_supported_on_sqlite && arg_name.is_empty() {
809        tokens = quote! {
810            #tokens
811
812            use diesel::sqlite::{Sqlite, SqliteConnection};
813            use diesel::serialize::ToSql;
814
815            #[allow(dead_code)]
816            /// Registers an implementation for this function on the given connection
817            ///
818            /// This function must be called for every `SqliteConnection` before
819            /// this SQL function can be used on SQLite. The implementation must be
820            /// deterministic (returns the same result given the same arguments). If
821            /// the function is nondeterministic, call
822            /// `register_nondeterministic_impl` instead.
823            pub fn register_impl<F, Ret>(
824                conn: &SqliteConnection,
825                f: F,
826            ) -> QueryResult<()>
827            where
828                F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
829                Ret: ToSql<#return_type, Sqlite>,
830            {
831                conn.register_noarg_sql_function::<#return_type, _, _>(
832                    #sql_name,
833                    true,
834                    f,
835                )
836            }
837
838            #[allow(dead_code)]
839            /// Registers an implementation for this function on the given connection
840            ///
841            /// This function must be called for every `SqliteConnection` before
842            /// this SQL function can be used on SQLite.
843            /// `register_nondeterministic_impl` should only be used if your
844            /// function can return different results with the same arguments (e.g.
845            /// `random`). If your function is deterministic, you should call
846            /// `register_impl` instead.
847            pub fn register_nondeterministic_impl<F, Ret>(
848                conn: &SqliteConnection,
849                mut f: F,
850            ) -> QueryResult<()>
851            where
852                F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
853                Ret: ToSql<#return_type, Sqlite>,
854            {
855                conn.register_noarg_sql_function::<#return_type, _, _>(
856                    #sql_name,
857                    false,
858                    f,
859                )
860            }
861        };
862    }
863    tokens
864}
865
866#[allow(clippy::too_many_arguments)]
867fn generate_tokens_for_aggregate_functions(
868    mut tokens: TokenStream,
869    impl_generics_internal: &syn::ImplGenerics<'_>,
870    impl_generics: &syn::ImplGenerics<'_>,
871    fn_name: &syn::Ident,
872    ty_generics: &syn::TypeGenerics<'_>,
873    arg_name: &[&syn::Ident],
874    arg_type: &[&syn::Type],
875    is_supported_on_sqlite: bool,
876    is_window: bool,
877    return_type: &syn::Type,
878    sql_name: &str,
879) -> TokenStream {
880    tokens = quote! {
881        #tokens
882
883        impl #impl_generics_internal ValidGrouping<__DieselInternal>
884            for #fn_name #ty_generics
885        {
886            type IsAggregate = diesel::expression::is_aggregate::Yes;
887        }
888
889        impl #impl_generics IsAggregateFunction for #fn_name #ty_generics {}
890    };
891    // we do not support custom window functions for sqlite yet
892    if is_supported_on_sqlite && !is_window {
893        tokens = quote! {
894            #tokens
895
896            use diesel::sqlite::{Sqlite, SqliteConnection};
897            use diesel::serialize::ToSql;
898            use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
899            use diesel::sqlite::SqliteAggregateFunction;
900            use diesel::sql_types::IntoNullable;
901        };
902
903        match arg_name.len() {
904            x if x > 1 => {
905                tokens = quote! {
906                    #tokens
907
908                    #[allow(dead_code)]
909                    /// Registers an implementation for this aggregate function on the given connection
910                    ///
911                    /// This function must be called for every `SqliteConnection` before
912                    /// this SQL function can be used on SQLite. The implementation must be
913                    /// deterministic (returns the same result given the same arguments).
914                    pub fn register_impl<A, #(#arg_name,)*>(
915                        conn: &mut SqliteConnection
916                    ) -> QueryResult<()>
917                        where
918                        A: SqliteAggregateFunction<(#(#arg_name,)*)>
919                            + Send
920                            + 'static
921                            + ::std::panic::UnwindSafe
922                            + ::std::panic::RefUnwindSafe,
923                        A::Output: ToSql<#return_type, Sqlite>,
924                        (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
925                            StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
926                            ::std::panic::UnwindSafe,
927                    {
928                        conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
929                    }
930                };
931            }
932            1 => {
933                let arg_name = arg_name[0];
934                let arg_type = arg_type[0];
935
936                tokens = quote! {
937                    #tokens
938
939                    #[allow(dead_code)]
940                    /// Registers an implementation for this aggregate function on the given connection
941                    ///
942                    /// This function must be called for every `SqliteConnection` before
943                    /// this SQL function can be used on SQLite. The implementation must be
944                    /// deterministic (returns the same result given the same arguments).
945                    pub fn register_impl<A, #arg_name>(
946                        conn: &mut SqliteConnection
947                    ) -> QueryResult<()>
948                        where
949                        A: SqliteAggregateFunction<#arg_name>
950                            + Send
951                            + 'static
952                            + std::panic::UnwindSafe
953                            + std::panic::RefUnwindSafe,
954                        A::Output: ToSql<#return_type, Sqlite>,
955                        #arg_name: FromSqlRow<#arg_type, Sqlite> +
956                            StaticallySizedRow<#arg_type, Sqlite> +
957                            ::std::panic::UnwindSafe,
958                        {
959                            conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
960                        }
961                };
962            }
963            _ => (),
964        }
965    }
966    tokens
967}
968
969fn function_cannot_be_called_directly(
970    attributes: &[AttributeSpanWrapper<SqlFunctionAttribute>],
971) -> bool {
972    let mut has_aggregate = false;
973    let mut has_window = false;
974    for attr in attributes {
975        has_aggregate = has_aggregate || matches!(attr.item, SqlFunctionAttribute::Aggregate(..));
976        has_window = has_window || matches!(attr.item, SqlFunctionAttribute::Window { .. });
977    }
978    has_window && !has_aggregate
979}
980
981pub(crate) struct ExternSqlBlock {
982    pub(crate) function_decls: Vec<SqlFunctionDecl>,
983}
984
985impl Parse for ExternSqlBlock {
986    fn parse(input: ParseStream) -> Result<Self> {
987        let mut error = None::<syn::Error>;
988
989        let mut combine_error = |e: syn::Error| {
990            error = Some(
991                error
992                    .take()
993                    .map(|mut o| {
994                        o.combine(e.clone());
995                        o
996                    })
997                    .unwrap_or(e),
998            )
999        };
1000
1001        let block = syn::ItemForeignMod::parse(input)?;
1002        if block.abi.name.as_ref().map(|n| n.value()) != Some("SQL".into()) {
1003            return Err(syn::Error::new(block.abi.span(), "expect `SQL` as ABI"));
1004        }
1005        if block.unsafety.is_some() {
1006            return Err(syn::Error::new(
1007                block.unsafety.unwrap().span(),
1008                "expect `SQL` function blocks to be safe",
1009            ));
1010        }
1011
1012        let parsed_block_attrs = parse_attributes(&mut combine_error, block.attrs);
1013
1014        let item_count = block.items.len();
1015        let function_decls_input = block
1016            .items
1017            .into_iter()
1018            .map(|i| syn::parse2::<SqlFunctionDecl>(quote! { #i }));
1019
1020        let mut function_decls = Vec::with_capacity(item_count);
1021        for decl in function_decls_input {
1022            match decl {
1023                Ok(mut decl) => {
1024                    decl.attributes = merge_attributes(&parsed_block_attrs, decl.attributes);
1025                    function_decls.push(decl)
1026                }
1027                Err(e) => {
1028                    error = Some(
1029                        error
1030                            .take()
1031                            .map(|mut o| {
1032                                o.combine(e.clone());
1033                                o
1034                            })
1035                            .unwrap_or(e),
1036                    );
1037                }
1038            }
1039        }
1040
1041        error
1042            .map(Err)
1043            .unwrap_or(Ok(ExternSqlBlock { function_decls }))
1044    }
1045}
1046
1047fn merge_attributes(
1048    parsed_block_attrs: &[AttributeSpanWrapper<SqlFunctionAttribute>],
1049    mut attributes: Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
1050) -> Vec<AttributeSpanWrapper<SqlFunctionAttribute>> {
1051    for attr in parsed_block_attrs {
1052        if attributes.iter().all(|a| match (&a.item, &attr.item) {
1053            (SqlFunctionAttribute::Aggregate(_), SqlFunctionAttribute::Aggregate(_)) => todo!(),
1054            (SqlFunctionAttribute::Window { .. }, SqlFunctionAttribute::Window { .. })
1055            | (SqlFunctionAttribute::SqlName(_, _), SqlFunctionAttribute::SqlName(_, _))
1056            | (SqlFunctionAttribute::Restriction(_), SqlFunctionAttribute::Restriction(_))
1057            | (SqlFunctionAttribute::Variadic(_, _), SqlFunctionAttribute::Variadic(_, _))
1058            | (
1059                SqlFunctionAttribute::SkipReturnTypeHelper(_),
1060                SqlFunctionAttribute::SkipReturnTypeHelper(_),
1061            ) => false,
1062            _ => true,
1063        }) {
1064            attributes.push(attr.clone());
1065        }
1066    }
1067    attributes
1068}
1069
1070#[derive(Clone)]
1071pub(crate) struct SqlFunctionDecl {
1072    attributes: Vec<AttributeSpanWrapper<SqlFunctionAttribute>>,
1073    fn_token: Token![fn],
1074    fn_name: Ident,
1075    generics: Generics,
1076    args: Punctuated<StrictFnArg, Token![,]>,
1077    return_type: Type,
1078}
1079
1080impl Parse for SqlFunctionDecl {
1081    fn parse(input: ParseStream) -> Result<Self> {
1082        let mut error = None::<syn::Error>;
1083        let mut combine_error = |e: syn::Error| {
1084            error = Some(
1085                error
1086                    .take()
1087                    .map(|mut o| {
1088                        o.combine(e.clone());
1089                        o
1090                    })
1091                    .unwrap_or(e),
1092            )
1093        };
1094
1095        let attributes = Attribute::parse_outer(input).unwrap_or_else(|e| {
1096            combine_error(e);
1097            Vec::new()
1098        });
1099        let attributes_collected = parse_attributes(&mut combine_error, attributes);
1100
1101        let fn_token: Token![fn] = input.parse().unwrap_or_else(|e| {
1102            combine_error(e);
1103            Default::default()
1104        });
1105        let fn_name = Ident::parse(input).unwrap_or_else(|e| {
1106            combine_error(e);
1107            Ident::new("dummy", Span::call_site())
1108        });
1109        let generics = Generics::parse(input).unwrap_or_else(|e| {
1110            combine_error(e);
1111            Generics {
1112                lt_token: None,
1113                params: Punctuated::new(),
1114                gt_token: None,
1115                where_clause: None,
1116            }
1117        });
1118        let args;
1119        let _paren = parenthesized!(args in input);
1120        let args = args
1121            .parse_terminated(StrictFnArg::parse, Token![,])
1122            .unwrap_or_else(|e| {
1123                combine_error(e);
1124                Punctuated::new()
1125            });
1126        let rarrow = Option::<Token![->]>::parse(input).unwrap_or_else(|e| {
1127            combine_error(e);
1128            None
1129        });
1130        let return_type = if rarrow.is_some() {
1131            Type::parse(input).unwrap_or_else(|e| {
1132                combine_error(e);
1133                Type::Never(syn::TypeNever {
1134                    bang_token: Default::default(),
1135                })
1136            })
1137        } else {
1138            parse_quote!(diesel::expression::expression_types::NotSelectable)
1139        };
1140        let _semi = Option::<Token![;]>::parse(input).unwrap_or_else(|e| {
1141            combine_error(e);
1142            None
1143        });
1144
1145        error.map(Err).unwrap_or(Ok(Self {
1146            attributes: attributes_collected,
1147            fn_token,
1148            fn_name,
1149            generics,
1150            args,
1151            return_type,
1152        }))
1153    }
1154}
1155
1156fn parse_attribute(
1157    attr: syn::Attribute,
1158) -> Result<Option<AttributeSpanWrapper<SqlFunctionAttribute>>> {
1159    match &attr.meta {
1160        syn::Meta::NameValue(syn::MetaNameValue {
1161            path,
1162            value:
1163                syn::Expr::Lit(syn::ExprLit {
1164                    lit: syn::Lit::Str(sql_name),
1165                    ..
1166                }),
1167            ..
1168        }) if path.is_ident("sql_name") => Ok(Some(AttributeSpanWrapper {
1169            attribute_span: attr.span(),
1170            ident_span: sql_name.span(),
1171            item: SqlFunctionAttribute::SqlName(path.require_ident()?.clone(), sql_name.clone()),
1172        })),
1173        syn::Meta::Path(path) if path.is_ident("aggregate") => Ok(Some(AttributeSpanWrapper {
1174            attribute_span: attr.span(),
1175            ident_span: path.span(),
1176            item: SqlFunctionAttribute::Aggregate(
1177                path.require_ident()
1178                    .map_err(|e| {
1179                        syn::Error::new(
1180                            e.span(),
1181                            format!("{e}, the correct format is `#[aggregate]`"),
1182                        )
1183                    })?
1184                    .clone(),
1185            ),
1186        })),
1187        syn::Meta::Path(path) if path.is_ident("skip_return_type_helper") => {
1188            Ok(Some(AttributeSpanWrapper {
1189                ident_span: attr.span(),
1190                attribute_span: path.span(),
1191                item: SqlFunctionAttribute::SkipReturnTypeHelper(
1192                    path.require_ident()
1193                        .map_err(|e| {
1194                            syn::Error::new(
1195                                e.span(),
1196                                format!("{e}, the correct format is `#[skip_return_type_helper]`"),
1197                            )
1198                        })?
1199                        .clone(),
1200                ),
1201            }))
1202        }
1203        syn::Meta::Path(path) if path.is_ident("window") => Ok(Some(AttributeSpanWrapper {
1204            attribute_span: attr.span(),
1205            ident_span: path.span(),
1206            item: SqlFunctionAttribute::Window {
1207                ident: path
1208                    .require_ident()
1209                    .map_err(|e| {
1210                        syn::Error::new(e.span(), format!("{e}, the correct format is `#[window]`"))
1211                    })?
1212                    .clone(),
1213                restrictions: BackendRestriction::None,
1214                require_order: None,
1215            },
1216        })),
1217        syn::Meta::List(syn::MetaList {
1218            path,
1219            delimiter: syn::MacroDelimiter::Paren(_),
1220            tokens,
1221        }) if path.is_ident("variadic") => {
1222            let count: syn::LitInt = syn::parse2(tokens.clone()).map_err(|e| {
1223                syn::Error::new(
1224                    e.span(),
1225                    format!("{e}, the correct format is `#[variadic(3)]`"),
1226                )
1227            })?;
1228            Ok(Some(AttributeSpanWrapper {
1229                item: SqlFunctionAttribute::Variadic(
1230                    path.require_ident()
1231                        .map_err(|e| {
1232                            syn::Error::new(
1233                                e.span(),
1234                                format!("{e}, the correct format is `#[variadic(3)]`"),
1235                            )
1236                        })?
1237                        .clone(),
1238                    count.clone(),
1239                ),
1240                attribute_span: attr.span(),
1241                ident_span: path.require_ident()?.span(),
1242            }))
1243        }
1244        syn::Meta::NameValue(_) | syn::Meta::Path(_) => Ok(Some(AttributeSpanWrapper {
1245            attribute_span: attr.span(),
1246            ident_span: attr.span(),
1247            item: SqlFunctionAttribute::Other(attr),
1248        })),
1249        syn::Meta::List(_) => {
1250            let name = attr.meta.path().require_ident()?;
1251            let attribute_span = attr.meta.span();
1252            attr.clone()
1253                .parse_args_with(|input: &syn::parse::ParseBuffer| {
1254                    SqlFunctionAttribute::parse_attr(
1255                        name.clone(),
1256                        input,
1257                        attr.clone(),
1258                        attribute_span,
1259                    )
1260                })
1261        }
1262    }
1263}
1264
1265fn parse_attributes(
1266    combine_error: &mut impl FnMut(syn::Error),
1267    attributes: Vec<Attribute>,
1268) -> Vec<AttributeSpanWrapper<SqlFunctionAttribute>> {
1269    let attribute_count = attributes.len();
1270
1271    let attributes = attributes
1272        .into_iter()
1273        .filter_map(|attr| parse_attribute(attr).transpose());
1274
1275    let mut attributes_collected = Vec::with_capacity(attribute_count);
1276    for attr in attributes {
1277        match attr {
1278            Ok(attr) => attributes_collected.push(attr),
1279            Err(e) => {
1280                combine_error(e);
1281            }
1282        }
1283    }
1284    attributes_collected
1285}
1286
1287/// Essentially the same as ArgCaptured, but only allowing ident patterns
1288#[derive(Clone)]
1289struct StrictFnArg {
1290    name: Ident,
1291    colon_token: Token![:],
1292    ty: Type,
1293}
1294
1295impl Parse for StrictFnArg {
1296    fn parse(input: ParseStream) -> Result<Self> {
1297        let name = input.parse()?;
1298        let colon_token = input.parse()?;
1299        let ty = input.parse()?;
1300        Ok(Self {
1301            name,
1302            colon_token,
1303            ty,
1304        })
1305    }
1306}
1307
1308impl ToTokens for StrictFnArg {
1309    fn to_tokens(&self, tokens: &mut TokenStream) {
1310        self.name.to_tokens(tokens);
1311        self.colon_token.to_tokens(tokens);
1312        self.name.to_tokens(tokens);
1313    }
1314}
1315
1316fn is_sqlite_type(ty: &Type) -> bool {
1317    let last_segment = if let Type::Path(tp) = ty {
1318        if let Some(segment) = tp.path.segments.last() {
1319            segment
1320        } else {
1321            return false;
1322        }
1323    } else {
1324        return false;
1325    };
1326
1327    let ident = last_segment.ident.to_string();
1328    if ident == "Nullable" {
1329        if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
1330            if let Some(GenericArgument::Type(ty)) = ab.args.first() {
1331                return is_sqlite_type(ty);
1332            }
1333        }
1334        return false;
1335    }
1336
1337    [
1338        "BigInt",
1339        "Binary",
1340        "Bool",
1341        "Date",
1342        "Double",
1343        "Float",
1344        "Integer",
1345        "Numeric",
1346        "SmallInt",
1347        "Text",
1348        "Time",
1349        "Timestamp",
1350    ]
1351    .contains(&ident.as_str())
1352}
1353
1354#[derive(Default, Clone, Debug)]
1355enum BackendRestriction {
1356    #[default]
1357    None,
1358    SqlDialect(syn::Ident, syn::Ident, syn::Path),
1359    BackendBound(
1360        syn::Ident,
1361        syn::punctuated::Punctuated<syn::TypeParamBound, syn::Token![+]>,
1362    ),
1363    Backends(
1364        syn::Ident,
1365        syn::punctuated::Punctuated<syn::Path, syn::Token![,]>,
1366    ),
1367}
1368
1369impl BackendRestriction {
1370    fn parse_from(input: &syn::parse::ParseBuffer<'_>) -> Result<Self> {
1371        if input.is_empty() {
1372            return Ok(Self::None);
1373        }
1374        Self::parse(input)
1375    }
1376
1377    fn parse_backends(
1378        input: &syn::parse::ParseBuffer<'_>,
1379        name: Ident,
1380    ) -> Result<BackendRestriction> {
1381        let backends = Punctuated::parse_terminated(input)?;
1382        Ok(Self::Backends(name, backends))
1383    }
1384
1385    fn parse_sql_dialect(
1386        content: &syn::parse::ParseBuffer<'_>,
1387        name: Ident,
1388    ) -> Result<BackendRestriction> {
1389        let dialect = content.parse()?;
1390        let _del: syn::Token![,] = content.parse()?;
1391        let dialect_variant = content.parse()?;
1392
1393        Ok(Self::SqlDialect(name, dialect, dialect_variant))
1394    }
1395
1396    fn parse_backend_bounds(
1397        input: &syn::parse::ParseBuffer<'_>,
1398        name: Ident,
1399    ) -> Result<BackendRestriction> {
1400        let restrictions = Punctuated::parse_terminated(input)?;
1401        Ok(Self::BackendBound(name, restrictions))
1402    }
1403
1404    fn generate_all_window_fragment_impls(
1405        &self,
1406        mut generics: Generics,
1407        ty_generics: &TypeGenerics<'_>,
1408        fn_name: &syn::Ident,
1409        require_order: bool,
1410    ) -> TokenStream {
1411        generics.params.push(parse_quote!(__P));
1412        generics.params.push(parse_quote!(__O));
1413        generics.params.push(parse_quote!(__F));
1414        let order = if require_order {
1415            quote::quote! {
1416                diesel::internal::sql_functions::Order<__O, true>
1417            }
1418        } else {
1419            quote::quote! {__O}
1420        };
1421        match *self {
1422            BackendRestriction::None => {
1423                generics.params.push(parse_quote!(__DieselInternal));
1424                let (impl_generics, _, _) = generics.split_for_impl();
1425                Self::generate_window_fragment_impl(
1426                    parse_quote!(__DieselInternal),
1427                    Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)),
1428                    &impl_generics,
1429                    ty_generics,
1430                    fn_name,
1431                    None,
1432                    &order,
1433                )
1434            }
1435            BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => {
1436                generics.params.push(parse_quote!(__DieselInternal));
1437                let (impl_generics, _, _) = generics.split_for_impl();
1438                let mut out = quote::quote! {
1439                    impl #impl_generics WindowFunctionFragment<#fn_name #ty_generics, __DieselInternal>
1440                        for OverClause<__P, #order, __F>
1441                    where
1442                        Self: WindowFunctionFragment<#fn_name #ty_generics, __DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>,
1443                        __DieselInternal: diesel::backend::Backend,
1444                    {
1445                    }
1446
1447                };
1448                let specific_impl = Self::generate_window_fragment_impl(
1449                    parse_quote!(__DieselInternal),
1450                    Some(
1451                        parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,),
1452                    ),
1453                    &impl_generics,
1454                    ty_generics,
1455                    fn_name,
1456                    Some(dialect_type),
1457                    &order,
1458                );
1459                out.extend(specific_impl);
1460                out
1461            }
1462            BackendRestriction::BackendBound(_, ref restriction) => {
1463                generics.params.push(parse_quote!(__DieselInternal));
1464                let (impl_generics, _, _) = generics.split_for_impl();
1465                Self::generate_window_fragment_impl(
1466                    parse_quote!(__DieselInternal),
1467                    Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)),
1468                    &impl_generics,
1469                    ty_generics,
1470                    fn_name,
1471                    None,
1472                    &order,
1473                )
1474            }
1475            BackendRestriction::Backends(_, ref backends) => {
1476                let (impl_generics, _, _) = generics.split_for_impl();
1477                let backends = backends.iter().map(|b| {
1478                    Self::generate_window_fragment_impl(
1479                        quote! {#b},
1480                        None,
1481                        &impl_generics,
1482                        ty_generics,
1483                        fn_name,
1484                        None,
1485                        &order,
1486                    )
1487                });
1488
1489                parse_quote!(#(#backends)*)
1490            }
1491        }
1492    }
1493
1494    fn generate_window_fragment_impl(
1495        backend: TokenStream,
1496        backend_bound: Option<proc_macro2::TokenStream>,
1497        impl_generics: &ImplGenerics<'_>,
1498        ty_generics: &TypeGenerics<'_>,
1499        fn_name: &syn::Ident,
1500        dialect: Option<&syn::Path>,
1501        order: &TokenStream,
1502    ) -> TokenStream {
1503        quote::quote! {
1504            impl #impl_generics WindowFunctionFragment<#fn_name #ty_generics, #backend, #dialect> for OverClause<__P, #order, __F>
1505                where #backend_bound
1506            {
1507
1508            }
1509        }
1510    }
1511
1512    fn generate_all_queryfragment_impls(
1513        &self,
1514        mut generics: Generics,
1515        ty_generics: &TypeGenerics<'_>,
1516        arg_name: &[&syn::Ident],
1517        fn_name: &syn::Ident,
1518    ) -> proc_macro2::TokenStream {
1519        match *self {
1520            BackendRestriction::None => {
1521                generics.params.push(parse_quote!(__DieselInternal));
1522                let (impl_generics, _, _) = generics.split_for_impl();
1523                Self::generate_queryfragment_impl(
1524                    parse_quote!(__DieselInternal),
1525                    Some(parse_quote!(__DieselInternal: diesel::backend::Backend,)),
1526                    &impl_generics,
1527                    ty_generics,
1528                    arg_name,
1529                    fn_name,
1530                    None,
1531                )
1532            }
1533            BackendRestriction::BackendBound(_, ref restriction) => {
1534                generics.params.push(parse_quote!(__DieselInternal));
1535                let (impl_generics, _, _) = generics.split_for_impl();
1536                Self::generate_queryfragment_impl(
1537                    parse_quote!(__DieselInternal),
1538                    Some(parse_quote!(__DieselInternal: diesel::backend::Backend + #restriction,)),
1539                    &impl_generics,
1540                    ty_generics,
1541                    arg_name,
1542                    fn_name,
1543                    None,
1544                )
1545            }
1546            BackendRestriction::SqlDialect(_, ref dialect, ref dialect_type) => {
1547                generics.params.push(parse_quote!(__DieselInternal));
1548                let (impl_generics, _, _) = generics.split_for_impl();
1549                let specific_impl = Self::generate_queryfragment_impl(
1550                    parse_quote!(__DieselInternal),
1551                    Some(
1552                        parse_quote!(__DieselInternal: diesel::backend::Backend + diesel::backend::SqlDialect<#dialect = #dialect_type>,),
1553                    ),
1554                    &impl_generics,
1555                    ty_generics,
1556                    arg_name,
1557                    fn_name,
1558                    Some(dialect_type),
1559                );
1560                quote::quote! {
1561                    impl #impl_generics QueryFragment<__DieselInternal>
1562                        for #fn_name #ty_generics
1563                    where
1564                        Self: QueryFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>,
1565                        __DieselInternal: diesel::backend::Backend,
1566                    {
1567                        fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()> {
1568                            <Self as QueryFragment<__DieselInternal, <__DieselInternal as diesel::backend::SqlDialect>::#dialect>>::walk_ast(self, out)
1569                        }
1570
1571                    }
1572
1573                    #specific_impl
1574                }
1575            }
1576            BackendRestriction::Backends(_, ref backends) => {
1577                let (impl_generics, _, _) = generics.split_for_impl();
1578                let backends = backends.iter().map(|b| {
1579                    Self::generate_queryfragment_impl(
1580                        quote! {#b},
1581                        None,
1582                        &impl_generics,
1583                        ty_generics,
1584                        arg_name,
1585                        fn_name,
1586                        None,
1587                    )
1588                });
1589
1590                parse_quote!(#(#backends)*)
1591            }
1592        }
1593    }
1594
1595    fn generate_queryfragment_impl(
1596        backend: proc_macro2::TokenStream,
1597        backend_bound: Option<proc_macro2::TokenStream>,
1598        impl_generics: &ImplGenerics<'_>,
1599        ty_generics: &TypeGenerics<'_>,
1600        arg_name: &[&syn::Ident],
1601        fn_name: &syn::Ident,
1602        dialect: Option<&syn::Path>,
1603    ) -> proc_macro2::TokenStream {
1604        quote::quote! {
1605            impl #impl_generics QueryFragment<#backend, #dialect>
1606                for #fn_name #ty_generics
1607            where
1608                #backend_bound
1609            #(#arg_name: QueryFragment<#backend>,)*
1610            {
1611                fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, #backend>) -> QueryResult<()>{
1612                    out.push_sql(<Self as FunctionFragment<#backend>>::FUNCTION_NAME);
1613                    out.push_sql("(");
1614                    self.walk_arguments(out.reborrow())?;
1615                    out.push_sql(")");
1616                    Ok(())
1617                }
1618            }
1619        }
1620    }
1621}
1622
1623impl Parse for BackendRestriction {
1624    fn parse(input: ParseStream) -> Result<Self> {
1625        let name: syn::Ident = input.parse()?;
1626        let name_str = name.to_string();
1627        let content;
1628        parenthesized!(content in input);
1629        match &*name_str {
1630            "backends" => Self::parse_backends(&content, name),
1631            "dialect" => Self::parse_sql_dialect(&content, name),
1632            "backend_bounds" => Self::parse_backend_bounds(&content, name),
1633            _ => Err(syn::Error::new(
1634                name.span(),
1635                format!("unexpected option `{name_str}`"),
1636            )),
1637        }
1638    }
1639}
1640
1641#[derive(Debug, Clone)]
1642enum SqlFunctionAttribute {
1643    Aggregate(Ident),
1644    Window {
1645        ident: Ident,
1646        restrictions: BackendRestriction,
1647        require_order: Option<bool>,
1648    },
1649    SqlName(Ident, LitStr),
1650    Restriction(BackendRestriction),
1651    Variadic(Ident, LitInt),
1652    SkipReturnTypeHelper(Ident),
1653    Other(Attribute),
1654}
1655
1656impl MySpanned for SqlFunctionAttribute {
1657    fn span(&self) -> proc_macro2::Span {
1658        match self {
1659            SqlFunctionAttribute::Restriction(BackendRestriction::Backends(ref ident, ..))
1660            | SqlFunctionAttribute::Restriction(BackendRestriction::SqlDialect(ref ident, ..))
1661            | SqlFunctionAttribute::Restriction(BackendRestriction::BackendBound(ref ident, ..))
1662            | SqlFunctionAttribute::Aggregate(ref ident, ..)
1663            | SqlFunctionAttribute::Window { ref ident, .. }
1664            | SqlFunctionAttribute::Variadic(ref ident, ..)
1665            | SqlFunctionAttribute::SkipReturnTypeHelper(ref ident)
1666            | SqlFunctionAttribute::SqlName(ref ident, ..) => ident.span(),
1667            SqlFunctionAttribute::Restriction(BackendRestriction::None) => {
1668                unreachable!("We do not construct that")
1669            }
1670            SqlFunctionAttribute::Other(ref attribute) => attribute.span(),
1671        }
1672    }
1673}
1674
1675fn parse_require_order(input: &syn::parse::ParseBuffer<'_>) -> Result<bool> {
1676    let ident = input.parse::<Ident>()?;
1677    if ident == "require_order" {
1678        let _ = input.parse::<Token![=]>()?;
1679        let value = input.parse::<LitBool>()?;
1680        Ok(value.value)
1681    } else {
1682        Err(syn::Error::new(
1683            ident.span(),
1684            format!("Expected `require_order` but got `{ident}`"),
1685        ))
1686    }
1687}
1688
1689impl SqlFunctionAttribute {
1690    fn parse_attr(
1691        name: Ident,
1692        input: &syn::parse::ParseBuffer<'_>,
1693        attr: Attribute,
1694        attribute_span: proc_macro2::Span,
1695    ) -> Result<Option<AttributeSpanWrapper<Self>>> {
1696        // rustc doesn't resolve cfg attrs for us :(
1697        // This is hacky, but mostly for internal use
1698        if name == "cfg_attr" {
1699            let ident = input.parse::<Ident>()?;
1700            if ident != "feature" {
1701                return Err(syn::Error::new(
1702                    ident.span(),
1703                    format!(
1704                        "only single feature `cfg_attr` attributes are supported. \
1705                             Got `{ident}` but expected `feature = \"foo\"`"
1706                    ),
1707                ));
1708            }
1709            let _ = input.parse::<Token![=]>()?;
1710            let feature = input.parse::<LitStr>()?;
1711            let feature_value = feature.value();
1712            let _ = input.parse::<Token![,]>()?;
1713            let ignore = match feature_value.as_str() {
1714                "postgres_backend" => !cfg!(feature = "postgres"),
1715                "sqlite" => !cfg!(feature = "sqlite"),
1716                "mysql_backend" => !cfg!(feature = "mysql"),
1717                feature => {
1718                    return Err(syn::Error::new(
1719                        feature.span(),
1720                        format!(
1721                            "only `mysql_backend`, `postgres_backend` and `sqlite` \
1722                                 are supported features, but got `{feature}`"
1723                        ),
1724                    ));
1725                }
1726            };
1727            let name = input.parse::<Ident>()?;
1728            let inner;
1729            let _paren = parenthesized!(inner in input);
1730            let ret = SqlFunctionAttribute::parse_attr(name, &inner, attr, attribute_span)?;
1731            if ignore {
1732                Ok(None)
1733            } else {
1734                Ok(ret)
1735            }
1736        } else {
1737            let name_str = name.to_string();
1738            let parsed_attr = match &*name_str {
1739                "window" => {
1740                    let restrictions = if BackendRestriction::parse_from(&input.fork()).is_ok() {
1741                        BackendRestriction::parse_from(input).map(Ok).ok()
1742                    } else {
1743                        None
1744                    };
1745                    if input.fork().parse::<Token![,]>().is_ok() {
1746                        let _ = input.parse::<Token![,]>()?;
1747                    }
1748                    let require_order = if parse_require_order(&input.fork()).is_ok() {
1749                        Some(parse_require_order(input)?)
1750                    } else {
1751                        None
1752                    };
1753                    if input.fork().parse::<Token![,]>().is_ok() {
1754                        let _ = input.parse::<Token![,]>()?;
1755                    }
1756                    let restrictions =
1757                        restrictions.unwrap_or_else(|| BackendRestriction::parse_from(input))?;
1758                    Self::Window {
1759                        ident: name,
1760                        restrictions,
1761                        require_order,
1762                    }
1763                }
1764                "sql_name" => {
1765                    parse_eq(input, "sql_name = \"SUM\"").map(|v| Self::SqlName(name, v))?
1766                }
1767                "backends" => {
1768                    BackendRestriction::parse_backends(input, name).map(Self::Restriction)?
1769                }
1770                "dialect" => {
1771                    BackendRestriction::parse_sql_dialect(input, name).map(Self::Restriction)?
1772                }
1773                "backend_bounds" => {
1774                    BackendRestriction::parse_backend_bounds(input, name).map(Self::Restriction)?
1775                }
1776                "variadic" => Self::Variadic(name, input.parse()?),
1777                _ => {
1778                    // empty the parse buffer otherwise syn will return an error
1779                    let _ = input.step(|cursor| {
1780                        let mut rest = *cursor;
1781                        while let Some((_, next)) = rest.token_tree() {
1782                            rest = next;
1783                        }
1784                        Ok(((), rest))
1785                    });
1786                    SqlFunctionAttribute::Other(attr)
1787                }
1788            };
1789            Ok(Some(AttributeSpanWrapper {
1790                ident_span: parsed_attr.span(),
1791                item: parsed_attr,
1792                attribute_span,
1793            }))
1794        }
1795    }
1796}
1797
1798#[derive(Default)]
1799pub(crate) struct DeclareSqlFunctionArgs {
1800    pub(crate) generate_return_type_helpers: bool,
1801}
1802
1803impl DeclareSqlFunctionArgs {
1804    pub(crate) fn parse_from_macro_input(input: TokenStream) -> syn::Result<Self> {
1805        if input.is_empty() {
1806            return Ok(Self::default());
1807        }
1808        let input_span = input.span();
1809        let parsed: syn::MetaNameValue = syn::parse2(input).map_err(|e| {
1810            let span = e.span();
1811            syn::Error::new(
1812                span,
1813                format!("{e}, the correct format is `generate_return_type_helpers = true/false`"),
1814            )
1815        })?;
1816        match parsed {
1817            syn::MetaNameValue {
1818                path,
1819                value:
1820                    syn::Expr::Lit(syn::ExprLit {
1821                        lit: syn::Lit::Bool(b),
1822                        ..
1823                    }),
1824                ..
1825            } if path.is_ident("generate_return_type_helpers") => Ok(Self {
1826                generate_return_type_helpers: b.value,
1827            }),
1828            _ => Err(syn::Error::new(input_span, "Invalid config")),
1829        }
1830    }
1831}