1use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
12use crate::distributions::Distribution;
13use crate::Rng;
14use core::cmp::PartialOrd;
15use core::fmt;
16
17use alloc::vec::Vec;
19
20#[cfg(feature = "serde1")]
21use serde::{Serialize, Deserialize};
22
23#[derive(Debug, Clone, PartialEq)]
79#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
80#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82    cumulative_weights: Vec<X>,
83    total_weight: X,
84    weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88    pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
97    where
98        I: IntoIterator,
99        I::Item: SampleBorrow<X>,
100        X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
101    {
102        let mut iter = weights.into_iter();
103        let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
104
105        let zero = <X as Default>::default();
106        if !(total_weight >= zero) {
107            return Err(WeightedError::InvalidWeight);
108        }
109
110        let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
111        for w in iter {
112            if !(w.borrow() >= &zero) {
115                return Err(WeightedError::InvalidWeight);
116            }
117            weights.push(total_weight.clone());
118            total_weight += w.borrow();
119        }
120
121        if total_weight == zero {
122            return Err(WeightedError::AllWeightsZero);
123        }
124        let distr = X::Sampler::new(zero, total_weight.clone());
125
126        Ok(WeightedIndex {
127            cumulative_weights: weights,
128            total_weight,
129            weight_distribution: distr,
130        })
131    }
132
133    pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError>
143    where X: for<'a> ::core::ops::AddAssign<&'a X>
144            + for<'a> ::core::ops::SubAssign<&'a X>
145            + Clone
146            + Default {
147        if new_weights.is_empty() {
148            return Ok(());
149        }
150
151        let zero = <X as Default>::default();
152
153        let mut total_weight = self.total_weight.clone();
154
155        let mut prev_i = None;
158        for &(i, w) in new_weights {
159            if let Some(old_i) = prev_i {
160                if old_i >= i {
161                    return Err(WeightedError::InvalidWeight);
162                }
163            }
164            if !(*w >= zero) {
165                return Err(WeightedError::InvalidWeight);
166            }
167            if i > self.cumulative_weights.len() {
168                return Err(WeightedError::TooMany);
169            }
170
171            let mut old_w = if i < self.cumulative_weights.len() {
172                self.cumulative_weights[i].clone()
173            } else {
174                self.total_weight.clone()
175            };
176            if i > 0 {
177                old_w -= &self.cumulative_weights[i - 1];
178            }
179
180            total_weight -= &old_w;
181            total_weight += w;
182            prev_i = Some(i);
183        }
184        if total_weight <= zero {
185            return Err(WeightedError::AllWeightsZero);
186        }
187
188        let mut iter = new_weights.iter();
191
192        let mut prev_weight = zero.clone();
193        let mut next_new_weight = iter.next();
194        let &(first_new_index, _) = next_new_weight.unwrap();
195        let mut cumulative_weight = if first_new_index > 0 {
196            self.cumulative_weights[first_new_index - 1].clone()
197        } else {
198            zero.clone()
199        };
200        for i in first_new_index..self.cumulative_weights.len() {
201            match next_new_weight {
202                Some(&(j, w)) if i == j => {
203                    cumulative_weight += w;
204                    next_new_weight = iter.next();
205                }
206                _ => {
207                    let mut tmp = self.cumulative_weights[i].clone();
208                    tmp -= &prev_weight; cumulative_weight += &tmp;
210                }
211            }
212            prev_weight = cumulative_weight.clone();
213            core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
214        }
215
216        self.total_weight = total_weight;
217        self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone());
218
219        Ok(())
220    }
221}
222
223impl<X> Distribution<usize> for WeightedIndex<X>
224where X: SampleUniform + PartialOrd
225{
226    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
227        use ::core::cmp::Ordering;
228        let chosen_weight = self.weight_distribution.sample(rng);
229        self.cumulative_weights
231            .binary_search_by(|w| {
232                if *w <= chosen_weight {
233                    Ordering::Less
234                } else {
235                    Ordering::Greater
236                }
237            })
238            .unwrap_err()
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245
246    #[cfg(feature = "serde1")]
247    #[test]
248    fn test_weightedindex_serde1() {
249        let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
250
251        let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
252        let de_weighted_index: WeightedIndex<i32> =
253            bincode::deserialize(&ser_weighted_index).unwrap();
254
255        assert_eq!(
256            de_weighted_index.cumulative_weights,
257            weighted_index.cumulative_weights
258        );
259        assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
260    }
261
262    #[test]
263    fn test_accepting_nan(){
264        assert_eq!(
265            WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(),
266            WeightedError::InvalidWeight,
267        );
268        assert_eq!(
269            WeightedIndex::new(&[core::f32::NAN]).unwrap_err(),
270            WeightedError::InvalidWeight,
271        );
272        assert_eq!(
273            WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(),
274            WeightedError::InvalidWeight,
275        );
276
277        assert_eq!(
278            WeightedIndex::new(&[0.5, 7.0])
279                .unwrap()
280                .update_weights(&[(0, &core::f32::NAN)])
281                .unwrap_err(),
282            WeightedError::InvalidWeight,
283        )
284    }
285
286
287    #[test]
288    #[cfg_attr(miri, ignore)] fn test_weightedindex() {
290        let mut r = crate::test::rng(700);
291        const N_REPS: u32 = 5000;
292        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
293        let total_weight = weights.iter().sum::<u32>() as f32;
294
295        let verify = |result: [i32; 14]| {
296            for (i, count) in result.iter().enumerate() {
297                let exp = (weights[i] * N_REPS) as f32 / total_weight;
298                let mut err = (*count as f32 - exp).abs();
299                if err != 0.0 {
300                    err /= exp;
301                }
302                assert!(err <= 0.25);
303            }
304        };
305
306        let mut chosen = [0i32; 14];
308        let distr = WeightedIndex::new(weights.to_vec()).unwrap();
309        for _ in 0..N_REPS {
310            chosen[distr.sample(&mut r)] += 1;
311        }
312        verify(chosen);
313
314        chosen = [0i32; 14];
316        let distr = WeightedIndex::new(&weights[..]).unwrap();
317        for _ in 0..N_REPS {
318            chosen[distr.sample(&mut r)] += 1;
319        }
320        verify(chosen);
321
322        chosen = [0i32; 14];
324        let distr = WeightedIndex::new(weights.iter()).unwrap();
325        for _ in 0..N_REPS {
326            chosen[distr.sample(&mut r)] += 1;
327        }
328        verify(chosen);
329
330        for _ in 0..5 {
331            assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
332            assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
333            assert_eq!(
334                WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
335                    .unwrap()
336                    .sample(&mut r),
337                4
338            );
339        }
340
341        assert_eq!(
342            WeightedIndex::new(&[10][0..0]).unwrap_err(),
343            WeightedError::NoItem
344        );
345        assert_eq!(
346            WeightedIndex::new(&[0]).unwrap_err(),
347            WeightedError::AllWeightsZero
348        );
349        assert_eq!(
350            WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
351            WeightedError::InvalidWeight
352        );
353        assert_eq!(
354            WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
355            WeightedError::InvalidWeight
356        );
357        assert_eq!(
358            WeightedIndex::new(&[-10]).unwrap_err(),
359            WeightedError::InvalidWeight
360        );
361    }
362
363    #[test]
364    fn test_update_weights() {
365        let data = [
366            (
367                &[10u32, 2, 3, 4][..],
368                &[(1, &100), (2, &4)][..], &[10, 100, 4, 4][..],
370            ),
371            (
372                &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
373                &[(2, &1), (5, &1), (13, &100)][..], &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
375            ),
376        ];
377
378        for (weights, update, expected_weights) in data.iter() {
379            let total_weight = weights.iter().sum::<u32>();
380            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
381            assert_eq!(distr.total_weight, total_weight);
382
383            distr.update_weights(update).unwrap();
384            let expected_total_weight = expected_weights.iter().sum::<u32>();
385            let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
386            assert_eq!(distr.total_weight, expected_total_weight);
387            assert_eq!(distr.total_weight, expected_distr.total_weight);
388            assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
389        }
390    }
391
392    #[test]
393    fn value_stability() {
394        fn test_samples<X: SampleUniform + PartialOrd, I>(
395            weights: I, buf: &mut [usize], expected: &[usize],
396        ) where
397            I: IntoIterator,
398            I::Item: SampleBorrow<X>,
399            X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
400        {
401            assert_eq!(buf.len(), expected.len());
402            let distr = WeightedIndex::new(weights).unwrap();
403            let mut rng = crate::test::rng(701);
404            for r in buf.iter_mut() {
405                *r = rng.sample(&distr);
406            }
407            assert_eq!(buf, expected);
408        }
409
410        let mut buf = [0; 10];
411        test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
412            0, 6, 2, 6, 3, 4, 7, 8, 2, 5,
413        ]);
414        test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
415            0, 0, 0, 1, 0, 0, 2, 3, 0, 0,
416        ]);
417        test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
418            2, 2, 1, 3, 2, 1, 3, 3, 2, 1,
419        ]);
420    }
421
422    #[test]
423    fn weighted_index_distributions_can_be_compared() {
424        assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
425    }
426}
427
428#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
430#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum WeightedError {
432    NoItem,
434
435    InvalidWeight,
438
439    AllWeightsZero,
441
442    TooMany,
444}
445
446#[cfg(feature = "std")]
447impl std::error::Error for WeightedError {}
448
449impl fmt::Display for WeightedError {
450    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
451        f.write_str(match *self {
452            WeightedError::NoItem => "No weights provided in distribution",
453            WeightedError::InvalidWeight => "A weight is invalid in distribution",
454            WeightedError::AllWeightsZero => "All weights are zero in distribution",
455            WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution",
456        })
457    }
458}