use std::io;
pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};
use crate::dict::{DecoderDictionary, EncoderDictionary};
use crate::map_error_code;
pub trait Operation {
fn run<C: WriteBuf + ?Sized>(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize>;
fn run_on_buffers(
&mut self,
input: &[u8],
output: &mut [u8],
) -> io::Result<Status> {
let mut input = InBuffer::around(input);
let mut output = OutBuffer::around(output);
let remaining = self.run(&mut input, &mut output)?;
Ok(Status {
remaining,
bytes_read: input.pos(),
bytes_written: output.pos(),
})
}
fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
let _ = output;
Ok(0)
}
fn reinit(&mut self) -> io::Result<()> {
Ok(())
}
fn finish<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
finished_frame: bool,
) -> io::Result<usize> {
let _ = output;
let _ = finished_frame;
Ok(0)
}
}
pub struct NoOp;
impl Operation for NoOp {
fn run<C: WriteBuf + ?Sized>(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
let src = &input.src[input.pos..];
let output_pos = output.pos();
let dst = unsafe { output.as_mut_ptr().add(output_pos) };
let len = usize::min(src.len(), output.capacity() - output_pos);
let src = &src[..len];
unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len) };
input.set_pos(input.pos() + len);
unsafe { output.set_pos(output_pos + len) };
Ok(0)
}
}
pub struct Status {
pub remaining: usize,
pub bytes_read: usize,
pub bytes_written: usize,
}
pub struct Decoder<'a> {
context: MaybeOwnedDCtx<'a>,
}
impl Decoder<'static> {
pub fn new() -> io::Result<Self> {
Self::with_dictionary(&[])
}
pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::DCtx::create();
context.init().map_err(map_error_code)?;
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
}
impl<'a> Decoder<'a> {
pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
Self {
context: MaybeOwnedDCtx::Borrowed(context),
}
}
pub fn with_prepared_dictionary<'b>(
dictionary: &DecoderDictionary<'b>,
) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::DCtx::create();
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
pub fn with_ref_prefix<'b>(ref_prefix: &'b [u8]) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::DCtx::create();
context.ref_prefix(ref_prefix).map_err(map_error_code)?;
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}
}
impl Operation for Decoder<'_> {
fn run<C: WriteBuf + ?Sized>(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
}
.map_err(map_error_code)
}
fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.run(&mut InBuffer::around(&[]), output)?;
if output.pos() < output.capacity() {
Ok(0)
} else {
Ok(1)
}
}
fn reinit(&mut self) -> io::Result<()> {
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedDCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}
fn finish<C: WriteBuf + ?Sized>(
&mut self,
_output: &mut OutBuffer<'_, C>,
finished_frame: bool,
) -> io::Result<usize> {
if finished_frame {
Ok(0)
} else {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"incomplete frame",
))
}
}
}
pub struct Encoder<'a> {
context: MaybeOwnedCCtx<'a>,
}
impl Encoder<'static> {
pub fn new(level: i32) -> io::Result<Self> {
Self::with_dictionary(level, &[])
}
pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
let mut context = zstd_safe::CCtx::create();
context
.set_parameter(CParameter::CompressionLevel(level))
.map_err(map_error_code)?;
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
}
impl<'a> Encoder<'a> {
pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
Self {
context: MaybeOwnedCCtx::Borrowed(context),
}
}
pub fn with_prepared_dictionary<'b>(
dictionary: &EncoderDictionary<'b>,
) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::CCtx::create();
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
pub fn with_ref_prefix<'b>(
level: i32,
ref_prefix: &'b [u8],
) -> io::Result<Self>
where
'b: 'a,
{
let mut context = zstd_safe::CCtx::create();
context
.set_parameter(CParameter::CompressionLevel(level))
.map_err(map_error_code)?;
context.ref_prefix(ref_prefix).map_err(map_error_code)?;
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}
pub fn set_pledged_src_size(
&mut self,
pledged_src_size: Option<u64>,
) -> io::Result<()> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.set_pledged_src_size(pledged_src_size)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.set_pledged_src_size(pledged_src_size)
}
}
.map_err(map_error_code)?;
Ok(())
}
}
impl<'a> Operation for Encoder<'a> {
fn run<C: WriteBuf + ?Sized>(
&mut self,
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
}
.map_err(map_error_code)
}
fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
}
.map_err(map_error_code)
}
fn finish<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
_finished_frame: bool,
) -> io::Result<usize> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
}
.map_err(map_error_code)
}
fn reinit(&mut self) -> io::Result<()> {
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}
}
enum MaybeOwnedCCtx<'a> {
Owned(zstd_safe::CCtx<'a>),
Borrowed(&'a mut zstd_safe::CCtx<'static>),
}
enum MaybeOwnedDCtx<'a> {
Owned(zstd_safe::DCtx<'a>),
Borrowed(&'a mut zstd_safe::DCtx<'static>),
}
#[cfg(test)]
mod tests {
#[cfg(feature = "arrays")]
#[test]
fn test_cycle() {
use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
let mut encoder = Encoder::new(1).unwrap();
let mut decoder = Decoder::new().unwrap();
let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
let mut output = [0u8; 128];
let mut output = OutBuffer::around(&mut output);
loop {
encoder.run(&mut input, &mut output).unwrap();
if input.pos == input.src.len() {
break;
}
}
encoder.finish(&mut output, true).unwrap();
let initial_data = input.src;
let mut input = InBuffer::around(output.as_slice());
let mut output = [0u8; 128];
let mut output = OutBuffer::around(&mut output);
loop {
decoder.run(&mut input, &mut output).unwrap();
if input.pos == input.src.len() {
break;
}
}
assert_eq!(initial_data, output.as_slice());
}
}