Move `fold` logic to `iter_fold` method and reuse it in `count` and `last`

This commit is contained in:
Tim Vermeulen 2022-07-20 16:33:53 +02:00
parent cbc5f62782
commit 3f7004920c
2 changed files with 118 additions and 16 deletions

View File

@ -78,6 +78,16 @@ where
fn advance_by(&mut self, n: usize) -> Result<(), usize> { fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.inner.advance_by(n) self.inner.advance_by(n)
} }
#[inline]
fn count(self) -> usize {
self.inner.count()
}
#[inline]
fn last(self) -> Option<Self::Item> {
self.inner.last()
}
} }
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
@ -229,6 +239,16 @@ where
fn advance_by(&mut self, n: usize) -> Result<(), usize> { fn advance_by(&mut self, n: usize) -> Result<(), usize> {
self.inner.advance_by(n) self.inner.advance_by(n)
} }
#[inline]
fn count(self) -> usize {
self.inner.count()
}
#[inline]
fn last(self) -> Option<Self::Item> {
self.inner.last()
}
} }
#[stable(feature = "iterator_flatten", since = "1.29.0")] #[stable(feature = "iterator_flatten", since = "1.29.0")]
@ -304,6 +324,35 @@ impl<I, U> FlattenCompat<I, U>
where where
I: Iterator<Item: IntoIterator<IntoIter = U>>, I: Iterator<Item: IntoIterator<IntoIter = U>>,
{ {
/// Folds the inner iterators into an accumulator by applying an operation.
///
/// Folds over the inner iterators, not over their elements. Is used by the `fold`, `count`,
/// and `last` methods.
#[inline]
fn iter_fold<Acc, Fold>(self, mut acc: Acc, mut fold: Fold) -> Acc
where
Fold: FnMut(Acc, U) -> Acc,
{
#[inline]
fn flatten<T: IntoIterator, Acc>(
fold: &mut impl FnMut(Acc, T::IntoIter) -> Acc,
) -> impl FnMut(Acc, T) -> Acc + '_ {
move |acc, iter| fold(acc, iter.into_iter())
}
if let Some(iter) = self.frontiter {
acc = fold(acc, iter);
}
acc = self.iter.fold(acc, flatten(&mut fold));
if let Some(iter) = self.backiter {
acc = fold(acc, iter);
}
acc
}
/// Folds over the inner iterators as long as the given function returns successfully, /// Folds over the inner iterators as long as the given function returns successfully,
/// always storing the most recent inner iterator in `self.frontiter`. /// always storing the most recent inner iterator in `self.frontiter`.
/// ///
@ -440,28 +489,18 @@ where
} }
#[inline] #[inline]
fn fold<Acc, Fold>(self, mut init: Acc, mut fold: Fold) -> Acc fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
where where
Fold: FnMut(Acc, Self::Item) -> Acc, Fold: FnMut(Acc, Self::Item) -> Acc,
{ {
#[inline] #[inline]
fn flatten<T: IntoIterator, Acc>( fn flatten<U: Iterator, Acc>(
fold: &mut impl FnMut(Acc, T::Item) -> Acc, mut fold: impl FnMut(Acc, U::Item) -> Acc,
) -> impl FnMut(Acc, T) -> Acc + '_ { ) -> impl FnMut(Acc, U) -> Acc {
move |acc, x| x.into_iter().fold(acc, &mut *fold) move |acc, iter| iter.fold(acc, &mut fold)
} }
if let Some(front) = self.frontiter { self.iter_fold(init, flatten(fold))
init = front.fold(init, &mut fold);
}
init = self.iter.fold(init, flatten(&mut fold));
if let Some(back) = self.backiter {
init = back.fold(init, &mut fold);
}
init
} }
#[inline] #[inline]
@ -481,6 +520,27 @@ where
_ => Ok(()), _ => Ok(()),
} }
} }
#[inline]
fn count(self) -> usize {
#[inline]
#[rustc_inherit_overflow_checks]
fn count<U: Iterator>(acc: usize, iter: U) -> usize {
acc + iter.count()
}
self.iter_fold(0, count)
}
#[inline]
fn last(self) -> Option<Self::Item> {
#[inline]
fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
iter.last().or(last)
}
self.iter_fold(None, last)
}
} }
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U> impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>

View File

@ -168,3 +168,45 @@ fn test_trusted_len_flatten() {
assert_trusted_len(&iter); assert_trusted_len(&iter);
assert_eq!(iter.size_hint(), (20, Some(20))); assert_eq!(iter.size_hint(), (20, Some(20)));
} }
#[test]
fn test_flatten_count() {
let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
assert_eq!(it.clone().count(), 40);
it.advance_by(5).unwrap();
assert_eq!(it.clone().count(), 35);
it.advance_back_by(5).unwrap();
assert_eq!(it.clone().count(), 30);
it.advance_by(10).unwrap();
assert_eq!(it.clone().count(), 20);
it.advance_back_by(8).unwrap();
assert_eq!(it.clone().count(), 12);
it.advance_by(4).unwrap();
assert_eq!(it.clone().count(), 8);
it.advance_back_by(5).unwrap();
assert_eq!(it.clone().count(), 3);
it.advance_by(3).unwrap();
assert_eq!(it.clone().count(), 0);
}
#[test]
fn test_flatten_last() {
let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
assert_eq!(it.clone().last(), Some(39));
it.advance_by(5).unwrap(); // 5..40
assert_eq!(it.clone().last(), Some(39));
it.advance_back_by(5).unwrap(); // 5..35
assert_eq!(it.clone().last(), Some(34));
it.advance_by(10).unwrap(); // 15..35
assert_eq!(it.clone().last(), Some(34));
it.advance_back_by(8).unwrap(); // 15..27
assert_eq!(it.clone().last(), Some(26));
it.advance_by(4).unwrap(); // 19..27
assert_eq!(it.clone().last(), Some(26));
it.advance_back_by(5).unwrap(); // 19..22
assert_eq!(it.clone().last(), Some(21));
it.advance_by(3).unwrap(); // 22..22
assert_eq!(it.clone().last(), None);
}