1use simd_adler32::Adler32;
2
3use crate::tables::{
4 self, CLCL_ORDER, DIST_SYM_TO_DIST_BASE, DIST_SYM_TO_DIST_EXTRA, FDEFLATE_DIST_DECODE_TABLE,
5 FDEFLATE_LITLEN_DECODE_TABLE, FIXED_CODE_LENGTHS, LEN_SYM_TO_LEN_BASE, LEN_SYM_TO_LEN_EXTRA,
6};
7
8#[derive(Debug, PartialEq)]
10pub enum DecompressionError {
11 BadZlibHeader,
13 InsufficientInput,
15 InvalidBlockType,
17 InvalidUncompressedBlockLength,
19 InvalidHlit,
21 InvalidHdist,
23 InvalidCodeLengthRepeat,
26 BadCodeLengthHuffmanTree,
28 BadLiteralLengthHuffmanTree,
30 BadDistanceHuffmanTree,
32 InvalidLiteralLengthCode,
34 InvalidDistanceCode,
36 InputStartsWithRun,
38 DistanceTooFarBack,
40 WrongChecksum,
42 ExtraInput,
44}
45
46struct BlockHeader {
47 hlit: usize,
48 hdist: usize,
49 hclen: usize,
50 num_lengths_read: usize,
51
52 table: [u8; 128],
54 code_lengths: [u8; 320],
55}
56
57const LITERAL_ENTRY: u32 = 0x8000;
58const EXCEPTIONAL_ENTRY: u32 = 0x4000;
59const SECONDARY_TABLE_ENTRY: u32 = 0x2000;
60
61#[repr(align(64))]
77#[derive(Eq, PartialEq, Debug)]
78struct CompressedBlock {
79 litlen_table: [u32; 4096],
80 dist_table: [u32; 512],
81
82 dist_symbol_lengths: [u8; 30],
83 dist_symbol_masks: [u16; 30],
84 dist_symbol_codes: [u16; 30],
85
86 secondary_table: Vec<u16>,
87 eof_code: u16,
88 eof_mask: u16,
89 eof_bits: u8,
90}
91
92const FDEFLATE_COMPRESSED_BLOCK: CompressedBlock = CompressedBlock {
93 litlen_table: FDEFLATE_LITLEN_DECODE_TABLE,
94 dist_table: FDEFLATE_DIST_DECODE_TABLE,
95 dist_symbol_lengths: [
96 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
97 ],
98 dist_symbol_masks: [
99 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100 ],
101 dist_symbol_codes: [
102 0, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
103 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
104 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
105 ],
106 secondary_table: Vec::new(),
107 eof_code: 0x8ff,
108 eof_mask: 0xfff,
109 eof_bits: 0xc,
110};
111
112#[derive(Debug, Copy, Clone, Eq, PartialEq)]
113enum State {
114 ZlibHeader,
115 BlockHeader,
116 CodeLengthCodes,
117 CodeLengths,
118 CompressedData,
119 UncompressedData,
120 Checksum,
121 Done,
122}
123
124pub struct Decompressor {
126 compression: CompressedBlock,
128 header: BlockHeader,
130 uncompressed_bytes_left: u16,
132
133 buffer: u64,
134 nbits: u8,
135
136 queued_rle: Option<(u8, usize)>,
137 queued_backref: Option<(usize, usize)>,
138 last_block: bool,
139
140 state: State,
141 checksum: Adler32,
142 ignore_adler32: bool,
143}
144
145impl Default for Decompressor {
146 fn default() -> Self {
147 Self::new()
148 }
149}
150
151impl Decompressor {
152 pub fn new() -> Self {
154 Self {
155 buffer: 0,
156 nbits: 0,
157 compression: CompressedBlock {
158 litlen_table: [0; 4096],
159 dist_table: [0; 512],
160 secondary_table: Vec::new(),
161 dist_symbol_lengths: [0; 30],
162 dist_symbol_masks: [0; 30],
163 dist_symbol_codes: [0xffff; 30],
164 eof_code: 0,
165 eof_mask: 0,
166 eof_bits: 0,
167 },
168 header: BlockHeader {
169 hlit: 0,
170 hdist: 0,
171 hclen: 0,
172 table: [0; 128],
173 num_lengths_read: 0,
174 code_lengths: [0; 320],
175 },
176 uncompressed_bytes_left: 0,
177 queued_rle: None,
178 queued_backref: None,
179 checksum: Adler32::new(),
180 state: State::ZlibHeader,
181 last_block: false,
182 ignore_adler32: false,
183 }
184 }
185
186 pub fn ignore_adler32(&mut self) {
188 self.ignore_adler32 = true;
189 }
190
191 fn fill_buffer(&mut self, input: &mut &[u8]) {
192 if input.len() >= 8 {
193 self.buffer |= u64::from_le_bytes(input[..8].try_into().unwrap()) << self.nbits;
194 *input = &input[(63 - self.nbits as usize) / 8..];
195 self.nbits |= 56;
196 } else {
197 let nbytes = input.len().min((63 - self.nbits as usize) / 8);
198 let mut input_data = [0; 8];
199 input_data[..nbytes].copy_from_slice(&input[..nbytes]);
200 self.buffer |= u64::from_le_bytes(input_data)
201 .checked_shl(self.nbits as u32)
202 .unwrap_or(0);
203 self.nbits += nbytes as u8 * 8;
204 *input = &input[nbytes..];
205 }
206 }
207
208 fn peak_bits(&mut self, nbits: u8) -> u64 {
209 debug_assert!(nbits <= 56 && nbits <= self.nbits);
210 self.buffer & ((1u64 << nbits) - 1)
211 }
212 fn consume_bits(&mut self, nbits: u8) {
213 debug_assert!(self.nbits >= nbits);
214 self.buffer >>= nbits;
215 self.nbits -= nbits;
216 }
217
218 fn read_block_header(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
219 self.fill_buffer(remaining_input);
220 if self.nbits < 3 {
221 return Ok(());
222 }
223
224 let start = self.peak_bits(3);
225 self.last_block = start & 1 != 0;
226 match start >> 1 {
227 0b00 => {
228 let align_bits = (self.nbits - 3) % 8;
229 let header_bits = 3 + 32 + align_bits;
230 if self.nbits < header_bits {
231 return Ok(());
232 }
233
234 let len = (self.peak_bits(align_bits + 19) >> (align_bits + 3)) as u16;
235 let nlen = (self.peak_bits(header_bits) >> (align_bits + 19)) as u16;
236 if nlen != !len {
237 return Err(DecompressionError::InvalidUncompressedBlockLength);
238 }
239
240 self.state = State::UncompressedData;
241 self.uncompressed_bytes_left = len;
242 self.consume_bits(header_bits);
243 Ok(())
244 }
245 0b01 => {
246 self.consume_bits(3);
247 Self::build_tables(288, &FIXED_CODE_LENGTHS, &mut self.compression, 6)?;
249 self.state = State::CompressedData;
250 Ok(())
251 }
252 0b10 => {
253 if self.nbits < 17 {
254 return Ok(());
255 }
256
257 self.header.hlit = (self.peak_bits(8) >> 3) as usize + 257;
258 self.header.hdist = (self.peak_bits(13) >> 8) as usize + 1;
259 self.header.hclen = (self.peak_bits(17) >> 13) as usize + 4;
260 if self.header.hlit > 286 {
261 return Err(DecompressionError::InvalidHlit);
262 }
263 if self.header.hdist > 30 {
264 return Err(DecompressionError::InvalidHdist);
265 }
266
267 self.consume_bits(17);
268 self.state = State::CodeLengthCodes;
269 Ok(())
270 }
271 0b11 => Err(DecompressionError::InvalidBlockType),
272 _ => unreachable!(),
273 }
274 }
275
276 fn read_code_length_codes(
277 &mut self,
278 remaining_input: &mut &[u8],
279 ) -> Result<(), DecompressionError> {
280 self.fill_buffer(remaining_input);
281 if self.nbits as usize + remaining_input.len() * 8 < 3 * self.header.hclen {
282 return Ok(());
283 }
284
285 let mut code_length_lengths = [0; 19];
286 for i in 0..self.header.hclen {
287 code_length_lengths[CLCL_ORDER[i]] = self.peak_bits(3) as u8;
288 self.consume_bits(3);
289
290 if i == 17 {
293 self.fill_buffer(remaining_input);
294 }
295 }
296 let code_length_codes: [u16; 19] = crate::compute_codes(&code_length_lengths)
297 .ok_or(DecompressionError::BadCodeLengthHuffmanTree)?;
298
299 self.header.table = [255; 128];
300 for i in 0..19 {
301 let length = code_length_lengths[i];
302 if length > 0 {
303 let mut j = code_length_codes[i];
304 while j < 128 {
305 self.header.table[j as usize] = ((i as u8) << 3) | length;
306 j += 1 << length;
307 }
308 }
309 }
310
311 self.state = State::CodeLengths;
312 self.header.num_lengths_read = 0;
313 Ok(())
314 }
315
316 fn read_code_lengths(&mut self, remaining_input: &mut &[u8]) -> Result<(), DecompressionError> {
317 let total_lengths = self.header.hlit + self.header.hdist;
318 while self.header.num_lengths_read < total_lengths {
319 self.fill_buffer(remaining_input);
320 if self.nbits < 7 {
321 return Ok(());
322 }
323
324 let code = self.peak_bits(7);
325 let entry = self.header.table[code as usize];
326 let length = entry & 0x7;
327 let symbol = entry >> 3;
328
329 debug_assert!(length != 0);
330 match symbol {
331 0..=15 => {
332 self.header.code_lengths[self.header.num_lengths_read] = symbol;
333 self.header.num_lengths_read += 1;
334 self.consume_bits(length);
335 }
336 16..=18 => {
337 let (base_repeat, extra_bits) = match symbol {
338 16 => (3, 2),
339 17 => (3, 3),
340 18 => (11, 7),
341 _ => unreachable!(),
342 };
343
344 if self.nbits < length + extra_bits {
345 return Ok(());
346 }
347
348 let value = match symbol {
349 16 => {
350 self.header.code_lengths[self
351 .header
352 .num_lengths_read
353 .checked_sub(1)
354 .ok_or(DecompressionError::InvalidCodeLengthRepeat)?]
355 }
357 17 => 0,
358 18 => 0,
359 _ => unreachable!(),
360 };
361
362 let repeat =
363 (self.peak_bits(length + extra_bits) >> length) as usize + base_repeat;
364 if self.header.num_lengths_read + repeat > total_lengths {
365 return Err(DecompressionError::InvalidCodeLengthRepeat);
366 }
367
368 for i in 0..repeat {
369 self.header.code_lengths[self.header.num_lengths_read + i] = value;
370 }
371 self.header.num_lengths_read += repeat;
372 self.consume_bits(length + extra_bits);
373 }
374 _ => unreachable!(),
375 }
376 }
377
378 self.header
379 .code_lengths
380 .copy_within(self.header.hlit..total_lengths, 288);
381 for i in self.header.hlit..288 {
382 self.header.code_lengths[i] = 0;
383 }
384 for i in 288 + self.header.hdist..320 {
385 self.header.code_lengths[i] = 0;
386 }
387
388 if self.header.hdist == 1
389 && self.header.code_lengths[..286] == tables::HUFFMAN_LENGTHS
390 && self.header.code_lengths[288] == 1
391 {
392 self.compression = FDEFLATE_COMPRESSED_BLOCK;
393 } else {
394 Self::build_tables(
395 self.header.hlit,
396 &self.header.code_lengths,
397 &mut self.compression,
398 6,
399 )?;
400 }
401 self.state = State::CompressedData;
402 Ok(())
403 }
404
405 fn build_tables(
406 hlit: usize,
407 code_lengths: &[u8],
408 compression: &mut CompressedBlock,
409 max_search_bits: u8,
410 ) -> Result<(), DecompressionError> {
411 if code_lengths[256] == 0 {
413 return Err(DecompressionError::BadLiteralLengthHuffmanTree);
415 }
416
417 let lengths = &code_lengths[..288];
419 let codes: [u16; 288] = crate::compute_codes(&lengths.try_into().unwrap())
420 .ok_or(DecompressionError::BadLiteralLengthHuffmanTree)?;
421
422 let table_bits = lengths.iter().cloned().max().unwrap().min(12).max(6);
423 let table_size = 1 << table_bits;
424
425 for i in 0..256 {
426 let code = codes[i];
427 let length = lengths[i];
428 let mut j = code;
429
430 while j < table_size && length != 0 && length <= 12 {
431 compression.litlen_table[j as usize] =
432 ((i as u32) << 16) | LITERAL_ENTRY | (1 << 8) | length as u32;
433 j += 1 << length;
434 }
435
436 if length > 0 && length <= max_search_bits {
437 for ii in 0..256 {
438 let code2 = codes[ii];
439 let length2 = lengths[ii];
440 if length2 != 0 && length + length2 <= table_bits {
441 let mut j = code | (code2 << length);
442
443 while j < table_size {
444 compression.litlen_table[j as usize] = (ii as u32) << 24
445 | (i as u32) << 16
446 | LITERAL_ENTRY
447 | (2 << 8)
448 | ((length + length2) as u32);
449 j += 1 << (length + length2);
450 }
451 }
452 }
453 }
454 }
455
456 if lengths[256] != 0 && lengths[256] <= 12 {
457 let mut j = codes[256];
458 while j < table_size {
459 compression.litlen_table[j as usize] = EXCEPTIONAL_ENTRY | lengths[256] as u32;
460 j += 1 << lengths[256];
461 }
462 }
463
464 let table_size = table_size as usize;
465 for i in (table_size..4096).step_by(table_size) {
466 compression.litlen_table.copy_within(0..table_size, i);
467 }
468
469 compression.eof_code = codes[256];
470 compression.eof_mask = (1 << lengths[256]) - 1;
471 compression.eof_bits = lengths[256];
472
473 for i in 257..hlit {
474 let code = codes[i];
475 let length = lengths[i];
476 if length != 0 && length <= 12 {
477 let mut j = code;
478 while j < 4096 {
479 compression.litlen_table[j as usize] = if i < 286 {
480 (LEN_SYM_TO_LEN_BASE[i - 257] as u32) << 16
481 | (LEN_SYM_TO_LEN_EXTRA[i - 257] as u32) << 8
482 | length as u32
483 } else {
484 EXCEPTIONAL_ENTRY
485 };
486 j += 1 << length;
487 }
488 }
489 }
490
491 for i in 0..hlit {
492 if lengths[i] > 12 {
493 compression.litlen_table[(codes[i] & 0xfff) as usize] = u32::MAX;
494 }
495 }
496
497 let mut secondary_table_len = 0;
498 for i in 0..hlit {
499 if lengths[i] > 12 {
500 let j = (codes[i] & 0xfff) as usize;
501 if compression.litlen_table[j] == u32::MAX {
502 compression.litlen_table[j] =
503 (secondary_table_len << 16) | EXCEPTIONAL_ENTRY | SECONDARY_TABLE_ENTRY;
504 secondary_table_len += 8;
505 }
506 }
507 }
508 assert!(secondary_table_len <= 0x7ff);
509 compression.secondary_table = vec![0; secondary_table_len as usize];
510 for i in 0..hlit {
511 let code = codes[i];
512 let length = lengths[i];
513 if length > 12 {
514 let j = (codes[i] & 0xfff) as usize;
515 let k = (compression.litlen_table[j] >> 16) as usize;
516
517 let mut s = code >> 12;
518 while s < 8 {
519 debug_assert_eq!(compression.secondary_table[k + s as usize], 0);
520 compression.secondary_table[k + s as usize] =
521 ((i as u16) << 4) | (length as u16);
522 s += 1 << (length - 12);
523 }
524 }
525 }
526 debug_assert!(compression
527 .secondary_table
528 .iter()
529 .all(|&x| x != 0 && (x & 0xf) > 12));
530
531 let lengths = &code_lengths[288..320];
533 if lengths == [0; 32] {
534 compression.dist_symbol_masks = [0; 30];
535 compression.dist_symbol_codes = [0xffff; 30];
536 compression.dist_table.fill(0);
537 } else {
538 let codes: [u16; 32] = match crate::compute_codes(&lengths.try_into().unwrap()) {
539 Some(codes) => codes,
540 None => {
541 if lengths.iter().filter(|&&l| l != 0).count() != 1 {
542 return Err(DecompressionError::BadDistanceHuffmanTree);
543 }
544 [0; 32]
545 }
546 };
547
548 compression.dist_symbol_codes.copy_from_slice(&codes[..30]);
549 compression
550 .dist_symbol_lengths
551 .copy_from_slice(&lengths[..30]);
552 compression.dist_table.fill(0);
553 for i in 0..30 {
554 let length = lengths[i];
555 let code = codes[i];
556 if length == 0 {
557 compression.dist_symbol_masks[i] = 0;
558 compression.dist_symbol_codes[i] = 0xffff;
559 } else {
560 compression.dist_symbol_masks[i] = (1 << lengths[i]) - 1;
561 if lengths[i] <= 9 {
562 let mut j = code;
563 while j < 512 {
564 compression.dist_table[j as usize] = (DIST_SYM_TO_DIST_BASE[i] as u32)
565 << 16
566 | (DIST_SYM_TO_DIST_EXTRA[i] as u32) << 8
567 | length as u32;
568 j += 1 << lengths[i];
569 }
570 }
571 }
572 }
573 }
574
575 Ok(())
576 }
577
578 fn read_compressed(
579 &mut self,
580 remaining_input: &mut &[u8],
581 output: &mut [u8],
582 mut output_index: usize,
583 ) -> Result<usize, DecompressionError> {
584 while let State::CompressedData = self.state {
585 self.fill_buffer(remaining_input);
586 if output_index == output.len() {
587 break;
588 }
589
590 let mut bits = self.buffer;
591 let litlen_entry = self.compression.litlen_table[(bits & 0xfff) as usize];
592 let litlen_code_bits = litlen_entry as u8;
593
594 if litlen_entry & LITERAL_ENTRY != 0 {
595 if self.nbits >= 48 {
597 let litlen_entry2 =
598 self.compression.litlen_table[(bits >> litlen_code_bits & 0xfff) as usize];
599 let litlen_code_bits2 = litlen_entry2 as u8;
600 let litlen_entry3 = self.compression.litlen_table
601 [(bits >> (litlen_code_bits + litlen_code_bits2) & 0xfff) as usize];
602 let litlen_code_bits3 = litlen_entry3 as u8;
603 let litlen_entry4 = self.compression.litlen_table[(bits
604 >> (litlen_code_bits + litlen_code_bits2 + litlen_code_bits3)
605 & 0xfff)
606 as usize];
607 let litlen_code_bits4 = litlen_entry4 as u8;
608 if litlen_entry2 & litlen_entry3 & litlen_entry4 & LITERAL_ENTRY != 0 {
609 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
610 let advance_output_bytes2 = ((litlen_entry2 & 0xf00) >> 8) as usize;
611 let advance_output_bytes3 = ((litlen_entry3 & 0xf00) >> 8) as usize;
612 let advance_output_bytes4 = ((litlen_entry4 & 0xf00) >> 8) as usize;
613 if output_index
614 + advance_output_bytes
615 + advance_output_bytes2
616 + advance_output_bytes3
617 + advance_output_bytes4
618 < output.len()
619 {
620 self.consume_bits(
621 litlen_code_bits
622 + litlen_code_bits2
623 + litlen_code_bits3
624 + litlen_code_bits4,
625 );
626
627 output[output_index] = (litlen_entry >> 16) as u8;
628 output[output_index + 1] = (litlen_entry >> 24) as u8;
629 output_index += advance_output_bytes;
630 output[output_index] = (litlen_entry2 >> 16) as u8;
631 output[output_index + 1] = (litlen_entry2 >> 24) as u8;
632 output_index += advance_output_bytes2;
633 output[output_index] = (litlen_entry3 >> 16) as u8;
634 output[output_index + 1] = (litlen_entry3 >> 24) as u8;
635 output_index += advance_output_bytes3;
636 output[output_index] = (litlen_entry4 >> 16) as u8;
637 output[output_index + 1] = (litlen_entry4 >> 24) as u8;
638 output_index += advance_output_bytes4;
639 continue;
640 }
641 }
642 }
643
644 let advance_output_bytes = ((litlen_entry & 0xf00) >> 8) as usize;
647
648 if self.nbits < litlen_code_bits {
664 break;
665 } else if output_index + 1 < output.len() {
666 output[output_index] = (litlen_entry >> 16) as u8;
667 output[output_index + 1] = (litlen_entry >> 24) as u8;
668 output_index += advance_output_bytes;
669 self.consume_bits(litlen_code_bits);
670 continue;
671 } else if output_index + advance_output_bytes == output.len() {
672 debug_assert_eq!(advance_output_bytes, 1);
673 output[output_index] = (litlen_entry >> 16) as u8;
674 output_index += 1;
675 self.consume_bits(litlen_code_bits);
676 break;
677 } else {
678 debug_assert_eq!(advance_output_bytes, 2);
679 output[output_index] = (litlen_entry >> 16) as u8;
680 self.queued_rle = Some(((litlen_entry >> 24) as u8, 1));
681 output_index += 1;
682 self.consume_bits(litlen_code_bits);
683 break;
684 }
685 }
686
687 let (length_base, length_extra_bits, litlen_code_bits) =
688 if litlen_entry & EXCEPTIONAL_ENTRY == 0 {
689 (
690 litlen_entry >> 16,
691 (litlen_entry >> 8) as u8,
692 litlen_code_bits,
693 )
694 } else if litlen_entry & SECONDARY_TABLE_ENTRY != 0 {
695 let secondary_index = litlen_entry >> 16;
696 let secondary_entry = self.compression.secondary_table
697 [secondary_index as usize + ((bits >> 12) & 0x7) as usize];
698 let litlen_symbol = secondary_entry >> 4;
699 let litlen_code_bits = (secondary_entry & 0xf) as u8;
700
701 if self.nbits < litlen_code_bits {
702 break;
703 } else if litlen_symbol < 256 {
704 self.consume_bits(litlen_code_bits);
707 output[output_index] = litlen_symbol as u8;
708 output_index += 1;
709 continue;
710 } else if litlen_symbol == 256 {
711 self.consume_bits(litlen_code_bits);
713 self.state = match self.last_block {
714 true => State::Checksum,
715 false => State::BlockHeader,
716 };
717 break;
718 }
719
720 (
721 LEN_SYM_TO_LEN_BASE[litlen_symbol as usize - 257] as u32,
722 LEN_SYM_TO_LEN_EXTRA[litlen_symbol as usize - 257],
723 litlen_code_bits,
724 )
725 } else if litlen_code_bits == 0 {
726 return Err(DecompressionError::InvalidLiteralLengthCode);
727 } else {
728 if self.nbits < litlen_code_bits {
729 break;
730 }
731 self.consume_bits(litlen_code_bits);
733 self.state = match self.last_block {
734 true => State::Checksum,
735 false => State::BlockHeader,
736 };
737 break;
738 };
739 bits >>= litlen_code_bits;
740
741 let length_extra_mask = (1 << length_extra_bits) - 1;
742 let length = length_base as usize + (bits & length_extra_mask) as usize;
743 bits >>= length_extra_bits;
744
745 let dist_entry = self.compression.dist_table[(bits & 0x1ff) as usize];
746 let (dist_base, dist_extra_bits, dist_code_bits) = if dist_entry != 0 {
747 (
748 (dist_entry >> 16) as u16,
749 (dist_entry >> 8) as u8,
750 dist_entry as u8,
751 )
752 } else if self.nbits > litlen_code_bits + length_extra_bits + 9 {
753 let mut dist_extra_bits = 0;
754 let mut dist_base = 0;
755 let mut dist_advance_bits = 0;
756 for i in 0..self.compression.dist_symbol_lengths.len() {
757 if bits as u16 & self.compression.dist_symbol_masks[i]
758 == self.compression.dist_symbol_codes[i]
759 {
760 dist_extra_bits = DIST_SYM_TO_DIST_EXTRA[i];
761 dist_base = DIST_SYM_TO_DIST_BASE[i];
762 dist_advance_bits = self.compression.dist_symbol_lengths[i];
763 break;
764 }
765 }
766 if dist_advance_bits == 0 {
767 return Err(DecompressionError::InvalidDistanceCode);
768 }
769 (dist_base, dist_extra_bits, dist_advance_bits)
770 } else {
771 break;
772 };
773 bits >>= dist_code_bits;
774
775 let dist = dist_base as usize + (bits & ((1 << dist_extra_bits) - 1)) as usize;
776 let total_bits =
777 litlen_code_bits + length_extra_bits + dist_code_bits + dist_extra_bits;
778
779 if self.nbits < total_bits {
780 break;
781 } else if dist > output_index {
782 return Err(DecompressionError::DistanceTooFarBack);
783 }
784
785 self.consume_bits(total_bits);
787
788 let copy_length = length.min(output.len() - output_index);
789 if dist == 1 {
790 let last = output[output_index - 1];
791 output[output_index..][..copy_length].fill(last);
792
793 if copy_length < length {
794 self.queued_rle = Some((last, length - copy_length));
795 output_index = output.len();
796 break;
797 }
798 } else if output_index + length + 15 <= output.len() {
799 let start = output_index - dist;
800 output.copy_within(start..start + 16, output_index);
801
802 if length > 16 || dist < 16 {
803 for i in (0..length).step_by(dist.min(16)).skip(1) {
804 output.copy_within(start + i..start + i + 16, output_index + i);
805 }
806 }
807 } else {
808 if dist < copy_length {
809 for i in 0..copy_length {
810 output[output_index + i] = output[output_index + i - dist];
811 }
812 } else {
813 output.copy_within(
814 output_index - dist..output_index + copy_length - dist,
815 output_index,
816 )
817 }
818
819 if copy_length < length {
820 self.queued_backref = Some((dist, length - copy_length));
821 output_index = output.len();
822 break;
823 }
824 }
825 output_index += copy_length;
826 }
827
828 if self.state == State::CompressedData
829 && self.queued_backref.is_none()
830 && self.queued_rle.is_none()
831 && self.nbits >= 15
832 && self.peak_bits(15) as u16 & self.compression.eof_mask == self.compression.eof_code
833 {
834 self.consume_bits(self.compression.eof_bits);
835 self.state = match self.last_block {
836 true => State::Checksum,
837 false => State::BlockHeader,
838 };
839 }
840
841 Ok(output_index)
842 }
843
844 pub fn read(
863 &mut self,
864 input: &[u8],
865 output: &mut [u8],
866 output_position: usize,
867 end_of_input: bool,
868 ) -> Result<(usize, usize), DecompressionError> {
869 if let State::Done = self.state {
870 return Ok((0, 0));
871 }
872
873 assert!(output_position <= output.len());
874
875 let mut remaining_input = input;
876 let mut output_index = output_position;
877
878 if let Some((data, len)) = self.queued_rle.take() {
879 let n = len.min(output.len() - output_index);
880 output[output_index..][..n].fill(data);
881 output_index += n;
882 if n < len {
883 self.queued_rle = Some((data, len - n));
884 return Ok((0, n));
885 }
886 }
887 if let Some((dist, len)) = self.queued_backref.take() {
888 let n = len.min(output.len() - output_index);
889 for i in 0..n {
890 output[output_index + i] = output[output_index + i - dist];
891 }
892 output_index += n;
893 if n < len {
894 self.queued_backref = Some((dist, len - n));
895 return Ok((0, n));
896 }
897 }
898
899 let mut last_state = None;
901 while last_state != Some(self.state) {
902 last_state = Some(self.state);
903 match self.state {
904 State::ZlibHeader => {
905 self.fill_buffer(&mut remaining_input);
906 if self.nbits < 16 {
907 break;
908 }
909
910 let input0 = self.peak_bits(8);
911 let input1 = self.peak_bits(16) >> 8 & 0xff;
912 if input0 & 0x0f != 0x08
913 || (input0 & 0xf0) > 0x70
914 || input1 & 0x20 != 0
915 || (input0 << 8 | input1) % 31 != 0
916 {
917 return Err(DecompressionError::BadZlibHeader);
918 }
919
920 self.consume_bits(16);
921 self.state = State::BlockHeader;
922 }
923 State::BlockHeader => {
924 self.read_block_header(&mut remaining_input)?;
925 }
926 State::CodeLengthCodes => {
927 self.read_code_length_codes(&mut remaining_input)?;
928 }
929 State::CodeLengths => {
930 self.read_code_lengths(&mut remaining_input)?;
931 }
932 State::CompressedData => {
933 output_index =
934 self.read_compressed(&mut remaining_input, output, output_index)?
935 }
936 State::UncompressedData => {
937 debug_assert_eq!(self.nbits % 8, 0);
939 while self.nbits > 0
940 && self.uncompressed_bytes_left > 0
941 && output_index < output.len()
942 {
943 output[output_index] = self.peak_bits(8) as u8;
944 self.consume_bits(8);
945 output_index += 1;
946 self.uncompressed_bytes_left -= 1;
947 }
948 if self.nbits == 0 {
950 self.buffer = 0;
951 }
952
953 let copy_bytes = (self.uncompressed_bytes_left as usize)
955 .min(remaining_input.len())
956 .min(output.len() - output_index);
957 output[output_index..][..copy_bytes]
958 .copy_from_slice(&remaining_input[..copy_bytes]);
959 remaining_input = &remaining_input[copy_bytes..];
960 output_index += copy_bytes;
961 self.uncompressed_bytes_left -= copy_bytes as u16;
962
963 if self.uncompressed_bytes_left == 0 {
964 self.state = if self.last_block {
965 State::Checksum
966 } else {
967 State::BlockHeader
968 };
969 }
970 }
971 State::Checksum => {
972 self.fill_buffer(&mut remaining_input);
973
974 let align_bits = self.nbits % 8;
975 if self.nbits >= 32 + align_bits {
976 self.checksum.write(&output[output_position..output_index]);
977 if align_bits != 0 {
978 self.consume_bits(align_bits);
979 }
980 #[cfg(not(fuzzing))]
981 if !self.ignore_adler32
982 && (self.peak_bits(32) as u32).swap_bytes() != self.checksum.finish()
983 {
984 return Err(DecompressionError::WrongChecksum);
985 }
986 self.state = State::Done;
987 self.consume_bits(32);
988 break;
989 }
990 }
991 State::Done => unreachable!(),
992 }
993 }
994
995 if !self.ignore_adler32 && self.state != State::Done {
996 self.checksum.write(&output[output_position..output_index]);
997 }
998
999 if self.state == State::Done || !end_of_input || output_index == output.len() {
1000 let input_left = remaining_input.len();
1001 Ok((input.len() - input_left, output_index - output_position))
1002 } else {
1003 Err(DecompressionError::InsufficientInput)
1004 }
1005 }
1006
1007 pub fn is_done(&self) -> bool {
1009 self.state == State::Done
1010 }
1011}
1012
1013pub fn decompress_to_vec(input: &[u8]) -> Result<Vec<u8>, DecompressionError> {
1015 match decompress_to_vec_bounded(input, usize::MAX) {
1016 Ok(output) => Ok(output),
1017 Err(BoundedDecompressionError::DecompressionError { inner }) => Err(inner),
1018 Err(BoundedDecompressionError::OutputTooLarge { .. }) => {
1019 unreachable!("Impossible to allocate more than isize::MAX bytes")
1020 }
1021 }
1022}
1023
1024pub enum BoundedDecompressionError {
1026 DecompressionError {
1028 inner: DecompressionError,
1030 },
1031
1032 OutputTooLarge {
1034 partial_output: Vec<u8>,
1036 },
1037}
1038impl From<DecompressionError> for BoundedDecompressionError {
1039 fn from(inner: DecompressionError) -> Self {
1040 BoundedDecompressionError::DecompressionError { inner }
1041 }
1042}
1043
1044pub fn decompress_to_vec_bounded(
1047 input: &[u8],
1048 maxlen: usize,
1049) -> Result<Vec<u8>, BoundedDecompressionError> {
1050 let mut decoder = Decompressor::new();
1051 let mut output = vec![0; 1024.min(maxlen)];
1052 let mut input_index = 0;
1053 let mut output_index = 0;
1054 loop {
1055 let (consumed, produced) =
1056 decoder.read(&input[input_index..], &mut output, output_index, true)?;
1057 input_index += consumed;
1058 output_index += produced;
1059 if decoder.is_done() || output_index == maxlen {
1060 break;
1061 }
1062 output.resize((output_index + 32 * 1024).min(maxlen), 0);
1063 }
1064 output.resize(output_index, 0);
1065
1066 if decoder.is_done() {
1067 Ok(output)
1068 } else {
1069 Err(BoundedDecompressionError::OutputTooLarge {
1070 partial_output: output,
1071 })
1072 }
1073}
1074
1075#[cfg(test)]
1076mod tests {
1077 use crate::tables::{LENGTH_TO_LEN_EXTRA, LENGTH_TO_SYMBOL};
1078
1079 use super::*;
1080 use rand::Rng;
1081
1082 fn roundtrip(data: &[u8]) {
1083 let compressed = crate::compress_to_vec(data);
1084 let decompressed = decompress_to_vec(&compressed).unwrap();
1085 assert_eq!(&decompressed, data);
1086 }
1087
1088 fn roundtrip_miniz_oxide(data: &[u8]) {
1089 let compressed = miniz_oxide::deflate::compress_to_vec_zlib(data, 3);
1090 let decompressed = decompress_to_vec(&compressed).unwrap();
1091 assert_eq!(decompressed.len(), data.len());
1092 for (i, (a, b)) in decompressed.chunks(1).zip(data.chunks(1)).enumerate() {
1093 assert_eq!(a, b, "chunk {}..{}", i, i + 1);
1094 }
1095 assert_eq!(&decompressed, data);
1096 }
1097
1098 #[allow(unused)]
1099 fn compare_decompression(data: &[u8]) {
1100 let decompressed = decompress_to_vec(data).unwrap();
1105 let decompressed2 = miniz_oxide::inflate::decompress_to_vec_zlib(data).unwrap();
1106 for i in 0..decompressed.len().min(decompressed2.len()) {
1107 if decompressed[i] != decompressed2[i] {
1108 panic!(
1109 "mismatch at index {} {:?} {:?}",
1110 i,
1111 &decompressed[i.saturating_sub(1)..(i + 16).min(decompressed.len())],
1112 &decompressed2[i.saturating_sub(1)..(i + 16).min(decompressed2.len())]
1113 );
1114 }
1115 }
1116 if decompressed != decompressed2 {
1117 panic!(
1118 "length mismatch {} {} {:x?}",
1119 decompressed.len(),
1120 decompressed2.len(),
1121 &decompressed2[decompressed.len()..][..16]
1122 );
1123 }
1124 }
1126
1127 #[test]
1128 fn tables() {
1129 for (i, &bits) in LEN_SYM_TO_LEN_EXTRA.iter().enumerate() {
1130 let len_base = LEN_SYM_TO_LEN_BASE[i];
1131 for j in 0..(1 << bits) {
1132 if i == 27 && j == 31 {
1133 continue;
1134 }
1135 assert_eq!(LENGTH_TO_LEN_EXTRA[len_base + j - 3], bits, "{} {}", i, j);
1136 assert_eq!(
1137 LENGTH_TO_SYMBOL[len_base + j - 3],
1138 i as u16 + 257,
1139 "{} {}",
1140 i,
1141 j
1142 );
1143 }
1144 }
1145 }
1146
1147 #[test]
1148 fn fdeflate_table() {
1149 let mut compression = CompressedBlock {
1150 litlen_table: [0; 4096],
1151 dist_table: [0; 512],
1152 dist_symbol_lengths: [0; 30],
1153 dist_symbol_masks: [0; 30],
1154 dist_symbol_codes: [0; 30],
1155 secondary_table: Vec::new(),
1156 eof_code: 0,
1157 eof_mask: 0,
1158 eof_bits: 0,
1159 };
1160 let mut lengths = tables::HUFFMAN_LENGTHS.to_vec();
1161 lengths.resize(288, 0);
1162 lengths.push(1);
1163 lengths.resize(320, 0);
1164 Decompressor::build_tables(286, &lengths, &mut compression, 11).unwrap();
1165
1166 assert_eq!(
1167 compression, FDEFLATE_COMPRESSED_BLOCK,
1168 "{:#x?}",
1169 compression
1170 );
1171 }
1172
1173 #[test]
1174 fn it_works() {
1175 roundtrip(b"Hello world!");
1176 }
1177
1178 #[test]
1179 fn constant() {
1180 roundtrip_miniz_oxide(&[0; 50]);
1181 roundtrip_miniz_oxide(&vec![5; 2048]);
1182 roundtrip_miniz_oxide(&vec![128; 2048]);
1183 roundtrip_miniz_oxide(&vec![254; 2048]);
1184 }
1185
1186 #[test]
1187 fn random() {
1188 let mut rng = rand::thread_rng();
1189 let mut data = vec![0; 50000];
1190 for _ in 0..10 {
1191 for byte in &mut data {
1192 *byte = rng.gen::<u8>() % 5;
1193 }
1194 println!("Random data: {:?}", data);
1195 roundtrip_miniz_oxide(&data);
1196 }
1197 }
1198
1199 #[test]
1200 fn ignore_adler32() {
1201 let mut compressed = crate::compress_to_vec(b"Hello world!");
1202 let last_byte = compressed.len() - 1;
1203 compressed[last_byte] = compressed[last_byte].wrapping_add(1);
1204
1205 match decompress_to_vec(&compressed) {
1206 Err(DecompressionError::WrongChecksum) => {}
1207 r => panic!("expected WrongChecksum, got {:?}", r),
1208 }
1209
1210 let mut decompressor = Decompressor::new();
1211 decompressor.ignore_adler32();
1212 let mut decompressed = vec![0; 1024];
1213 let decompressed_len = decompressor
1214 .read(&compressed, &mut decompressed, 0, true)
1215 .unwrap()
1216 .1;
1217 assert_eq!(&decompressed[..decompressed_len], b"Hello world!");
1218 }
1219
1220 #[test]
1221 fn checksum_after_eof() {
1222 let input = b"Hello world!";
1223 let compressed = crate::compress_to_vec(input);
1224
1225 let mut decompressor = Decompressor::new();
1226 let mut decompressed = vec![0; 1024];
1227 let (input_consumed, output_written) = decompressor
1228 .read(
1229 &compressed[..compressed.len() - 1],
1230 &mut decompressed,
1231 0,
1232 false,
1233 )
1234 .unwrap();
1235 assert_eq!(output_written, input.len());
1236 assert_eq!(input_consumed, compressed.len() - 1);
1237
1238 let (input_consumed, output_written) = decompressor
1239 .read(
1240 &compressed[input_consumed..],
1241 &mut decompressed[..output_written],
1242 output_written,
1243 true,
1244 )
1245 .unwrap();
1246 assert!(decompressor.is_done());
1247 assert_eq!(input_consumed, 1);
1248 assert_eq!(output_written, 0);
1249
1250 assert_eq!(&decompressed[..input.len()], input);
1251 }
1252
1253 #[test]
1254 fn zero_length() {
1255 let mut compressed = crate::compress_to_vec(b"").to_vec();
1256
1257 for _ in 0..10 {
1259 println!("compressed len: {}", compressed.len());
1260 compressed.splice(2..2, [0u8, 0, 0, 0xff, 0xff].into_iter());
1261 }
1262
1263 for end_of_input in [true, false] {
1266 let mut decompressor = Decompressor::new();
1267 let (input_consumed, output_written) = decompressor
1268 .read(&compressed, &mut [], 0, end_of_input)
1269 .unwrap();
1270
1271 assert!(decompressor.is_done());
1272 assert_eq!(input_consumed, compressed.len());
1273 assert_eq!(output_written, 0);
1274 }
1275 }
1276
1277 mod test_utils;
1278 use test_utils::{decompress_by_chunks, TestDecompressionError};
1279
1280 fn verify_no_sensitivity_to_input_chunking(
1281 input: &[u8],
1282 ) -> Result<Vec<u8>, TestDecompressionError> {
1283 let r_whole = decompress_by_chunks(input, vec![input.len()], false);
1284 let r_bytewise = decompress_by_chunks(input, std::iter::repeat(1), false);
1285 assert_eq!(r_whole, r_bytewise);
1286 r_whole }
1288
1289 #[test]
1294 fn test_input_chunking_sensitivity_when_handling_distance_codes() {
1295 let result = verify_no_sensitivity_to_input_chunking(include_bytes!(
1296 "../tests/input-chunking-sensitivity-example1.zz"
1297 ))
1298 .unwrap();
1299 assert_eq!(result.len(), 281);
1300 assert_eq!(simd_adler32::adler32(&result.as_slice()), 751299);
1301 }
1302
1303 #[test]
1308 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example1() {
1309 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1310 "../tests/input-chunking-sensitivity-example2.zz"
1311 ))
1312 .unwrap_err();
1313 assert_eq!(
1314 err,
1315 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1316 );
1317 }
1318
1319 #[test]
1324 fn test_input_chunking_sensitivity_when_no_end_of_block_symbol_example2() {
1325 let err = verify_no_sensitivity_to_input_chunking(include_bytes!(
1326 "../tests/input-chunking-sensitivity-example3.zz"
1327 ))
1328 .unwrap_err();
1329 assert_eq!(
1330 err,
1331 TestDecompressionError::ProdError(DecompressionError::BadLiteralLengthHuffmanTree)
1332 );
1333 }
1334}