1use super::Adler32Imp;
2
3pub fn get_imp() -> Option<Adler32Imp> {
5 get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(
10 feature = "std",
11 feature = "nightly",
12 any(target_arch = "x86", target_arch = "x86_64")
13))]
14fn get_imp_inner() -> Option<Adler32Imp> {
15 let has_avx512f = std::is_x86_feature_detected!("avx512f");
16 let has_avx512bw = std::is_x86_feature_detected!("avx512bw");
17
18 if has_avx512f && has_avx512bw {
19 Some(imp::update)
20 } else {
21 None
22 }
23}
24
25#[inline]
26#[cfg(all(
27 feature = "nightly",
28 all(target_feature = "avx512f", target_feature = "avx512bw"),
29 not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
30))]
31fn get_imp_inner() -> Option<Adler32Imp> {
32 Some(imp::update)
33}
34
35#[inline]
36#[cfg(all(
37 not(all(feature = "nightly", target_feature = "avx512f", target_feature = "avx512bw")),
38 not(all(
39 feature = "std",
40 feature = "nightly",
41 any(target_arch = "x86", target_arch = "x86_64")
42 ))
43))]
44fn get_imp_inner() -> Option<Adler32Imp> {
45 None
46}
47
48#[cfg(all(
49 feature = "nightly",
50 any(target_arch = "x86", target_arch = "x86_64"),
51 any(
52 feature = "std",
53 all(target_feature = "avx512f", target_feature = "avx512bw")
54 )
55))]
56mod imp {
57 const MOD: u32 = 65521;
58 const NMAX: usize = 5552;
59 const BLOCK_SIZE: usize = 64;
60 const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
61
62 #[cfg(target_arch = "x86")]
63 use core::arch::x86::*;
64 #[cfg(target_arch = "x86_64")]
65 use core::arch::x86_64::*;
66
67 pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
68 unsafe { update_imp(a, b, data) }
69 }
70
71 #[inline]
72 #[target_feature(enable = "avx512f")]
73 #[target_feature(enable = "avx512bw")]
74 unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
75 let mut a = a as u32;
76 let mut b = b as u32;
77
78 let chunks = data.chunks_exact(CHUNK_SIZE);
79 let remainder = chunks.remainder();
80 for chunk in chunks {
81 update_chunk_block(&mut a, &mut b, chunk);
82 }
83
84 update_block(&mut a, &mut b, remainder);
85
86 (a as u16, b as u16)
87 }
88
89 #[inline]
90 unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
91 debug_assert_eq!(
92 chunk.len(),
93 CHUNK_SIZE,
94 "Unexpected chunk size (expected {}, got {})",
95 CHUNK_SIZE,
96 chunk.len()
97 );
98
99 reduce_add_blocks(a, b, chunk);
100
101 *a %= MOD;
102 *b %= MOD;
103 }
104
105 #[inline]
106 unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
107 debug_assert!(
108 chunk.len() <= CHUNK_SIZE,
109 "Unexpected chunk size (expected <= {}, got {})",
110 CHUNK_SIZE,
111 chunk.len()
112 );
113
114 for byte in reduce_add_blocks(a, b, chunk) {
115 *a += *byte as u32;
116 *b += *a;
117 }
118
119 *a %= MOD;
120 *b %= MOD;
121 }
122
123 #[inline(always)]
124 unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
125 if chunk.len() < BLOCK_SIZE {
126 return chunk;
127 }
128
129 let blocks = chunk.chunks_exact(BLOCK_SIZE);
130 let blocks_remainder = blocks.remainder();
131
132 let one_v = _mm512_set1_epi16(1);
133 let zero_v = _mm512_setzero_si512();
134 let weights = get_weights();
135
136 let p_v = (*a * blocks.len() as u32) as _;
137 let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
138 let mut a_v = _mm512_setzero_si512();
139 let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
140
141 for block in blocks {
142 let block_ptr = block.as_ptr() as *const _;
143 let block = _mm512_loadu_si512(block_ptr);
144
145 p_v = _mm512_add_epi32(p_v, a_v);
146
147 a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
148 let mad = _mm512_maddubs_epi16(block, weights);
149 b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
150 }
151
152 b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
153
154 *a += reduce_add(a_v);
155 *b = reduce_add(b_v);
156
157 blocks_remainder
158 }
159
160 #[inline(always)]
161 unsafe fn reduce_add(v: __m512i) -> u32 {
162 let v: [__m256i; 2] = core::mem::transmute(v);
163
164 reduce_add_256(v[0]) + reduce_add_256(v[1])
165 }
166
167 #[inline(always)]
168 unsafe fn reduce_add_256(v: __m256i) -> u32 {
169 let v: [__m128i; 2] = core::mem::transmute(v);
170 let sum = _mm_add_epi32(v[0], v[1]);
171 let hi = _mm_unpackhi_epi64(sum, sum);
172
173 let sum = _mm_add_epi32(hi, sum);
174 let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
175
176 let sum = _mm_add_epi32(sum, hi);
177 let sum = _mm_cvtsi128_si32(sum) as _;
178
179 sum
180 }
181
182 #[inline(always)]
183 unsafe fn get_weights() -> __m512i {
184 _mm512_set_epi8(
185 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
186 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
187 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
188 )
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use rand::{Rng, SeedableRng, rngs::SmallRng};
195
196 #[test]
197 fn zeroes() {
198 assert_sum_eq(&[]);
199 assert_sum_eq(&[0]);
200 assert_sum_eq(&[0, 0]);
201 assert_sum_eq(&[0; 100]);
202 assert_sum_eq(&[0; 1024]);
203 assert_sum_eq(&[0; 1024 - 5]);
204 #[cfg(not(miri))]
205 assert_sum_eq(&[0; 1024 * 1024]);
206 }
207
208 #[test]
209 fn ones() {
210 assert_sum_eq(&[]);
211 assert_sum_eq(&[1]);
212 assert_sum_eq(&[1, 1]);
213 assert_sum_eq(&[1; 100]);
214 assert_sum_eq(&[1; 1024]);
215 assert_sum_eq(&[1; 1024 - 5]); #[cfg(not(miri))]
217 assert_sum_eq(&[1; 1024 * 1024]);
218 }
219
220 #[test]
221 fn random() {
222 if super::get_imp().is_none() { return; } let mut random = [0; 1024 * 10];
224 SmallRng::from_entropy().fill(&mut random[..]);
225
226 assert_sum_eq(&random[..1]);
227 assert_sum_eq(&random[..100]);
228 assert_sum_eq(&random[..1024]);
229 assert_sum_eq(&random[..1024 - 5]); assert_sum_eq(&random[..1024 * 10]);
231 }
232
233 #[test]
235 fn wiki() {
236 assert_sum_eq(b"Wikipedia");
237 }
238
239 fn assert_sum_eq(data: &[u8]) {
240 if let Some(update) = super::get_imp() {
241 let (a, b) = update(1, 0, data);
242 let left = u32::from(b) << 16 | u32::from(a);
243 let right = adler::adler32_slice(data);
244
245 assert_eq!(left, right, "len({})", data.len());
246 }
247 }
248}