utoipa_gen/
schema_type.rs

1use proc_macro2::TokenStream;
2use proc_macro_error::abort_call_site;
3use quote::{quote, ToTokens};
4use syn::{parse::Parse, Error, Ident, LitStr, Path};
5
6/// Tokenizes OpenAPI data type correctly according to the Rust type
7pub struct SchemaType<'a>(pub &'a syn::Path);
8
9impl SchemaType<'_> {
10    fn last_segment_to_string(&self) -> String {
11        self.0
12            .segments
13            .last()
14            .expect("Expected at least one segment is_integer")
15            .ident
16            .to_string()
17    }
18
19    pub fn is_value(&self) -> bool {
20        matches!(&*self.last_segment_to_string(), "Value")
21    }
22
23    /// Check whether type is known to be primitive in which case returns true.
24    pub fn is_primitive(&self) -> bool {
25        let SchemaType(path) = self;
26        let last_segment = match path.segments.last() {
27            Some(segment) => segment,
28            None => return false,
29        };
30        let name = &*last_segment.ident.to_string();
31
32        #[cfg(not(any(
33            feature = "chrono",
34            feature = "decimal",
35            feature = "rocket_extras",
36            feature = "uuid",
37            feature = "ulid",
38            feature = "time",
39        )))]
40        {
41            is_primitive(name)
42        }
43
44        #[cfg(any(
45            feature = "chrono",
46            feature = "decimal",
47            feature = "rocket_extras",
48            feature = "uuid",
49            feature = "ulid",
50            feature = "time",
51        ))]
52        {
53            let mut primitive = is_primitive(name);
54
55            #[cfg(feature = "chrono")]
56            if !primitive {
57                primitive = is_primitive_chrono(name);
58            }
59
60            #[cfg(feature = "decimal")]
61            if !primitive {
62                primitive = is_primitive_rust_decimal(name);
63            }
64
65            #[cfg(feature = "rocket_extras")]
66            if !primitive {
67                primitive = matches!(name, "PathBuf");
68            }
69
70            #[cfg(feature = "uuid")]
71            if !primitive {
72                primitive = matches!(name, "Uuid");
73            }
74
75            #[cfg(feature = "ulid")]
76            if !primitive {
77                primitive = matches!(name, "Ulid");
78            }
79
80            #[cfg(feature = "time")]
81            if !primitive {
82                primitive = matches!(
83                    name,
84                    "Date" | "PrimitiveDateTime" | "OffsetDateTime" | "Duration"
85                );
86            }
87
88            primitive
89        }
90    }
91
92    pub fn is_integer(&self) -> bool {
93        matches!(
94            &*self.last_segment_to_string(),
95            "i8" | "i16"
96                | "i32"
97                | "i64"
98                | "i128"
99                | "isize"
100                | "u8"
101                | "u16"
102                | "u32"
103                | "u64"
104                | "u128"
105                | "usize"
106        )
107    }
108
109    pub fn is_unsigned_integer(&self) -> bool {
110        matches!(
111            &*self.last_segment_to_string(),
112            "u8" | "u16" | "u32" | "u64" | "u128" | "usize"
113        )
114    }
115
116    pub fn is_number(&self) -> bool {
117        match &*self.last_segment_to_string() {
118            "f32" | "f64" => true,
119            _ if self.is_integer() => true,
120            _ => false,
121        }
122    }
123
124    pub fn is_string(&self) -> bool {
125        matches!(&*self.last_segment_to_string(), "str" | "String")
126    }
127
128    pub fn is_byte(&self) -> bool {
129        matches!(&*self.last_segment_to_string(), "u8")
130    }
131}
132
133#[inline]
134fn is_primitive(name: &str) -> bool {
135    matches!(
136        name,
137        "String"
138            | "str"
139            | "char"
140            | "bool"
141            | "usize"
142            | "u8"
143            | "u16"
144            | "u32"
145            | "u64"
146            | "u128"
147            | "isize"
148            | "i8"
149            | "i16"
150            | "i32"
151            | "i64"
152            | "i128"
153            | "f32"
154            | "f64"
155    )
156}
157
158#[inline]
159#[cfg(feature = "chrono")]
160fn is_primitive_chrono(name: &str) -> bool {
161    matches!(
162        name,
163        "DateTime" | "Date" | "NaiveDate" | "NaiveTime" | "Duration" | "NaiveDateTime"
164    )
165}
166
167#[inline]
168#[cfg(feature = "decimal")]
169fn is_primitive_rust_decimal(name: &str) -> bool {
170    matches!(name, "Decimal")
171}
172
173impl ToTokens for SchemaType<'_> {
174    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
175        let last_segment = self.0.segments.last().unwrap_or_else(|| {
176            abort_call_site!("expected there to be at least one segment in the path")
177        });
178        let name = &*last_segment.ident.to_string();
179
180        match name {
181            "String" | "str" | "char" => {
182                tokens.extend(quote! {utoipa::openapi::SchemaType::String})
183            }
184
185            "bool" => tokens.extend(quote! { utoipa::openapi::SchemaType::Boolean }),
186
187            "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
188            | "u128" | "usize" => tokens.extend(quote! { utoipa::openapi::SchemaType::Integer }),
189            "f32" | "f64" => tokens.extend(quote! { utoipa::openapi::SchemaType::Number }),
190
191            #[cfg(feature = "chrono")]
192            "DateTime" | "NaiveDateTime" | "NaiveDate" | "NaiveTime" => {
193                tokens.extend(quote! { utoipa::openapi::SchemaType::String })
194            }
195
196            #[cfg(any(feature = "chrono", feature = "time"))]
197            "Date" | "Duration" => tokens.extend(quote! { utoipa::openapi::SchemaType::String }),
198
199            #[cfg(feature = "decimal")]
200            "Decimal" => tokens.extend(quote! { utoipa::openapi::SchemaType::String }),
201
202            #[cfg(feature = "rocket_extras")]
203            "PathBuf" => tokens.extend(quote! { utoipa::openapi::SchemaType::String }),
204
205            #[cfg(feature = "uuid")]
206            "Uuid" => tokens.extend(quote! { utoipa::openapi::SchemaType::String }),
207
208            #[cfg(feature = "ulid")]
209            "Ulid" => tokens.extend(quote! { utoipa::openapi::SchemaType::String }),
210
211            #[cfg(feature = "time")]
212            "PrimitiveDateTime" | "OffsetDateTime" => {
213                tokens.extend(quote! { utoipa::openapi::SchemaType::String })
214            }
215            _ => tokens.extend(quote! { utoipa::openapi::SchemaType::Object }),
216        }
217    }
218}
219
220/// Either Rust type component variant or enum variant schema variant.
221#[derive(Clone)]
222#[cfg_attr(feature = "debug", derive(Debug))]
223pub enum SchemaFormat<'c> {
224    /// [`utoipa::openapi::schema::SchemaFormat`] enum variant schema format.
225    Variant(Variant),
226    /// Rust type schema format.
227    Type(Type<'c>),
228}
229
230impl SchemaFormat<'_> {
231    pub fn is_known_format(&self) -> bool {
232        match self {
233            Self::Type(ty) => ty.is_known_format(),
234            Self::Variant(_) => true,
235        }
236    }
237}
238
239impl<'a> From<&'a Path> for SchemaFormat<'a> {
240    fn from(path: &'a Path) -> Self {
241        Self::Type(Type(path))
242    }
243}
244
245impl Parse for SchemaFormat<'_> {
246    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
247        Ok(Self::Variant(input.parse()?))
248    }
249}
250
251impl ToTokens for SchemaFormat<'_> {
252    fn to_tokens(&self, tokens: &mut TokenStream) {
253        match self {
254            Self::Type(ty) => ty.to_tokens(tokens),
255            Self::Variant(variant) => variant.to_tokens(tokens),
256        }
257    }
258}
259
260/// Tokenizes OpenAPI data type format correctly by given Rust type.
261#[derive(Clone)]
262#[cfg_attr(feature = "debug", derive(Debug))]
263pub struct Type<'a>(&'a syn::Path);
264
265impl Type<'_> {
266    /// Check is the format know format. Known formats can be used within `quote! {...}` statements.
267    pub fn is_known_format(&self) -> bool {
268        let last_segment = match self.0.segments.last() {
269            Some(segment) => segment,
270            None => return false,
271        };
272        let name = &*last_segment.ident.to_string();
273
274        #[cfg(not(any(
275            feature = "chrono",
276            feature = "uuid",
277            feature = "ulid",
278            feature = "time"
279        )))]
280        {
281            is_known_format(name)
282        }
283
284        #[cfg(any(
285            feature = "chrono",
286            feature = "uuid",
287            feature = "ulid",
288            feature = "time"
289        ))]
290        {
291            let mut known_format = is_known_format(name);
292
293            #[cfg(feature = "chrono")]
294            if !known_format {
295                known_format = matches!(name, "DateTime" | "Date" | "NaiveDate" | "NaiveDateTime");
296            }
297
298            #[cfg(feature = "uuid")]
299            if !known_format {
300                known_format = matches!(name, "Uuid");
301            }
302
303            #[cfg(feature = "ulid")]
304            if !known_format {
305                known_format = matches!(name, "Ulid");
306            }
307
308            #[cfg(feature = "time")]
309            if !known_format {
310                known_format = matches!(name, "Date" | "PrimitiveDateTime" | "OffsetDateTime");
311            }
312
313            known_format
314        }
315    }
316}
317
318#[inline]
319fn is_known_format(name: &str) -> bool {
320    matches!(
321        name,
322        "i8" | "i16" | "i32" | "u8" | "u16" | "u32" | "i64" | "u64" | "f32" | "f64"
323    )
324}
325
326impl ToTokens for Type<'_> {
327    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
328        let last_segment = self.0.segments.last().unwrap_or_else(|| {
329            abort_call_site!("expected there to be at least one segment in the path")
330        });
331        let name = &*last_segment.ident.to_string();
332
333        match name {
334            #[cfg(feature="non_strict_integers")]
335            "i8" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int8) }),
336            #[cfg(feature="non_strict_integers")]
337            "u8" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::UInt8) }),
338            #[cfg(feature="non_strict_integers")]
339            "i16" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int16) }),
340            #[cfg(feature="non_strict_integers")]
341            "u16" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::UInt16) }),
342            #[cfg(feature="non_strict_integers")]
343            #[cfg(feature="non_strict_integers")]
344            "u32" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::UInt32) }),
345            #[cfg(feature="non_strict_integers")]
346            "u64" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::UInt64) }),
347
348            #[cfg(not(feature="non_strict_integers"))]
349            "i8" | "i16" | "u8" | "u16" | "u32" => {
350                tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int32) })
351            }
352
353            #[cfg(not(feature="non_strict_integers"))]
354            "u64" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int64) }),
355
356            "i32" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int32) }),
357            "i64" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Int64) }),
358            "f32" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Float) }),
359            "f64" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Double) }),
360
361            #[cfg(feature = "chrono")]
362            "NaiveDate" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Date) }),
363
364            #[cfg(feature = "chrono")]
365            "DateTime" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::DateTime) }),
366
367            #[cfg(feature = "chrono")]
368            "NaiveDateTime" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::DateTime) }),
369
370            #[cfg(any(feature = "chrono", feature = "time"))]
371            "Date" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Date) }),
372
373            #[cfg(feature = "uuid")]
374            "Uuid" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Uuid) }),
375
376            #[cfg(feature = "ulid")]
377            "Ulid" => tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::Ulid) }),
378
379            #[cfg(feature = "time")]
380            "PrimitiveDateTime" | "OffsetDateTime" => {
381                tokens.extend(quote! { utoipa::openapi::SchemaFormat::KnownFormat(utoipa::openapi::KnownFormat::DateTime) })
382            }
383            _ => (),
384        }
385    }
386}
387
388/// [`Parse`] and [`ToTokens`] implementation for [`utoipa::openapi::schema::SchemaFormat`].
389#[derive(Clone)]
390#[cfg_attr(feature = "debug", derive(Debug))]
391pub enum Variant {
392    Int32,
393    Int64,
394    Float,
395    Double,
396    Byte,
397    Binary,
398    Date,
399    DateTime,
400    Password,
401    #[cfg(feature = "uuid")]
402    Uuid,
403    #[cfg(feature = "ulid")]
404    Ulid,
405    Custom(String),
406}
407
408impl Parse for Variant {
409    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
410        const FORMATS: [&str; 11] = [
411            "Int32", "Int64", "Float", "Double", "Byte", "Binary", "Date", "DateTime", "Password",
412            "Uuid", "Ulid",
413        ];
414        let known_formats = FORMATS
415            .into_iter()
416            .filter(|_format| {
417                #[cfg(all(feature = "uuid", feature = "ulid"))]
418                {
419                    true
420                }
421                #[cfg(all(not(feature = "uuid"), feature = "ulid"))]
422                {
423                    _format != &"Uuid"
424                }
425                #[cfg(all(feature = "uuid", not(feature = "ulid")))]
426                {
427                    _format != &"Ulid"
428                }
429                #[cfg(all(not(feature = "uuid"), not(feature = "ulid")))]
430                {
431                    _format != &"Uuid" && _format != &"Ulid"
432                }
433            })
434            .collect::<Vec<_>>();
435
436        let lookahead = input.lookahead1();
437        if lookahead.peek(Ident) {
438            let format = input.parse::<Ident>()?;
439            let name = &*format.to_string();
440
441            match name {
442                "Int32" => Ok(Self::Int32),
443                "Int64" => Ok(Self::Int64),
444                "Float" => Ok(Self::Float),
445                "Double" => Ok(Self::Double),
446                "Byte" => Ok(Self::Byte),
447                "Binary" => Ok(Self::Binary),
448                "Date" => Ok(Self::Date),
449                "DateTime" => Ok(Self::DateTime),
450                "Password" => Ok(Self::Password),
451                #[cfg(feature = "uuid")]
452                "Uuid" => Ok(Self::Uuid),
453                #[cfg(feature = "ulid")]
454                "Ulid" => Ok(Self::Ulid),
455                _ => Err(Error::new(
456                    format.span(),
457                    format!(
458                        "unexpected format: {name}, expected one of: {}",
459                        known_formats.join(", ")
460                    ),
461                )),
462            }
463        } else if lookahead.peek(LitStr) {
464            let value = input.parse::<LitStr>()?.value();
465            Ok(Self::Custom(value))
466        } else {
467            Err(lookahead.error())
468        }
469    }
470}
471
472impl ToTokens for Variant {
473    fn to_tokens(&self, tokens: &mut TokenStream) {
474        match self {
475            Self::Int32 => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
476                utoipa::openapi::KnownFormat::Int32
477            ))),
478            Self::Int64 => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
479                utoipa::openapi::KnownFormat::Int64
480            ))),
481            Self::Float => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
482                utoipa::openapi::KnownFormat::Float
483            ))),
484            Self::Double => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
485                utoipa::openapi::KnownFormat::Double
486            ))),
487            Self::Byte => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
488                utoipa::openapi::KnownFormat::Byte
489            ))),
490            Self::Binary => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
491                utoipa::openapi::KnownFormat::Binary
492            ))),
493            Self::Date => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
494                utoipa::openapi::KnownFormat::Date
495            ))),
496            Self::DateTime => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
497                utoipa::openapi::KnownFormat::DateTime
498            ))),
499            Self::Password => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
500                utoipa::openapi::KnownFormat::Password
501            ))),
502            #[cfg(feature = "uuid")]
503            Self::Uuid => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
504                utoipa::openapi::KnownFormat::Uuid
505            ))),
506            #[cfg(feature = "ulid")]
507            Self::Ulid => tokens.extend(quote!(utoipa::openapi::SchemaFormat::KnownFormat(
508                utoipa::openapi::KnownFormat::Ulid
509            ))),
510            Self::Custom(value) => tokens.extend(quote!(utoipa::openapi::SchemaFormat::Custom(
511                String::from(#value)
512            ))),
513        };
514    }
515}