zerovec_derive/
utils.rs

1// This file is part of ICU4X. For terms of use, please see the file
2// called LICENSE at the top level of the ICU4X source tree
3// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
4
5use quote::{quote, ToTokens};
6
7use proc_macro2::Span;
8use proc_macro2::TokenStream as TokenStream2;
9use syn::parse::{Parse, ParseStream};
10use syn::punctuated::Punctuated;
11use syn::spanned::Spanned;
12use syn::{Attribute, Error, Field, Fields, Ident, Index, Result, Token};
13
14#[derive(Default)]
15pub struct ReprInfo {
16    pub c: bool,
17    pub transparent: bool,
18    pub u8: bool,
19    pub packed: bool,
20}
21
22impl ReprInfo {
23    pub fn compute(attrs: &[Attribute]) -> Self {
24        let mut info = ReprInfo::default();
25        for attr in attrs.iter().filter(|a| a.path().is_ident("repr")) {
26            if let Ok(pieces) = attr.parse_args::<IdentListAttribute>() {
27                for piece in pieces.idents.iter() {
28                    if piece == "C" || piece == "c" {
29                        info.c = true;
30                    } else if piece == "transparent" {
31                        info.transparent = true;
32                    } else if piece == "packed" {
33                        info.packed = true;
34                    } else if piece == "u8" {
35                        info.u8 = true;
36                    }
37                }
38            }
39        }
40        info
41    }
42
43    pub fn cpacked_or_transparent(self) -> bool {
44        (self.c && self.packed) || self.transparent
45    }
46}
47
48// An attribute that is a list of idents
49struct IdentListAttribute {
50    idents: Punctuated<Ident, Token![,]>,
51}
52
53impl Parse for IdentListAttribute {
54    fn parse(input: ParseStream) -> Result<Self> {
55        Ok(IdentListAttribute {
56            idents: input.parse_terminated(Ident::parse, Token![,])?,
57        })
58    }
59}
60
61/// Given a set of entries for struct field definitions to go inside a `struct {}` definition,
62/// wrap in a () or {} based on the type of field
63pub fn wrap_field_inits(streams: &[TokenStream2], fields: &Fields) -> TokenStream2 {
64    match *fields {
65        Fields::Named(_) => quote!( { #(#streams),* } ),
66        Fields::Unnamed(_) => quote!( ( #(#streams),* ) ),
67        Fields::Unit => {
68            unreachable!("#[make_(var)ule] should have already checked that there are fields")
69        }
70    }
71}
72
73/// Return a semicolon token if necessary after the struct definition
74pub fn semi_for(f: &Fields) -> TokenStream2 {
75    if let Fields::Unnamed(..) = *f {
76        quote!(;)
77    } else {
78        quote!()
79    }
80}
81
82/// Returns the repr attribute to be applied to the resultant ULE or VarULE type
83pub fn repr_for(f: &Fields) -> TokenStream2 {
84    if f.len() == 1 {
85        quote!(transparent)
86    } else {
87        quote!(C, packed)
88    }
89}
90
91fn suffixed_ident(name: &str, suffix: usize, s: Span) -> Ident {
92    Ident::new(&format!("{name}_{suffix}"), s)
93}
94
95/// Given an iterator over ULE or AsULE struct fields, returns code that calculates field sizes and generates a line
96/// of code per field based on the per_field_code function (whose parameters are the field, the identifier of the const
97/// for the previous offset, the identifier for the const for the next offset, and the field index)
98pub(crate) fn generate_per_field_offsets<'a>(
99    fields: &[FieldInfo<'a>],
100    // Whether the fields are ULE types or AsULE (and need conversion)
101    fields_are_asule: bool,
102    // (field, prev_offset_ident, size_ident)
103    mut per_field_code: impl FnMut(&FieldInfo<'a>, &Ident, &Ident) -> TokenStream2, /* (code, remaining_offset) */
104) -> (TokenStream2, syn::Ident) {
105    let mut prev_offset_ident = Ident::new("ZERO", Span::call_site());
106    let mut code = quote!(
107        const ZERO: usize = 0;
108    );
109
110    for (i, field_info) in fields.iter().enumerate() {
111        let field = &field_info.field;
112        let ty = &field.ty;
113        let ty = if fields_are_asule {
114            quote!(<#ty as zerovec::ule::AsULE>::ULE)
115        } else {
116            quote!(#ty)
117        };
118        let new_offset_ident = suffixed_ident("OFFSET", i, field.span());
119        let size_ident = suffixed_ident("SIZE", i, field.span());
120        let pf_code = per_field_code(field_info, &prev_offset_ident, &size_ident);
121        code = quote! {
122            #code;
123            const #size_ident: usize = ::core::mem::size_of::<#ty>();
124            const #new_offset_ident: usize = #prev_offset_ident + #size_ident;
125            #pf_code;
126        };
127
128        prev_offset_ident = new_offset_ident;
129    }
130
131    (code, prev_offset_ident)
132}
133
134#[derive(Clone, Debug)]
135pub(crate) struct FieldInfo<'a> {
136    pub accessor: TokenStream2,
137    pub field: &'a Field,
138    pub index: usize,
139}
140
141impl<'a> FieldInfo<'a> {
142    pub fn make_list(iter: impl Iterator<Item = &'a Field>) -> Vec<Self> {
143        iter.enumerate()
144            .map(|(i, field)| Self::new_for_field(field, i))
145            .collect()
146    }
147
148    pub fn new_for_field(f: &'a Field, index: usize) -> Self {
149        if let Some(ref i) = f.ident {
150            FieldInfo {
151                accessor: quote!(#i),
152                field: f,
153                index,
154            }
155        } else {
156            let idx = Index::from(index);
157            FieldInfo {
158                accessor: quote!(#idx),
159                field: f,
160                index,
161            }
162        }
163    }
164
165    /// Get the code for setting this field in struct decl/brace syntax
166    ///
167    /// Use self.accessor for dot-notation accesses
168    pub fn setter(&self) -> TokenStream2 {
169        if let Some(ref i) = self.field.ident {
170            quote!(#i: )
171        } else {
172            quote!()
173        }
174    }
175
176    /// Produce a name for a getter for the field
177    pub fn getter(&self) -> TokenStream2 {
178        if let Some(ref i) = self.field.ident {
179            quote!(#i)
180        } else {
181            suffixed_ident("field", self.index, self.field.span()).into_token_stream()
182        }
183    }
184
185    /// Produce a prose name for the field for use in docs
186    pub fn getter_doc_name(&self) -> String {
187        if let Some(ref i) = self.field.ident {
188            format!("the unsized `{i}` field")
189        } else {
190            format!("tuple struct field #{}", self.index)
191        }
192    }
193}
194
195/// Extracts all `zerovec::name(..)` attribute
196pub fn extract_parenthetical_zerovec_attrs(
197    attrs: &mut Vec<Attribute>,
198    name: &str,
199) -> Result<Vec<Ident>> {
200    let mut ret = vec![];
201    let mut error = None;
202    attrs.retain(|a| {
203        // skip the "zerovec" part
204        let second_segment = a.path().segments.iter().nth(1);
205
206        if let Some(second) = second_segment {
207            if second.ident == name {
208                let list = match a.parse_args::<IdentListAttribute>() {
209                    Ok(l) => l,
210                    Err(_) => {
211                        error = Some(Error::new(
212                            a.span(),
213                            format!("#[zerovec::{name}(..)] takes in a comma separated list of identifiers"),
214                        ));
215                        return false;
216                    }
217                };
218                ret.extend(list.idents.iter().cloned());
219                return false;
220            }
221        }
222
223        true
224    });
225
226    if let Some(error) = error {
227        return Err(error);
228    }
229    Ok(ret)
230}
231
232/// Removes all attributes with `zerovec` in the name and places them in a separate vector
233pub fn extract_zerovec_attributes(attrs: &mut Vec<Attribute>) -> Vec<Attribute> {
234    let mut ret = vec![];
235    attrs.retain(|a| {
236        if a.path().segments.len() == 2 && a.path().segments[0].ident == "zerovec" {
237            ret.push(a.clone());
238            return false;
239        }
240        true
241    });
242    ret
243}
244
245/// Extract attributes from field, and return them
246///
247/// Only current field attribute is `zerovec::varule(VarUleType)`
248pub fn extract_field_attributes(attrs: &mut Vec<Attribute>) -> Result<Option<Ident>> {
249    let mut zerovec_attrs = extract_zerovec_attributes(attrs);
250    let varule = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "varule")?;
251
252    if varule.len() > 1 {
253        return Err(Error::new(
254            varule[1].span(),
255            "Found multiple #[zerovec::varule()] on one field",
256        ));
257    }
258
259    if !zerovec_attrs.is_empty() {
260        return Err(Error::new(
261            zerovec_attrs[1].span(),
262            "Found unusable #[zerovec::] attrs on field, only #[zerovec::varule()] supported",
263        ));
264    }
265
266    Ok(varule.first().cloned())
267}
268
269#[derive(Default, Copy, Clone)]
270pub struct ZeroVecAttrs {
271    pub skip_kv: bool,
272    pub skip_ord: bool,
273    pub serialize: bool,
274    pub deserialize: bool,
275    pub debug: bool,
276    pub hash: bool,
277}
278
279/// Removes all known zerovec:: attributes from struct attrs and validates them
280pub fn extract_attributes_common(
281    attrs: &mut Vec<Attribute>,
282    span: Span,
283    is_var: bool,
284) -> Result<ZeroVecAttrs> {
285    let mut zerovec_attrs = extract_zerovec_attributes(attrs);
286
287    let derive = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "derive")?;
288    let skip = extract_parenthetical_zerovec_attrs(&mut zerovec_attrs, "skip_derive")?;
289
290    let name = if is_var { "make_varule" } else { "make_ule" };
291
292    if let Some(attr) = zerovec_attrs.first() {
293        return Err(Error::new(
294            attr.span(),
295            format!("Found unknown or duplicate attribute for #[{name}]"),
296        ));
297    }
298
299    let mut attrs = ZeroVecAttrs::default();
300
301    for ident in derive {
302        if ident == "Serialize" {
303            attrs.serialize = true;
304        } else if ident == "Deserialize" {
305            attrs.deserialize = true;
306        } else if ident == "Debug" {
307            attrs.debug = true;
308        } else if ident == "Hash" {
309            attrs.hash = true;
310        } else {
311            return Err(Error::new(
312                ident.span(),
313                format!(
314                    "Found unknown derive attribute for #[{name}]: #[zerovec::derive({ident})]"
315                ),
316            ));
317        }
318    }
319
320    for ident in skip {
321        if ident == "ZeroMapKV" {
322            attrs.skip_kv = true;
323        } else if ident == "Ord" {
324            attrs.skip_ord = true;
325        } else {
326            return Err(Error::new(
327                ident.span(),
328                format!("Found unknown derive attribute for #[{name}]: #[zerovec::skip_derive({ident})]"),
329            ));
330        }
331    }
332
333    if (attrs.serialize || attrs.deserialize) && !is_var {
334        return Err(Error::new(
335            span,
336            "#[make_ule] does not support #[zerovec::derive(Serialize, Deserialize)]",
337        ));
338    }
339
340    Ok(attrs)
341}