use crate::{
arithmetic::montgomery::*,
bits, bssl, c, error,
limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
};
use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
pub unsafe trait Prime {}
struct Width<M> {
num_limbs: usize,
m: PhantomData<M>,
}
struct BoxedLimbs<M> {
limbs: Box<[Limb]>,
m: PhantomData<M>,
}
impl<M> Deref for BoxedLimbs<M> {
type Target = [Limb];
#[inline]
fn deref(&self) -> &Self::Target {
&self.limbs
}
}
impl<M> DerefMut for BoxedLimbs<M> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.limbs
}
}
impl<M> Clone for BoxedLimbs<M> {
fn clone(&self) -> Self {
Self {
limbs: self.limbs.clone(),
m: self.m,
}
}
}
impl<M> BoxedLimbs<M> {
fn positive_minimal_width_from_be_bytes(
input: untrusted::Input,
) -> Result<Self, error::KeyRejected> {
if untrusted::Reader::new(input).peek(0) {
return Err(error::KeyRejected::invalid_encoding());
}
let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
let mut r = Self::zero(Width {
num_limbs,
m: PhantomData,
});
limb::parse_big_endian_and_pad_consttime(input, &mut r)
.map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
Ok(r)
}
fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
debug_assert_ne!(limbs.last(), Some(&0));
Self {
limbs: limbs.to_owned().into_boxed_slice(),
m: PhantomData,
}
}
fn from_be_bytes_padded_less_than(
input: untrusted::Input,
m: &Modulus<M>,
) -> Result<Self, error::Unspecified> {
let mut r = Self::zero(m.width());
limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
return Err(error::Unspecified);
}
Ok(r)
}
#[inline]
fn is_zero(&self) -> bool {
limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
}
fn zero(width: Width<M>) -> Self {
Self {
limbs: vec![0; width.num_limbs].into_boxed_slice(),
m: PhantomData,
}
}
fn width(&self) -> Width<M> {
Width {
num_limbs: self.limbs.len(),
m: PhantomData,
}
}
}
pub unsafe trait SmallerModulus<L> {}
pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
pub unsafe trait PublicModulus {}
pub const MODULUS_MIN_LIMBS: usize = 4;
pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
pub struct Modulus<M> {
limbs: BoxedLimbs<M>, n0: N0,
oneRR: One<M, RR>,
}
impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
fmt.debug_struct("Modulus")
.finish()
}
}
impl<M> Modulus<M> {
pub fn from_be_bytes_with_bit_length(
input: untrusted::Input,
) -> Result<(Self, bits::BitLength), error::KeyRejected> {
let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
Self::from_boxed_limbs(limbs)
}
pub fn from_nonnegative_with_bit_length(
n: Nonnegative,
) -> Result<(Self, bits::BitLength), error::KeyRejected> {
let limbs = BoxedLimbs {
limbs: n.limbs.into_boxed_slice(),
m: PhantomData,
};
Self::from_boxed_limbs(limbs)
}
fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
if n.len() > MODULUS_MAX_LIMBS {
return Err(error::KeyRejected::too_large());
}
if n.len() < MODULUS_MIN_LIMBS {
return Err(error::KeyRejected::unexpected_error());
}
if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
return Err(error::KeyRejected::invalid_component());
}
if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
return Err(error::KeyRejected::unexpected_error());
}
#[allow(clippy::useless_conversion)]
let n0 = {
extern "C" {
fn GFp_bn_neg_inv_mod_r_u64(n: u64) -> u64;
}
let mut n_mod_r: u64 = u64::from(n[0]);
if N0_LIMBS_USED == 2 {
debug_assert_eq!(LIMB_BITS, 32);
n_mod_r |= u64::from(n[1]) << 32;
}
N0::from(unsafe { GFp_bn_neg_inv_mod_r_u64(n_mod_r) })
};
let bits = limb::limbs_minimal_bits(&n.limbs);
let oneRR = {
let partial = PartialModulus {
limbs: &n.limbs,
n0: n0.clone(),
m: PhantomData,
};
One::newRR(&partial, bits)
};
Ok((
Self {
limbs: n,
n0,
oneRR,
},
bits,
))
}
#[inline]
fn width(&self) -> Width<M> {
self.limbs.width()
}
fn zero<E>(&self) -> Elem<M, E> {
Elem {
limbs: BoxedLimbs::zero(self.width()),
encoding: PhantomData,
}
}
fn one(&self) -> Elem<M, Unencoded> {
let mut r = self.zero();
r.limbs[0] = 1;
r
}
pub fn oneRR(&self) -> &One<M, RR> {
&self.oneRR
}
pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
where
M: SmallerModulus<L>,
{
assert_eq!(self.width().num_limbs, l.width().num_limbs);
let limbs = self.limbs.clone();
Elem {
limbs: BoxedLimbs {
limbs: limbs.limbs,
m: PhantomData,
},
encoding: PhantomData,
}
}
fn as_partial(&self) -> PartialModulus<M> {
PartialModulus {
limbs: &self.limbs,
n0: self.n0.clone(),
m: PhantomData,
}
}
}
struct PartialModulus<'a, M> {
limbs: &'a [Limb],
n0: N0,
m: PhantomData<M>,
}
impl<M> PartialModulus<'_, M> {
fn zero(&self) -> Elem<M, R> {
let width = Width {
num_limbs: self.limbs.len(),
m: PhantomData,
};
Elem {
limbs: BoxedLimbs::zero(width),
encoding: PhantomData,
}
}
}
pub struct Elem<M, E = Unencoded> {
limbs: BoxedLimbs<M>,
encoding: PhantomData<E>,
}
impl<M, E> Clone for Elem<M, E> {
fn clone(&self) -> Self {
Self {
limbs: self.limbs.clone(),
encoding: self.encoding,
}
}
}
impl<M, E> Elem<M, E> {
#[inline]
pub fn is_zero(&self) -> bool {
self.limbs.is_zero()
}
}
impl<M, E: ReductionEncoding> Elem<M, E> {
fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
let mut limbs = self.limbs;
let num_limbs = m.width().num_limbs;
let mut one = [0; MODULUS_MAX_LIMBS];
one[0] = 1;
let one = &one[..num_limbs]; limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
Elem {
limbs,
encoding: PhantomData,
}
}
}
impl<M> Elem<M, R> {
#[inline]
pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
self.decode_once(m)
}
}
impl<M> Elem<M, Unencoded> {
pub fn from_be_bytes_padded(
input: untrusted::Input,
m: &Modulus<M>,
) -> Result<Self, error::Unspecified> {
Ok(Elem {
limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
encoding: PhantomData,
})
}
#[inline]
pub fn fill_be_bytes(&self, out: &mut [u8]) {
limb::big_endian_from_limbs(&self.limbs, out)
}
pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
let (m, _bits) =
Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
Ok(m)
}
fn is_one(&self) -> bool {
limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
}
}
pub fn elem_mul<M, AF, BF>(
a: &Elem<M, AF>,
b: Elem<M, BF>,
m: &Modulus<M>,
) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
where
(AF, BF): ProductEncoding,
{
elem_mul_(a, b, &m.as_partial())
}
fn elem_mul_<M, AF, BF>(
a: &Elem<M, AF>,
mut b: Elem<M, BF>,
m: &PartialModulus<M>,
) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
where
(AF, BF): ProductEncoding,
{
limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
Elem {
limbs: b.limbs,
encoding: PhantomData,
}
}
fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
extern "C" {
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
unsafe {
LIMBS_shl_mod(
a.limbs.as_mut_ptr(),
a.limbs.as_ptr(),
m.limbs.as_ptr(),
m.limbs.len(),
);
}
}
pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
a: &Elem<Larger, Unencoded>,
m: &Modulus<Smaller>,
) -> Elem<Smaller, Unencoded> {
let mut r = a.limbs.clone();
assert!(r.len() <= m.limbs.len());
limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
Elem {
limbs: BoxedLimbs {
limbs: r.limbs,
m: PhantomData,
},
encoding: PhantomData,
}
}
#[inline]
pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
a: &Elem<Larger, Unencoded>,
m: &Modulus<Smaller>,
) -> Elem<Smaller, RInverse> {
let mut tmp = [0; MODULUS_MAX_LIMBS];
let tmp = &mut tmp[..a.limbs.len()];
tmp.copy_from_slice(&a.limbs);
let mut r = m.zero();
limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
r
}
fn elem_squared<M, E>(
mut a: Elem<M, E>,
m: &PartialModulus<M>,
) -> Elem<M, <(E, E) as ProductEncoding>::Output>
where
(E, E): ProductEncoding,
{
limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
Elem {
limbs: a.limbs,
encoding: PhantomData,
}
}
pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
a: Elem<Smaller, Unencoded>,
m: &Modulus<Larger>,
) -> Elem<Larger, Unencoded> {
let mut r = m.zero();
r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
r
}
pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
extern "C" {
fn LIMBS_add_mod(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
m: *const Limb,
num_limbs: c::size_t,
);
}
unsafe {
LIMBS_add_mod(
a.limbs.as_mut_ptr(),
a.limbs.as_ptr(),
b.limbs.as_ptr(),
m.limbs.as_ptr(),
m.limbs.len(),
)
}
a
}
pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
extern "C" {
fn LIMBS_sub_mod(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
m: *const Limb,
num_limbs: c::size_t,
);
}
unsafe {
LIMBS_sub_mod(
a.limbs.as_mut_ptr(),
a.limbs.as_ptr(),
b.limbs.as_ptr(),
m.limbs.as_ptr(),
m.limbs.len(),
);
}
a
}
pub struct One<M, E>(Elem<M, E>);
impl<M> One<M, RR> {
fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
let m_bits = m_bits.as_usize_bits();
let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
let bit = m_bits - 1;
let mut base = m.zero();
base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
let lg_base = 2usize; debug_assert_eq!(lg_base.count_ones(), 1); let shifts = r - bit + lg_base;
let exponent = (r / lg_base) as u64;
for _ in 0..shifts {
elem_mul_by_2(&mut base, m)
}
let RR = elem_exp_vartime_(base, exponent, m);
Self(Elem {
limbs: RR.limbs,
encoding: PhantomData, })
}
}
impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
fn as_ref(&self) -> &Elem<M, E> {
&self.0
}
}
#[derive(Clone, Copy, Debug)]
pub struct PublicExponent(u64);
impl PublicExponent {
pub fn from_be_bytes(
input: untrusted::Input,
min_value: u64,
) -> Result<Self, error::KeyRejected> {
if input.len() > 5 {
return Err(error::KeyRejected::too_large());
}
let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
if input.peek(0) {
return Err(error::KeyRejected::invalid_encoding());
}
let mut value = 0u64;
loop {
let byte = input
.read_byte()
.map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
value = (value << 8) | u64::from(byte);
if input.at_end() {
return Ok(value);
}
}
})?;
if value & 1 != 1 {
return Err(error::KeyRejected::invalid_component());
}
debug_assert!(min_value & 1 == 1);
debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
if min_value < 3 {
return Err(error::KeyRejected::invalid_component());
}
if value < min_value {
return Err(error::KeyRejected::too_small());
}
if value > PUBLIC_EXPONENT_MAX_VALUE {
return Err(error::KeyRejected::too_large());
}
Ok(Self(value))
}
}
const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
pub fn elem_exp_vartime<M>(
base: Elem<M, Unencoded>,
PublicExponent(exponent): PublicExponent,
m: &Modulus<M>,
) -> Elem<M, R> {
let base = elem_mul(m.oneRR().as_ref(), base, &m);
elem_exp_vartime_(base, exponent, &m.as_partial())
}
fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
assert!(exponent >= 1);
assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
let mut acc = base.clone();
let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
debug_assert!((exponent & bit) != 0);
while bit > 1 {
bit >>= 1;
acc = elem_squared(acc, m);
if (exponent & bit) != 0 {
acc = elem_mul_(&base, acc, m);
}
}
acc
}
pub struct PrivateExponent<M> {
limbs: BoxedLimbs<M>,
}
impl<M> PrivateExponent<M> {
pub fn from_be_bytes_padded(
input: untrusted::Input,
p: &Modulus<M>,
) -> Result<Self, error::Unspecified> {
let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
return Err(error::Unspecified);
}
Ok(Self { limbs: dP })
}
}
impl<M: Prime> PrivateExponent<M> {
fn for_flt(p: &Modulus<M>) -> Self {
let two = elem_add(p.one(), p.one(), p);
let p_minus_2 = elem_sub(p.zero(), &two, p);
Self {
limbs: p_minus_2.limbs,
}
}
}
#[cfg(not(target_arch = "x86_64"))]
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent<M>,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::Window;
const WINDOW_BITS: usize = 5;
const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
let num_limbs = m.limbs.len();
let mut table = vec![0; TABLE_ENTRIES * num_limbs];
fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
extern "C" {
fn LIMBS_select_512_32(
r: *mut Limb,
table: *const Limb,
num_limbs: c::size_t,
i: Window,
) -> bssl::Result;
}
Result::from(unsafe {
LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
})
.unwrap();
}
fn power<M>(
table: &[Limb],
i: Window,
mut acc: Elem<M, R>,
mut tmp: Elem<M, R>,
m: &Modulus<M>,
) -> (Elem<M, R>, Elem<M, R>) {
for _ in 0..WINDOW_BITS {
acc = elem_squared(acc, &m.as_partial());
}
gather(table, i, &mut tmp);
let acc = elem_mul(&tmp, acc, m);
(acc, tmp)
}
let tmp = m.one();
let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
&table[(i * num_limbs)..][..num_limbs]
}
fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
&mut table[(i * num_limbs)..][..num_limbs]
}
let num_limbs = m.limbs.len();
entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
for i in 2..TABLE_ENTRIES {
let (src1, src2) = if i % 2 == 0 {
(i / 2, i / 2)
} else {
(i - 1, 1)
};
let (previous, rest) = table.split_at_mut(num_limbs * i);
let src1 = entry(previous, src1, num_limbs);
let src2 = entry(previous, src2, num_limbs);
let dst = entry_mut(rest, 0, num_limbs);
limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
}
let (r, _) = limb::fold_5_bit_windows(
&exponent.limbs,
|initial_window| {
let mut r = Elem {
limbs: base.limbs,
encoding: PhantomData,
};
gather(&table, initial_window, &mut r);
(r, tmp)
},
|(acc, tmp), window| power(&table, window, acc, tmp, m),
);
let r = r.into_unencoded(m);
Ok(r)
}
pub fn elem_inverse_consttime<M: Prime>(
a: Elem<M, R>,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
}
#[cfg(target_arch = "x86_64")]
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent<M>,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::Window;
const WINDOW_BITS: usize = 5;
const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
let num_limbs = m.limbs.len();
const ALIGNMENT: usize = 64;
assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
let (table, state) = {
let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
table.split_at_mut(TABLE_ENTRIES * num_limbs)
};
fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
&table[(i * num_limbs)..][..num_limbs]
}
fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
&mut table[(i * num_limbs)..][..num_limbs]
}
const ACC: usize = 0; const BASE: usize = ACC + 1; const M: usize = BASE + 1; entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
extern "C" {
fn GFp_bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
}
unsafe {
GFp_bn_scatter5(
entry(state, ACC, num_limbs).as_ptr(),
num_limbs,
table.as_mut_ptr(),
i,
)
}
}
fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
extern "C" {
fn GFp_bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
}
unsafe {
GFp_bn_gather5(
entry_mut(state, ACC, num_limbs).as_mut_ptr(),
num_limbs,
table.as_ptr(),
i,
)
}
}
fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
gather(table, state, i, num_limbs);
assert_eq!(ACC, 0);
let (acc, rest) = state.split_at_mut(num_limbs);
let m = entry(rest, M - 1, num_limbs);
limbs_mont_square(acc, m, n0);
}
fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
extern "C" {
fn GFp_bn_mul_mont_gather5(
rp: *mut Limb,
ap: *const Limb,
table: *const Limb,
np: *const Limb,
n0: &N0,
num: c::size_t,
power: Window,
);
}
unsafe {
GFp_bn_mul_mont_gather5(
entry_mut(state, ACC, num_limbs).as_mut_ptr(),
entry(state, BASE, num_limbs).as_ptr(),
table.as_ptr(),
entry(state, M, num_limbs).as_ptr(),
n0,
num_limbs,
i,
);
}
}
fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
extern "C" {
fn GFp_bn_power5(
r: *mut Limb,
a: *const Limb,
table: *const Limb,
n: *const Limb,
n0: &N0,
num: c::size_t,
i: Window,
);
}
unsafe {
GFp_bn_power5(
entry_mut(state, ACC, num_limbs).as_mut_ptr(),
entry_mut(state, ACC, num_limbs).as_mut_ptr(),
table.as_ptr(),
entry(state, M, num_limbs).as_ptr(),
n0,
num_limbs,
i,
);
}
}
{
let acc = entry_mut(state, ACC, num_limbs);
acc[0] = 1;
limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
}
scatter(table, state, 0, num_limbs);
entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
scatter(table, state, 1, num_limbs);
for i in 2..(TABLE_ENTRIES as Window) {
if i % 2 == 0 {
gather_square(table, state, &m.n0, i / 2, num_limbs);
} else {
gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
};
scatter(table, state, i, num_limbs);
}
let state = limb::fold_5_bit_windows(
&exponent.limbs,
|initial_window| {
gather(table, state, initial_window, num_limbs);
state
},
|state, window| {
power(table, state, &m.n0, window, num_limbs);
state
},
);
extern "C" {
fn GFp_bn_from_montgomery(
r: *mut Limb,
a: *const Limb,
not_used: *const Limb,
n: *const Limb,
n0: &N0,
num: c::size_t,
) -> bssl::Result;
}
Result::from(unsafe {
GFp_bn_from_montgomery(
entry_mut(state, ACC, num_limbs).as_mut_ptr(),
entry(state, ACC, num_limbs).as_ptr(),
core::ptr::null(),
entry(state, M, num_limbs).as_ptr(),
&m.n0,
num_limbs,
)
})?;
let mut r = Elem {
limbs: base.limbs,
encoding: PhantomData,
};
r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
Ok(r)
}
pub fn verify_inverses_consttime<M>(
a: &Elem<M, R>,
b: Elem<M, Unencoded>,
m: &Modulus<M>,
) -> Result<(), error::Unspecified> {
if elem_mul(a, b, m).is_one() {
Ok(())
} else {
Err(error::Unspecified)
}
}
#[inline]
pub fn elem_verify_equal_consttime<M, E>(
a: &Elem<M, E>,
b: &Elem<M, E>,
) -> Result<(), error::Unspecified> {
if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
Ok(())
} else {
Err(error::Unspecified)
}
}
pub struct Nonnegative {
limbs: Vec<Limb>,
}
impl Nonnegative {
pub fn from_be_bytes_with_bit_length(
input: untrusted::Input,
) -> Result<(Self, bits::BitLength), error::Unspecified> {
let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
while limbs.last() == Some(&0) {
let _ = limbs.pop();
}
let r_bits = limb::limbs_minimal_bits(&limbs);
Ok((Self { limbs }, r_bits))
}
#[inline]
pub fn is_odd(&self) -> bool {
limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
}
pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
if !greater_than(other, self) {
return Err(error::Unspecified);
}
Ok(())
}
pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
self.verify_less_than_modulus(&m)?;
let mut r = m.zero();
r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
Ok(r)
}
pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
if self.limbs.len() > m.limbs.len() {
return Err(error::Unspecified);
}
if self.limbs.len() == m.limbs.len() {
if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
return Err(error::Unspecified);
}
}
Ok(())
}
}
fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
if a.limbs.len() == b.limbs.len() {
limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
} else {
a.limbs.len() > b.limbs.len()
}
}
#[derive(Clone)]
#[repr(transparent)]
struct N0([Limb; 2]);
const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
impl From<u64> for N0 {
#[inline]
fn from(n0: u64) -> Self {
#[cfg(target_pointer_width = "64")]
{
Self([n0, 0])
}
#[cfg(target_pointer_width = "32")]
{
Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
}
}
}
fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
))]
unsafe {
GFp_bn_mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
a.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
#[cfg(not(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
)))]
{
let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
let tmp = &mut tmp[..(2 * a.len())];
limbs_mul(tmp, r, a);
limbs_from_mont_in_place(r, tmp, m, n0);
}
}
fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
extern "C" {
fn GFp_bn_from_montgomery_in_place(
r: *mut Limb,
num_r: c::size_t,
a: *mut Limb,
num_a: c::size_t,
n: *const Limb,
num_n: c::size_t,
n0: &N0,
) -> bssl::Result;
}
Result::from(unsafe {
GFp_bn_from_montgomery_in_place(
r.as_mut_ptr(),
r.len(),
tmp.as_mut_ptr(),
tmp.len(),
m.as_ptr(),
m.len(),
&n0,
)
})
.unwrap()
}
#[cfg(not(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
)))]
fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
debug_assert_eq!(r.len(), 2 * a.len());
debug_assert_eq!(a.len(), b.len());
let ab_len = a.len();
crate::polyfill::slice::fill(&mut r[..ab_len], 0);
for (i, &b_limb) in b.iter().enumerate() {
r[ab_len + i] = unsafe {
GFp_limbs_mul_add_limb(
(&mut r[i..][..ab_len]).as_mut_ptr(),
a.as_ptr(),
b_limb,
ab_len,
)
};
}
}
#[cfg(not(target_arch = "x86_64"))]
fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
))]
unsafe {
GFp_bn_mul_mont(
r.as_mut_ptr(),
a.as_ptr(),
b.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
#[cfg(not(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
)))]
{
let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
let tmp = &mut tmp[..(2 * a.len())];
limbs_mul(tmp, a, b);
limbs_from_mont_in_place(r, tmp, m, n0)
}
}
fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
debug_assert_eq!(r.len(), m.len());
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
))]
unsafe {
GFp_bn_mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
r.as_ptr(),
m.as_ptr(),
n0,
r.len(),
)
}
#[cfg(not(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
)))]
{
let mut tmp = [0; 2 * MODULUS_MAX_LIMBS];
let tmp = &mut tmp[..(2 * r.len())];
limbs_mul(tmp, r, r);
limbs_from_mont_in_place(r, tmp, m, n0)
}
}
extern "C" {
#[cfg(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
))]
fn GFp_bn_mul_mont(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
n: *const Limb,
n0: &N0,
num_limbs: c::size_t,
);
#[cfg(any(
test,
not(any(
target_arch = "aarch64",
target_arch = "arm",
target_arch = "x86_64",
target_arch = "x86"
))
))]
#[must_use]
fn GFp_limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test;
use alloc::format;
struct M {}
unsafe impl PublicModulus for M {}
#[test]
fn test_elem_exp_consttime() {
test::run(
test_file!("bigint_elem_exp_consttime_tests.txt"),
|section, test_case| {
assert_eq!(section, "");
let m = consume_modulus::<M>(test_case, "M");
let expected_result = consume_elem(test_case, "ModExp", &m);
let base = consume_elem(test_case, "A", &m);
let e = {
let bytes = test_case.consume_bytes("E");
PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
.expect("valid exponent")
};
let base = into_encoded(base, &m);
let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
assert_elem_eq(&actual_result, &expected_result);
Ok(())
},
)
}
#[test]
fn test_elem_mul() {
test::run(
test_file!("bigint_elem_mul_tests.txt"),
|section, test_case| {
assert_eq!(section, "");
let m = consume_modulus::<M>(test_case, "M");
let expected_result = consume_elem(test_case, "ModMul", &m);
let a = consume_elem(test_case, "A", &m);
let b = consume_elem(test_case, "B", &m);
let b = into_encoded(b, &m);
let a = into_encoded(a, &m);
let actual_result = elem_mul(&a, b, &m);
let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result);
Ok(())
},
)
}
#[test]
fn test_elem_squared() {
test::run(
test_file!("bigint_elem_squared_tests.txt"),
|section, test_case| {
assert_eq!(section, "");
let m = consume_modulus::<M>(test_case, "M");
let expected_result = consume_elem(test_case, "ModSquare", &m);
let a = consume_elem(test_case, "A", &m);
let a = into_encoded(a, &m);
let actual_result = elem_squared(a, &m.as_partial());
let actual_result = actual_result.into_unencoded(&m);
assert_elem_eq(&actual_result, &expected_result);
Ok(())
},
)
}
#[test]
fn test_elem_reduced() {
test::run(
test_file!("bigint_elem_reduced_tests.txt"),
|section, test_case| {
assert_eq!(section, "");
struct MM {}
unsafe impl SmallerModulus<MM> for M {}
unsafe impl NotMuchSmallerModulus<MM> for M {}
let m = consume_modulus::<M>(test_case, "M");
let expected_result = consume_elem(test_case, "R", &m);
let a =
consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
let actual_result = elem_reduced(&a, &m);
let oneRR = m.oneRR();
let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
assert_elem_eq(&actual_result, &expected_result);
Ok(())
},
)
}
#[test]
fn test_elem_reduced_once() {
test::run(
test_file!("bigint_elem_reduced_once_tests.txt"),
|section, test_case| {
assert_eq!(section, "");
struct N {}
struct QQ {}
unsafe impl SmallerModulus<N> for QQ {}
unsafe impl SlightlySmallerModulus<N> for QQ {}
let qq = consume_modulus::<QQ>(test_case, "QQ");
let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
let n = consume_modulus::<N>(test_case, "N");
let a = consume_elem::<N>(test_case, "A", &n);
let actual_result = elem_reduced_once(&a, &qq);
assert_elem_eq(&actual_result, &expected_result);
Ok(())
},
)
}
#[test]
fn test_modulus_debug() {
let (modulus, _) = Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(
&[0xff; LIMB_BYTES * MODULUS_MIN_LIMBS],
))
.unwrap();
assert_eq!("Modulus", format!("{:?}", modulus));
}
#[test]
fn test_public_exponent_debug() {
let exponent =
PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
.unwrap();
assert_eq!("PublicExponent(65537)", format!("{:?}", exponent));
}
fn consume_elem<M>(
test_case: &mut test::TestCase,
name: &str,
m: &Modulus<M>,
) -> Elem<M, Unencoded> {
let value = test_case.consume_bytes(name);
Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
}
fn consume_elem_unchecked<M>(
test_case: &mut test::TestCase,
name: &str,
num_limbs: usize,
) -> Elem<M, Unencoded> {
let value = consume_nonnegative(test_case, name);
let mut limbs = BoxedLimbs::zero(Width {
num_limbs,
m: PhantomData,
});
limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
Elem {
limbs,
encoding: PhantomData,
}
}
fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
let value = test_case.consume_bytes(name);
let (value, _) =
Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
value
}
fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
let bytes = test_case.consume_bytes(name);
let (r, _r_bits) =
Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
r
}
fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
if elem_verify_equal_consttime(&a, b).is_err() {
panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
}
}
fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
elem_mul(m.oneRR().as_ref(), a, m)
}
#[test]
fn test_mul_add_words() {
const ZERO: Limb = 0;
const MAX: Limb = ZERO.wrapping_sub(1);
static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
(&[0], &[0], 0, 0, &[0]),
(&[MAX], &[0], MAX, 0, &[MAX]),
(&[0], &[MAX], MAX, MAX - 1, &[1]),
(&[MAX], &[MAX], MAX, MAX, &[0]),
(&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
(&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
(&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
(&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
(&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
];
for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
extern crate std;
let mut r = std::vec::Vec::from(*r_input);
assert_eq!(r.len(), a.len()); let actual_retval =
unsafe { GFp_limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
assert_eq!(
actual_retval, *expected_retval,
"{}: {:x?} != {:x?}",
i, actual_retval, *expected_retval
);
}
}
}