diesel_derives/
sql_function.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4use syn::parse::{Parse, ParseStream, Result};
5use syn::punctuated::Punctuated;
6use syn::{
7    parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
8    PathArguments, Token, Type,
9};
10
11pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool) -> TokenStream {
12    let SqlFunctionDecl {
13        mut attributes,
14        fn_token,
15        fn_name,
16        mut generics,
17        args,
18        return_type,
19    } = input;
20
21    let sql_name = attributes
22        .iter()
23        .find(|attr| attr.meta.path().is_ident("sql_name"))
24        .and_then(|attr| {
25            if let Meta::NameValue(MetaNameValue {
26                value:
27                    syn::Expr::Lit(syn::ExprLit {
28                        lit: syn::Lit::Str(ref lit),
29                        ..
30                    }),
31                ..
32            }) = attr.meta
33            {
34                Some(lit.value())
35            } else {
36                None
37            }
38        })
39        .unwrap_or_else(|| fn_name.to_string());
40
41    let is_aggregate = attributes
42        .iter()
43        .any(|attr| attr.meta.path().is_ident("aggregate"));
44
45    attributes.retain(|attr| {
46        !attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
47    });
48
49    let args = &args;
50    let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
51        .iter()
52        .map(|StrictFnArg { name, ty, .. }| (name, ty))
53        .unzip();
54    let arg_struct_assign = args.iter().map(
55        |StrictFnArg {
56             name, colon_token, ..
57         }| {
58            let name2 = name.clone();
59            quote!(#name #colon_token #name2.as_expression())
60        },
61    );
62
63    let type_args = &generics
64        .type_params()
65        .map(|type_param| type_param.ident.clone())
66        .collect::<Vec<_>>();
67
68    for StrictFnArg { name, .. } in args {
69        generics.params.push(parse_quote!(#name));
70    }
71
72    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
73    // Even if we force an empty where clause, it still won't print the where
74    // token with no bounds.
75    let where_clause = where_clause
76        .map(|w| quote!(#w))
77        .unwrap_or_else(|| quote!(where));
78
79    let mut generics_with_internal = generics.clone();
80    generics_with_internal
81        .params
82        .push(parse_quote!(__DieselInternal));
83    let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
84
85    let sql_type;
86    let numeric_derive;
87
88    if arg_name.is_empty() {
89        sql_type = None;
90        // FIXME: We can always derive once trivial bounds are stable
91        numeric_derive = None;
92    } else {
93        sql_type = Some(quote!((#(#arg_name),*): Expression,));
94        numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
95    }
96
97    let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
98
99    let args_iter = args.iter();
100    let mut tokens = quote! {
101        use diesel::{self, QueryResult};
102        use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
103        use diesel::query_builder::{QueryFragment, AstPass};
104        use diesel::sql_types::*;
105        use super::*;
106
107        #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
108        #numeric_derive
109        pub struct #fn_name #ty_generics {
110            #(pub(in super) #args_iter,)*
111            #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
112        }
113
114        #[doc = #helper_type_doc]
115        pub type HelperType #ty_generics = #fn_name <
116            #(#type_args,)*
117            #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
118        >;
119
120        impl #impl_generics Expression for #fn_name #ty_generics
121        #where_clause
122            #sql_type
123        {
124            type SqlType = #return_type;
125        }
126
127        // __DieselInternal is what we call QS normally
128        impl #impl_generics_internal SelectableExpression<__DieselInternal>
129            for #fn_name #ty_generics
130        #where_clause
131            #(#arg_name: SelectableExpression<__DieselInternal>,)*
132            Self: AppearsOnTable<__DieselInternal>,
133        {
134        }
135
136        // __DieselInternal is what we call QS normally
137        impl #impl_generics_internal AppearsOnTable<__DieselInternal>
138            for #fn_name #ty_generics
139        #where_clause
140            #(#arg_name: AppearsOnTable<__DieselInternal>,)*
141            Self: Expression,
142        {
143        }
144
145        // __DieselInternal is what we call DB normally
146        impl #impl_generics_internal QueryFragment<__DieselInternal>
147            for #fn_name #ty_generics
148        where
149            __DieselInternal: diesel::backend::Backend,
150            #(#arg_name: QueryFragment<__DieselInternal>,)*
151        {
152            #[allow(unused_assignments)]
153            fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
154                out.push_sql(concat!(#sql_name, "("));
155                // we unroll the arguments manually here, to prevent borrow check issues
156                let mut needs_comma = false;
157                #(
158                    if !self.#arg_name.is_noop(out.backend())? {
159                        if needs_comma {
160                            out.push_sql(", ");
161                        }
162                        self.#arg_name.walk_ast(out.reborrow())?;
163                        needs_comma = true;
164                    }
165                )*
166                out.push_sql(")");
167                Ok(())
168            }
169        }
170    };
171
172    let is_supported_on_sqlite = cfg!(feature = "sqlite")
173        && type_args.is_empty()
174        && is_sqlite_type(&return_type)
175        && arg_type.iter().all(|a| is_sqlite_type(a));
176
177    if is_aggregate {
178        tokens = quote! {
179            #tokens
180
181            impl #impl_generics_internal ValidGrouping<__DieselInternal>
182                for #fn_name #ty_generics
183            {
184                type IsAggregate = diesel::expression::is_aggregate::Yes;
185            }
186        };
187        if is_supported_on_sqlite {
188            tokens = quote! {
189                #tokens
190
191                use diesel::sqlite::{Sqlite, SqliteConnection};
192                use diesel::serialize::ToSql;
193                use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
194                use diesel::sqlite::SqliteAggregateFunction;
195                use diesel::sql_types::IntoNullable;
196            };
197
198            match arg_name.len() {
199                x if x > 1 => {
200                    tokens = quote! {
201                        #tokens
202
203                        #[allow(dead_code)]
204                        /// Registers an implementation for this aggregate function on the given connection
205                        ///
206                        /// This function must be called for every `SqliteConnection` before
207                        /// this SQL function can be used on SQLite. The implementation must be
208                        /// deterministic (returns the same result given the same arguments).
209                        pub fn register_impl<A, #(#arg_name,)*>(
210                            conn: &mut SqliteConnection
211                        ) -> QueryResult<()>
212                            where
213                            A: SqliteAggregateFunction<(#(#arg_name,)*)>
214                                + Send
215                                + 'static
216                                + ::std::panic::UnwindSafe
217                                + ::std::panic::RefUnwindSafe,
218                            A::Output: ToSql<#return_type, Sqlite>,
219                            (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
220                                StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
221                                ::std::panic::UnwindSafe,
222                        {
223                            conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
224                        }
225                    };
226                }
227                1 => {
228                    let arg_name = arg_name[0];
229                    let arg_type = arg_type[0];
230
231                    tokens = quote! {
232                        #tokens
233
234                        #[allow(dead_code)]
235                        /// Registers an implementation for this aggregate function on the given connection
236                        ///
237                        /// This function must be called for every `SqliteConnection` before
238                        /// this SQL function can be used on SQLite. The implementation must be
239                        /// deterministic (returns the same result given the same arguments).
240                        pub fn register_impl<A, #arg_name>(
241                            conn: &mut SqliteConnection
242                        ) -> QueryResult<()>
243                            where
244                            A: SqliteAggregateFunction<#arg_name>
245                                + Send
246                                + 'static
247                                + std::panic::UnwindSafe
248                                + std::panic::RefUnwindSafe,
249                            A::Output: ToSql<#return_type, Sqlite>,
250                            #arg_name: FromSqlRow<#arg_type, Sqlite> +
251                                StaticallySizedRow<#arg_type, Sqlite> +
252                                ::std::panic::UnwindSafe,
253                            {
254                                conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
255                            }
256                    };
257                }
258                _ => (),
259            }
260        }
261    } else {
262        tokens = quote! {
263            #tokens
264
265            #[derive(ValidGrouping)]
266            pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
267
268            impl #impl_generics_internal ValidGrouping<__DieselInternal>
269                for #fn_name #ty_generics
270            where
271                __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
272            {
273                type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
274            }
275        };
276
277        if is_supported_on_sqlite && !arg_name.is_empty() {
278            tokens = quote! {
279                #tokens
280
281                use diesel::sqlite::{Sqlite, SqliteConnection};
282                use diesel::serialize::ToSql;
283                use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
284
285                #[allow(dead_code)]
286                /// Registers an implementation for this function on the given connection
287                ///
288                /// This function must be called for every `SqliteConnection` before
289                /// this SQL function can be used on SQLite. The implementation must be
290                /// deterministic (returns the same result given the same arguments). If
291                /// the function is nondeterministic, call
292                /// `register_nondeterministic_impl` instead.
293                pub fn register_impl<F, Ret, #(#arg_name,)*>(
294                    conn: &mut SqliteConnection,
295                    f: F,
296                ) -> QueryResult<()>
297                where
298                    F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
299                    (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
300                        StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
301                    Ret: ToSql<#return_type, Sqlite>,
302                {
303                    conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
304                        #sql_name,
305                        true,
306                        move |(#(#arg_name,)*)| f(#(#arg_name,)*),
307                    )
308                }
309
310                #[allow(dead_code)]
311                /// Registers an implementation for this function on the given connection
312                ///
313                /// This function must be called for every `SqliteConnection` before
314                /// this SQL function can be used on SQLite.
315                /// `register_nondeterministic_impl` should only be used if your
316                /// function can return different results with the same arguments (e.g.
317                /// `random`). If your function is deterministic, you should call
318                /// `register_impl` instead.
319                pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
320                    conn: &mut SqliteConnection,
321                    mut f: F,
322                ) -> QueryResult<()>
323                where
324                    F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
325                    (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
326                        StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
327                    Ret: ToSql<#return_type, Sqlite>,
328                {
329                    conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
330                        #sql_name,
331                        false,
332                        move |(#(#arg_name,)*)| f(#(#arg_name,)*),
333                    )
334                }
335            };
336        }
337
338        if is_supported_on_sqlite && arg_name.is_empty() {
339            tokens = quote! {
340                #tokens
341
342                use diesel::sqlite::{Sqlite, SqliteConnection};
343                use diesel::serialize::ToSql;
344
345                #[allow(dead_code)]
346                /// Registers an implementation for this function on the given connection
347                ///
348                /// This function must be called for every `SqliteConnection` before
349                /// this SQL function can be used on SQLite. The implementation must be
350                /// deterministic (returns the same result given the same arguments). If
351                /// the function is nondeterministic, call
352                /// `register_nondeterministic_impl` instead.
353                pub fn register_impl<F, Ret>(
354                    conn: &SqliteConnection,
355                    f: F,
356                ) -> QueryResult<()>
357                where
358                    F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
359                    Ret: ToSql<#return_type, Sqlite>,
360                {
361                    conn.register_noarg_sql_function::<#return_type, _, _>(
362                        #sql_name,
363                        true,
364                        f,
365                    )
366                }
367
368                #[allow(dead_code)]
369                /// Registers an implementation for this function on the given connection
370                ///
371                /// This function must be called for every `SqliteConnection` before
372                /// this SQL function can be used on SQLite.
373                /// `register_nondeterministic_impl` should only be used if your
374                /// function can return different results with the same arguments (e.g.
375                /// `random`). If your function is deterministic, you should call
376                /// `register_impl` instead.
377                pub fn register_nondeterministic_impl<F, Ret>(
378                    conn: &SqliteConnection,
379                    mut f: F,
380                ) -> QueryResult<()>
381                where
382                    F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
383                    Ret: ToSql<#return_type, Sqlite>,
384                {
385                    conn.register_noarg_sql_function::<#return_type, _, _>(
386                        #sql_name,
387                        false,
388                        f,
389                    )
390                }
391            };
392        }
393    }
394
395    let args_iter = args.iter();
396
397    let (outside_of_module_helper_type, return_type_path, internals_module_name) =
398        if legacy_helper_type_and_module {
399            (None, quote! { #fn_name::HelperType }, fn_name.clone())
400        } else {
401            let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
402            (
403                Some(quote! {
404                    #[allow(non_camel_case_types, non_snake_case)]
405                    #[doc = #helper_type_doc]
406                    pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
407                        #(#type_args,)*
408                        #(<#arg_name as ::diesel::expression::AsExpression<#arg_type>>::Expression,)*
409                    >;
410                }),
411                quote! { #fn_name },
412                internals_module_name,
413            )
414        };
415
416    quote! {
417        #(#attributes)*
418        #[allow(non_camel_case_types)]
419        pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
420            -> #return_type_path #ty_generics
421        #where_clause
422            #(#arg_name: ::diesel::expression::AsExpression<#arg_type>,)*
423        {
424            #internals_module_name::#fn_name {
425                #(#arg_struct_assign,)*
426                #(#type_args: ::std::marker::PhantomData,)*
427            }
428        }
429
430        #outside_of_module_helper_type
431
432        #[doc(hidden)]
433        #[allow(non_camel_case_types, non_snake_case, unused_imports)]
434        pub(crate) mod #internals_module_name {
435            #tokens
436        }
437    }
438}
439
440pub(crate) struct SqlFunctionDecl {
441    attributes: Vec<Attribute>,
442    fn_token: Token![fn],
443    fn_name: Ident,
444    generics: Generics,
445    args: Punctuated<StrictFnArg, Token![,]>,
446    return_type: Type,
447}
448
449impl Parse for SqlFunctionDecl {
450    fn parse(input: ParseStream) -> Result<Self> {
451        let attributes = Attribute::parse_outer(input)?;
452        let fn_token: Token![fn] = input.parse()?;
453        let fn_name = Ident::parse(input)?;
454        let generics = Generics::parse(input)?;
455        let args;
456        let _paren = parenthesized!(args in input);
457        let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
458        let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
459            Type::parse(input)?
460        } else {
461            parse_quote!(diesel::expression::expression_types::NotSelectable)
462        };
463        let _semi = Option::<Token![;]>::parse(input)?;
464
465        Ok(Self {
466            attributes,
467            fn_token,
468            fn_name,
469            generics,
470            args,
471            return_type,
472        })
473    }
474}
475
476/// Essentially the same as ArgCaptured, but only allowing ident patterns
477struct StrictFnArg {
478    name: Ident,
479    colon_token: Token![:],
480    ty: Type,
481}
482
483impl Parse for StrictFnArg {
484    fn parse(input: ParseStream) -> Result<Self> {
485        let name = input.parse()?;
486        let colon_token = input.parse()?;
487        let ty = input.parse()?;
488        Ok(Self {
489            name,
490            colon_token,
491            ty,
492        })
493    }
494}
495
496impl ToTokens for StrictFnArg {
497    fn to_tokens(&self, tokens: &mut TokenStream) {
498        self.name.to_tokens(tokens);
499        self.colon_token.to_tokens(tokens);
500        self.name.to_tokens(tokens);
501    }
502}
503
504fn is_sqlite_type(ty: &Type) -> bool {
505    let last_segment = if let Type::Path(tp) = ty {
506        if let Some(segment) = tp.path.segments.last() {
507            segment
508        } else {
509            return false;
510        }
511    } else {
512        return false;
513    };
514
515    let ident = last_segment.ident.to_string();
516    if ident == "Nullable" {
517        if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
518            if let Some(GenericArgument::Type(ty)) = ab.args.first() {
519                return is_sqlite_type(ty);
520            }
521        }
522        return false;
523    }
524
525    [
526        "BigInt",
527        "Binary",
528        "Bool",
529        "Date",
530        "Double",
531        "Float",
532        "Integer",
533        "Numeric",
534        "SmallInt",
535        "Text",
536        "Time",
537        "Timestamp",
538    ]
539    .contains(&ident.as_str())
540}