1use super::Adler32Imp;
2
3pub fn get_imp() -> Option<Adler32Imp> {
5 get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))]
10fn get_imp_inner() -> Option<Adler32Imp> {
11 if std::is_x86_feature_detected!("avx2") {
12 Some(imp::update)
13 } else {
14 None
15 }
16}
17
18#[inline]
19#[cfg(all(
20 target_feature = "avx2",
21 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
22))]
23fn get_imp_inner() -> Option<Adler32Imp> {
24 Some(imp::update)
25}
26
27#[inline]
28#[cfg(all(
29 not(target_feature = "avx2"),
30 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
31))]
32fn get_imp_inner() -> Option<Adler32Imp> {
33 None
34}
35
36#[cfg(all(
37 any(target_arch = "x86", target_arch = "x86_64"),
38 any(feature = "std", target_feature = "avx2")
39))]
40mod imp {
41 const MOD: u32 = 65521;
42 const NMAX: usize = 5552;
43 const BLOCK_SIZE: usize = 32;
44 const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
45
46 #[cfg(target_arch = "x86")]
47 use core::arch::x86::*;
48 #[cfg(target_arch = "x86_64")]
49 use core::arch::x86_64::*;
50
51 pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
52 unsafe { update_imp(a, b, data) }
53 }
54
55 #[inline]
56 #[target_feature(enable = "avx2")]
57 unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
58 let mut a = a as u32;
59 let mut b = b as u32;
60
61 let chunks = data.chunks_exact(CHUNK_SIZE);
62 let remainder = chunks.remainder();
63 for chunk in chunks {
64 update_chunk_block(&mut a, &mut b, chunk);
65 }
66
67 update_block(&mut a, &mut b, remainder);
68
69 (a as u16, b as u16)
70 }
71
72 #[inline]
73 unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
74 debug_assert_eq!(
75 chunk.len(),
76 CHUNK_SIZE,
77 "Unexpected chunk size (expected {}, got {})",
78 CHUNK_SIZE,
79 chunk.len()
80 );
81
82 reduce_add_blocks(a, b, chunk);
83
84 *a %= MOD;
85 *b %= MOD;
86 }
87
88 #[inline]
89 unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
90 debug_assert!(
91 chunk.len() <= CHUNK_SIZE,
92 "Unexpected chunk size (expected <= {}, got {})",
93 CHUNK_SIZE,
94 chunk.len()
95 );
96
97 for byte in reduce_add_blocks(a, b, chunk) {
98 *a += *byte as u32;
99 *b += *a;
100 }
101
102 *a %= MOD;
103 *b %= MOD;
104 }
105
106 #[inline(always)]
107 unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
108 if chunk.len() < BLOCK_SIZE {
109 return chunk;
110 }
111
112 let blocks = chunk.chunks_exact(BLOCK_SIZE);
113 let blocks_remainder = blocks.remainder();
114
115 let one_v = _mm256_set1_epi16(1);
116 let zero_v = _mm256_setzero_si256();
117 let weights = get_weights();
118
119 let mut p_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, (*a * blocks.len() as u32) as _);
120 let mut a_v = _mm256_setzero_si256();
121 let mut b_v = _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, *b as _);
122
123 for block in blocks {
124 let block_ptr = block.as_ptr() as *const _;
125 let block = _mm256_loadu_si256(block_ptr);
126
127 p_v = _mm256_add_epi32(p_v, a_v);
128
129 a_v = _mm256_add_epi32(a_v, _mm256_sad_epu8(block, zero_v));
130 let mad = _mm256_maddubs_epi16(block, weights);
131 b_v = _mm256_add_epi32(b_v, _mm256_madd_epi16(mad, one_v));
132 }
133
134 b_v = _mm256_add_epi32(b_v, _mm256_slli_epi32(p_v, 5));
135
136 *a += reduce_add(a_v);
137 *b = reduce_add(b_v);
138
139 blocks_remainder
140 }
141
142 #[inline(always)]
143 unsafe fn reduce_add(v: __m256i) -> u32 {
144 let sum = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
145 let hi = _mm_unpackhi_epi64(sum, sum);
146
147 let sum = _mm_add_epi32(hi, sum);
148 let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
149
150 let sum = _mm_add_epi32(sum, hi);
151
152 _mm_cvtsi128_si32(sum) as _
153 }
154
155 #[inline(always)]
156 unsafe fn get_weights() -> __m256i {
157 _mm256_set_epi8(
158 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
159 24, 25, 26, 27, 28, 29, 30, 31, 32,
160 )
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use rand::{Rng, SeedableRng, rngs::SmallRng};
167
168 #[test]
169 fn zeroes() {
170 assert_sum_eq(&[]);
171 assert_sum_eq(&[0]);
172 assert_sum_eq(&[0, 0]);
173 assert_sum_eq(&[0; 100]);
174 assert_sum_eq(&[0; 1024]);
175 assert_sum_eq(&[0; 1024 - 5]);
176 #[cfg(not(miri))]
177 assert_sum_eq(&[0; 1024 * 1024]);
178 }
179
180 #[test]
181 fn ones() {
182 assert_sum_eq(&[]);
183 assert_sum_eq(&[1]);
184 assert_sum_eq(&[1, 1]);
185 assert_sum_eq(&[1; 100]);
186 assert_sum_eq(&[1; 1024]);
187 assert_sum_eq(&[1; 1024 - 5]); #[cfg(not(miri))]
189 assert_sum_eq(&[1; 1024 * 1024]);
190 }
191
192 #[test]
193 fn random() {
194 if super::get_imp().is_none() { return; } let mut random = [0; 1024 * 10];
196 SmallRng::from_entropy().fill(&mut random[..]);
197
198 assert_sum_eq(&random[..1]);
199 assert_sum_eq(&random[..100]);
200 assert_sum_eq(&random[..1024]);
201 assert_sum_eq(&random[..1024 - 5]); assert_sum_eq(&random[..1024 * 10]);
203 }
204
205 #[test]
207 fn wiki() {
208 assert_sum_eq(b"Wikipedia");
209 }
210
211 fn assert_sum_eq(data: &[u8]) {
212 if let Some(update) = super::get_imp() {
213 let (a, b) = update(1, 0, data);
214 let left = u32::from(b) << 16 | u32::from(a);
215 let right = adler::adler32_slice(data);
216
217 assert_eq!(left, right, "len({})", data.len());
218 }
219 }
220}