diesel_derives/
util.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::parse::{Parse, ParseStream, Peek, Result};
4use syn::token::Eq;
5use syn::{parenthesized, parse_quote, Data, DeriveInput, GenericArgument, Ident, Type};
6
7use crate::model::Model;
8
9pub const COLUMN_NAME_NOTE: &str = "column_name = foo";
10pub const SQL_TYPE_NOTE: &str = "sql_type = Foo";
11pub const SERIALIZE_AS_NOTE: &str = "serialize_as = Foo";
12pub const DESERIALIZE_AS_NOTE: &str = "deserialize_as = Foo";
13pub const TABLE_NAME_NOTE: &str = "table_name = foo";
14pub const TREAT_NONE_AS_DEFAULT_VALUE_NOTE: &str = "treat_none_as_default_value = true";
15pub const TREAT_NONE_AS_NULL_NOTE: &str = "treat_none_as_null = true";
16pub const BELONGS_TO_NOTE: &str = "belongs_to(Foo, foreign_key = foo_id)";
17pub const MYSQL_TYPE_NOTE: &str = "mysql_type(name = \"foo\")";
18pub const SQLITE_TYPE_NOTE: &str = "sqlite_type(name = \"foo\")";
19pub const POSTGRES_TYPE_NOTE: &str = "postgres_type(name = \"foo\", schema = \"public\")";
20pub const POSTGRES_TYPE_NOTE_ID: &str = "postgres_type(oid = 37, array_oid = 54)";
21pub const SELECT_EXPRESSION_NOTE: &str =
22    "select_expression = schema::table_name::column_name.is_not_null()";
23pub const SELECT_EXPRESSION_TYPE_NOTE: &str =
24    "select_expression_type = dsl::IsNotNull<schema::table_name::column_name>";
25pub const CHECK_FOR_BACKEND_NOTE: &str = "diesel::pg::Pg";
26
27pub fn unknown_attribute(name: &Ident, valid: &[&str]) -> syn::Error {
28    let prefix = if valid.len() == 1 { "" } else { " one of" };
29
30    syn::Error::new(
31        name.span(),
32        format!(
33            "unknown attribute, expected{prefix} `{}`",
34            valid.join("`, `")
35        ),
36    )
37}
38
39pub fn parse_eq<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
40    if input.is_empty() {
41        return Err(syn::Error::new(
42            input.span(),
43            format!(
44                "unexpected end of input, expected `=`\n\
45                 help: The correct format looks like `#[diesel({help})]`",
46            ),
47        ));
48    }
49
50    input.parse::<Eq>()?;
51    input.parse()
52}
53
54pub fn parse_paren<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
55    if input.is_empty() {
56        return Err(syn::Error::new(
57            input.span(),
58            format!(
59                "unexpected end of input, expected parentheses\n\
60                 help: The correct format looks like `#[diesel({help})]`",
61            ),
62        ));
63    }
64
65    let content;
66    parenthesized!(content in input);
67    content.parse()
68}
69
70pub fn parse_paren_list<T, D>(
71    input: ParseStream,
72    help: &str,
73    sep: D,
74) -> Result<syn::punctuated::Punctuated<T, <D as Peek>::Token>>
75where
76    T: Parse,
77    D: Peek,
78    D::Token: Parse,
79{
80    if input.is_empty() {
81        return Err(syn::Error::new(
82            input.span(),
83            format!(
84                "unexpected end of input, expected parentheses\n\
85                 help: The correct format looks like `#[diesel({help})]`",
86            ),
87        ));
88    }
89
90    let content;
91    parenthesized!(content in input);
92    content.parse_terminated(T::parse, sep)
93}
94
95pub fn wrap_in_dummy_mod(item: TokenStream) -> TokenStream {
96    // #[allow(unused_qualifications)] can be removed if https://github.com/rust-lang/rust/issues/130277 gets done
97    quote! {
98        #[allow(unused_imports)]
99        #[allow(unused_qualifications)]
100        const _: () = {
101            // This import is not actually redundant. When using diesel_derives
102            // inside of diesel, `diesel` doesn't exist as an extern crate, and
103            // to work around that it contains a private
104            // `mod diesel { pub use super::*; }` that this import will then
105            // refer to. In all other cases, this imports refers to the extern
106            // crate diesel.
107            use diesel;
108
109            #item
110        };
111    }
112}
113
114pub fn inner_of_option_ty(ty: &Type) -> &Type {
115    option_ty_arg(ty).unwrap_or(ty)
116}
117
118pub fn is_option_ty(ty: &Type) -> bool {
119    option_ty_arg(ty).is_some()
120}
121
122fn option_ty_arg(mut ty: &Type) -> Option<&Type> {
123    use syn::PathArguments::AngleBracketed;
124
125    // Check the inner equivalent type
126    loop {
127        match ty {
128            Type::Group(group) => ty = &group.elem,
129            Type::Paren(paren) => ty = &paren.elem,
130            _ => break,
131        }
132    }
133
134    match *ty {
135        Type::Path(ref ty) => {
136            let last_segment = ty.path.segments.iter().next_back().unwrap();
137            match last_segment.arguments {
138                AngleBracketed(ref args) if last_segment.ident == "Option" => {
139                    match args.args.iter().next_back() {
140                        Some(GenericArgument::Type(ty)) => Some(ty),
141                        _ => None,
142                    }
143                }
144                _ => None,
145            }
146        }
147        _ => None,
148    }
149}
150
151pub fn ty_for_foreign_derive(item: &DeriveInput, model: &Model) -> Result<Type> {
152    if model.foreign_derive {
153        match item.data {
154            Data::Struct(ref body) => match body.fields.iter().next() {
155                Some(field) => Ok(field.ty.clone()),
156                None => Err(syn::Error::new(
157                    proc_macro2::Span::call_site(),
158                    "foreign_derive requires at least one field",
159                )),
160            },
161            _ => Err(syn::Error::new(
162                proc_macro2::Span::call_site(),
163                "foreign_derive can only be used with structs",
164            )),
165        }
166    } else {
167        let ident = &item.ident;
168        let (_, ty_generics, ..) = item.generics.split_for_impl();
169        Ok(parse_quote!(#ident #ty_generics))
170    }
171}
172
173pub fn camel_to_snake(name: &str) -> String {
174    let mut result = String::with_capacity(name.len());
175    result.push_str(&name[..1].to_lowercase());
176    for character in name[1..].chars() {
177        if character.is_uppercase() {
178            result.push('_');
179            for lowercase in character.to_lowercase() {
180                result.push(lowercase);
181            }
182        } else {
183            result.push(character);
184        }
185    }
186    result
187}