1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4
5use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
6use proc_macro::TokenStream;
7use proc_macro2::{Ident, Span};
8use quote::quote;
9use syn::*;
10
11#[proc_macro_derive(
33 DbEnum,
34 attributes(PgType, DieselType, ExistingTypePath, DbValueStyle, db_rename)
35)]
36pub fn derive(input: TokenStream) -> TokenStream {
37 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
38
39 let existing_mapping_path = val_from_attrs(&input.attrs, "ExistingTypePath");
40 if !cfg!(feature = "postgres") && existing_mapping_path.is_some() {
41 panic!("ExistingTypePath attribute only applies when the 'postgres' feature is enabled");
42 }
43
44 let pg_internal_type = val_from_attrs(&input.attrs, "PgType");
49
50 if existing_mapping_path.is_some() && pg_internal_type.is_some() {
51 panic!("Cannot specify both `ExistingTypePath` and `PgType` attributes");
52 }
53
54 let pg_internal_type = pg_internal_type.unwrap_or(input.ident.to_string().to_snake_case());
55
56 let new_diesel_mapping = val_from_attrs(&input.attrs, "DieselType");
57 if existing_mapping_path.is_some() && new_diesel_mapping.is_some() {
58 panic!("Cannot specify both `ExistingTypePath` and `DieselType` attributes");
59 }
60 let new_diesel_mapping =
61 new_diesel_mapping.unwrap_or_else(|| format!("{}Mapping", input.ident));
62
63 let case_style =
65 val_from_attrs(&input.attrs, "DbValueStyle").unwrap_or_else(|| "snake_case".to_string());
66 let case_style = CaseStyle::from_string(&case_style);
67
68 let existing_mapping_path = existing_mapping_path.map(|v| {
69 v.parse::<proc_macro2::TokenStream>()
70 .expect("ExistingTypePath is not a valid token")
71 });
72 let new_diesel_mapping = Ident::new(new_diesel_mapping.as_ref(), Span::call_site());
73 if let Data::Enum(syn::DataEnum {
74 variants: data_variants,
75 ..
76 }) = input.data
77 {
78 generate_derive_enum_impls(
79 &existing_mapping_path,
80 &new_diesel_mapping,
81 &pg_internal_type,
82 case_style,
83 &input.ident,
84 &data_variants,
85 )
86 } else {
87 syn::Error::new(
88 Span::call_site(),
89 "derive(DbEnum) can only be applied to enums",
90 )
91 .to_compile_error()
92 .into()
93 }
94}
95
96fn val_from_attrs(attrs: &[Attribute], attrname: &str) -> Option<String> {
97 for attr in attrs {
98 if attr.path().is_ident(attrname) {
99 match &attr.meta {
100 Meta::NameValue(MetaNameValue {
101 value:
102 Expr::Lit(ExprLit {
103 lit: Lit::Str(lit_str),
104 ..
105 }),
106 ..
107 }) => return Some(lit_str.value()),
108 _ => panic!(
109 "Attribute '{}' must have form: {} = \"value\"",
110 attrname, attrname
111 ),
112 }
113 }
114 }
115 None
116}
117
118#[derive(Copy, Clone, Debug, PartialEq)]
120enum CaseStyle {
121 Camel,
122 Kebab,
123 Pascal,
124 Upper,
125 ScreamingSnake,
126 Snake,
127 Verbatim,
128}
129
130impl CaseStyle {
131 fn from_string(name: &str) -> Self {
132 match name {
133 "camelCase" => CaseStyle::Camel,
134 "kebab-case" => CaseStyle::Kebab,
135 "PascalCase" => CaseStyle::Pascal,
136 "SCREAMING_SNAKE_CASE" => CaseStyle::ScreamingSnake,
137 "UPPERCASE" => CaseStyle::Upper,
138 "snake_case" => CaseStyle::Snake,
139 "verbatim" | "verbatimcase" => CaseStyle::Verbatim,
140 s => panic!("unsupported casing: `{}`", s),
141 }
142 }
143}
144
145fn generate_derive_enum_impls(
146 existing_mapping_path: &Option<proc_macro2::TokenStream>,
147 new_diesel_mapping: &Ident,
148 pg_internal_type: &str,
149 case_style: CaseStyle,
150 enum_ty: &Ident,
151 variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
152) -> TokenStream {
153 let modname = Ident::new(&format!("db_enum_impl_{}", enum_ty), Span::call_site());
154 let variant_ids: Vec<proc_macro2::TokenStream> = variants
155 .iter()
156 .map(|variant| {
157 if let Fields::Unit = variant.fields {
158 let id = &variant.ident;
159 quote! {
160 #enum_ty::#id
161 }
162 } else {
163 panic!("Variants must be fieldless")
164 }
165 })
166 .collect();
167
168 let variants_db: Vec<String> = variants
169 .iter()
170 .map(|variant| {
171 val_from_attrs(&variant.attrs, "db_rename")
172 .unwrap_or_else(|| stylize_value(&variant.ident.to_string(), case_style))
173 })
174 .collect();
175 let variants_db_bytes: Vec<LitByteStr> = variants_db
176 .iter()
177 .map(|variant_str| LitByteStr::new(variant_str.as_bytes(), Span::call_site()))
178 .collect();
179
180 let common = generate_common(enum_ty, &variant_ids, &variants_db, &variants_db_bytes);
181 let (diesel_mapping_def, diesel_mapping_use) =
182 if existing_mapping_path.is_some() {
184 (None, None)
185 } else {
186 let new_diesel_mapping_def = generate_new_diesel_mapping(new_diesel_mapping, pg_internal_type);
187 let common_impls_on_new_diesel_mapping =
188 generate_common_impls("e! { #new_diesel_mapping }, enum_ty);
189 (
190 Some(quote! {
191 #new_diesel_mapping_def
192 #common_impls_on_new_diesel_mapping
193 }),
194 Some(quote! {
195 pub use self::#modname::#new_diesel_mapping;
196 }),
197 )
198 };
199
200 let pg_impl = if cfg!(feature = "postgres") {
201 match existing_mapping_path {
202 Some(path) => {
203 let common_impls_on_existing_diesel_mapping = generate_common_impls(path, enum_ty);
204 let postgres_impl = generate_postgres_impl(path, enum_ty, true);
205 Some(quote! {
206 #common_impls_on_existing_diesel_mapping
207 #postgres_impl
208 })
209 }
210 None => Some(generate_postgres_impl(
211 "e! { #new_diesel_mapping },
212 enum_ty,
213 false,
214 )),
215 }
216 } else {
217 None
218 };
219
220 let mysql_impl = if cfg!(feature = "mysql") {
221 Some(generate_mysql_impl(new_diesel_mapping, enum_ty))
222 } else {
223 None
224 };
225
226 let sqlite_impl = if cfg!(feature = "sqlite") {
227 Some(generate_sqlite_impl(new_diesel_mapping, enum_ty))
228 } else {
229 None
230 };
231
232 let imports = quote! {
233 use super::*;
234 use diesel::{
235 backend::{self, Backend},
236 deserialize::{self, FromSql},
237 expression::AsExpression,
238 internal::derives::as_expression::Bound,
239 query_builder::{bind_collector::RawBytesBindCollector},
240 row::Row,
241 serialize::{self, IsNull, Output, ToSql},
242 sql_types::*,
243 Queryable,
244 };
245 use std::io::Write;
246 };
247
248 let quoted = quote! {
249 #diesel_mapping_use
250 #[allow(non_snake_case)]
251 mod #modname {
252 #imports
253
254 #common
255 #diesel_mapping_def
256 #pg_impl
257 #mysql_impl
258 #sqlite_impl
259 }
260 };
261
262 quoted.into()
263}
264
265fn stylize_value(value: &str, style: CaseStyle) -> String {
266 match style {
267 CaseStyle::Camel => value.to_lower_camel_case(),
268 CaseStyle::Kebab => value.to_kebab_case(),
269 CaseStyle::Pascal => value.to_upper_camel_case(),
270 CaseStyle::Upper => value.to_uppercase(),
271 CaseStyle::ScreamingSnake => value.to_shouty_snake_case(),
272 CaseStyle::Snake => value.to_snake_case(),
273 CaseStyle::Verbatim => value.to_string(),
274 }
275}
276
277fn generate_common(
278 enum_ty: &Ident,
279 variants_rs: &[proc_macro2::TokenStream],
280 variants_db: &[String],
281 variants_db_bytes: &[LitByteStr],
282) -> proc_macro2::TokenStream {
283 quote! {
284 fn db_str_representation(e: &#enum_ty) -> &'static str {
285 match *e {
286 #(#variants_rs => #variants_db,)*
287 }
288 }
289
290 fn from_db_binary_representation(bytes: &[u8]) -> deserialize::Result<#enum_ty> {
291 match bytes {
292 #(#variants_db_bytes => Ok(#variants_rs),)*
293 v => Err(format!("Unrecognized enum variant: '{}'",
294 String::from_utf8_lossy(v)).into()),
295 }
296 }
297 }
298}
299
300fn generate_new_diesel_mapping(
301 new_diesel_mapping: &Ident,
302 pg_internal_type: &str,
303) -> proc_macro2::TokenStream {
304 quote! {
307 #[derive(Clone, SqlType, diesel::query_builder::QueryId)]
308 #[diesel(mysql_type(name = "Enum"))]
309 #[diesel(sqlite_type(name = "Text"))]
310 #[diesel(postgres_type(name = #pg_internal_type))]
311 pub struct #new_diesel_mapping;
312 }
313}
314
315fn generate_common_impls(
316 diesel_mapping: &proc_macro2::TokenStream,
317 enum_ty: &Ident,
318) -> proc_macro2::TokenStream {
319 quote! {
320 impl AsExpression<#diesel_mapping> for #enum_ty {
321 type Expression = Bound<#diesel_mapping, Self>;
322
323 fn as_expression(self) -> Self::Expression {
324 Bound::new(self)
325 }
326 }
327
328 impl AsExpression<Nullable<#diesel_mapping>> for #enum_ty {
329 type Expression = Bound<Nullable<#diesel_mapping>, Self>;
330
331 fn as_expression(self) -> Self::Expression {
332 Bound::new(self)
333 }
334 }
335
336 impl<'a> AsExpression<#diesel_mapping> for &'a #enum_ty {
337 type Expression = Bound<#diesel_mapping, Self>;
338
339 fn as_expression(self) -> Self::Expression {
340 Bound::new(self)
341 }
342 }
343
344 impl<'a> AsExpression<Nullable<#diesel_mapping>> for &'a #enum_ty {
345 type Expression = Bound<Nullable<#diesel_mapping>, Self>;
346
347 fn as_expression(self) -> Self::Expression {
348 Bound::new(self)
349 }
350 }
351
352 impl<'a, 'b> AsExpression<#diesel_mapping> for &'a &'b #enum_ty {
353 type Expression = Bound<#diesel_mapping, Self>;
354
355 fn as_expression(self) -> Self::Expression {
356 Bound::new(self)
357 }
358 }
359
360 impl<'a, 'b> AsExpression<Nullable<#diesel_mapping>> for &'a &'b #enum_ty {
361 type Expression = Bound<Nullable<#diesel_mapping>, Self>;
362
363 fn as_expression(self) -> Self::Expression {
364 Bound::new(self)
365 }
366 }
367
368 impl<DB> ToSql<Nullable<#diesel_mapping>, DB> for #enum_ty
369 where
370 DB: Backend,
371 Self: ToSql<#diesel_mapping, DB>,
372 {
373 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, DB>) -> serialize::Result {
374 ToSql::<#diesel_mapping, DB>::to_sql(self, out)
375 }
376 }
377 }
378}
379
380fn generate_postgres_impl(
381 diesel_mapping: &proc_macro2::TokenStream,
382 enum_ty: &Ident,
383 with_clone: bool,
384) -> proc_macro2::TokenStream {
385 let clone_impl = if with_clone {
388 Some(quote! {
389 impl Clone for #diesel_mapping {
390 fn clone(&self) -> Self {
391 #diesel_mapping
392 }
393 }
394 })
395 } else {
396 None
397 };
398
399 quote! {
400 mod pg_impl {
401 use super::*;
402 use diesel::pg::{Pg, PgValue};
403
404 #clone_impl
405
406 impl FromSql<#diesel_mapping, Pg> for #enum_ty {
407 fn from_sql(raw: PgValue) -> deserialize::Result<Self> {
408 from_db_binary_representation(raw.as_bytes())
409 }
410 }
411
412 impl ToSql<#diesel_mapping, Pg> for #enum_ty
413 {
414 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
415 out.write_all(db_str_representation(self).as_bytes())?;
416 Ok(IsNull::No)
417 }
418 }
419
420 impl Queryable<#diesel_mapping, Pg> for #enum_ty {
421 type Row = Self;
422
423 fn build(row: Self::Row) -> deserialize::Result<Self> {
424 Ok(row)
425 }
426 }
427 }
428 }
429}
430
431fn generate_mysql_impl(diesel_mapping: &Ident, enum_ty: &Ident) -> proc_macro2::TokenStream {
432 quote! {
433 mod mysql_impl {
434 use super::*;
435 use diesel;
436 use diesel::mysql::{Mysql, MysqlValue};
437
438 impl FromSql<#diesel_mapping, Mysql> for #enum_ty {
439 fn from_sql(raw: MysqlValue) -> deserialize::Result<Self> {
440 from_db_binary_representation(raw.as_bytes())
441 }
442 }
443
444 impl ToSql<#diesel_mapping, Mysql> for #enum_ty
445 {
446 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Mysql>) -> serialize::Result {
447 out.write_all(db_str_representation(self).as_bytes())?;
448 Ok(IsNull::No)
449 }
450 }
451
452 impl Queryable<#diesel_mapping, Mysql> for #enum_ty {
453 type Row = Self;
454
455 fn build(row: Self::Row) -> deserialize::Result<Self> {
456 Ok(row)
457 }
458 }
459 }
460 }
461}
462
463fn generate_sqlite_impl(diesel_mapping: &Ident, enum_ty: &Ident) -> proc_macro2::TokenStream {
464 quote! {
465 mod sqlite_impl {
466 use super::*;
467 use diesel;
468 use diesel::sql_types;
469 use diesel::sqlite::Sqlite;
470
471 impl FromSql<#diesel_mapping, Sqlite> for #enum_ty {
472 fn from_sql(value: backend::RawValue<Sqlite>) -> deserialize::Result<Self> {
473 let bytes = <Vec<u8> as FromSql<sql_types::Binary, Sqlite>>::from_sql(value)?;
474 from_db_binary_representation(bytes.as_slice())
475 }
476 }
477
478 impl ToSql<#diesel_mapping, Sqlite> for #enum_ty {
479 fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
480 <str as ToSql<sql_types::Text, Sqlite>>::to_sql(db_str_representation(self), out)
481 }
482 }
483
484 impl Queryable<#diesel_mapping, Sqlite> for #enum_ty {
485 type Row = Self;
486
487 fn build(row: Self::Row) -> deserialize::Result<Self> {
488 Ok(row)
489 }
490 }
491 }
492 }
493}