diff --git a/amadeus-core/src/par_sink/sum.rs b/amadeus-core/src/par_sink/sum.rs index 65d952c4..e674296b 100644 --- a/amadeus-core/src/par_sink/sum.rs +++ b/amadeus-core/src/par_sink/sum.rs @@ -1,8 +1,8 @@ use derive_new::new; use educe::Educe; -use replace_with::replace_with_or_abort; +use replace_with::{replace_with, replace_with_or_abort}; use serde::{Deserialize, Serialize}; -use std::{iter, marker::PhantomData, mem}; +use std::{iter, marker::PhantomData}; use super::{folder_par_sink, FolderSync, FolderSyncReducer, ParallelPipe, ParallelSink}; @@ -44,14 +44,15 @@ where #[inline(always)] fn zero(&mut self) -> Self::Done { - iter::empty::().sum() + B::sum(iter::empty::()) } #[inline(always)] - fn push(&mut self, state: &mut Self::State, item: Item) { - let zero = iter::empty::().sum(); - let left = mem::replace(state, zero); - let right = iter::once(item).sum::(); - *state = B::sum(iter::once(left).chain(iter::once(right))); + fn push(&mut self, state: &mut Self::Done, item: Item) { + let default = || B::sum(iter::empty::()); + replace_with(state, default, |left| { + let right = iter::once(item).sum::(); + B::sum(iter::once(left).chain(iter::once(right))) + }) } #[inline(always)] @@ -79,11 +80,11 @@ where type Done = Self::State; #[inline(always)] - fn zero(&mut self) -> Self::State { + fn zero(&mut self) -> Self::Done { self.zero.take().unwrap() } #[inline(always)] - fn push(&mut self, state: &mut Self::State, item: Item) { + fn push(&mut self, state: &mut Self::Done, item: Item) { replace_with_or_abort(state, |left| { let right = iter::once(item).sum::>().unwrap(); as iter::Sum>::sum(iter::once(left).chain(iter::once(right)))