1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::ops::{Index, Shr};
4
5use rustc_index::Idx;
6use rustc_span::{DUMMY_SP, Span, SpanData};
7use smallvec::SmallVec;
8
9use super::data_race::NaReadType;
10use crate::helpers::ToUsize;
11
12#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)]
16pub(super) struct VectorIdx(u32);
17
18impl VectorIdx {
19    #[inline(always)]
20    fn to_u32(self) -> u32 {
21        self.0
22    }
23}
24
25impl Idx for VectorIdx {
26    #[inline]
27    fn new(idx: usize) -> Self {
28        VectorIdx(u32::try_from(idx).unwrap())
29    }
30
31    #[inline]
32    fn index(self) -> usize {
33        usize::try_from(self.0).unwrap()
34    }
35}
36
37impl From<u32> for VectorIdx {
38    #[inline]
39    fn from(id: u32) -> Self {
40        Self(id)
41    }
42}
43
44const SMALL_VECTOR: usize = 4;
47
48#[derive(Clone, Copy, Debug)]
52pub(super) struct VTimestamp {
53    time_and_read_type: u32,
56    pub span: Span,
57}
58
59impl VTimestamp {
60    pub const ZERO: VTimestamp = VTimestamp::new(0, NaReadType::Read, DUMMY_SP);
61
62    #[inline]
63    const fn encode_time_and_read_type(time: u32, read_type: NaReadType) -> u32 {
64        let read_type_bit = match read_type {
65            NaReadType::Read => 0,
66            NaReadType::Retag => 1,
67        };
68        read_type_bit | time.checked_mul(2).expect("Vector clock overflow")
70    }
71
72    #[inline]
73    const fn new(time: u32, read_type: NaReadType, span: Span) -> Self {
74        Self { time_and_read_type: Self::encode_time_and_read_type(time, read_type), span }
75    }
76
77    #[inline]
78    fn time(&self) -> u32 {
79        self.time_and_read_type.shr(1)
80    }
81
82    #[inline]
83    fn set_time(&mut self, time: u32) {
84        self.time_and_read_type = Self::encode_time_and_read_type(time, self.read_type());
85    }
86
87    #[inline]
88    pub(super) fn read_type(&self) -> NaReadType {
89        if self.time_and_read_type & 1 == 0 { NaReadType::Read } else { NaReadType::Retag }
90    }
91
92    #[inline]
93    pub(super) fn set_read_type(&mut self, read_type: NaReadType) {
94        self.time_and_read_type = Self::encode_time_and_read_type(self.time(), read_type);
95    }
96
97    #[inline]
98    pub(super) fn span_data(&self) -> SpanData {
99        self.span.data()
100    }
101}
102
103impl PartialEq for VTimestamp {
104    fn eq(&self, other: &Self) -> bool {
105        self.time() == other.time()
106    }
107}
108
109impl Eq for VTimestamp {}
110
111impl PartialOrd for VTimestamp {
112    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
113        Some(self.cmp(other))
114    }
115}
116
117impl Ord for VTimestamp {
118    fn cmp(&self, other: &Self) -> Ordering {
119        self.time().cmp(&other.time())
120    }
121}
122
123#[derive(PartialEq, Eq, Default, Debug)]
137pub struct VClock(SmallVec<[VTimestamp; SMALL_VECTOR]>);
138
139impl VClock {
140    pub(super) fn new_with_index(index: VectorIdx, timestamp: VTimestamp) -> VClock {
143        if timestamp.time() == 0 {
144            return VClock::default();
145        }
146        let len = index.index() + 1;
147        let mut vec = smallvec::smallvec![VTimestamp::ZERO; len];
148        vec[index.index()] = timestamp;
149        VClock(vec)
150    }
151
152    #[inline]
154    pub(super) fn as_slice(&self) -> &[VTimestamp] {
155        debug_assert!(self.0.last().is_none_or(|t| t.time() != 0));
156        self.0.as_slice()
157    }
158
159    #[inline]
160    pub(super) fn index_mut(&mut self, index: VectorIdx) -> &mut VTimestamp {
161        self.0.as_mut_slice().get_mut(index.to_u32().to_usize()).unwrap()
162    }
163
164    #[inline]
168    fn get_mut_with_min_len(&mut self, min_len: usize) -> &mut [VTimestamp] {
169        if self.0.len() < min_len {
170            self.0.resize(min_len, VTimestamp::ZERO);
171        }
172        assert!(self.0.len() >= min_len);
173        self.0.as_mut_slice()
174    }
175
176    #[inline]
179    pub(super) fn increment_index(&mut self, idx: VectorIdx, current_span: Span) {
180        let idx = idx.index();
181        let mut_slice = self.get_mut_with_min_len(idx + 1);
182        let idx_ref = &mut mut_slice[idx];
183        idx_ref.set_time(idx_ref.time().checked_add(1).expect("Vector clock overflow"));
184        if !current_span.is_dummy() {
185            idx_ref.span = current_span;
186        }
187    }
188
189    pub fn join(&mut self, other: &Self) {
193        let rhs_slice = other.as_slice();
194        let lhs_slice = self.get_mut_with_min_len(rhs_slice.len());
195        for (l, &r) in lhs_slice.iter_mut().zip(rhs_slice.iter()) {
196            let l_span = l.span;
197            let r_span = r.span;
198            *l = r.max(*l);
199            l.span = l.span.substitute_dummy(r_span).substitute_dummy(l_span);
200        }
201    }
202
203    pub(super) fn set_at_index(&mut self, other: &Self, idx: VectorIdx) {
205        let new_timestamp = other[idx];
206        if new_timestamp.time() == 0 {
208            if idx.index() >= self.0.len() {
209                return;
211            }
212            }
215
216        let mut_slice = self.get_mut_with_min_len(idx.index() + 1);
217        let mut_timestamp = &mut mut_slice[idx.index()];
218
219        let prev_span = mut_timestamp.span;
220
221        assert!(*mut_timestamp <= new_timestamp, "set_at_index: may only increase the timestamp");
222        *mut_timestamp = new_timestamp;
223
224        let span = &mut mut_timestamp.span;
225        *span = span.substitute_dummy(prev_span);
226    }
227
228    #[inline]
230    pub(super) fn set_zero_vector(&mut self) {
231        self.0.clear();
232    }
233}
234
235impl Clone for VClock {
236    fn clone(&self) -> Self {
237        VClock(self.0.clone())
238    }
239
240    fn clone_from(&mut self, source: &Self) {
245        let source_slice = source.as_slice();
246        self.0.clear();
247        self.0.extend_from_slice(source_slice);
248    }
249}
250
251impl PartialOrd for VClock {
252    fn partial_cmp(&self, other: &VClock) -> Option<Ordering> {
253        let lhs_slice = self.as_slice();
255        let rhs_slice = other.as_slice();
256
257        let mut iter = lhs_slice.iter().zip(rhs_slice.iter());
266        let mut order = match iter.next() {
267            Some((lhs, rhs)) => lhs.cmp(rhs),
268            None => Ordering::Equal,
269        };
270        for (l, r) in iter {
271            match order {
272                Ordering::Equal => order = l.cmp(r),
273                Ordering::Less =>
274                    if l > r {
275                        return None;
276                    },
277                Ordering::Greater =>
278                    if l < r {
279                        return None;
280                    },
281            }
282        }
283
284        let l_len = lhs_slice.len();
289        let r_len = rhs_slice.len();
290        match l_len.cmp(&r_len) {
291            Ordering::Equal => Some(order),
293            Ordering::Less =>
296                match order {
297                    Ordering::Less | Ordering::Equal => Some(Ordering::Less),
298                    Ordering::Greater => None,
299                },
300            Ordering::Greater =>
303                match order {
304                    Ordering::Greater | Ordering::Equal => Some(Ordering::Greater),
305                    Ordering::Less => None,
306                },
307        }
308    }
309
310    fn lt(&self, other: &VClock) -> bool {
311        let lhs_slice = self.as_slice();
313        let rhs_slice = other.as_slice();
314
315        let l_len = lhs_slice.len();
320        let r_len = rhs_slice.len();
321        if l_len <= r_len {
322            let mut equal = l_len == r_len;
329            for (&l, &r) in lhs_slice.iter().zip(rhs_slice.iter()) {
330                if l > r {
331                    return false;
332                } else if l < r {
333                    equal = false;
334                }
335            }
336            !equal
337        } else {
338            false
339        }
340    }
341
342    fn le(&self, other: &VClock) -> bool {
343        let lhs_slice = self.as_slice();
345        let rhs_slice = other.as_slice();
346
347        let l_len = lhs_slice.len();
352        let r_len = rhs_slice.len();
353        if l_len <= r_len {
354            !lhs_slice.iter().zip(rhs_slice.iter()).any(|(&l, &r)| l > r)
359        } else {
360            false
361        }
362    }
363
364    fn gt(&self, other: &VClock) -> bool {
365        let lhs_slice = self.as_slice();
367        let rhs_slice = other.as_slice();
368
369        let l_len = lhs_slice.len();
374        let r_len = rhs_slice.len();
375        if l_len >= r_len {
376            let mut equal = l_len == r_len;
383            for (&l, &r) in lhs_slice.iter().zip(rhs_slice.iter()) {
384                if l < r {
385                    return false;
386                } else if l > r {
387                    equal = false;
388                }
389            }
390            !equal
391        } else {
392            false
393        }
394    }
395
396    fn ge(&self, other: &VClock) -> bool {
397        let lhs_slice = self.as_slice();
399        let rhs_slice = other.as_slice();
400
401        let l_len = lhs_slice.len();
406        let r_len = rhs_slice.len();
407        if l_len >= r_len {
408            !lhs_slice.iter().zip(rhs_slice.iter()).any(|(&l, &r)| l < r)
413        } else {
414            false
415        }
416    }
417}
418
419impl Index<VectorIdx> for VClock {
420    type Output = VTimestamp;
421
422    #[inline]
423    fn index(&self, index: VectorIdx) -> &VTimestamp {
424        self.as_slice().get(index.to_u32().to_usize()).unwrap_or(&VTimestamp::ZERO)
425    }
426}
427
428#[cfg(test)]
432mod tests {
433    use std::cmp::Ordering;
434
435    use rustc_span::DUMMY_SP;
436
437    use super::{VClock, VTimestamp, VectorIdx};
438    use crate::concurrency::data_race::NaReadType;
439
440    #[test]
441    fn test_equal() {
442        let mut c1 = VClock::default();
443        let mut c2 = VClock::default();
444        assert_eq!(c1, c2);
445        c1.increment_index(VectorIdx(5), DUMMY_SP);
446        assert_ne!(c1, c2);
447        c2.increment_index(VectorIdx(53), DUMMY_SP);
448        assert_ne!(c1, c2);
449        c1.increment_index(VectorIdx(53), DUMMY_SP);
450        assert_ne!(c1, c2);
451        c2.increment_index(VectorIdx(5), DUMMY_SP);
452        assert_eq!(c1, c2);
453    }
454
455    #[test]
456    fn test_partial_order() {
457        assert_order(&[1], &[1], Some(Ordering::Equal));
459        assert_order(&[1], &[2], Some(Ordering::Less));
460        assert_order(&[2], &[1], Some(Ordering::Greater));
461        assert_order(&[1], &[1, 2], Some(Ordering::Less));
462        assert_order(&[2], &[1, 2], None);
463
464        assert_order(&[400], &[0, 1], None);
466
467        assert_order(
469            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
470            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0],
471            Some(Ordering::Equal),
472        );
473        assert_order(
474            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
475            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 0],
476            Some(Ordering::Less),
477        );
478        assert_order(
479            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11],
480            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0],
481            Some(Ordering::Greater),
482        );
483        assert_order(
484            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11],
485            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 0],
486            None,
487        );
488        assert_order(
489            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9],
490            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0, 0],
491            Some(Ordering::Less),
492        );
493        assert_order(
494            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9],
495            &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 0],
496            Some(Ordering::Less),
497        );
498    }
499
500    fn from_slice(mut slice: &[u32]) -> VClock {
501        while let Some(0) = slice.last() {
502            slice = &slice[..slice.len() - 1]
503        }
504        VClock(
505            slice
506                .iter()
507                .copied()
508                .map(|time| VTimestamp::new(time, NaReadType::Read, DUMMY_SP))
509                .collect(),
510        )
511    }
512
513    fn assert_order(l: &[u32], r: &[u32], o: Option<Ordering>) {
514        let l = from_slice(l);
515        let r = from_slice(r);
516
517        let compare = l.partial_cmp(&r);
519        assert_eq!(compare, o, "Invalid comparison\n l: {l:?}\n r: {r:?}");
520        let alt_compare = r.partial_cmp(&l);
521        assert_eq!(
522            alt_compare,
523            o.map(Ordering::reverse),
524            "Invalid alt comparison\n l: {l:?}\n r: {r:?}"
525        );
526
527        assert_eq!(
529            matches!(compare, Some(Ordering::Less)),
530            l < r,
531            "Invalid (<):\n l: {l:?}\n r: {r:?}"
532        );
533        assert_eq!(
534            matches!(compare, Some(Ordering::Less) | Some(Ordering::Equal)),
535            l <= r,
536            "Invalid (<=):\n l: {l:?}\n r: {r:?}"
537        );
538        assert_eq!(
539            matches!(compare, Some(Ordering::Greater)),
540            l > r,
541            "Invalid (>):\n l: {l:?}\n r: {r:?}"
542        );
543        assert_eq!(
544            matches!(compare, Some(Ordering::Greater) | Some(Ordering::Equal)),
545            l >= r,
546            "Invalid (>=):\n l: {l:?}\n r: {r:?}"
547        );
548        assert_eq!(
549            matches!(alt_compare, Some(Ordering::Less)),
550            r < l,
551            "Invalid alt (<):\n l: {l:?}\n r: {r:?}"
552        );
553        assert_eq!(
554            matches!(alt_compare, Some(Ordering::Less) | Some(Ordering::Equal)),
555            r <= l,
556            "Invalid alt (<=):\n l: {l:?}\n r: {r:?}"
557        );
558        assert_eq!(
559            matches!(alt_compare, Some(Ordering::Greater)),
560            r > l,
561            "Invalid alt (>):\n l: {l:?}\n r: {r:?}"
562        );
563        assert_eq!(
564            matches!(alt_compare, Some(Ordering::Greater) | Some(Ordering::Equal)),
565            r >= l,
566            "Invalid alt (>=):\n l: {l:?}\n r: {r:?}"
567        );
568    }
569
570    #[test]
571    fn set_index_to_0() {
572        let mut clock1 = from_slice(&[0, 1, 2, 3]);
573        let clock2 = from_slice(&[0, 2, 3, 4, 0, 5]);
574        clock1.set_at_index(&clock2, VectorIdx(4));
577        assert!(clock1.0.last().unwrap().time() != 0);
579    }
580}