simd_adler32/imp/
avx512.rs

1use super::Adler32Imp;
2
3/// Resolves update implementation if CPU supports avx512f and avx512bw instructions.
4pub fn get_imp() -> Option<Adler32Imp> {
5  get_imp_inner()
6}
7
8#[inline]
9#[cfg(all(
10  feature = "std",
11  feature = "nightly",
12  any(target_arch = "x86", target_arch = "x86_64")
13))]
14fn get_imp_inner() -> Option<Adler32Imp> {
15  let has_avx512f = std::is_x86_feature_detected!("avx512f");
16  let has_avx512bw = std::is_x86_feature_detected!("avx512bw");
17
18  if has_avx512f && has_avx512bw {
19    Some(imp::update)
20  } else {
21    None
22  }
23}
24
25#[inline]
26#[cfg(all(
27  feature = "nightly",
28  all(target_feature = "avx512f", target_feature = "avx512bw"),
29  not(all(feature = "std", any(target_arch = "x86", target_arch = "x86_64")))
30))]
31fn get_imp_inner() -> Option<Adler32Imp> {
32  Some(imp::update)
33}
34
35#[inline]
36#[cfg(all(
37  not(all(feature = "nightly", target_feature = "avx512f", target_feature = "avx512bw")),
38  not(all(
39    feature = "std",
40    feature = "nightly",
41    any(target_arch = "x86", target_arch = "x86_64")
42  ))
43))]
44fn get_imp_inner() -> Option<Adler32Imp> {
45  None
46}
47
48#[cfg(all(
49  feature = "nightly",
50  any(target_arch = "x86", target_arch = "x86_64"),
51  any(
52    feature = "std",
53    all(target_feature = "avx512f", target_feature = "avx512bw")
54  )
55))]
56mod imp {
57  const MOD: u32 = 65521;
58  const NMAX: usize = 5552;
59  const BLOCK_SIZE: usize = 64;
60  const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
61
62  #[cfg(target_arch = "x86")]
63  use core::arch::x86::*;
64  #[cfg(target_arch = "x86_64")]
65  use core::arch::x86_64::*;
66
67  pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
68    unsafe { update_imp(a, b, data) }
69  }
70
71  #[inline]
72  #[target_feature(enable = "avx512f")]
73  #[target_feature(enable = "avx512bw")]
74  unsafe fn update_imp(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
75    let mut a = a as u32;
76    let mut b = b as u32;
77
78    let chunks = data.chunks_exact(CHUNK_SIZE);
79    let remainder = chunks.remainder();
80    for chunk in chunks {
81      update_chunk_block(&mut a, &mut b, chunk);
82    }
83
84    update_block(&mut a, &mut b, remainder);
85
86    (a as u16, b as u16)
87  }
88
89  #[inline]
90  unsafe fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
91    debug_assert_eq!(
92      chunk.len(),
93      CHUNK_SIZE,
94      "Unexpected chunk size (expected {}, got {})",
95      CHUNK_SIZE,
96      chunk.len()
97    );
98
99    reduce_add_blocks(a, b, chunk);
100
101    *a %= MOD;
102    *b %= MOD;
103  }
104
105  #[inline]
106  unsafe fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
107    debug_assert!(
108      chunk.len() <= CHUNK_SIZE,
109      "Unexpected chunk size (expected <= {}, got {})",
110      CHUNK_SIZE,
111      chunk.len()
112    );
113
114    for byte in reduce_add_blocks(a, b, chunk) {
115      *a += *byte as u32;
116      *b += *a;
117    }
118
119    *a %= MOD;
120    *b %= MOD;
121  }
122
123  #[inline(always)]
124  unsafe fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
125    if chunk.len() < BLOCK_SIZE {
126      return chunk;
127    }
128
129    let blocks = chunk.chunks_exact(BLOCK_SIZE);
130    let blocks_remainder = blocks.remainder();
131
132    let one_v = _mm512_set1_epi16(1);
133    let zero_v = _mm512_setzero_si512();
134    let weights = get_weights();
135
136    let p_v = (*a * blocks.len() as u32) as _;
137    let mut p_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, p_v);
138    let mut a_v = _mm512_setzero_si512();
139    let mut b_v = _mm512_set_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, *b as _);
140
141    for block in blocks {
142      let block_ptr = block.as_ptr() as *const _;
143      let block = _mm512_loadu_si512(block_ptr);
144
145      p_v = _mm512_add_epi32(p_v, a_v);
146
147      a_v = _mm512_add_epi32(a_v, _mm512_sad_epu8(block, zero_v));
148      let mad = _mm512_maddubs_epi16(block, weights);
149      b_v = _mm512_add_epi32(b_v, _mm512_madd_epi16(mad, one_v));
150    }
151
152    b_v = _mm512_add_epi32(b_v, _mm512_slli_epi32(p_v, 6));
153
154    *a += reduce_add(a_v);
155    *b = reduce_add(b_v);
156
157    blocks_remainder
158  }
159
160  #[inline(always)]
161  unsafe fn reduce_add(v: __m512i) -> u32 {
162    let v: [__m256i; 2] = core::mem::transmute(v);
163
164    reduce_add_256(v[0]) + reduce_add_256(v[1])
165  }
166
167  #[inline(always)]
168  unsafe fn reduce_add_256(v: __m256i) -> u32 {
169    let v: [__m128i; 2] = core::mem::transmute(v);
170    let sum = _mm_add_epi32(v[0], v[1]);
171    let hi = _mm_unpackhi_epi64(sum, sum);
172
173    let sum = _mm_add_epi32(hi, sum);
174    let hi = _mm_shuffle_epi32(sum, crate::imp::_MM_SHUFFLE(2, 3, 0, 1));
175
176    let sum = _mm_add_epi32(sum, hi);
177    let sum = _mm_cvtsi128_si32(sum) as _;
178
179    sum
180  }
181
182  #[inline(always)]
183  unsafe fn get_weights() -> __m512i {
184    _mm512_set_epi8(
185      1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
186      24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
187      45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
188    )
189  }
190}
191
192#[cfg(test)]
193mod tests {
194  use rand::{Rng, SeedableRng, rngs::SmallRng};
195
196  #[test]
197  fn zeroes() {
198    assert_sum_eq(&[]);
199    assert_sum_eq(&[0]);
200    assert_sum_eq(&[0, 0]);
201    assert_sum_eq(&[0; 100]);
202    assert_sum_eq(&[0; 1024]);
203    assert_sum_eq(&[0; 1024 - 5]);
204    #[cfg(not(miri))]
205    assert_sum_eq(&[0; 1024 * 1024]);
206  }
207
208  #[test]
209  fn ones() {
210    assert_sum_eq(&[]);
211    assert_sum_eq(&[1]);
212    assert_sum_eq(&[1, 1]);
213    assert_sum_eq(&[1; 100]);
214    assert_sum_eq(&[1; 1024]);
215    assert_sum_eq(&[1; 1024 - 5]); // non-power-of-2 to test remainder handling
216    #[cfg(not(miri))]
217    assert_sum_eq(&[1; 1024 * 1024]);
218  }
219
220  #[test]
221  fn random() {
222    if super::get_imp().is_none() { return; } // don't do any work if we're not on this target
223    let mut random = [0; 1024 * 10];
224    SmallRng::from_entropy().fill(&mut random[..]);
225
226    assert_sum_eq(&random[..1]);
227    assert_sum_eq(&random[..100]);
228    assert_sum_eq(&random[..1024]);
229    assert_sum_eq(&random[..1024 - 5]); // non-power-of-2 to test remainder handling
230    assert_sum_eq(&random[..1024 * 10]);
231  }
232
233  /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
234  #[test]
235  fn wiki() {
236    assert_sum_eq(b"Wikipedia");
237  }
238
239  fn assert_sum_eq(data: &[u8]) {
240    if let Some(update) = super::get_imp() {
241      let (a, b) = update(1, 0, data);
242      let left = u32::from(b) << 16 | u32::from(a);
243      let right = adler::adler32_slice(data);
244
245      assert_eq!(left, right, "len({})", data.len());
246    }
247  }
248}