diesel_table_macro_syntax/
lib.rs

1use syn::spanned::Spanned;
2use syn::Ident;
3use syn::MetaNameValue;
4
5pub struct TableDecl {
6    pub use_statements: Vec<syn::ItemUse>,
7    pub meta: Vec<syn::Attribute>,
8    pub schema: Option<Ident>,
9    _punct: Option<syn::Token![.]>,
10    pub sql_name: String,
11    pub table_name: Ident,
12    pub primary_keys: Option<PrimaryKey>,
13    _brace_token: syn::token::Brace,
14    pub column_defs: syn::punctuated::Punctuated<ColumnDef, syn::Token![,]>,
15}
16
17#[allow(dead_code)] // paren_token is currently unused
18pub struct PrimaryKey {
19    paren_token: syn::token::Paren,
20    pub keys: syn::punctuated::Punctuated<Ident, syn::Token![,]>,
21}
22
23pub struct ColumnDef {
24    pub meta: Vec<syn::Attribute>,
25    pub column_name: Ident,
26    pub sql_name: String,
27    _arrow: syn::Token![->],
28    pub tpe: syn::TypePath,
29    pub max_length: Option<syn::LitInt>,
30}
31
32impl syn::parse::Parse for TableDecl {
33    fn parse(buf: &syn::parse::ParseBuffer<'_>) -> Result<Self, syn::Error> {
34        let mut use_statements = Vec::new();
35        loop {
36            let fork = buf.fork();
37            if fork.parse::<syn::ItemUse>().is_ok() {
38                use_statements.push(buf.parse()?);
39            } else {
40                break;
41            };
42        }
43        let mut meta = syn::Attribute::parse_outer(buf)?;
44        let fork = buf.fork();
45        let (schema, punct, table_name) = if parse_table_with_schema(&fork).is_ok() {
46            let (schema, punct, table_name) = parse_table_with_schema(buf)?;
47            (Some(schema), Some(punct), table_name)
48        } else {
49            let table_name = buf.parse()?;
50            (None, None, table_name)
51        };
52        let fork = buf.fork();
53        let primary_keys = if fork.parse::<PrimaryKey>().is_ok() {
54            Some(buf.parse()?)
55        } else {
56            None
57        };
58        let content;
59        let brace_token = syn::braced!(content in buf);
60        let column_defs = syn::punctuated::Punctuated::parse_terminated(&content)?;
61        let sql_name = get_sql_name(&mut meta, &table_name)?;
62        Ok(Self {
63            use_statements,
64            meta,
65            table_name,
66            primary_keys,
67            _brace_token: brace_token,
68            column_defs,
69            sql_name,
70            _punct: punct,
71            schema,
72        })
73    }
74}
75
76impl syn::parse::Parse for PrimaryKey {
77    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
78        let content;
79        let paren_token = syn::parenthesized!(content in input);
80        let keys = content.parse_terminated(Ident::parse, syn::Token![,])?;
81        Ok(Self { paren_token, keys })
82    }
83}
84
85impl syn::parse::Parse for ColumnDef {
86    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
87        let mut meta = syn::Attribute::parse_outer(input)?;
88        let column_name: syn::Ident = input.parse()?;
89        let _arrow: syn::Token![->] = input.parse()?;
90        let tpe: syn::TypePath = input.parse()?;
91
92        let sql_name = get_sql_name(&mut meta, &column_name)?;
93        let max_length = take_lit(&mut meta, "max_length", |lit| match lit {
94            syn::Lit::Int(lit_int) => Some(lit_int),
95            _ => None,
96        })?;
97
98        Ok(Self {
99            meta,
100            column_name,
101            _arrow,
102            tpe,
103            max_length,
104            sql_name,
105        })
106    }
107}
108
109pub fn parse_table_with_schema(
110    input: &syn::parse::ParseBuffer<'_>,
111) -> Result<(syn::Ident, syn::Token![.], syn::Ident), syn::Error> {
112    Ok((input.parse()?, input.parse()?, input.parse()?))
113}
114
115fn get_sql_name(
116    meta: &mut Vec<syn::Attribute>,
117    fallback_ident: &syn::Ident,
118) -> Result<String, syn::Error> {
119    Ok(
120        match take_lit(meta, "sql_name", |lit| match lit {
121            syn::Lit::Str(lit_str) => Some(lit_str),
122            _ => None,
123        })? {
124            None => {
125                use syn::ext::IdentExt;
126                fallback_ident.unraw().to_string()
127            }
128            Some(str_lit) => {
129                let mut str_lit = str_lit.value();
130                if str_lit.starts_with("r#") {
131                    str_lit.drain(..2);
132                }
133                str_lit
134            }
135        },
136    )
137}
138
139fn take_lit<O, F>(
140    meta: &mut Vec<syn::Attribute>,
141    attribute_name: &'static str,
142    extraction_fn: F,
143) -> Result<Option<O>, syn::Error>
144where
145    F: FnOnce(syn::Lit) -> Option<O>,
146{
147    if let Some(index) = meta.iter().position(|m| {
148        m.path()
149            .get_ident()
150            .map(|i| i == attribute_name)
151            .unwrap_or(false)
152    }) {
153        let attribute = meta.remove(index);
154        let span = attribute.span();
155        let extraction_after_finding_attr = if let syn::Meta::NameValue(MetaNameValue {
156            value: syn::Expr::Lit(syn::ExprLit { lit, .. }),
157            ..
158        }) = attribute.meta
159        {
160            extraction_fn(lit)
161        } else {
162            None
163        };
164        return Ok(Some(extraction_after_finding_attr.ok_or_else(|| {
165            syn::Error::new(
166                span,
167                format_args!("Invalid `#[sql_name = {attribute_name:?}]` attribute"),
168            )
169        })?));
170    }
171    Ok(None)
172}