diesel_derives/
sql_function.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4use syn::parse::{Parse, ParseStream, Result};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{
8    parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
9    PathArguments, Token, Type,
10};
11
12pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool) -> TokenStream {
13    let SqlFunctionDecl {
14        mut attributes,
15        fn_token,
16        fn_name,
17        mut generics,
18        args,
19        return_type,
20    } = input;
21
22    let sql_name = attributes
23        .iter()
24        .find(|attr| attr.meta.path().is_ident("sql_name"))
25        .and_then(|attr| {
26            if let Meta::NameValue(MetaNameValue {
27                value:
28                    syn::Expr::Lit(syn::ExprLit {
29                        lit: syn::Lit::Str(ref lit),
30                        ..
31                    }),
32                ..
33            }) = attr.meta
34            {
35                Some(lit.value())
36            } else {
37                None
38            }
39        })
40        .unwrap_or_else(|| fn_name.to_string());
41
42    let is_aggregate = attributes
43        .iter()
44        .any(|attr| attr.meta.path().is_ident("aggregate"));
45
46    attributes.retain(|attr| {
47        !attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
48    });
49
50    let args = &args;
51    let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
52        .iter()
53        .map(|StrictFnArg { name, ty, .. }| (name, ty))
54        .unzip();
55    let arg_struct_assign = args.iter().map(
56        |StrictFnArg {
57             name, colon_token, ..
58         }| {
59            let name2 = name.clone();
60            quote!(#name #colon_token #name2.as_expression())
61        },
62    );
63
64    let type_args = &generics
65        .type_params()
66        .map(|type_param| type_param.ident.clone())
67        .collect::<Vec<_>>();
68
69    for StrictFnArg { name, .. } in args {
70        generics.params.push(parse_quote!(#name));
71    }
72
73    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
74    // Even if we force an empty where clause, it still won't print the where
75    // token with no bounds.
76    let where_clause = where_clause
77        .map(|w| quote!(#w))
78        .unwrap_or_else(|| quote!(where));
79
80    let mut generics_with_internal = generics.clone();
81    generics_with_internal
82        .params
83        .push(parse_quote!(__DieselInternal));
84    let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
85
86    let sql_type;
87    let numeric_derive;
88
89    if arg_name.is_empty() {
90        sql_type = None;
91        // FIXME: We can always derive once trivial bounds are stable
92        numeric_derive = None;
93    } else {
94        sql_type = Some(quote!((#(#arg_name),*): Expression,));
95        numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
96    }
97
98    let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
99
100    let args_iter = args.iter();
101    let mut tokens = quote! {
102        use diesel::{self, QueryResult};
103        use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
104        use diesel::query_builder::{QueryFragment, AstPass};
105        use diesel::sql_types::*;
106        use super::*;
107
108        #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
109        #numeric_derive
110        pub struct #fn_name #ty_generics {
111            #(pub(in super) #args_iter,)*
112            #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
113        }
114
115        #[doc = #helper_type_doc]
116        pub type HelperType #ty_generics = #fn_name <
117            #(#type_args,)*
118            #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
119        >;
120
121        impl #impl_generics Expression for #fn_name #ty_generics
122        #where_clause
123            #sql_type
124        {
125            type SqlType = #return_type;
126        }
127
128        // __DieselInternal is what we call QS normally
129        impl #impl_generics_internal SelectableExpression<__DieselInternal>
130            for #fn_name #ty_generics
131        #where_clause
132            #(#arg_name: SelectableExpression<__DieselInternal>,)*
133            Self: AppearsOnTable<__DieselInternal>,
134        {
135        }
136
137        // __DieselInternal is what we call QS normally
138        impl #impl_generics_internal AppearsOnTable<__DieselInternal>
139            for #fn_name #ty_generics
140        #where_clause
141            #(#arg_name: AppearsOnTable<__DieselInternal>,)*
142            Self: Expression,
143        {
144        }
145
146        // __DieselInternal is what we call DB normally
147        impl #impl_generics_internal QueryFragment<__DieselInternal>
148            for #fn_name #ty_generics
149        where
150            __DieselInternal: diesel::backend::Backend,
151            #(#arg_name: QueryFragment<__DieselInternal>,)*
152        {
153            #[allow(unused_assignments)]
154            fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
155                out.push_sql(concat!(#sql_name, "("));
156                // we unroll the arguments manually here, to prevent borrow check issues
157                let mut needs_comma = false;
158                #(
159                    if !self.#arg_name.is_noop(out.backend())? {
160                        if needs_comma {
161                            out.push_sql(", ");
162                        }
163                        self.#arg_name.walk_ast(out.reborrow())?;
164                        needs_comma = true;
165                    }
166                )*
167                out.push_sql(")");
168                Ok(())
169            }
170        }
171    };
172
173    let is_supported_on_sqlite = cfg!(feature = "sqlite")
174        && type_args.is_empty()
175        && is_sqlite_type(&return_type)
176        && arg_type.iter().all(|a| is_sqlite_type(a));
177
178    if is_aggregate {
179        tokens = quote! {
180            #tokens
181
182            impl #impl_generics_internal ValidGrouping<__DieselInternal>
183                for #fn_name #ty_generics
184            {
185                type IsAggregate = diesel::expression::is_aggregate::Yes;
186            }
187        };
188        if is_supported_on_sqlite {
189            tokens = quote! {
190                #tokens
191
192                use diesel::sqlite::{Sqlite, SqliteConnection};
193                use diesel::serialize::ToSql;
194                use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
195                use diesel::sqlite::SqliteAggregateFunction;
196                use diesel::sql_types::IntoNullable;
197            };
198
199            match arg_name.len() {
200                x if x > 1 => {
201                    tokens = quote! {
202                        #tokens
203
204                        #[allow(dead_code)]
205                        /// Registers an implementation for this aggregate function on the given connection
206                        ///
207                        /// This function must be called for every `SqliteConnection` before
208                        /// this SQL function can be used on SQLite. The implementation must be
209                        /// deterministic (returns the same result given the same arguments).
210                        pub fn register_impl<A, #(#arg_name,)*>(
211                            conn: &mut SqliteConnection
212                        ) -> QueryResult<()>
213                            where
214                            A: SqliteAggregateFunction<(#(#arg_name,)*)>
215                                + Send
216                                + 'static
217                                + ::std::panic::UnwindSafe
218                                + ::std::panic::RefUnwindSafe,
219                            A::Output: ToSql<#return_type, Sqlite>,
220                            (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
221                                StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
222                                ::std::panic::UnwindSafe,
223                        {
224                            conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
225                        }
226                    };
227                }
228                1 => {
229                    let arg_name = arg_name[0];
230                    let arg_type = arg_type[0];
231
232                    tokens = quote! {
233                        #tokens
234
235                        #[allow(dead_code)]
236                        /// Registers an implementation for this aggregate function on the given connection
237                        ///
238                        /// This function must be called for every `SqliteConnection` before
239                        /// this SQL function can be used on SQLite. The implementation must be
240                        /// deterministic (returns the same result given the same arguments).
241                        pub fn register_impl<A, #arg_name>(
242                            conn: &mut SqliteConnection
243                        ) -> QueryResult<()>
244                            where
245                            A: SqliteAggregateFunction<#arg_name>
246                                + Send
247                                + 'static
248                                + std::panic::UnwindSafe
249                                + std::panic::RefUnwindSafe,
250                            A::Output: ToSql<#return_type, Sqlite>,
251                            #arg_name: FromSqlRow<#arg_type, Sqlite> +
252                                StaticallySizedRow<#arg_type, Sqlite> +
253                                ::std::panic::UnwindSafe,
254                            {
255                                conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
256                            }
257                    };
258                }
259                _ => (),
260            }
261        }
262    } else {
263        tokens = quote! {
264            #tokens
265
266            #[derive(ValidGrouping)]
267            pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
268
269            impl #impl_generics_internal ValidGrouping<__DieselInternal>
270                for #fn_name #ty_generics
271            where
272                __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
273            {
274                type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
275            }
276        };
277
278        if is_supported_on_sqlite && !arg_name.is_empty() {
279            tokens = quote! {
280                #tokens
281
282                use diesel::sqlite::{Sqlite, SqliteConnection};
283                use diesel::serialize::ToSql;
284                use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
285
286                #[allow(dead_code)]
287                /// Registers an implementation for this function on the given connection
288                ///
289                /// This function must be called for every `SqliteConnection` before
290                /// this SQL function can be used on SQLite. The implementation must be
291                /// deterministic (returns the same result given the same arguments). If
292                /// the function is nondeterministic, call
293                /// `register_nondeterministic_impl` instead.
294                pub fn register_impl<F, Ret, #(#arg_name,)*>(
295                    conn: &mut SqliteConnection,
296                    f: F,
297                ) -> QueryResult<()>
298                where
299                    F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
300                    (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
301                        StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
302                    Ret: ToSql<#return_type, Sqlite>,
303                {
304                    conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
305                        #sql_name,
306                        true,
307                        move |(#(#arg_name,)*)| f(#(#arg_name,)*),
308                    )
309                }
310
311                #[allow(dead_code)]
312                /// Registers an implementation for this function on the given connection
313                ///
314                /// This function must be called for every `SqliteConnection` before
315                /// this SQL function can be used on SQLite.
316                /// `register_nondeterministic_impl` should only be used if your
317                /// function can return different results with the same arguments (e.g.
318                /// `random`). If your function is deterministic, you should call
319                /// `register_impl` instead.
320                pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
321                    conn: &mut SqliteConnection,
322                    mut f: F,
323                ) -> QueryResult<()>
324                where
325                    F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
326                    (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
327                        StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
328                    Ret: ToSql<#return_type, Sqlite>,
329                {
330                    conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
331                        #sql_name,
332                        false,
333                        move |(#(#arg_name,)*)| f(#(#arg_name,)*),
334                    )
335                }
336            };
337        }
338
339        if is_supported_on_sqlite && arg_name.is_empty() {
340            tokens = quote! {
341                #tokens
342
343                use diesel::sqlite::{Sqlite, SqliteConnection};
344                use diesel::serialize::ToSql;
345
346                #[allow(dead_code)]
347                /// Registers an implementation for this function on the given connection
348                ///
349                /// This function must be called for every `SqliteConnection` before
350                /// this SQL function can be used on SQLite. The implementation must be
351                /// deterministic (returns the same result given the same arguments). If
352                /// the function is nondeterministic, call
353                /// `register_nondeterministic_impl` instead.
354                pub fn register_impl<F, Ret>(
355                    conn: &SqliteConnection,
356                    f: F,
357                ) -> QueryResult<()>
358                where
359                    F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
360                    Ret: ToSql<#return_type, Sqlite>,
361                {
362                    conn.register_noarg_sql_function::<#return_type, _, _>(
363                        #sql_name,
364                        true,
365                        f,
366                    )
367                }
368
369                #[allow(dead_code)]
370                /// Registers an implementation for this function on the given connection
371                ///
372                /// This function must be called for every `SqliteConnection` before
373                /// this SQL function can be used on SQLite.
374                /// `register_nondeterministic_impl` should only be used if your
375                /// function can return different results with the same arguments (e.g.
376                /// `random`). If your function is deterministic, you should call
377                /// `register_impl` instead.
378                pub fn register_nondeterministic_impl<F, Ret>(
379                    conn: &SqliteConnection,
380                    mut f: F,
381                ) -> QueryResult<()>
382                where
383                    F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
384                    Ret: ToSql<#return_type, Sqlite>,
385                {
386                    conn.register_noarg_sql_function::<#return_type, _, _>(
387                        #sql_name,
388                        false,
389                        f,
390                    )
391                }
392            };
393        }
394    }
395
396    let args_iter = args.iter();
397
398    let (outside_of_module_helper_type, return_type_path, internals_module_name) =
399        if legacy_helper_type_and_module {
400            (None, quote! { #fn_name::HelperType }, fn_name.clone())
401        } else {
402            let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
403            (
404                Some(quote! {
405                    #[allow(non_camel_case_types, non_snake_case)]
406                    #[doc = #helper_type_doc]
407                    pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
408                        #(#type_args,)*
409                        #(<#arg_name as diesel::expression::AsExpression<#arg_type>>::Expression,)*
410                    >;
411                }),
412                quote! { #fn_name },
413                internals_module_name,
414            )
415        };
416
417    quote! {
418        #(#attributes)*
419        #[allow(non_camel_case_types)]
420        pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
421            -> #return_type_path #ty_generics
422        #where_clause
423            #(#arg_name: diesel::expression::AsExpression<#arg_type>,)*
424        {
425            #internals_module_name::#fn_name {
426                #(#arg_struct_assign,)*
427                #(#type_args: ::std::marker::PhantomData,)*
428            }
429        }
430
431        #outside_of_module_helper_type
432
433        #[doc(hidden)]
434        #[allow(non_camel_case_types, non_snake_case, unused_imports)]
435        pub(crate) mod #internals_module_name {
436            #tokens
437        }
438    }
439}
440
441pub(crate) struct ExternSqlBlock {
442    pub(crate) function_decls: Vec<SqlFunctionDecl>,
443}
444
445impl Parse for ExternSqlBlock {
446    fn parse(input: ParseStream) -> Result<Self> {
447        let block = syn::ItemForeignMod::parse(input)?;
448        if block.abi.name.as_ref().map(|n| n.value()) != Some("SQL".into()) {
449            return Err(syn::Error::new(block.abi.span(), "expect `SQL` as ABI"));
450        }
451        if block.unsafety.is_some() {
452            return Err(syn::Error::new(
453                block.unsafety.unwrap().span(),
454                "expect `SQL` function blocks to be safe",
455            ));
456        }
457        let function_decls = block
458            .items
459            .into_iter()
460            .map(|i| syn::parse2(quote! { #i }))
461            .collect::<Result<Vec<_>>>()?;
462
463        Ok(ExternSqlBlock { function_decls })
464    }
465}
466
467pub(crate) struct SqlFunctionDecl {
468    attributes: Vec<Attribute>,
469    fn_token: Token![fn],
470    fn_name: Ident,
471    generics: Generics,
472    args: Punctuated<StrictFnArg, Token![,]>,
473    return_type: Type,
474}
475
476impl Parse for SqlFunctionDecl {
477    fn parse(input: ParseStream) -> Result<Self> {
478        let attributes = Attribute::parse_outer(input)?;
479        let fn_token: Token![fn] = input.parse()?;
480        let fn_name = Ident::parse(input)?;
481        let generics = Generics::parse(input)?;
482        let args;
483        let _paren = parenthesized!(args in input);
484        let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
485        let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
486            Type::parse(input)?
487        } else {
488            parse_quote!(diesel::expression::expression_types::NotSelectable)
489        };
490        let _semi = Option::<Token![;]>::parse(input)?;
491
492        Ok(Self {
493            attributes,
494            fn_token,
495            fn_name,
496            generics,
497            args,
498            return_type,
499        })
500    }
501}
502
503/// Essentially the same as ArgCaptured, but only allowing ident patterns
504struct StrictFnArg {
505    name: Ident,
506    colon_token: Token![:],
507    ty: Type,
508}
509
510impl Parse for StrictFnArg {
511    fn parse(input: ParseStream) -> Result<Self> {
512        let name = input.parse()?;
513        let colon_token = input.parse()?;
514        let ty = input.parse()?;
515        Ok(Self {
516            name,
517            colon_token,
518            ty,
519        })
520    }
521}
522
523impl ToTokens for StrictFnArg {
524    fn to_tokens(&self, tokens: &mut TokenStream) {
525        self.name.to_tokens(tokens);
526        self.colon_token.to_tokens(tokens);
527        self.name.to_tokens(tokens);
528    }
529}
530
531fn is_sqlite_type(ty: &Type) -> bool {
532    let last_segment = if let Type::Path(tp) = ty {
533        if let Some(segment) = tp.path.segments.last() {
534            segment
535        } else {
536            return false;
537        }
538    } else {
539        return false;
540    };
541
542    let ident = last_segment.ident.to_string();
543    if ident == "Nullable" {
544        if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
545            if let Some(GenericArgument::Type(ty)) = ab.args.first() {
546                return is_sqlite_type(ty);
547            }
548        }
549        return false;
550    }
551
552    [
553        "BigInt",
554        "Binary",
555        "Bool",
556        "Date",
557        "Double",
558        "Float",
559        "Integer",
560        "Numeric",
561        "SmallInt",
562        "Text",
563        "Time",
564        "Timestamp",
565    ]
566    .contains(&ident.as_str())
567}