dsl_auto_type/auto_type/
mod.rs

1mod case;
2pub mod expression_type_inference;
3mod local_variables_map;
4mod referenced_generics;
5mod settings_builder;
6
7use {
8    darling::{util::SpannedValue, FromMeta},
9    either::Either,
10    proc_macro2::{Span, TokenStream},
11    quote::quote,
12    std::{collections::HashMap, rc::Rc},
13    syn::{parse_quote, parse_quote_spanned, spanned::Spanned, Ident, ItemFn, Token, Type},
14};
15
16use local_variables_map::*;
17
18pub use {
19    case::Case, expression_type_inference::InferrerSettings,
20    settings_builder::DeriveSettingsBuilder,
21};
22
23pub struct DeriveSettings {
24    default_dsl_path: syn::Path,
25    default_method_type_case: Case,
26    default_function_type_case: Case,
27    default_generate_type_alias: bool,
28}
29
30#[derive(darling::FromMeta)]
31struct DeriveParameters {
32    /// Can be overridden to provide custom DSLs
33    dsl_path: Option<syn::Path>,
34    type_alias: darling::util::Flag,
35    no_type_alias: darling::util::Flag,
36    type_name: Option<syn::Ident>,
37    type_case: Option<SpannedValue<String>>,
38}
39
40pub(crate) fn auto_type_impl(
41    attr: TokenStream,
42    input: &TokenStream,
43    derive_settings: DeriveSettings,
44) -> Result<TokenStream, crate::Error> {
45    let settings_input: DeriveParameters =
46        DeriveParameters::from_list(&darling::ast::NestedMeta::parse_meta_list(attr)?)?;
47
48    let mut input_function = syn::parse2::<ItemFn>(input.clone())?;
49
50    let inferrer_settings = InferrerSettings {
51        dsl_path: settings_input
52            .dsl_path
53            .unwrap_or(derive_settings.default_dsl_path),
54        method_types_case: derive_settings.default_method_type_case,
55        function_types_case: derive_settings.default_function_type_case,
56    };
57
58    let function_name = &input_function.sig.ident;
59    let type_alias = match (
60        settings_input.type_alias.is_present(),
61        settings_input.no_type_alias.is_present(),
62        derive_settings.default_generate_type_alias,
63    ) {
64        (false, false, b) => b,
65        (true, false, _) => true,
66        (false, true, _) => false,
67        (true, true, _) => {
68            return Err(syn::Error::new(
69                Span::call_site(),
70                "type_alias and no_type_alias are mutually exclusive",
71            )
72            .into())
73        }
74    };
75    let type_alias: Option<syn::Ident> = match (
76        type_alias,
77        settings_input.type_name,
78        settings_input.type_case,
79    ) {
80        (false, None, None) => None,
81        (true, None, None) => {
82            // By default be consistent with call expressions, for when other will refer
83            // this query fragment in another auto_type function
84            Some(
85                inferrer_settings
86                    .function_types_case
87                    .ident_with_case(function_name),
88            )
89        }
90        (_, Some(ident), None) => Some(ident),
91        (_, None, Some(case)) => {
92            let case = Case::from_str(case.as_str(), case.span())?;
93            Some(case.ident_with_case(function_name))
94        }
95        (_, Some(_), Some(type_case)) => {
96            return Err(syn::Error::new(
97                type_case.span(),
98                "type_name and type_case are mutually exclusive",
99            )
100            .into())
101        }
102    };
103
104    let last_statement = input_function.block.stmts.last().ok_or_else(|| {
105        syn::Error::new(
106            input_function.span(),
107            "function body should not be empty for auto_type",
108        )
109    })?;
110    let mut errors = Vec::new();
111    let return_type = match input_function.sig.output {
112        syn::ReturnType::Type(_, return_type) => {
113            let return_expression = match last_statement {
114                syn::Stmt::Expr(expr, None) => expr,
115                syn::Stmt::Expr(
116                    syn::Expr::Return(syn::ExprReturn {
117                        expr: Some(expr), ..
118                    }),
119                    _,
120                ) => &**expr,
121                _ => {
122                    return Err(syn::Error::new(
123                        last_statement.span(),
124                        "last statement should be an expression for auto_type",
125                    )
126                    .into())
127                }
128            };
129
130            // Build a map of local variables, and get the function parameters in there
131            let mut local_variables_map = LocalVariablesMap {
132                inferrer_settings: &inferrer_settings,
133                inner: LocalVariablesMapInner {
134                    map: Default::default(),
135                    parent: None,
136                },
137            };
138            for const_generic in input_function.sig.generics.const_params() {
139                local_variables_map.process_const_generic(const_generic);
140            }
141            for function_param in &input_function.sig.inputs {
142                if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = function_param {
143                    match local_variables_map.process_pat(pat, Some(ty), None) {
144                        Ok(()) => {}
145                        Err(e) => errors.push(Rc::new(e)),
146                    }
147                };
148            }
149
150            // Add local variables from the function body, and finally infer the type
151            local_variables_map.infer_block_expression_type(
152                return_expression,
153                Some(&return_type),
154                &input_function.block,
155                &mut errors,
156            )
157        }
158        _ => {
159            // This error message is not strictly correct: we also support
160            // partially-specified return types that involve `_`, but for simplicity we just
161            // put the overwhelmingly most common case in this error message
162            return Err(syn::Error::new(
163                input_function.sig.output.span(),
164                "Function return type should be explicitly specified as `-> _` for auto_type",
165            )
166            .into());
167        }
168    };
169
170    let type_alias = match type_alias {
171        Some(type_alias) => {
172            // We're generating a type alias so we need to extract the necessary lifetimes and
173            // generic type parameters for that type alias
174            let type_alias_generics = referenced_generics::extract_referenced_generics(
175                &return_type,
176                &input_function.sig.generics,
177                &mut errors,
178            );
179
180            let vis = &input_function.vis;
181            input_function.sig.output = parse_quote!(-> #type_alias #type_alias_generics);
182            quote! {
183                #[allow(non_camel_case_types)]
184                #vis type #type_alias #type_alias_generics = #return_type;
185            }
186        }
187        None => {
188            input_function.sig.output = parse_quote!(-> #return_type);
189            quote! {}
190        }
191    };
192
193    let mut res = quote! {
194        #type_alias
195        #[allow(clippy::needless_lifetimes)]
196        #input_function
197    };
198
199    for error in errors {
200        // Extracting from the `Rc` only if it's the last reference is an elegant way to
201        // deduplicate errors. For this to work it is necessary that the rest of
202        // the errors (those from the local variables map that weren't used) are
203        // dropped before, which is the case here, and that we are iterating on the
204        // errors in an owned manner.
205        if let Ok(error) = Rc::try_unwrap(error) {
206            res.extend(error.into_compile_error());
207        }
208    }
209
210    Ok(res)
211}