1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream, Peek, Result};
use syn::token::Eq;
use syn::{parenthesized, parse_quote, Data, DeriveInput, GenericArgument, Ident, Type};

use crate::model::Model;

pub const COLUMN_NAME_NOTE: &str = "column_name = foo";
pub const SQL_TYPE_NOTE: &str = "sql_type = Foo";
pub const SERIALIZE_AS_NOTE: &str = "serialize_as = Foo";
pub const DESERIALIZE_AS_NOTE: &str = "deserialize_as = Foo";
pub const TABLE_NAME_NOTE: &str = "table_name = foo";
pub const TREAT_NONE_AS_DEFAULT_VALUE_NOTE: &str = "treat_none_as_default_value = true";
pub const TREAT_NONE_AS_NULL_NOTE: &str = "treat_none_as_null = true";
pub const BELONGS_TO_NOTE: &str = "belongs_to(Foo, foreign_key = foo_id)";
pub const MYSQL_TYPE_NOTE: &str = "mysql_type(name = \"foo\")";
pub const SQLITE_TYPE_NOTE: &str = "sqlite_type(name = \"foo\")";
pub const POSTGRES_TYPE_NOTE: &str = "postgres_type(name = \"foo\", schema = \"public\")";
pub const POSTGRES_TYPE_NOTE_ID: &str = "postgres_type(oid = 37, array_oid = 54)";
pub const SELECT_EXPRESSION_NOTE: &str =
    "select_expression = schema::table_name::column_name.is_not_null()";
pub const SELECT_EXPRESSION_TYPE_NOTE: &str =
    "select_expression_type = dsl::IsNotNull<schema::table_name::column_name>";
pub const CHECK_FOR_BACKEND_NOTE: &str = "diesel::pg::Pg";

pub fn unknown_attribute(name: &Ident, valid: &[&str]) -> syn::Error {
    let prefix = if valid.len() == 1 { "" } else { " one of" };

    syn::Error::new(
        name.span(),
        format!(
            "unknown attribute, expected{prefix} `{}`",
            valid.join("`, `")
        ),
    )
}

pub fn parse_eq<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
    if input.is_empty() {
        return Err(syn::Error::new(
            input.span(),
            format!(
                "unexpected end of input, expected `=`\n\
                 help: The correct format looks like `#[diesel({help})]`",
            ),
        ));
    }

    input.parse::<Eq>()?;
    input.parse()
}

pub fn parse_paren<T: Parse>(input: ParseStream, help: &str) -> Result<T> {
    if input.is_empty() {
        return Err(syn::Error::new(
            input.span(),
            format!(
                "unexpected end of input, expected parentheses\n\
                 help: The correct format looks like `#[diesel({help})]`",
            ),
        ));
    }

    let content;
    parenthesized!(content in input);
    content.parse()
}

pub fn parse_paren_list<T, D>(
    input: ParseStream,
    help: &str,
    sep: D,
) -> Result<syn::punctuated::Punctuated<T, <D as Peek>::Token>>
where
    T: Parse,
    D: Peek,
    D::Token: Parse,
{
    if input.is_empty() {
        return Err(syn::Error::new(
            input.span(),
            format!(
                "unexpected end of input, expected parentheses\n\
                 help: The correct format looks like `#[diesel({help})]`",
            ),
        ));
    }

    let content;
    parenthesized!(content in input);
    content.parse_terminated(T::parse, sep)
}

pub fn wrap_in_dummy_mod(item: TokenStream) -> TokenStream {
    quote! {
        #[allow(unused_imports)]
        const _: () = {
            // This import is not actually redundant. When using diesel_derives
            // inside of diesel, `diesel` doesn't exist as an extern crate, and
            // to work around that it contains a private
            // `mod diesel { pub use super::*; }` that this import will then
            // refer to. In all other cases, this imports refers to the extern
            // crate diesel.
            use diesel;

            #item
        };
    }
}

pub fn inner_of_option_ty(ty: &Type) -> &Type {
    option_ty_arg(ty).unwrap_or(ty)
}

pub fn is_option_ty(ty: &Type) -> bool {
    option_ty_arg(ty).is_some()
}

fn option_ty_arg(ty: &Type) -> Option<&Type> {
    use syn::PathArguments::AngleBracketed;

    match *ty {
        Type::Path(ref ty) => {
            let last_segment = ty.path.segments.iter().last().unwrap();
            match last_segment.arguments {
                AngleBracketed(ref args) if last_segment.ident == "Option" => {
                    match args.args.iter().last() {
                        Some(GenericArgument::Type(ty)) => Some(ty),
                        _ => None,
                    }
                }
                _ => None,
            }
        }
        _ => None,
    }
}

pub fn ty_for_foreign_derive(item: &DeriveInput, model: &Model) -> Result<Type> {
    if model.foreign_derive {
        match item.data {
            Data::Struct(ref body) => match body.fields.iter().next() {
                Some(field) => Ok(field.ty.clone()),
                None => Err(syn::Error::new(
                    proc_macro2::Span::call_site(),
                    "foreign_derive requires at least one field",
                )),
            },
            _ => Err(syn::Error::new(
                proc_macro2::Span::call_site(),
                "foreign_derive can only be used with structs",
            )),
        }
    } else {
        let ident = &item.ident;
        let (_, ty_generics, ..) = item.generics.split_for_impl();
        Ok(parse_quote!(#ident #ty_generics))
    }
}

pub fn camel_to_snake(name: &str) -> String {
    let mut result = String::with_capacity(name.len());
    result.push_str(&name[..1].to_lowercase());
    for character in name[1..].chars() {
        if character.is_uppercase() {
            result.push('_');
            for lowercase in character.to_lowercase() {
                result.push(lowercase);
            }
        } else {
            result.push(character);
        }
    }
    result
}