use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
DecodeError, DecodeSliceError, PAD_BYTE,
};
#[doc(hidden)]
pub struct GeneralPurposeEstimate {
rem: usize,
conservative_decoded_len: usize,
}
impl GeneralPurposeEstimate {
pub(crate) fn new(encoded_len: usize) -> Self {
let rem = encoded_len % 4;
Self {
rem,
conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
}
}
}
impl DecodeEstimate for GeneralPurposeEstimate {
fn decoded_len_estimate(&self) -> usize {
self.conservative_decoded_len
}
}
#[inline]
pub(crate) fn decode_helper(
input: &[u8],
estimate: GeneralPurposeEstimate,
output: &mut [u8],
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<DecodeMetadata, DecodeSliceError> {
let input_complete_nonterminal_quads_len =
complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
let input_complete_quads_after_unrolled_chunks_len =
input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
let input_unrolled_loop_len =
input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
for (chunk_index, chunk) in input[..input_unrolled_loop_len]
.chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
.enumerate()
{
let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
decode_chunk_8(
&chunk[0..8],
input_index,
decode_table,
&mut chunk_output[0..6],
)?;
decode_chunk_8(
&chunk[8..16],
input_index + 8,
decode_table,
&mut chunk_output[6..12],
)?;
decode_chunk_8(
&chunk[16..24],
input_index + 16,
decode_table,
&mut chunk_output[12..18],
)?;
decode_chunk_8(
&chunk[24..32],
input_index + 24,
decode_table,
&mut chunk_output[18..24],
)?;
}
let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
{
let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
for (chunk_index, chunk) in input
[input_unrolled_loop_len..input_complete_nonterminal_quads_len]
.chunks_exact(4)
.enumerate()
{
let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
decode_chunk_4(
chunk,
input_unrolled_loop_len + chunk_index * 4,
decode_table,
chunk_output,
)?;
}
}
super::decode_suffix::decode_suffix(
input,
input_complete_nonterminal_quads_len,
output,
output_complete_quad_len,
decode_table,
decode_allow_trailing_bits,
padding_mode,
)
}
pub(crate) fn complete_quads_len(
input: &[u8],
input_len_rem: usize,
output_len: usize,
decode_table: &[u8; 256],
) -> Result<usize, DecodeSliceError> {
debug_assert!(input.len() % 4 == input_len_rem);
if input_len_rem == 1 {
let last_byte = input[input.len() - 1];
if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
}
};
let input_complete_nonterminal_quads_len = input
.len()
.saturating_sub(input_len_rem)
.saturating_sub((input_len_rem == 0) as usize * 4);
debug_assert!(
input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
);
if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
return Err(DecodeSliceError::OutputSliceTooSmall);
};
Ok(input_complete_nonterminal_quads_len)
}
#[inline(always)]
fn decode_chunk_8(
input: &[u8],
index_at_start_of_input: usize,
decode_table: &[u8; 256],
output: &mut [u8],
) -> Result<(), DecodeError> {
let morsel = decode_table[usize::from(input[0])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
}
let mut accum = u64::from(morsel) << 58;
let morsel = decode_table[usize::from(input[1])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 1,
input[1],
));
}
accum |= u64::from(morsel) << 52;
let morsel = decode_table[usize::from(input[2])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 2,
input[2],
));
}
accum |= u64::from(morsel) << 46;
let morsel = decode_table[usize::from(input[3])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 3,
input[3],
));
}
accum |= u64::from(morsel) << 40;
let morsel = decode_table[usize::from(input[4])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 4,
input[4],
));
}
accum |= u64::from(morsel) << 34;
let morsel = decode_table[usize::from(input[5])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 5,
input[5],
));
}
accum |= u64::from(morsel) << 28;
let morsel = decode_table[usize::from(input[6])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 6,
input[6],
));
}
accum |= u64::from(morsel) << 22;
let morsel = decode_table[usize::from(input[7])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 7,
input[7],
));
}
accum |= u64::from(morsel) << 16;
output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
Ok(())
}
#[inline(always)]
fn decode_chunk_4(
input: &[u8],
index_at_start_of_input: usize,
decode_table: &[u8; 256],
output: &mut [u8],
) -> Result<(), DecodeError> {
let morsel = decode_table[usize::from(input[0])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
}
let mut accum = u32::from(morsel) << 26;
let morsel = decode_table[usize::from(input[1])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 1,
input[1],
));
}
accum |= u32::from(morsel) << 20;
let morsel = decode_table[usize::from(input[2])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 2,
input[2],
));
}
accum |= u32::from(morsel) << 14;
let morsel = decode_table[usize::from(input[3])];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(
index_at_start_of_input + 3,
input[3],
));
}
accum |= u32::from(morsel) << 8;
output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::general_purpose::STANDARD;
#[test]
fn decode_chunk_8_writes_only_6_bytes() {
let input = b"Zm9vYmFy"; let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
}
#[test]
fn decode_chunk_4_writes_only_3_bytes() {
let input = b"Zm9v"; let mut output = [0_u8, 1, 2, 3];
decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
}
#[test]
fn estimate_short_lengths() {
for (range, decoded_len_estimate) in [
(0..=0, 0),
(1..=4, 3),
(5..=8, 6),
(9..=12, 9),
(13..=16, 12),
(17..=20, 15),
] {
for encoded_len in range {
let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
}
}
}
#[test]
fn estimate_via_u128_inflation() {
(0..1000)
.chain(usize::MAX - 1000..=usize::MAX)
.for_each(|encoded_len| {
let len_128 = encoded_len as u128;
let estimate = GeneralPurposeEstimate::new(encoded_len);
assert_eq!(
(len_128 + 3) / 4 * 3,
estimate.conservative_decoded_len as u128
);
})
}
}