core/iter/adapters/
take.rs

1use crate::cmp;
2use crate::iter::adapters::SourceIter;
3use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen, TrustedRandomAccess};
4use crate::num::NonZero;
5use crate::ops::{ControlFlow, Try};
6
7/// An iterator that only iterates over the first `n` iterations of `iter`.
8///
9/// This `struct` is created by the [`take`] method on [`Iterator`]. See its
10/// documentation for more.
11///
12/// [`take`]: Iterator::take
13/// [`Iterator`]: trait.Iterator.html
14#[derive(Clone, Debug)]
15#[must_use = "iterators are lazy and do nothing unless consumed"]
16#[stable(feature = "rust1", since = "1.0.0")]
17pub struct Take<I> {
18    iter: I,
19    n: usize,
20}
21
22impl<I> Take<I> {
23    pub(in crate::iter) fn new(iter: I, n: usize) -> Take<I> {
24        Take { iter, n }
25    }
26}
27
28#[stable(feature = "rust1", since = "1.0.0")]
29impl<I> Iterator for Take<I>
30where
31    I: Iterator,
32{
33    type Item = <I as Iterator>::Item;
34
35    #[inline]
36    fn next(&mut self) -> Option<<I as Iterator>::Item> {
37        if self.n != 0 {
38            self.n -= 1;
39            self.iter.next()
40        } else {
41            None
42        }
43    }
44
45    #[inline]
46    fn nth(&mut self, n: usize) -> Option<I::Item> {
47        if self.n > n {
48            self.n -= n + 1;
49            self.iter.nth(n)
50        } else {
51            if self.n > 0 {
52                self.iter.nth(self.n - 1);
53                self.n = 0;
54            }
55            None
56        }
57    }
58
59    #[inline]
60    fn size_hint(&self) -> (usize, Option<usize>) {
61        if self.n == 0 {
62            return (0, Some(0));
63        }
64
65        let (lower, upper) = self.iter.size_hint();
66
67        let lower = cmp::min(lower, self.n);
68
69        let upper = match upper {
70            Some(x) if x < self.n => Some(x),
71            _ => Some(self.n),
72        };
73
74        (lower, upper)
75    }
76
77    #[inline]
78    fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
79    where
80        Fold: FnMut(Acc, Self::Item) -> R,
81        R: Try<Output = Acc>,
82    {
83        fn check<'a, T, Acc, R: Try<Output = Acc>>(
84            n: &'a mut usize,
85            mut fold: impl FnMut(Acc, T) -> R + 'a,
86        ) -> impl FnMut(Acc, T) -> ControlFlow<R, Acc> + 'a {
87            move |acc, x| {
88                *n -= 1;
89                let r = fold(acc, x);
90                if *n == 0 { ControlFlow::Break(r) } else { ControlFlow::from_try(r) }
91            }
92        }
93
94        if self.n == 0 {
95            try { init }
96        } else {
97            let n = &mut self.n;
98            self.iter.try_fold(init, check(n, fold)).into_try()
99        }
100    }
101
102    #[inline]
103    fn fold<B, F>(self, init: B, f: F) -> B
104    where
105        Self: Sized,
106        F: FnMut(B, Self::Item) -> B,
107    {
108        Self::spec_fold(self, init, f)
109    }
110
111    #[inline]
112    fn for_each<F: FnMut(Self::Item)>(self, f: F) {
113        Self::spec_for_each(self, f)
114    }
115
116    #[inline]
117    #[rustc_inherit_overflow_checks]
118    fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
119        let min = self.n.min(n);
120        let rem = match self.iter.advance_by(min) {
121            Ok(()) => 0,
122            Err(rem) => rem.get(),
123        };
124        let advanced = min - rem;
125        self.n -= advanced;
126        NonZero::new(n - advanced).map_or(Ok(()), Err)
127    }
128}
129
130#[unstable(issue = "none", feature = "inplace_iteration")]
131unsafe impl<I> SourceIter for Take<I>
132where
133    I: SourceIter,
134{
135    type Source = I::Source;
136
137    #[inline]
138    unsafe fn as_inner(&mut self) -> &mut I::Source {
139        // SAFETY: unsafe function forwarding to unsafe function with the same requirements
140        unsafe { SourceIter::as_inner(&mut self.iter) }
141    }
142}
143
144#[unstable(issue = "none", feature = "inplace_iteration")]
145unsafe impl<I: InPlaceIterable> InPlaceIterable for Take<I> {
146    const EXPAND_BY: Option<NonZero<usize>> = I::EXPAND_BY;
147    const MERGE_BY: Option<NonZero<usize>> = I::MERGE_BY;
148}
149
150#[stable(feature = "double_ended_take_iterator", since = "1.38.0")]
151impl<I> DoubleEndedIterator for Take<I>
152where
153    I: DoubleEndedIterator + ExactSizeIterator,
154{
155    #[inline]
156    fn next_back(&mut self) -> Option<Self::Item> {
157        if self.n == 0 {
158            None
159        } else {
160            let n = self.n;
161            self.n -= 1;
162            self.iter.nth_back(self.iter.len().saturating_sub(n))
163        }
164    }
165
166    #[inline]
167    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
168        let len = self.iter.len();
169        if self.n > n {
170            let m = len.saturating_sub(self.n) + n;
171            self.n -= n + 1;
172            self.iter.nth_back(m)
173        } else {
174            if len > 0 {
175                self.iter.nth_back(len - 1);
176            }
177            None
178        }
179    }
180
181    #[inline]
182    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
183    where
184        Self: Sized,
185        Fold: FnMut(Acc, Self::Item) -> R,
186        R: Try<Output = Acc>,
187    {
188        if self.n == 0 {
189            try { init }
190        } else {
191            let len = self.iter.len();
192            if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
193                try { init }
194            } else {
195                self.iter.try_rfold(init, fold)
196            }
197        }
198    }
199
200    #[inline]
201    fn rfold<Acc, Fold>(mut self, init: Acc, fold: Fold) -> Acc
202    where
203        Self: Sized,
204        Fold: FnMut(Acc, Self::Item) -> Acc,
205    {
206        if self.n == 0 {
207            init
208        } else {
209            let len = self.iter.len();
210            if len > self.n && self.iter.nth_back(len - self.n - 1).is_none() {
211                init
212            } else {
213                self.iter.rfold(init, fold)
214            }
215        }
216    }
217
218    #[inline]
219    #[rustc_inherit_overflow_checks]
220    fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
221        // The amount by which the inner iterator needs to be shortened for it to be
222        // at most as long as the take() amount.
223        let trim_inner = self.iter.len().saturating_sub(self.n);
224        // The amount we need to advance inner to fulfill the caller's request.
225        // take(), advance_by() and len() all can be at most usize, so we don't have to worry
226        // about having to advance more than usize::MAX here.
227        let advance_by = trim_inner.saturating_add(n);
228
229        let remainder = match self.iter.advance_back_by(advance_by) {
230            Ok(()) => 0,
231            Err(rem) => rem.get(),
232        };
233        let advanced_by_inner = advance_by - remainder;
234        let advanced_by = advanced_by_inner - trim_inner;
235        self.n -= advanced_by;
236        NonZero::new(n - advanced_by).map_or(Ok(()), Err)
237    }
238}
239
240#[stable(feature = "rust1", since = "1.0.0")]
241impl<I> ExactSizeIterator for Take<I> where I: ExactSizeIterator {}
242
243#[stable(feature = "fused", since = "1.26.0")]
244impl<I> FusedIterator for Take<I> where I: FusedIterator {}
245
246#[unstable(issue = "none", feature = "trusted_fused")]
247unsafe impl<I: TrustedFused> TrustedFused for Take<I> {}
248
249#[unstable(feature = "trusted_len", issue = "37572")]
250unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}
251
252trait SpecTake: Iterator {
253    fn spec_fold<B, F>(self, init: B, f: F) -> B
254    where
255        Self: Sized,
256        F: FnMut(B, Self::Item) -> B;
257
258    fn spec_for_each<F: FnMut(Self::Item)>(self, f: F);
259}
260
261impl<I: Iterator> SpecTake for Take<I> {
262    #[inline]
263    default fn spec_fold<B, F>(mut self, init: B, f: F) -> B
264    where
265        Self: Sized,
266        F: FnMut(B, Self::Item) -> B,
267    {
268        use crate::ops::NeverShortCircuit;
269        self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
270    }
271
272    #[inline]
273    default fn spec_for_each<F: FnMut(Self::Item)>(mut self, f: F) {
274        // The default implementation would use a unit accumulator, so we can
275        // avoid a stateful closure by folding over the remaining number
276        // of items we wish to return instead.
277        fn check<'a, Item>(
278            mut action: impl FnMut(Item) + 'a,
279        ) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
280            move |more, x| {
281                action(x);
282                more.checked_sub(1)
283            }
284        }
285
286        let remaining = self.n;
287        if remaining > 0 {
288            self.iter.try_fold(remaining - 1, check(f));
289        }
290    }
291}
292
293impl<I: Iterator + TrustedRandomAccess> SpecTake for Take<I> {
294    #[inline]
295    fn spec_fold<B, F>(mut self, init: B, mut f: F) -> B
296    where
297        Self: Sized,
298        F: FnMut(B, Self::Item) -> B,
299    {
300        let mut acc = init;
301        let end = self.n.min(self.iter.size());
302        for i in 0..end {
303            // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
304            let val = unsafe { self.iter.__iterator_get_unchecked(i) };
305            acc = f(acc, val);
306        }
307        acc
308    }
309
310    #[inline]
311    fn spec_for_each<F: FnMut(Self::Item)>(mut self, mut f: F) {
312        let end = self.n.min(self.iter.size());
313        for i in 0..end {
314            // SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
315            let val = unsafe { self.iter.__iterator_get_unchecked(i) };
316            f(val);
317        }
318    }
319}
320
321#[stable(feature = "exact_size_take_repeat", since = "1.82.0")]
322impl<T: Clone> DoubleEndedIterator for Take<crate::iter::Repeat<T>> {
323    #[inline]
324    fn next_back(&mut self) -> Option<Self::Item> {
325        self.next()
326    }
327
328    #[inline]
329    fn nth_back(&mut self, n: usize) -> Option<Self::Item> {
330        self.nth(n)
331    }
332
333    #[inline]
334    fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
335    where
336        Self: Sized,
337        Fold: FnMut(Acc, Self::Item) -> R,
338        R: Try<Output = Acc>,
339    {
340        self.try_fold(init, fold)
341    }
342
343    #[inline]
344    fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
345    where
346        Self: Sized,
347        Fold: FnMut(Acc, Self::Item) -> Acc,
348    {
349        self.fold(init, fold)
350    }
351
352    #[inline]
353    #[rustc_inherit_overflow_checks]
354    fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
355        self.advance_by(n)
356    }
357}
358
359// Note: It may be tempting to impl DoubleEndedIterator for Take<RepeatWith>.
360// One must fight that temptation since such implementation wouldn’t be correct
361// because we have no way to return value of nth invocation of repeater followed
362// by n-1st without remembering all results.
363
364#[stable(feature = "exact_size_take_repeat", since = "1.82.0")]
365impl<T: Clone> ExactSizeIterator for Take<crate::iter::Repeat<T>> {
366    fn len(&self) -> usize {
367        self.n
368    }
369}
370
371#[stable(feature = "exact_size_take_repeat", since = "1.82.0")]
372impl<F: FnMut() -> A, A> ExactSizeIterator for Take<crate::iter::RepeatWith<F>> {
373    fn len(&self) -> usize {
374        self.n
375    }
376}