1use proc_macro2::TokenStream;
2use quote::quote;
3use quote::ToTokens;
4use syn::parse::{Parse, ParseStream, Result};
5use syn::punctuated::Punctuated;
6use syn::spanned::Spanned;
7use syn::{
8 parenthesized, parse_quote, Attribute, GenericArgument, Generics, Ident, Meta, MetaNameValue,
9 PathArguments, Token, Type,
10};
11
12pub(crate) fn expand(input: SqlFunctionDecl, legacy_helper_type_and_module: bool) -> TokenStream {
13 let SqlFunctionDecl {
14 mut attributes,
15 fn_token,
16 fn_name,
17 mut generics,
18 args,
19 return_type,
20 } = input;
21
22 let sql_name = attributes
23 .iter()
24 .find(|attr| attr.meta.path().is_ident("sql_name"))
25 .and_then(|attr| {
26 if let Meta::NameValue(MetaNameValue {
27 value:
28 syn::Expr::Lit(syn::ExprLit {
29 lit: syn::Lit::Str(ref lit),
30 ..
31 }),
32 ..
33 }) = attr.meta
34 {
35 Some(lit.value())
36 } else {
37 None
38 }
39 })
40 .unwrap_or_else(|| fn_name.to_string());
41
42 let is_aggregate = attributes
43 .iter()
44 .any(|attr| attr.meta.path().is_ident("aggregate"));
45
46 attributes.retain(|attr| {
47 !attr.meta.path().is_ident("sql_name") && !attr.meta.path().is_ident("aggregate")
48 });
49
50 let args = &args;
51 let (ref arg_name, ref arg_type): (Vec<_>, Vec<_>) = args
52 .iter()
53 .map(|StrictFnArg { name, ty, .. }| (name, ty))
54 .unzip();
55 let arg_struct_assign = args.iter().map(
56 |StrictFnArg {
57 name, colon_token, ..
58 }| {
59 let name2 = name.clone();
60 quote!(#name #colon_token #name2.as_expression())
61 },
62 );
63
64 let type_args = &generics
65 .type_params()
66 .map(|type_param| type_param.ident.clone())
67 .collect::<Vec<_>>();
68
69 for StrictFnArg { name, .. } in args {
70 generics.params.push(parse_quote!(#name));
71 }
72
73 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
74 let where_clause = where_clause
77 .map(|w| quote!(#w))
78 .unwrap_or_else(|| quote!(where));
79
80 let mut generics_with_internal = generics.clone();
81 generics_with_internal
82 .params
83 .push(parse_quote!(__DieselInternal));
84 let (impl_generics_internal, _, _) = generics_with_internal.split_for_impl();
85
86 let sql_type;
87 let numeric_derive;
88
89 if arg_name.is_empty() {
90 sql_type = None;
91 numeric_derive = None;
93 } else {
94 sql_type = Some(quote!((#(#arg_name),*): Expression,));
95 numeric_derive = Some(quote!(#[derive(diesel::sql_types::DieselNumericOps)]));
96 }
97
98 let helper_type_doc = format!("The return type of [`{fn_name}()`](super::fn_name)");
99
100 let args_iter = args.iter();
101 let mut tokens = quote! {
102 use diesel::{self, QueryResult};
103 use diesel::expression::{AsExpression, Expression, SelectableExpression, AppearsOnTable, ValidGrouping};
104 use diesel::query_builder::{QueryFragment, AstPass};
105 use diesel::sql_types::*;
106 use super::*;
107
108 #[derive(Debug, Clone, Copy, diesel::query_builder::QueryId)]
109 #numeric_derive
110 pub struct #fn_name #ty_generics {
111 #(pub(in super) #args_iter,)*
112 #(pub(in super) #type_args: ::std::marker::PhantomData<#type_args>,)*
113 }
114
115 #[doc = #helper_type_doc]
116 pub type HelperType #ty_generics = #fn_name <
117 #(#type_args,)*
118 #(<#arg_name as AsExpression<#arg_type>>::Expression,)*
119 >;
120
121 impl #impl_generics Expression for #fn_name #ty_generics
122 #where_clause
123 #sql_type
124 {
125 type SqlType = #return_type;
126 }
127
128 impl #impl_generics_internal SelectableExpression<__DieselInternal>
130 for #fn_name #ty_generics
131 #where_clause
132 #(#arg_name: SelectableExpression<__DieselInternal>,)*
133 Self: AppearsOnTable<__DieselInternal>,
134 {
135 }
136
137 impl #impl_generics_internal AppearsOnTable<__DieselInternal>
139 for #fn_name #ty_generics
140 #where_clause
141 #(#arg_name: AppearsOnTable<__DieselInternal>,)*
142 Self: Expression,
143 {
144 }
145
146 impl #impl_generics_internal QueryFragment<__DieselInternal>
148 for #fn_name #ty_generics
149 where
150 __DieselInternal: diesel::backend::Backend,
151 #(#arg_name: QueryFragment<__DieselInternal>,)*
152 {
153 #[allow(unused_assignments)]
154 fn walk_ast<'__b>(&'__b self, mut out: AstPass<'_, '__b, __DieselInternal>) -> QueryResult<()>{
155 out.push_sql(concat!(#sql_name, "("));
156 let mut needs_comma = false;
158 #(
159 if !self.#arg_name.is_noop(out.backend())? {
160 if needs_comma {
161 out.push_sql(", ");
162 }
163 self.#arg_name.walk_ast(out.reborrow())?;
164 needs_comma = true;
165 }
166 )*
167 out.push_sql(")");
168 Ok(())
169 }
170 }
171 };
172
173 let is_supported_on_sqlite = cfg!(feature = "sqlite")
174 && type_args.is_empty()
175 && is_sqlite_type(&return_type)
176 && arg_type.iter().all(|a| is_sqlite_type(a));
177
178 if is_aggregate {
179 tokens = quote! {
180 #tokens
181
182 impl #impl_generics_internal ValidGrouping<__DieselInternal>
183 for #fn_name #ty_generics
184 {
185 type IsAggregate = diesel::expression::is_aggregate::Yes;
186 }
187 };
188 if is_supported_on_sqlite {
189 tokens = quote! {
190 #tokens
191
192 use diesel::sqlite::{Sqlite, SqliteConnection};
193 use diesel::serialize::ToSql;
194 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
195 use diesel::sqlite::SqliteAggregateFunction;
196 use diesel::sql_types::IntoNullable;
197 };
198
199 match arg_name.len() {
200 x if x > 1 => {
201 tokens = quote! {
202 #tokens
203
204 #[allow(dead_code)]
205 pub fn register_impl<A, #(#arg_name,)*>(
211 conn: &mut SqliteConnection
212 ) -> QueryResult<()>
213 where
214 A: SqliteAggregateFunction<(#(#arg_name,)*)>
215 + Send
216 + 'static
217 + ::std::panic::UnwindSafe
218 + ::std::panic::RefUnwindSafe,
219 A::Output: ToSql<#return_type, Sqlite>,
220 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
221 StaticallySizedRow<(#(#arg_type,)*), Sqlite> +
222 ::std::panic::UnwindSafe,
223 {
224 conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name)
225 }
226 };
227 }
228 1 => {
229 let arg_name = arg_name[0];
230 let arg_type = arg_type[0];
231
232 tokens = quote! {
233 #tokens
234
235 #[allow(dead_code)]
236 pub fn register_impl<A, #arg_name>(
242 conn: &mut SqliteConnection
243 ) -> QueryResult<()>
244 where
245 A: SqliteAggregateFunction<#arg_name>
246 + Send
247 + 'static
248 + std::panic::UnwindSafe
249 + std::panic::RefUnwindSafe,
250 A::Output: ToSql<#return_type, Sqlite>,
251 #arg_name: FromSqlRow<#arg_type, Sqlite> +
252 StaticallySizedRow<#arg_type, Sqlite> +
253 ::std::panic::UnwindSafe,
254 {
255 conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name)
256 }
257 };
258 }
259 _ => (),
260 }
261 }
262 } else {
263 tokens = quote! {
264 #tokens
265
266 #[derive(ValidGrouping)]
267 pub struct __Derived<#(#arg_name,)*>(#(#arg_name,)*);
268
269 impl #impl_generics_internal ValidGrouping<__DieselInternal>
270 for #fn_name #ty_generics
271 where
272 __Derived<#(#arg_name,)*>: ValidGrouping<__DieselInternal>,
273 {
274 type IsAggregate = <__Derived<#(#arg_name,)*> as ValidGrouping<__DieselInternal>>::IsAggregate;
275 }
276 };
277
278 if is_supported_on_sqlite && !arg_name.is_empty() {
279 tokens = quote! {
280 #tokens
281
282 use diesel::sqlite::{Sqlite, SqliteConnection};
283 use diesel::serialize::ToSql;
284 use diesel::deserialize::{FromSqlRow, StaticallySizedRow};
285
286 #[allow(dead_code)]
287 pub fn register_impl<F, Ret, #(#arg_name,)*>(
295 conn: &mut SqliteConnection,
296 f: F,
297 ) -> QueryResult<()>
298 where
299 F: Fn(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
300 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
301 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
302 Ret: ToSql<#return_type, Sqlite>,
303 {
304 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
305 #sql_name,
306 true,
307 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
308 )
309 }
310
311 #[allow(dead_code)]
312 pub fn register_nondeterministic_impl<F, Ret, #(#arg_name,)*>(
321 conn: &mut SqliteConnection,
322 mut f: F,
323 ) -> QueryResult<()>
324 where
325 F: FnMut(#(#arg_name,)*) -> Ret + std::panic::UnwindSafe + Send + 'static,
326 (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> +
327 StaticallySizedRow<(#(#arg_type,)*), Sqlite>,
328 Ret: ToSql<#return_type, Sqlite>,
329 {
330 conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>(
331 #sql_name,
332 false,
333 move |(#(#arg_name,)*)| f(#(#arg_name,)*),
334 )
335 }
336 };
337 }
338
339 if is_supported_on_sqlite && arg_name.is_empty() {
340 tokens = quote! {
341 #tokens
342
343 use diesel::sqlite::{Sqlite, SqliteConnection};
344 use diesel::serialize::ToSql;
345
346 #[allow(dead_code)]
347 pub fn register_impl<F, Ret>(
355 conn: &SqliteConnection,
356 f: F,
357 ) -> QueryResult<()>
358 where
359 F: Fn() -> Ret + std::panic::UnwindSafe + Send + 'static,
360 Ret: ToSql<#return_type, Sqlite>,
361 {
362 conn.register_noarg_sql_function::<#return_type, _, _>(
363 #sql_name,
364 true,
365 f,
366 )
367 }
368
369 #[allow(dead_code)]
370 pub fn register_nondeterministic_impl<F, Ret>(
379 conn: &SqliteConnection,
380 mut f: F,
381 ) -> QueryResult<()>
382 where
383 F: FnMut() -> Ret + std::panic::UnwindSafe + Send + 'static,
384 Ret: ToSql<#return_type, Sqlite>,
385 {
386 conn.register_noarg_sql_function::<#return_type, _, _>(
387 #sql_name,
388 false,
389 f,
390 )
391 }
392 };
393 }
394 }
395
396 let args_iter = args.iter();
397
398 let (outside_of_module_helper_type, return_type_path, internals_module_name) =
399 if legacy_helper_type_and_module {
400 (None, quote! { #fn_name::HelperType }, fn_name.clone())
401 } else {
402 let internals_module_name = Ident::new(&format!("{fn_name}_utils"), fn_name.span());
403 (
404 Some(quote! {
405 #[allow(non_camel_case_types, non_snake_case)]
406 #[doc = #helper_type_doc]
407 pub type #fn_name #ty_generics = #internals_module_name::#fn_name <
408 #(#type_args,)*
409 #(<#arg_name as diesel::expression::AsExpression<#arg_type>>::Expression,)*
410 >;
411 }),
412 quote! { #fn_name },
413 internals_module_name,
414 )
415 };
416
417 quote! {
418 #(#attributes)*
419 #[allow(non_camel_case_types)]
420 pub #fn_token #fn_name #impl_generics (#(#args_iter,)*)
421 -> #return_type_path #ty_generics
422 #where_clause
423 #(#arg_name: diesel::expression::AsExpression<#arg_type>,)*
424 {
425 #internals_module_name::#fn_name {
426 #(#arg_struct_assign,)*
427 #(#type_args: ::std::marker::PhantomData,)*
428 }
429 }
430
431 #outside_of_module_helper_type
432
433 #[doc(hidden)]
434 #[allow(non_camel_case_types, non_snake_case, unused_imports)]
435 pub(crate) mod #internals_module_name {
436 #tokens
437 }
438 }
439}
440
441pub(crate) struct ExternSqlBlock {
442 pub(crate) function_decls: Vec<SqlFunctionDecl>,
443}
444
445impl Parse for ExternSqlBlock {
446 fn parse(input: ParseStream) -> Result<Self> {
447 let block = syn::ItemForeignMod::parse(input)?;
448 if block.abi.name.as_ref().map(|n| n.value()) != Some("SQL".into()) {
449 return Err(syn::Error::new(block.abi.span(), "expect `SQL` as ABI"));
450 }
451 if block.unsafety.is_some() {
452 return Err(syn::Error::new(
453 block.unsafety.unwrap().span(),
454 "expect `SQL` function blocks to be safe",
455 ));
456 }
457 let function_decls = block
458 .items
459 .into_iter()
460 .map(|i| syn::parse2(quote! { #i }))
461 .collect::<Result<Vec<_>>>()?;
462
463 Ok(ExternSqlBlock { function_decls })
464 }
465}
466
467pub(crate) struct SqlFunctionDecl {
468 attributes: Vec<Attribute>,
469 fn_token: Token![fn],
470 fn_name: Ident,
471 generics: Generics,
472 args: Punctuated<StrictFnArg, Token![,]>,
473 return_type: Type,
474}
475
476impl Parse for SqlFunctionDecl {
477 fn parse(input: ParseStream) -> Result<Self> {
478 let attributes = Attribute::parse_outer(input)?;
479 let fn_token: Token![fn] = input.parse()?;
480 let fn_name = Ident::parse(input)?;
481 let generics = Generics::parse(input)?;
482 let args;
483 let _paren = parenthesized!(args in input);
484 let args = args.parse_terminated(StrictFnArg::parse, Token![,])?;
485 let return_type = if Option::<Token![->]>::parse(input)?.is_some() {
486 Type::parse(input)?
487 } else {
488 parse_quote!(diesel::expression::expression_types::NotSelectable)
489 };
490 let _semi = Option::<Token![;]>::parse(input)?;
491
492 Ok(Self {
493 attributes,
494 fn_token,
495 fn_name,
496 generics,
497 args,
498 return_type,
499 })
500 }
501}
502
503struct StrictFnArg {
505 name: Ident,
506 colon_token: Token![:],
507 ty: Type,
508}
509
510impl Parse for StrictFnArg {
511 fn parse(input: ParseStream) -> Result<Self> {
512 let name = input.parse()?;
513 let colon_token = input.parse()?;
514 let ty = input.parse()?;
515 Ok(Self {
516 name,
517 colon_token,
518 ty,
519 })
520 }
521}
522
523impl ToTokens for StrictFnArg {
524 fn to_tokens(&self, tokens: &mut TokenStream) {
525 self.name.to_tokens(tokens);
526 self.colon_token.to_tokens(tokens);
527 self.name.to_tokens(tokens);
528 }
529}
530
531fn is_sqlite_type(ty: &Type) -> bool {
532 let last_segment = if let Type::Path(tp) = ty {
533 if let Some(segment) = tp.path.segments.last() {
534 segment
535 } else {
536 return false;
537 }
538 } else {
539 return false;
540 };
541
542 let ident = last_segment.ident.to_string();
543 if ident == "Nullable" {
544 if let PathArguments::AngleBracketed(ref ab) = last_segment.arguments {
545 if let Some(GenericArgument::Type(ty)) = ab.args.first() {
546 return is_sqlite_type(ty);
547 }
548 }
549 return false;
550 }
551
552 [
553 "BigInt",
554 "Binary",
555 "Bool",
556 "Date",
557 "Double",
558 "Float",
559 "Integer",
560 "Numeric",
561 "SmallInt",
562 "Text",
563 "Time",
564 "Timestamp",
565 ]
566 .contains(&ident.as_str())
567}