simd_adler32/imp/
neon.rs

1use super::Adler32Imp;
2
3#[cfg(target_feature = "neon")]
4pub fn get_imp() -> Option<Adler32Imp> {
5  Some(imp::update)
6}
7
8#[cfg(not(target_feature = "neon"))]
9pub fn get_imp() -> Option<Adler32Imp> {
10  None
11}
12
13#[cfg(target_feature = "neon")]
14mod imp {
15  const MOD: u32 = 65521;
16  const NMAX: usize = 5552;
17  const BLOCK_SIZE: usize = 32;
18  const CHUNK_SIZE: usize = NMAX / BLOCK_SIZE * BLOCK_SIZE;
19
20  #[cfg(target_arch = "aarch64")]
21  use core::arch::aarch64::*;
22  #[cfg(target_arch = "arm")]
23  use core::arch::arm::*;
24
25  pub fn update(a: u16, b: u16, data: &[u8]) -> (u16, u16) {
26    let mut a = a as u32;
27    let mut b = b as u32;
28
29    let chunks = data.chunks_exact(CHUNK_SIZE);
30    let remainder = chunks.remainder();
31    for chunk in chunks {
32      update_chunk_block(&mut a, &mut b, chunk);
33    }
34
35    update_block(&mut a, &mut b, remainder);
36
37    (a as u16, b as u16)
38  }
39
40  fn update_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
41    debug_assert!(
42      chunk.len() <= CHUNK_SIZE,
43      "Unexpected chunk size (expected <= {}, got {})",
44      CHUNK_SIZE,
45      chunk.len()
46    );
47
48    for byte in reduce_add_blocks(a, b, chunk) {
49      *a += *byte as u32;
50      *b += *a;
51    }
52
53    *a %= MOD;
54    *b %= MOD;
55  }
56
57  fn update_chunk_block(a: &mut u32, b: &mut u32, chunk: &[u8]) {
58    debug_assert_eq!(
59      chunk.len(),
60      CHUNK_SIZE,
61      "Unexpected chunk size (expected {}, got {})",
62      CHUNK_SIZE,
63      chunk.len()
64    );
65
66    reduce_add_blocks(a, b, chunk);
67  }
68
69  fn reduce_add_blocks<'a>(a: &mut u32, b: &mut u32, chunk: &'a [u8]) -> &'a [u8] {
70    if chunk.len() < BLOCK_SIZE {
71      return chunk;
72    }
73    let blocks = chunk.chunks_exact(BLOCK_SIZE);
74    let blocks_remainder = blocks.remainder();
75
76    // Conversion of the code from Chromium zlib:
77    // https://chromium.googlesource.com/chromium/src/third_party/+/main/zlib/adler32_simd.c
78    unsafe {
79      // a and b accumulators are initially zero.
80      let mut a_v: uint32x4_t = vdupq_n_u32(0);
81      let mut b_v: uint32x4_t = vdupq_n_u32(0);
82      // b_v[3] contains the last term (n) for the B part
83      b_v = vsetq_lane_u32(*a * (blocks.len() as u32), b_v, 3);
84
85      // Computing the unrolled prefix-sum
86      let mut v_column_sum_1: uint16x8_t = vdupq_n_u16(0);
87      let mut v_column_sum_2: uint16x8_t = vdupq_n_u16(0);
88      let mut v_column_sum_3: uint16x8_t = vdupq_n_u16(0);
89      let mut v_column_sum_4: uint16x8_t = vdupq_n_u16(0);
90
91      for block in blocks {
92        let block_ptr = block.as_ptr();
93        // Slurp in 32 bytes
94        let bytes1: uint8x16_t = vld1q_u8(block_ptr);
95        let bytes2: uint8x16_t = vld1q_u8(block_ptr.add(16));
96
97        // Wrapping-add the sums from the previous block together.
98        // b_v[i] += a_v[i]
99        b_v = vaddq_u32(b_v, a_v);
100
101        // Unsigned add, accumulate long pairwise.
102        // Adjacent elements in bytes1 are zipped, added, lengthened.
103        a_v = vpadalq_u16(a_v, vpadalq_u8(vpaddlq_u8(bytes1), bytes2));
104
105        // Have to oscillate between low and high elements, since vaddw's first
106        // argument is already q-length.
107        v_column_sum_1 = vaddw_u8(v_column_sum_1, vget_low_u8(bytes1));
108        v_column_sum_2 = vaddw_u8(v_column_sum_2, vget_high_u8(bytes1));
109        v_column_sum_3 = vaddw_u8(v_column_sum_3, vget_low_u8(bytes2));
110        v_column_sum_4 = vaddw_u8(v_column_sum_4, vget_high_u8(bytes2));
111      }
112
113      // No more data/updates to a, so now we shake out all of the accumulated data
114      // Previous block was 32 indices ago, so multiply B to start
115      b_v = vshlq_n_u32(b_v, 5);
116
117      // Then product-sum of each D column.
118      let w1: [u16; 4] = [32, 31, 30, 29];
119      let w2: [u16; 4] = [28, 27, 26, 25];
120      let w3: [u16; 4] = [24, 23, 22, 21];
121      let w4: [u16; 4] = [20, 19, 18, 17];
122      let w5: [u16; 4] = [16, 15, 14, 13];
123      let w6: [u16; 4] = [12, 11, 10, 9];
124      let w7: [u16; 4] = [8, 7, 6, 5];
125      let w8: [u16; 4] = [4, 3, 2, 1];
126      b_v = vmlal_u16(b_v, vget_low_u16(v_column_sum_1), vld1_u16(w1.as_ptr()));
127      b_v = vmlal_u16(b_v, vget_high_u16(v_column_sum_1), vld1_u16(w2.as_ptr()));
128      b_v = vmlal_u16(b_v, vget_low_u16(v_column_sum_2), vld1_u16(w3.as_ptr()));
129      b_v = vmlal_u16(b_v, vget_high_u16(v_column_sum_2), vld1_u16(w4.as_ptr()));
130      b_v = vmlal_u16(b_v, vget_low_u16(v_column_sum_3), vld1_u16(w5.as_ptr()));
131      b_v = vmlal_u16(b_v, vget_high_u16(v_column_sum_3), vld1_u16(w6.as_ptr()));
132      b_v = vmlal_u16(b_v, vget_low_u16(v_column_sum_4), vld1_u16(w7.as_ptr()));
133      b_v = vmlal_u16(b_v, vget_high_u16(v_column_sum_4), vld1_u16(w8.as_ptr()));
134
135      // Pyramid pairwise-add to get the final output.
136      // *a = vaddvq_u32(a_v) would also do the job.
137      let sum1: uint32x2_t = vpadd_u32(vget_low_u32(a_v), vget_high_u32(a_v));
138      let sum2: uint32x2_t = vpadd_u32(vget_low_u32(b_v), vget_high_u32(b_v));
139      let sum3: uint32x2_t = vpadd_u32(sum1, sum2);
140      *a += vget_lane_u32(sum3, 0);
141      *b += vget_lane_u32(sum3, 1);
142
143      *a %= MOD;
144      *b %= MOD;
145
146      blocks_remainder
147    }
148  }
149}
150
151#[cfg(test)]
152mod tests {
153  use rand::{rngs::SmallRng, Rng, SeedableRng};
154
155  #[test]
156  fn zeroes() {
157    assert_sum_eq(&[]);
158    assert_sum_eq(&[0]);
159    assert_sum_eq(&[0, 0]);
160    assert_sum_eq(&[0; 100]);
161    assert_sum_eq(&[0; 1024]);
162    assert_sum_eq(&[0; 1024 * 1024]);
163  }
164
165  #[test]
166  fn ones() {
167    assert_sum_eq(&[]);
168    assert_sum_eq(&[1]);
169    assert_sum_eq(&[1, 1]);
170    assert_sum_eq(&[1; 100]);
171    assert_sum_eq(&[1; 1024]);
172    assert_sum_eq(&[1; 1024 * 1024]);
173  }
174
175  #[test]
176  fn random() {
177    let mut random = [0; 1024 * 1024];
178    SmallRng::from_entropy().fill(&mut random[..]);
179
180    assert_sum_eq(&random[..1]);
181    assert_sum_eq(&random[..100]);
182    assert_sum_eq(&random[..1024]);
183    assert_sum_eq(&random[..1024 * 1024]);
184  }
185
186  /// Example calculation from https://en.wikipedia.org/wiki/Adler-32.
187  #[test]
188  fn wiki() {
189    assert_sum_eq(b"Wikipedia");
190  }
191
192  fn assert_sum_eq(data: &[u8]) {
193    if let Some(update) = super::get_imp() {
194      let (a, b) = update(1, 0, data);
195      let left = u32::from(b) << 16 | u32::from(a);
196      let right = adler::adler32_slice(data);
197
198      assert_eq!(left, right, "len({})", data.len());
199    }
200  }
201}