diesel_derive_enum/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4
5use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, Span};
8use quote::quote;
9use syn::*;
10
11/// Implement the traits necessary for inserting the enum directly into a database
12///
13/// # Attributes
14///
15/// ## Type attributes
16///
17/// * `#[ExistingTypePath = "crate::schema::sql_types::NewEnum"]` specifies
18///   the path to a corresponding diesel type that was already created by the
19///   diesel CLI. If omitted, the type will be generated by this macro.
20///   *Note*: Only applies to `postgres`, will error if specified for other databases
21/// * `#[DieselType = "NewEnumMapping"]` specifies the name for the diesel type
22///   to create. If omitted, uses `<enum name>Mapping`.
23///   *Note*: Cannot be specified alongside `ExistingTypePath`
24/// * `#[DbValueStyle = "snake_case"]` specifies a renaming style from each of
25///   the rust enum variants to each of the database variants. Either `camelCase`,
26///   `kebab-case`, `PascalCase`, `SCREAMING_SNAKE_CASE`, `snake_case`,
27///   `verbatim`. If omitted, uses `snake_case`.
28///
29/// ## Variant attributes
30///
31/// * `#[db_rename = "variant"]` specifies the db name for a specific variant.
32#[proc_macro_derive(
33    DbEnum,
34    attributes(PgType, DieselType, ExistingTypePath, DbValueStyle, db_rename)
35)]
36pub fn derive(input: TokenStream) -> TokenStream {
37    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
38
39    let existing_mapping_path = val_from_attrs(&input.attrs, "ExistingTypePath");
40    if !cfg!(feature = "postgres") && existing_mapping_path.is_some() {
41        panic!("ExistingTypePath attribute only applies when the 'postgres' feature is enabled");
42    }
43
44    // we could allow a default value here but... I'm not very keen
45    // let existing_mapping_path = existing_mapping_path
46    //     .unwrap_or_else(|| format!("crate::schema::sql_types::{}", input.ident));
47
48    let pg_internal_type = val_from_attrs(&input.attrs, "PgType");
49
50    if existing_mapping_path.is_some() && pg_internal_type.is_some() {
51        panic!("Cannot specify both `ExistingTypePath` and `PgType` attributes");
52    }
53
54    let pg_internal_type = pg_internal_type.unwrap_or(input.ident.to_string().to_snake_case());
55
56    let new_diesel_mapping = val_from_attrs(&input.attrs, "DieselType");
57    if existing_mapping_path.is_some() && new_diesel_mapping.is_some() {
58        panic!("Cannot specify both `ExistingTypePath` and `DieselType` attributes");
59    }
60    let new_diesel_mapping =
61        new_diesel_mapping.unwrap_or_else(|| format!("{}Mapping", input.ident));
62
63    // Maintain backwards compatibility by defaulting to snake case.
64    let case_style =
65        val_from_attrs(&input.attrs, "DbValueStyle").unwrap_or_else(|| "snake_case".to_string());
66    let case_style = CaseStyle::from_string(&case_style);
67
68    let existing_mapping_path = existing_mapping_path.map(|v| {
69        v.parse::<proc_macro2::TokenStream>()
70            .expect("ExistingTypePath is not a valid token")
71    });
72    let new_diesel_mapping = Ident::new(new_diesel_mapping.as_ref(), Span::call_site());
73    if let Data::Enum(syn::DataEnum {
74        variants: data_variants,
75        ..
76    }) = input.data
77    {
78        generate_derive_enum_impls(
79            &existing_mapping_path,
80            &new_diesel_mapping,
81            &pg_internal_type,
82            case_style,
83            &input.ident,
84            &data_variants,
85        )
86    } else {
87        syn::Error::new(
88            Span::call_site(),
89            "derive(DbEnum) can only be applied to enums",
90        )
91        .to_compile_error()
92        .into()
93    }
94}
95
96fn val_from_attrs(attrs: &[Attribute], attrname: &str) -> Option<String> {
97    for attr in attrs {
98        if attr.path().is_ident(attrname) {
99            match &attr.meta {
100                Meta::NameValue(MetaNameValue {
101                    value:
102                        Expr::Lit(ExprLit {
103                            lit: Lit::Str(lit_str),
104                            ..
105                        }),
106                    ..
107                }) => return Some(lit_str.value()),
108                _ => panic!(
109                    "Attribute '{}' must have form: {} = \"value\"",
110                    attrname, attrname
111                ),
112            }
113        }
114    }
115    None
116}
117
118/// Defines the casing for the database representation.  Follows serde naming convention.
119#[derive(Copy, Clone, Debug, PartialEq)]
120enum CaseStyle {
121    Camel,
122    Kebab,
123    Pascal,
124    Upper,
125    ScreamingSnake,
126    Snake,
127    Verbatim,
128}
129
130impl CaseStyle {
131    fn from_string(name: &str) -> Self {
132        match name {
133            "camelCase" => CaseStyle::Camel,
134            "kebab-case" => CaseStyle::Kebab,
135            "PascalCase" => CaseStyle::Pascal,
136            "SCREAMING_SNAKE_CASE" => CaseStyle::ScreamingSnake,
137            "UPPERCASE" => CaseStyle::Upper,
138            "snake_case" => CaseStyle::Snake,
139            "verbatim" | "verbatimcase" => CaseStyle::Verbatim,
140            s => panic!("unsupported casing: `{}`", s),
141        }
142    }
143}
144
145fn generate_derive_enum_impls(
146    existing_mapping_path: &Option<proc_macro2::TokenStream>,
147    new_diesel_mapping: &Ident,
148    pg_internal_type: &str,
149    case_style: CaseStyle,
150    enum_ty: &Ident,
151    variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
152) -> TokenStream {
153    let modname = Ident::new(&format!("db_enum_impl_{}", enum_ty), Span::call_site());
154    let variant_ids: Vec<proc_macro2::TokenStream> = variants
155        .iter()
156        .map(|variant| {
157            if let Fields::Unit = variant.fields {
158                let id = &variant.ident;
159                quote! {
160                    #enum_ty::#id
161                }
162            } else {
163                panic!("Variants must be fieldless")
164            }
165        })
166        .collect();
167
168    let variants_db: Vec<String> = variants
169        .iter()
170        .map(|variant| {
171            val_from_attrs(&variant.attrs, "db_rename")
172                .unwrap_or_else(|| stylize_value(&variant.ident.to_string(), case_style))
173        })
174        .collect();
175    let variants_db_bytes: Vec<LitByteStr> = variants_db
176        .iter()
177        .map(|variant_str| LitByteStr::new(variant_str.as_bytes(), Span::call_site()))
178        .collect();
179
180    let common = generate_common(enum_ty, &variant_ids, &variants_db, &variants_db_bytes);
181    let (diesel_mapping_def, diesel_mapping_use) =
182        // Skip this part if we already have an existing mapping
183        if existing_mapping_path.is_some() {
184            (None, None)
185        } else {
186            let new_diesel_mapping_def = generate_new_diesel_mapping(new_diesel_mapping, pg_internal_type);
187            let common_impls_on_new_diesel_mapping =
188                generate_common_impls(&quote! { #new_diesel_mapping }, enum_ty);
189            (
190                Some(quote! {
191                    #new_diesel_mapping_def
192                    #common_impls_on_new_diesel_mapping
193                }),
194                Some(quote! {
195                    pub use self::#modname::#new_diesel_mapping;
196                }),
197            )
198        };
199
200    let pg_impl = if cfg!(feature = "postgres") {
201        match existing_mapping_path {
202            Some(path) => {
203                let common_impls_on_existing_diesel_mapping = generate_common_impls(path, enum_ty);
204                let postgres_impl = generate_postgres_impl(path, enum_ty, true);
205                Some(quote! {
206                    #common_impls_on_existing_diesel_mapping
207                    #postgres_impl
208                })
209            }
210            None => Some(generate_postgres_impl(
211                &quote! { #new_diesel_mapping },
212                enum_ty,
213                false,
214            )),
215        }
216    } else {
217        None
218    };
219
220    let mysql_impl = if cfg!(feature = "mysql") {
221        Some(generate_mysql_impl(new_diesel_mapping, enum_ty))
222    } else {
223        None
224    };
225
226    let sqlite_impl = if cfg!(feature = "sqlite") {
227        Some(generate_sqlite_impl(new_diesel_mapping, enum_ty))
228    } else {
229        None
230    };
231
232    let imports = quote! {
233        use super::*;
234        use diesel::{
235            backend::{self, Backend},
236            deserialize::{self, FromSql},
237            expression::AsExpression,
238            internal::derives::as_expression::Bound,
239            query_builder::{bind_collector::RawBytesBindCollector},
240            row::Row,
241            serialize::{self, IsNull, Output, ToSql},
242            sql_types::*,
243            Queryable,
244        };
245        use std::io::Write;
246    };
247
248    let quoted = quote! {
249        #diesel_mapping_use
250        #[allow(non_snake_case)]
251        mod #modname {
252            #imports
253
254            #common
255            #diesel_mapping_def
256            #pg_impl
257            #mysql_impl
258            #sqlite_impl
259        }
260    };
261
262    quoted.into()
263}
264
265fn stylize_value(value: &str, style: CaseStyle) -> String {
266    match style {
267        CaseStyle::Camel => value.to_lower_camel_case(),
268        CaseStyle::Kebab => value.to_kebab_case(),
269        CaseStyle::Pascal => value.to_upper_camel_case(),
270        CaseStyle::Upper => value.to_uppercase(),
271        CaseStyle::ScreamingSnake => value.to_shouty_snake_case(),
272        CaseStyle::Snake => value.to_snake_case(),
273        CaseStyle::Verbatim => value.to_string(),
274    }
275}
276
277fn generate_common(
278    enum_ty: &Ident,
279    variants_rs: &[proc_macro2::TokenStream],
280    variants_db: &[String],
281    variants_db_bytes: &[LitByteStr],
282) -> proc_macro2::TokenStream {
283    quote! {
284        fn db_str_representation(e: &#enum_ty) -> &'static str {
285            match *e {
286                #(#variants_rs => #variants_db,)*
287            }
288        }
289
290        fn from_db_binary_representation(bytes: &[u8]) -> deserialize::Result<#enum_ty> {
291            match bytes {
292                #(#variants_db_bytes => Ok(#variants_rs),)*
293                v => Err(format!("Unrecognized enum variant: '{}'",
294                    String::from_utf8_lossy(v)).into()),
295            }
296        }
297    }
298}
299
300fn generate_new_diesel_mapping(
301    new_diesel_mapping: &Ident,
302    pg_internal_type: &str,
303) -> proc_macro2::TokenStream {
304    // Note - we only generate a new mapping for mysql and sqlite, postgres
305    // should already have one
306    quote! {
307        #[derive(Clone, SqlType, diesel::query_builder::QueryId)]
308        #[diesel(mysql_type(name = "Enum"))]
309        #[diesel(sqlite_type(name = "Text"))]
310        #[diesel(postgres_type(name = #pg_internal_type))]
311        pub struct #new_diesel_mapping;
312    }
313}
314
315fn generate_common_impls(
316    diesel_mapping: &proc_macro2::TokenStream,
317    enum_ty: &Ident,
318) -> proc_macro2::TokenStream {
319    quote! {
320        impl AsExpression<#diesel_mapping> for #enum_ty {
321            type Expression = Bound<#diesel_mapping, Self>;
322
323            fn as_expression(self) -> Self::Expression {
324                Bound::new(self)
325            }
326        }
327
328        impl AsExpression<Nullable<#diesel_mapping>> for #enum_ty {
329            type Expression = Bound<Nullable<#diesel_mapping>, Self>;
330
331            fn as_expression(self) -> Self::Expression {
332                Bound::new(self)
333            }
334        }
335
336        impl<'a> AsExpression<#diesel_mapping> for &'a #enum_ty {
337            type Expression = Bound<#diesel_mapping, Self>;
338
339            fn as_expression(self) -> Self::Expression {
340                Bound::new(self)
341            }
342        }
343
344        impl<'a> AsExpression<Nullable<#diesel_mapping>> for &'a #enum_ty {
345            type Expression = Bound<Nullable<#diesel_mapping>, Self>;
346
347            fn as_expression(self) -> Self::Expression {
348                Bound::new(self)
349            }
350        }
351
352        impl<'a, 'b> AsExpression<#diesel_mapping> for &'a &'b #enum_ty {
353            type Expression = Bound<#diesel_mapping, Self>;
354
355            fn as_expression(self) -> Self::Expression {
356                Bound::new(self)
357            }
358        }
359
360        impl<'a, 'b> AsExpression<Nullable<#diesel_mapping>> for &'a &'b #enum_ty {
361            type Expression = Bound<Nullable<#diesel_mapping>, Self>;
362
363            fn as_expression(self) -> Self::Expression {
364                Bound::new(self)
365            }
366        }
367
368        impl<DB> ToSql<Nullable<#diesel_mapping>, DB> for #enum_ty
369        where
370            DB: Backend,
371            Self: ToSql<#diesel_mapping, DB>,
372        {
373            fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, DB>) -> serialize::Result {
374                ToSql::<#diesel_mapping, DB>::to_sql(self, out)
375            }
376        }
377    }
378}
379
380fn generate_postgres_impl(
381    diesel_mapping: &proc_macro2::TokenStream,
382    enum_ty: &Ident,
383    with_clone: bool,
384) -> proc_macro2::TokenStream {
385    // If the type was generated by postgres, we have to manually add a clone impl,
386    // if generated by 'us' it has already been done
387    let clone_impl = if with_clone {
388        Some(quote! {
389            impl Clone for #diesel_mapping {
390                fn clone(&self) -> Self {
391                    #diesel_mapping
392                }
393            }
394        })
395    } else {
396        None
397    };
398
399    quote! {
400        mod pg_impl {
401            use super::*;
402            use diesel::pg::{Pg, PgValue};
403
404            #clone_impl
405
406            impl FromSql<#diesel_mapping, Pg> for #enum_ty {
407                fn from_sql(raw: PgValue) -> deserialize::Result<Self> {
408                    from_db_binary_representation(raw.as_bytes())
409                }
410            }
411
412            impl ToSql<#diesel_mapping, Pg> for #enum_ty
413            {
414                fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
415                    out.write_all(db_str_representation(self).as_bytes())?;
416                    Ok(IsNull::No)
417                }
418            }
419
420            impl Queryable<#diesel_mapping, Pg> for #enum_ty {
421                type Row = Self;
422
423                fn build(row: Self::Row) -> deserialize::Result<Self> {
424                    Ok(row)
425                }
426            }
427        }
428    }
429}
430
431fn generate_mysql_impl(diesel_mapping: &Ident, enum_ty: &Ident) -> proc_macro2::TokenStream {
432    quote! {
433        mod mysql_impl {
434            use super::*;
435            use diesel;
436            use diesel::mysql::{Mysql, MysqlValue};
437
438            impl FromSql<#diesel_mapping, Mysql> for #enum_ty {
439                fn from_sql(raw: MysqlValue) -> deserialize::Result<Self> {
440                    from_db_binary_representation(raw.as_bytes())
441                }
442            }
443
444            impl ToSql<#diesel_mapping, Mysql> for #enum_ty
445            {
446                fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
447                    out.write_all(db_str_representation(self).as_bytes())?;
448                    Ok(IsNull::No)
449                }
450            }
451
452            impl Queryable<#diesel_mapping, Mysql> for #enum_ty {
453                type Row = Self;
454
455                fn build(row: Self::Row) -> deserialize::Result<Self> {
456                    Ok(row)
457                }
458            }
459        }
460    }
461}
462
463fn generate_sqlite_impl(diesel_mapping: &Ident, enum_ty: &Ident) -> proc_macro2::TokenStream {
464    quote! {
465        mod sqlite_impl {
466            use super::*;
467            use diesel;
468            use diesel::sql_types;
469            use diesel::sqlite::Sqlite;
470
471            impl FromSql<#diesel_mapping, Sqlite> for #enum_ty {
472                fn from_sql(value: backend::RawValue<Sqlite>) -> deserialize::Result<Self> {
473                    let bytes = <Vec<u8> as FromSql<sql_types::Binary, Sqlite>>::from_sql(value)?;
474                    from_db_binary_representation(bytes.as_slice())
475                }
476            }
477
478            impl ToSql<#diesel_mapping, Sqlite> for #enum_ty {
479                fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
480                    <str as ToSql<sql_types::Text, Sqlite>>::to_sql(db_str_representation(self), out)
481                }
482            }
483
484            impl Queryable<#diesel_mapping, Sqlite> for #enum_ty {
485                type Row = Self;
486
487                fn build(row: Self::Row) -> deserialize::Result<Self> {
488                    Ok(row)
489                }
490            }
491        }
492    }
493}