diesel_derives/
util.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::parse::{Parse, ParseStream, Peek, Result};
4use syn::token::Eq;
5use syn::{parenthesized, parse_quote, Data, DeriveInput, GenericArgument, Ident, Type};
6
7use crate::model::Model;
8
9pub const COLUMN_NAME_NOTE: &str = "column_name = foo";
10pub const SQL_TYPE_NOTE: &str = "sql_type = Foo";
11pub const SERIALIZE_AS_NOTE: &str = "serialize_as = Foo";
12pub const DESERIALIZE_AS_NOTE: &str = "deserialize_as = Foo";
13pub const TABLE_NAME_NOTE: &str = "table_name = foo";
14pub const TREAT_NONE_AS_DEFAULT_VALUE_NOTE: &str = "treat_none_as_default_value = true";
15pub const TREAT_NONE_AS_NULL_NOTE: &str = "treat_none_as_null = true";
16pub const BELONGS_TO_NOTE: &str = "belongs_to(Foo, foreign_key = foo_id)";
17pub const MYSQL_TYPE_NOTE: &str = "mysql_type(name = \"foo\")";
18pub const SQLITE_TYPE_NOTE: &str = "sqlite_type(name = \"foo\")";
19pub const POSTGRES_TYPE_NOTE: &str = "postgres_type(name = \"foo\", schema = \"public\")";
20pub const POSTGRES_TYPE_NOTE_ID: &str = "postgres_type(oid = 37, array_oid = 54)";
21pub const SELECT_EXPRESSION_NOTE: &str =
22    "select_expression = schema::table_name::column_name.is_not_null()";
23pub const SELECT_EXPRESSION_TYPE_NOTE: &str =
24    "select_expression_type = dsl::IsNotNull<schema::table_name::column_name>";
25pub const CHECK_FOR_BACKEND_NOTE: &str = "diesel::pg::Pg";
26pub const BASE_QUERY_NOTE: &str =
27    "base_query = schema::table_name::table.order_by(schema::table_name::id)";
28pub const BASE_QUERY_TYPE_NOTE: &str =
29    "base_query_type = dsl::OrderBy<schema::table_name::table, schema::table_name::id>";
30
31pub fn unknown_attribute(name: &Ident, valid: &[&str]) -> syn::Error {
32    let prefix = if valid.len() == 1 { "" } else { " one of" };
33
34    syn::Error::new(
35        name.span(),
36        format!(
37            "unknown attribute, expected{prefix} `{}`",
38            valid.join("`, `")
39        ),
40    )
41}
42
43pub fn parse_eq<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
44    if input.is_empty() {
45        return Err(syn::Error::new(
46            input.span(),
47            format!(
48                "unexpected end of input, expected `=`\n\
49                 help: the correct format looks like `#[diesel({help})]`",
50            ),
51        ));
52    }
53
54    input.parse::<Eq>()?;
55    input.parse()
56}
57
58pub fn parse_paren<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
59    if input.is_empty() {
60        return Err(syn::Error::new(
61            input.span(),
62            format!(
63                "unexpected end of input, expected parentheses\n\
64                 help: the correct format looks like `#[diesel({help})]`",
65            ),
66        ));
67    }
68
69    let content;
70    parenthesized!(content in input);
71    content.parse()
72}
73
74pub fn parse_paren_list<T, D>(
75    input: ParseStream,
76    help: &str,
77    sep: D,
78) -> Result<syn::punctuated::Punctuated<T, <D as Peek>::Token>>
79where
80    T: Parse,
81    D: Peek,
82    D::Token: Parse,
83{
84    if input.is_empty() {
85        return Err(syn::Error::new(
86            input.span(),
87            format!(
88                "unexpected end of input, expected parentheses\n\
89                 help: the correct format looks like `#[diesel({help})]`",
90            ),
91        ));
92    }
93
94    let content;
95    parenthesized!(content in input);
96    content.parse_terminated(T::parse, sep)
97}
98
99pub fn wrap_in_dummy_mod(item: TokenStream) -> TokenStream {
100    quote! {
101        const _: () = {
102            // This import is not actually redundant. When using diesel_derives
103            // inside of diesel, `diesel` doesn't exist as an extern crate, and
104            // to work around that it contains a private
105            // `mod diesel { pub use super::*; }` that this import will then
106            // refer to. In all other cases, this imports refers to the extern
107            // crate diesel.
108            use diesel;
109
110            #item
111        };
112    }
113}
114
115pub fn inner_of_option_ty(ty: &Type) -> &Type {
116    option_ty_arg(ty).unwrap_or(ty)
117}
118
119pub fn is_option_ty(ty: &Type) -> bool {
120    option_ty_arg(ty).is_some()
121}
122
123fn option_ty_arg(mut ty: &Type) -> Option<&Type> {
124    use syn::PathArguments::AngleBracketed;
125
126    // Check the inner equivalent type
127    loop {
128        match ty {
129            Type::Group(group) => ty = &group.elem,
130            Type::Paren(paren) => ty = &paren.elem,
131            _ => break,
132        }
133    }
134
135    match *ty {
136        Type::Path(ref ty) => {
137            let last_segment = ty.path.segments.iter().next_back().unwrap();
138            match last_segment.arguments {
139                AngleBracketed(ref args) if last_segment.ident == "Option" => {
140                    match args.args.iter().next_back() {
141                        Some(GenericArgument::Type(ty)) => Some(ty),
142                        _ => None,
143                    }
144                }
145                _ => None,
146            }
147        }
148        _ => None,
149    }
150}
151
152pub fn ty_for_foreign_derive(item: &DeriveInput, model: &Model) -> Result<Type> {
153    if model.foreign_derive {
154        match item.data {
155            Data::Struct(ref body) => match body.fields.iter().next() {
156                Some(field) => Ok(field.ty.clone()),
157                None => Err(syn::Error::new(
158                    proc_macro2::Span::mixed_site(),
159                    "foreign_derive requires at least one field",
160                )),
161            },
162            _ => Err(syn::Error::new(
163                proc_macro2::Span::mixed_site(),
164                "foreign_derive can only be used with structs",
165            )),
166        }
167    } else {
168        let ident = &item.ident;
169        let (_, ty_generics, ..) = item.generics.split_for_impl();
170        Ok(parse_quote!(#ident #ty_generics))
171    }
172}
173
174pub fn camel_to_snake(name: &str) -> String {
175    let mut result = String::with_capacity(name.len());
176    result.push_str(&name[..1].to_lowercase());
177    for character in name[1..].chars() {
178        if character.is_uppercase() {
179            result.push('_');
180            for lowercase in character.to_lowercase() {
181                result.push(lowercase);
182            }
183        } else {
184            result.push(character);
185        }
186    }
187    result
188}