displaydoc/
expand.rs

1use super::attr::AttrsHelper;
2use proc_macro2::{Span, TokenStream};
3use quote::{format_ident, quote};
4use syn::{
5    punctuated::Punctuated,
6    token::{Colon, Comma, PathSep, Plus, Where},
7    Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Generics, Ident, Path, PathArguments,
8    PathSegment, PredicateType, Result, TraitBound, TraitBoundModifier, Type, TypeParam,
9    TypeParamBound, TypePath, WhereClause, WherePredicate,
10};
11
12use std::collections::HashMap;
13
14pub(crate) fn derive(input: &DeriveInput) -> Result<TokenStream> {
15    let impls = match &input.data {
16        Data::Struct(data) => impl_struct(input, data),
17        Data::Enum(data) => impl_enum(input, data),
18        Data::Union(_) => Err(Error::new_spanned(input, "Unions are not supported")),
19    }?;
20
21    let helpers = specialization();
22    Ok(quote! {
23        #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
24        const _: () = {
25            #helpers
26            #impls
27        };
28    })
29}
30
31#[cfg(feature = "std")]
32fn specialization() -> TokenStream {
33    quote! {
34        trait DisplayToDisplayDoc {
35            fn __displaydoc_display(&self) -> Self;
36        }
37
38        impl<T: ::core::fmt::Display> DisplayToDisplayDoc for &T {
39            fn __displaydoc_display(&self) -> Self {
40                self
41            }
42        }
43
44        // If the `std` feature gets enabled we want to ensure that any crate
45        // using displaydoc can still reference the std crate, which is already
46        // being compiled in by whoever enabled the `std` feature in
47        // `displaydoc`, even if the crates using displaydoc are no_std.
48        extern crate std;
49
50        trait PathToDisplayDoc {
51            fn __displaydoc_display(&self) -> std::path::Display<'_>;
52        }
53
54        impl PathToDisplayDoc for std::path::Path {
55            fn __displaydoc_display(&self) -> std::path::Display<'_> {
56                self.display()
57            }
58        }
59
60        impl PathToDisplayDoc for std::path::PathBuf {
61            fn __displaydoc_display(&self) -> std::path::Display<'_> {
62                self.display()
63            }
64        }
65    }
66}
67
68#[cfg(not(feature = "std"))]
69fn specialization() -> TokenStream {
70    quote! {}
71}
72
73fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
74    let ty = &input.ident;
75    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
76    let where_clause = generate_where_clause(&input.generics, where_clause);
77
78    let helper = AttrsHelper::new(&input.attrs);
79
80    let display = helper.display(&input.attrs)?.map(|display| {
81        let pat = match &data.fields {
82            Fields::Named(fields) => {
83                let var = fields.named.iter().map(|field| &field.ident);
84                quote!(Self { #(#var),* })
85            }
86            Fields::Unnamed(fields) => {
87                let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i));
88                quote!(Self(#(#var),*))
89            }
90            Fields::Unit => quote!(_),
91        };
92        quote! {
93            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
94                fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
95                    // NB: This destructures the fields of `self` into named variables (for unnamed
96                    // fields, it uses _0, _1, etc as above). The `#[allow(unused_variables)]`
97                    // section means it doesn't have to parse the individual field references out of
98                    // the docstring.
99                    #[allow(unused_variables)]
100                    let #pat = self;
101                    #display
102                }
103            }
104        }
105    });
106
107    Ok(quote! { #display })
108}
109
110/// Create a `where` predicate for `ident`, without any [bound][TypeParamBound]s yet.
111fn new_empty_where_type_predicate(ident: Ident) -> PredicateType {
112    let mut path_segments = Punctuated::<PathSegment, PathSep>::new();
113    path_segments.push_value(PathSegment {
114        ident,
115        arguments: PathArguments::None,
116    });
117    PredicateType {
118        lifetimes: None,
119        bounded_ty: Type::Path(TypePath {
120            qself: None,
121            path: Path {
122                leading_colon: None,
123                segments: path_segments,
124            },
125        }),
126        colon_token: Colon {
127            spans: [Span::call_site()],
128        },
129        bounds: Punctuated::<TypeParamBound, Plus>::new(),
130    }
131}
132
133/// Create a `where` clause that we can add [WherePredicate]s to.
134fn new_empty_where_clause() -> WhereClause {
135    WhereClause {
136        where_token: Where {
137            span: Span::call_site(),
138        },
139        predicates: Punctuated::<WherePredicate, Comma>::new(),
140    }
141}
142
143enum UseGlobalPrefix {
144    LeadingColon,
145    #[allow(dead_code)]
146    NoLeadingColon,
147}
148
149/// Create a path with segments composed of [Idents] *without* any [PathArguments].
150fn join_paths(name_segments: &[&str], use_global_prefix: UseGlobalPrefix) -> Path {
151    let mut segments = Punctuated::<PathSegment, PathSep>::new();
152    assert!(!name_segments.is_empty());
153    segments.push_value(PathSegment {
154        ident: Ident::new(name_segments[0], Span::call_site()),
155        arguments: PathArguments::None,
156    });
157    for name in name_segments[1..].iter() {
158        segments.push_punct(PathSep {
159            spans: [Span::call_site(), Span::mixed_site()],
160        });
161        segments.push_value(PathSegment {
162            ident: Ident::new(name, Span::call_site()),
163            arguments: PathArguments::None,
164        });
165    }
166    Path {
167        leading_colon: match use_global_prefix {
168            UseGlobalPrefix::LeadingColon => Some(PathSep {
169                spans: [Span::call_site(), Span::mixed_site()],
170            }),
171            UseGlobalPrefix::NoLeadingColon => None,
172        },
173        segments,
174    }
175}
176
177/// Push `new_type_predicate` onto the end of `where_clause`.
178fn append_where_clause_type_predicate(
179    where_clause: &mut WhereClause,
180    new_type_predicate: PredicateType,
181) {
182    // Push a comma at the end if there are already any `where` predicates.
183    if !where_clause.predicates.is_empty() {
184        where_clause.predicates.push_punct(Comma {
185            spans: [Span::call_site()],
186        });
187    }
188    where_clause
189        .predicates
190        .push_value(WherePredicate::Type(new_type_predicate));
191}
192
193/// Add a requirement for [core::fmt::Display] to a `where` predicate for some type.
194fn add_display_constraint_to_type_predicate(
195    predicate_that_needs_a_display_impl: &mut PredicateType,
196) {
197    // Create a `Path` of `::core::fmt::Display`.
198    let display_path = join_paths(&["core", "fmt", "Display"], UseGlobalPrefix::LeadingColon);
199
200    let display_bound = TypeParamBound::Trait(TraitBound {
201        paren_token: None,
202        modifier: TraitBoundModifier::None,
203        lifetimes: None,
204        path: display_path,
205    });
206    if !predicate_that_needs_a_display_impl.bounds.is_empty() {
207        predicate_that_needs_a_display_impl.bounds.push_punct(Plus {
208            spans: [Span::call_site()],
209        });
210    }
211
212    predicate_that_needs_a_display_impl
213        .bounds
214        .push_value(display_bound);
215}
216
217/// Map each declared generic type parameter to the set of all trait boundaries declared on it.
218///
219/// These boundaries may come from the declaration site:
220///     pub enum E<T: MyTrait> { ... }
221/// or a `where` clause after the parameter declarations:
222///     pub enum E<T> where T: MyTrait { ... }
223/// This method will return the boundaries from both of those cases.
224fn extract_trait_constraints_from_source(
225    where_clause: &WhereClause,
226    type_params: &[&TypeParam],
227) -> HashMap<Ident, Vec<TraitBound>> {
228    // Add trait bounds provided at the declaration site of type parameters for the struct/enum.
229    let mut param_constraint_mapping: HashMap<Ident, Vec<TraitBound>> = type_params
230        .iter()
231        .map(|type_param| {
232            let trait_bounds: Vec<TraitBound> = type_param
233                .bounds
234                .iter()
235                .flat_map(|bound| match bound {
236                    TypeParamBound::Trait(trait_bound) => Some(trait_bound),
237                    _ => None,
238                })
239                .cloned()
240                .collect();
241            (type_param.ident.clone(), trait_bounds)
242        })
243        .collect();
244
245    // Add trait bounds from `where` clauses, which may be type parameters or types containing
246    // those parameters.
247    for predicate in where_clause.predicates.iter() {
248        // We only care about type and not lifetime constraints here.
249        if let WherePredicate::Type(ref pred_ty) = predicate {
250            let ident = match &pred_ty.bounded_ty {
251                Type::Path(TypePath { path, qself: None }) => match path.get_ident() {
252                    None => continue,
253                    Some(ident) => ident,
254                },
255                _ => continue,
256            };
257            // We ignore any type constraints that aren't direct references to type
258            // parameters of the current enum of struct definition. No types can be
259            // constrained in a `where` clause unless they are a type parameter or a generic
260            // type instantiated with one of the type parameters, so by only allowing single
261            // identifiers, we can be sure that the constrained type is a type parameter
262            // that is contained in `param_constraint_mapping`.
263            if let Some((_, ref mut known_bounds)) = param_constraint_mapping
264                .iter_mut()
265                .find(|(id, _)| *id == ident)
266            {
267                for bound in pred_ty.bounds.iter() {
268                    // We only care about trait bounds here.
269                    if let TypeParamBound::Trait(ref bound) = bound {
270                        known_bounds.push(bound.clone());
271                    }
272                }
273            }
274        }
275    }
276
277    param_constraint_mapping
278}
279
280/// Hygienically add `where _: Display` to the set of [TypeParamBound]s for `ident`, creating such
281/// a set if necessary.
282fn ensure_display_in_where_clause_for_type(where_clause: &mut WhereClause, ident: Ident) {
283    for pred_ty in where_clause
284        .predicates
285        .iter_mut()
286        // Find the `where` predicate constraining the current type param, if it exists.
287        .flat_map(|predicate| match predicate {
288            WherePredicate::Type(pred_ty) => Some(pred_ty),
289            // We're looking through type constraints, not lifetime constraints.
290            _ => None,
291        })
292    {
293        // Do a complicated destructuring in order to check if the type being constrained in this
294        // `where` clause is the type we're looking for, so we can use the mutable reference to
295        // `pred_ty` if so.
296        let matches_desired_type = matches!(
297            &pred_ty.bounded_ty,
298            Type::Path(TypePath { path, .. }) if Some(&ident) == path.get_ident());
299        if matches_desired_type {
300            add_display_constraint_to_type_predicate(pred_ty);
301            return;
302        }
303    }
304
305    // If there is no `where` predicate for the current type param, we will construct one.
306    let mut new_type_predicate = new_empty_where_type_predicate(ident);
307    add_display_constraint_to_type_predicate(&mut new_type_predicate);
308    append_where_clause_type_predicate(where_clause, new_type_predicate);
309}
310
311/// For all declared type parameters, add a [core::fmt::Display] constraint, unless the type
312/// parameter already has any type constraint.
313fn ensure_where_clause_has_display_for_all_unconstrained_members(
314    where_clause: &mut WhereClause,
315    type_params: &[&TypeParam],
316) {
317    let param_constraint_mapping = extract_trait_constraints_from_source(where_clause, type_params);
318
319    for (ident, known_bounds) in param_constraint_mapping.into_iter() {
320        // If the type parameter has any constraints already, we don't want to touch it, to avoid
321        // breaking use cases where a type parameter only needs to impl `Debug`, for example.
322        if known_bounds.is_empty() {
323            ensure_display_in_where_clause_for_type(where_clause, ident);
324        }
325    }
326}
327
328/// Generate a `where` clause that ensures all generic type parameters `impl`
329/// [core::fmt::Display] unless already constrained.
330///
331/// This approach allows struct/enum definitions deriving [crate::Display] to avoid hardcoding
332/// a [core::fmt::Display] constraint into every type parameter.
333///
334/// If the type parameter isn't already constrained, we add a `where _: Display` clause to our
335/// display implementation to expect to be able to format every enum case or struct member.
336///
337/// In fact, we would preferably only require `where _: Display` or `where _: Debug` where the
338/// format string actually requires it. However, while [`std::fmt` defines a formal syntax for
339/// `format!()`][format syntax], it *doesn't* expose the actual logic to parse the format string,
340/// which appears to live in [`rustc_parse_format`]. While we use the [`syn`] crate to parse rust
341/// syntax, it also doesn't currently provide any method to introspect a `format!()` string. It
342/// would be nice to contribute this upstream in [`syn`].
343///
344/// [format syntax]: std::fmt#syntax
345/// [`rustc_parse_format`]: https://doc.rust-lang.org/nightly/nightly-rustc/rustc_parse_format/index.html
346fn generate_where_clause(generics: &Generics, where_clause: Option<&WhereClause>) -> WhereClause {
347    let mut where_clause = where_clause.cloned().unwrap_or_else(new_empty_where_clause);
348    let type_params: Vec<&TypeParam> = generics.type_params().collect();
349    ensure_where_clause_has_display_for_all_unconstrained_members(&mut where_clause, &type_params);
350    where_clause
351}
352
353fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
354    let ty = &input.ident;
355    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
356    let where_clause = generate_where_clause(&input.generics, where_clause);
357
358    let helper = AttrsHelper::new(&input.attrs);
359
360    let displays = data
361        .variants
362        .iter()
363        .map(|variant| helper.display_with_input(&input.attrs, &variant.attrs))
364        .collect::<Result<Vec<_>>>()?;
365
366    if data.variants.is_empty() {
367        Ok(quote! {
368            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
369                fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
370                    unreachable!("empty enums cannot be instantiated and thus cannot be printed")
371                }
372            }
373        })
374    } else if displays.iter().any(Option::is_some) {
375        let arms = data
376            .variants
377            .iter()
378            .zip(displays)
379            .map(|(variant, display)| {
380                let display =
381                    display.ok_or_else(|| Error::new_spanned(variant, "missing doc comment"))?;
382                let ident = &variant.ident;
383                Ok(match &variant.fields {
384                    Fields::Named(fields) => {
385                        let var = fields.named.iter().map(|field| &field.ident);
386                        quote!(Self::#ident { #(#var),* } => { #display })
387                    }
388                    Fields::Unnamed(fields) => {
389                        let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i));
390                        quote!(Self::#ident(#(#var),*) => { #display })
391                    }
392                    Fields::Unit => quote!(Self::#ident => { #display }),
393                })
394            })
395            .collect::<Result<Vec<_>>>()?;
396        Ok(quote! {
397            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
398                fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
399                    #[allow(unused_variables)]
400                    match self {
401                        #(#arms,)*
402                    }
403                }
404            }
405        })
406    } else {
407        Err(Error::new_spanned(input, "Missing doc comments"))
408    }
409}