1use std::borrow::Cow;
2use std::collections::HashSet;
3use std::fmt;
4use std::marker::PhantomData;
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use serde::de::{self, Visitor};
8use serde::{Deserialize, Deserializer};
9
10use crate::algorithms::Algorithm;
11use crate::errors::{new_error, ErrorKind, Result};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct Validation {
31 pub required_spec_claims: HashSet<String>,
38 pub leeway: u64,
43 pub validate_exp: bool,
49 pub validate_nbf: bool,
55 pub aud: Option<HashSet<String>>,
61 pub iss: Option<HashSet<String>>,
67 pub sub: Option<String>,
72 pub algorithms: Vec<Algorithm>,
77
78 pub(crate) validate_signature: bool,
80}
81
82impl Validation {
83 pub fn new(alg: Algorithm) -> Validation {
85 let mut required_claims = HashSet::with_capacity(1);
86 required_claims.insert("exp".to_owned());
87
88 Validation {
89 required_spec_claims: required_claims,
90 algorithms: vec![alg],
91 leeway: 60,
92
93 validate_exp: true,
94 validate_nbf: false,
95
96 iss: None,
97 sub: None,
98 aud: None,
99
100 validate_signature: true,
101 }
102 }
103
104 pub fn set_audience<T: ToString>(&mut self, items: &[T]) {
107 self.aud = Some(items.iter().map(|x| x.to_string()).collect())
108 }
109
110 pub fn set_issuer<T: ToString>(&mut self, items: &[T]) {
113 self.iss = Some(items.iter().map(|x| x.to_string()).collect())
114 }
115
116 pub fn set_required_spec_claims<T: ToString>(&mut self, items: &[T]) {
122 self.required_spec_claims = items.iter().map(|x| x.to_string()).collect();
123 }
124
125 pub fn insecure_disable_signature_validation(&mut self) {
129 self.validate_signature = false;
130 }
131}
132
133impl Default for Validation {
134 fn default() -> Self {
135 Self::new(Algorithm::HS256)
136 }
137}
138
139pub fn get_current_timestamp() -> u64 {
141 let start = SystemTime::now();
142 start.duration_since(UNIX_EPOCH).expect("Time went backwards").as_secs()
143}
144
145#[derive(Deserialize)]
146pub(crate) struct ClaimsForValidation<'a> {
147 #[serde(deserialize_with = "numeric_type", default)]
148 exp: TryParse<u64>,
149 #[serde(deserialize_with = "numeric_type", default)]
150 nbf: TryParse<u64>,
151 #[serde(borrow)]
152 sub: TryParse<Cow<'a, str>>,
153 #[serde(borrow)]
154 iss: TryParse<Issuer<'a>>,
155 #[serde(borrow)]
156 aud: TryParse<Audience<'a>>,
157}
158#[derive(Debug)]
159enum TryParse<T> {
160 Parsed(T),
161 FailedToParse,
162 NotPresent,
163}
164impl<'de, T: Deserialize<'de>> Deserialize<'de> for TryParse<T> {
165 fn deserialize<D: serde::Deserializer<'de>>(
166 deserializer: D,
167 ) -> std::result::Result<Self, D::Error> {
168 Ok(match Option::<T>::deserialize(deserializer) {
169 Ok(Some(value)) => TryParse::Parsed(value),
170 Ok(None) => TryParse::NotPresent,
171 Err(_) => TryParse::FailedToParse,
172 })
173 }
174}
175impl<T> Default for TryParse<T> {
176 fn default() -> Self {
177 Self::NotPresent
178 }
179}
180
181#[derive(Deserialize)]
182#[serde(untagged)]
183enum Audience<'a> {
184 Single(#[serde(borrow)] Cow<'a, str>),
185 Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
186}
187
188#[derive(Deserialize)]
189#[serde(untagged)]
190enum Issuer<'a> {
191 Single(#[serde(borrow)] Cow<'a, str>),
192 Multiple(#[serde(borrow)] HashSet<BorrowedCowIfPossible<'a>>),
193}
194
195#[derive(Deserialize, PartialEq, Eq, Hash)]
199struct BorrowedCowIfPossible<'a>(#[serde(borrow)] Cow<'a, str>);
200impl std::borrow::Borrow<str> for BorrowedCowIfPossible<'_> {
201 fn borrow(&self) -> &str {
202 &self.0
203 }
204}
205
206fn is_subset(reference: &HashSet<String>, given: &HashSet<BorrowedCowIfPossible<'_>>) -> bool {
207 if reference.len() < given.len() {
209 reference.iter().any(|a| given.contains(&**a))
210 } else {
211 given.iter().any(|a| reference.contains(&*a.0))
212 }
213}
214
215pub(crate) fn validate(claims: ClaimsForValidation, options: &Validation) -> Result<()> {
216 let now = get_current_timestamp();
217
218 for required_claim in &options.required_spec_claims {
219 let present = match required_claim.as_str() {
220 "exp" => matches!(claims.exp, TryParse::Parsed(_)),
221 "sub" => matches!(claims.sub, TryParse::Parsed(_)),
222 "iss" => matches!(claims.iss, TryParse::Parsed(_)),
223 "aud" => matches!(claims.aud, TryParse::Parsed(_)),
224 "nbf" => matches!(claims.nbf, TryParse::Parsed(_)),
225 _ => continue,
226 };
227
228 if !present {
229 return Err(new_error(ErrorKind::MissingRequiredClaim(required_claim.clone())));
230 }
231 }
232
233 if matches!(claims.exp, TryParse::Parsed(exp) if options.validate_exp && exp < now - options.leeway)
234 {
235 return Err(new_error(ErrorKind::ExpiredSignature));
236 }
237
238 if matches!(claims.nbf, TryParse::Parsed(nbf) if options.validate_nbf && nbf > now + options.leeway)
239 {
240 return Err(new_error(ErrorKind::ImmatureSignature));
241 }
242
243 if let (TryParse::Parsed(sub), Some(correct_sub)) = (claims.sub, options.sub.as_deref()) {
244 if sub != correct_sub {
245 return Err(new_error(ErrorKind::InvalidSubject));
246 }
247 }
248
249 match (claims.iss, options.iss.as_ref()) {
250 (TryParse::Parsed(Issuer::Single(iss)), Some(correct_iss)) => {
251 if !correct_iss.contains(&*iss) {
252 return Err(new_error(ErrorKind::InvalidIssuer));
253 }
254 }
255 (TryParse::Parsed(Issuer::Multiple(iss)), Some(correct_iss)) => {
256 if !is_subset(correct_iss, &iss) {
257 return Err(new_error(ErrorKind::InvalidIssuer));
258 }
259 }
260 _ => {}
261 }
262
263 match (claims.aud, options.aud.as_ref()) {
264 (TryParse::Parsed(Audience::Single(aud)), Some(correct_aud)) => {
265 if !correct_aud.contains(&*aud) {
266 return Err(new_error(ErrorKind::InvalidAudience));
267 }
268 }
269 (TryParse::Parsed(Audience::Multiple(aud)), Some(correct_aud)) => {
270 if !is_subset(correct_aud, &aud) {
271 return Err(new_error(ErrorKind::InvalidAudience));
272 }
273 }
274 _ => {}
275 }
276
277 Ok(())
278}
279
280fn numeric_type<'de, D>(deserializer: D) -> std::result::Result<TryParse<u64>, D::Error>
281where
282 D: Deserializer<'de>,
283{
284 struct NumericType(PhantomData<fn() -> TryParse<u64>>);
285
286 impl<'de> Visitor<'de> for NumericType {
287 type Value = TryParse<u64>;
288
289 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
290 formatter.write_str("A NumericType that can be reasonably coerced into a u64")
291 }
292
293 fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
294 where
295 E: de::Error,
296 {
297 if value.is_finite() && value >= 0.0 && value < (u64::MAX as f64) {
298 Ok(TryParse::Parsed(value.round() as u64))
299 } else {
300 Err(serde::de::Error::custom("NumericType must be representable as a u64"))
301 }
302 }
303
304 fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
305 where
306 E: de::Error,
307 {
308 Ok(TryParse::Parsed(value))
309 }
310 }
311
312 match deserializer.deserialize_any(NumericType(PhantomData)) {
313 Ok(ok) => Ok(ok),
314 Err(_) => Ok(TryParse::FailedToParse),
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use serde_json::json;
321
322 use super::{get_current_timestamp, validate, ClaimsForValidation, Validation};
323
324 use crate::errors::ErrorKind;
325 use crate::Algorithm;
326 use std::collections::HashSet;
327
328 fn deserialize_claims(claims: &serde_json::Value) -> ClaimsForValidation {
329 serde::Deserialize::deserialize(claims).unwrap()
330 }
331
332 #[test]
333 fn exp_in_future_ok() {
334 let claims = json!({ "exp": get_current_timestamp() + 10000 });
335 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
336 assert!(res.is_ok());
337 }
338
339 #[test]
340 fn exp_float_in_future_ok() {
341 let claims = json!({ "exp": (get_current_timestamp() as f64) + 10000.123 });
342 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
343 assert!(res.is_ok());
344 }
345
346 #[test]
347 fn exp_in_past_fails() {
348 let claims = json!({ "exp": get_current_timestamp() - 100000 });
349 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
350 assert!(res.is_err());
351
352 match res.unwrap_err().kind() {
353 ErrorKind::ExpiredSignature => (),
354 _ => unreachable!(),
355 };
356 }
357
358 #[test]
359 fn exp_float_in_past_fails() {
360 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
361 let res = validate(deserialize_claims(&claims), &Validation::new(Algorithm::HS256));
362 assert!(res.is_err());
363
364 match res.unwrap_err().kind() {
365 ErrorKind::ExpiredSignature => (),
366 _ => unreachable!(),
367 };
368 }
369
370 #[test]
371 fn exp_in_past_but_in_leeway_ok() {
372 let claims = json!({ "exp": get_current_timestamp() - 500 });
373 let mut validation = Validation::new(Algorithm::HS256);
374 validation.leeway = 1000 * 60;
375 let res = validate(deserialize_claims(&claims), &validation);
376 assert!(res.is_ok());
377 }
378
379 #[test]
381 fn validate_required_fields_are_present() {
382 for spec_claim in ["exp", "nbf", "aud", "iss", "sub"] {
383 let claims = json!({});
384 let mut validation = Validation::new(Algorithm::HS256);
385 validation.set_required_spec_claims(&[spec_claim]);
386 let res = validate(deserialize_claims(&claims), &validation).unwrap_err();
387 assert_eq!(res.kind(), &ErrorKind::MissingRequiredClaim(spec_claim.to_owned()));
388 }
389 }
390
391 #[test]
392 fn exp_validated_but_not_required_ok() {
393 let claims = json!({});
394 let mut validation = Validation::new(Algorithm::HS256);
395 validation.required_spec_claims = HashSet::new();
396 validation.validate_exp = true;
397 let res = validate(deserialize_claims(&claims), &validation);
398 assert!(res.is_ok());
399 }
400
401 #[test]
402 fn exp_validated_but_not_required_fails() {
403 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
404 let mut validation = Validation::new(Algorithm::HS256);
405 validation.required_spec_claims = HashSet::new();
406 validation.validate_exp = true;
407 let res = validate(deserialize_claims(&claims), &validation);
408 assert!(res.is_err());
409 }
410
411 #[test]
412 fn exp_required_but_not_validated_ok() {
413 let claims = json!({ "exp": (get_current_timestamp() as f64) - 100000.1234 });
414 let mut validation = Validation::new(Algorithm::HS256);
415 validation.set_required_spec_claims(&["exp"]);
416 validation.validate_exp = false;
417 let res = validate(deserialize_claims(&claims), &validation);
418 assert!(res.is_ok());
419 }
420
421 #[test]
422 fn exp_required_but_not_validated_fails() {
423 let claims = json!({});
424 let mut validation = Validation::new(Algorithm::HS256);
425 validation.set_required_spec_claims(&["exp"]);
426 validation.validate_exp = false;
427 let res = validate(deserialize_claims(&claims), &validation);
428 assert!(res.is_err());
429 }
430
431 #[test]
432 fn nbf_in_past_ok() {
433 let claims = json!({ "nbf": get_current_timestamp() - 10000 });
434 let mut validation = Validation::new(Algorithm::HS256);
435 validation.required_spec_claims = HashSet::new();
436 validation.validate_exp = false;
437 validation.validate_nbf = true;
438 let res = validate(deserialize_claims(&claims), &validation);
439 assert!(res.is_ok());
440 }
441
442 #[test]
443 fn nbf_float_in_past_ok() {
444 let claims = json!({ "nbf": (get_current_timestamp() as f64) - 10000.1234 });
445 let mut validation = Validation::new(Algorithm::HS256);
446 validation.required_spec_claims = HashSet::new();
447 validation.validate_exp = false;
448 validation.validate_nbf = true;
449 let res = validate(deserialize_claims(&claims), &validation);
450 assert!(res.is_ok());
451 }
452
453 #[test]
454 fn nbf_in_future_fails() {
455 let claims = json!({ "nbf": get_current_timestamp() + 100000 });
456 let mut validation = Validation::new(Algorithm::HS256);
457 validation.required_spec_claims = HashSet::new();
458 validation.validate_exp = false;
459 validation.validate_nbf = true;
460 let res = validate(deserialize_claims(&claims), &validation);
461 assert!(res.is_err());
462
463 match res.unwrap_err().kind() {
464 ErrorKind::ImmatureSignature => (),
465 _ => unreachable!(),
466 };
467 }
468
469 #[test]
470 fn nbf_in_future_but_in_leeway_ok() {
471 let claims = json!({ "nbf": get_current_timestamp() + 500 });
472 let mut validation = Validation::new(Algorithm::HS256);
473 validation.required_spec_claims = HashSet::new();
474 validation.validate_exp = false;
475 validation.validate_nbf = true;
476 validation.leeway = 1000 * 60;
477 let res = validate(deserialize_claims(&claims), &validation);
478 assert!(res.is_ok());
479 }
480
481 #[test]
482 fn iss_string_ok() {
483 let claims = json!({"iss": ["Keats"]});
484 let mut validation = Validation::new(Algorithm::HS256);
485 validation.required_spec_claims = HashSet::new();
486 validation.validate_exp = false;
487 validation.set_issuer(&["Keats"]);
488 let res = validate(deserialize_claims(&claims), &validation);
489 assert!(res.is_ok());
490 }
491
492 #[test]
493 fn iss_array_of_string_ok() {
494 let claims = json!({"iss": ["UserA", "UserB"]});
495 let mut validation = Validation::new(Algorithm::HS256);
496 validation.required_spec_claims = HashSet::new();
497 validation.validate_exp = false;
498 validation.set_issuer(&["UserA", "UserB"]);
499 let res = validate(deserialize_claims(&claims), &validation);
500 assert!(res.is_ok());
501 }
502
503 #[test]
504 fn iss_not_matching_fails() {
505 let claims = json!({"iss": "Hacked"});
506
507 let mut validation = Validation::new(Algorithm::HS256);
508 validation.required_spec_claims = HashSet::new();
509 validation.validate_exp = false;
510 validation.set_issuer(&["Keats"]);
511 let res = validate(deserialize_claims(&claims), &validation);
512 assert!(res.is_err());
513
514 match res.unwrap_err().kind() {
515 ErrorKind::InvalidIssuer => (),
516 _ => unreachable!(),
517 };
518 }
519
520 #[test]
521 fn iss_missing_fails() {
522 let claims = json!({});
523
524 let mut validation = Validation::new(Algorithm::HS256);
525 validation.set_required_spec_claims(&["iss"]);
526 validation.validate_exp = false;
527 validation.set_issuer(&["Keats"]);
528 let res = validate(deserialize_claims(&claims), &validation);
529
530 match res.unwrap_err().kind() {
531 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "iss"),
532 _ => unreachable!(),
533 };
534 }
535
536 #[test]
537 fn sub_ok() {
538 let claims = json!({"sub": "Keats"});
539 let mut validation = Validation::new(Algorithm::HS256);
540 validation.required_spec_claims = HashSet::new();
541 validation.validate_exp = false;
542 validation.sub = Some("Keats".to_owned());
543 let res = validate(deserialize_claims(&claims), &validation);
544 assert!(res.is_ok());
545 }
546
547 #[test]
548 fn sub_not_matching_fails() {
549 let claims = json!({"sub": "Hacked"});
550 let mut validation = Validation::new(Algorithm::HS256);
551 validation.required_spec_claims = HashSet::new();
552 validation.validate_exp = false;
553 validation.sub = Some("Keats".to_owned());
554 let res = validate(deserialize_claims(&claims), &validation);
555 assert!(res.is_err());
556
557 match res.unwrap_err().kind() {
558 ErrorKind::InvalidSubject => (),
559 _ => unreachable!(),
560 };
561 }
562
563 #[test]
564 fn sub_missing_fails() {
565 let claims = json!({});
566 let mut validation = Validation::new(Algorithm::HS256);
567 validation.validate_exp = false;
568 validation.set_required_spec_claims(&["sub"]);
569 validation.sub = Some("Keats".to_owned());
570 let res = validate(deserialize_claims(&claims), &validation);
571 assert!(res.is_err());
572
573 match res.unwrap_err().kind() {
574 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "sub"),
575 _ => unreachable!(),
576 };
577 }
578
579 #[test]
580 fn aud_string_ok() {
581 let claims = json!({"aud": "Everyone"});
582 let mut validation = Validation::new(Algorithm::HS256);
583 validation.validate_exp = false;
584 validation.required_spec_claims = HashSet::new();
585 validation.set_audience(&["Everyone"]);
586 let res = validate(deserialize_claims(&claims), &validation);
587 assert!(res.is_ok());
588 }
589
590 #[test]
591 fn aud_array_of_string_ok() {
592 let claims = json!({"aud": ["UserA", "UserB"]});
593 let mut validation = Validation::new(Algorithm::HS256);
594 validation.validate_exp = false;
595 validation.required_spec_claims = HashSet::new();
596 validation.set_audience(&["UserA", "UserB"]);
597 let res = validate(deserialize_claims(&claims), &validation);
598 assert!(res.is_ok());
599 }
600
601 #[test]
602 fn aud_type_mismatch_fails() {
603 let claims = json!({"aud": ["Everyone"]});
604 let mut validation = Validation::new(Algorithm::HS256);
605 validation.validate_exp = false;
606 validation.required_spec_claims = HashSet::new();
607 validation.set_audience(&["UserA", "UserB"]);
608 let res = validate(deserialize_claims(&claims), &validation);
609 assert!(res.is_err());
610
611 match res.unwrap_err().kind() {
612 ErrorKind::InvalidAudience => (),
613 _ => unreachable!(),
614 };
615 }
616
617 #[test]
618 fn aud_correct_type_not_matching_fails() {
619 let claims = json!({"aud": ["Everyone"]});
620 let mut validation = Validation::new(Algorithm::HS256);
621 validation.validate_exp = false;
622 validation.required_spec_claims = HashSet::new();
623 validation.set_audience(&["None"]);
624 let res = validate(deserialize_claims(&claims), &validation);
625 assert!(res.is_err());
626
627 match res.unwrap_err().kind() {
628 ErrorKind::InvalidAudience => (),
629 _ => unreachable!(),
630 };
631 }
632
633 #[test]
634 fn aud_missing_fails() {
635 let claims = json!({});
636 let mut validation = Validation::new(Algorithm::HS256);
637 validation.validate_exp = false;
638 validation.set_required_spec_claims(&["aud"]);
639 validation.set_audience(&["None"]);
640 let res = validate(deserialize_claims(&claims), &validation);
641 assert!(res.is_err());
642
643 match res.unwrap_err().kind() {
644 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "aud"),
645 _ => unreachable!(),
646 };
647 }
648
649 #[test]
651 fn does_validation_in_right_order() {
652 let claims = json!({ "exp": get_current_timestamp() + 10000 });
653
654 let mut validation = Validation::new(Algorithm::HS256);
655 validation.set_required_spec_claims(&["exp", "iss"]);
656 validation.leeway = 5;
657 validation.set_issuer(&["iss no check"]);
658 validation.set_audience(&["iss no check"]);
659
660 let res = validate(deserialize_claims(&claims), &validation);
661 assert!(res.is_err());
663 match res.unwrap_err().kind() {
664 ErrorKind::MissingRequiredClaim(claim) => assert_eq!(claim, "iss"),
665 t => panic!("{:?}", t),
666 };
667 }
668
669 #[test]
671 fn aud_use_validation_struct() {
672 let claims = json!({"aud": "my-googleclientid1234.apps.googleusercontent.com"});
673
674 let aud = "my-googleclientid1234.apps.googleusercontent.com".to_string();
675 let mut aud_hashset = std::collections::HashSet::new();
676 aud_hashset.insert(aud);
677 let mut validation = Validation::new(Algorithm::HS256);
678 validation.validate_exp = false;
679 validation.required_spec_claims = HashSet::new();
680 validation.set_audience(&["my-googleclientid1234.apps.googleusercontent.com"]);
681
682 let res = validate(deserialize_claims(&claims), &validation);
683 assert!(res.is_ok());
684 }
685}