From f0e0e536189ad8a3205d980b8b51b5d191359fa3 Mon Sep 17 00:00:00 2001 From: alecmocatta Date: Fri, 7 Aug 2020 12:58:33 +0100 Subject: [PATCH] tweaks --- amadeus-core/src/par_pipe.rs | 32 +++++++++++----------- amadeus-core/src/par_sink/mean.rs | 3 +-- amadeus-core/src/par_sink/sample.rs | 6 ++++- amadeus-core/src/par_sink/stddev.rs | 31 +++++++++++---------- amadeus-core/src/par_stream.rs | 36 ++++++++++++++++--------- amadeus-core/src/par_stream/identity.rs | 20 +++++++------- 6 files changed, 71 insertions(+), 57 deletions(-) diff --git a/amadeus-core/src/par_pipe.rs b/amadeus-core/src/par_pipe.rs index f8c2bf1d..4d9bd797 100644 --- a/amadeus-core/src/par_pipe.rs +++ b/amadeus-core/src/par_pipe.rs @@ -204,6 +204,22 @@ macro_rules! pipe { $assert_sink(Sum::new(self)) } + #[inline] + fn mean(self) -> Mean + where + Self: $pipe + Sized, + { + $assert_sink(Mean::new(self)) + } + + #[inline] + fn stddev(self) -> StdDev + where + Self: $pipe + Sized, + { + $assert_sink(StdDev::new(self)) + } + #[inline] fn combine(self, f: F) -> Combine where @@ -244,14 +260,6 @@ macro_rules! pipe { $assert_sink(MaxByKey::new(self, f)) } - #[inline] - fn mean(self) -> Mean - where - Self: $pipe + Sized, - { - $assert_sink(Mean::new(self)) - } - #[inline] fn min(self) -> Min where @@ -318,14 +326,6 @@ macro_rules! pipe { $assert_sink(SampleUnstable::new(self, samples)) } - #[inline] - fn stddev(self) -> StdDev - where - Self: $pipe + Sized, - { - $assert_sink(StdDev::new(self)) - } - #[inline] fn all(self, f: F) -> All where diff --git a/amadeus-core/src/par_sink/mean.rs b/amadeus-core/src/par_sink/mean.rs index d041529f..c9f1288e 100644 --- a/amadeus-core/src/par_sink/mean.rs +++ b/amadeus-core/src/par_sink/mean.rs @@ -13,8 +13,7 @@ pub struct Mean

{ } impl_par_dist! { - impl, Item> ParallelSink for Mean

- { + impl, Item> ParallelSink for Mean

{ folder_par_sink!( MeanFolder, MeanFolder, diff --git a/amadeus-core/src/par_sink/sample.rs b/amadeus-core/src/par_sink/sample.rs index d0753e4f..68aba007 100644 --- a/amadeus-core/src/par_sink/sample.rs +++ b/amadeus-core/src/par_sink/sample.rs @@ -91,7 +91,8 @@ impl FolderSync for SortFolder where F: traits::Fn(&Item, &Item) -> Ordering + Clone, { - type Done = SASort; + type State = SASort; + type Done = Self::State; fn zero(&mut self) -> Self::Done { SASort::new(self.f.clone(), self.n) @@ -99,6 +100,9 @@ where fn push(&mut self, state: &mut Self::Done, item: Item) { state.push(item) } + fn done(&mut self, state: Self::State) -> Self::Done { + state + } } #[derive(new)] diff --git a/amadeus-core/src/par_sink/stddev.rs b/amadeus-core/src/par_sink/stddev.rs index ae5343cb..dbf3403e 100644 --- a/amadeus-core/src/par_sink/stddev.rs +++ b/amadeus-core/src/par_sink/stddev.rs @@ -13,7 +13,7 @@ pub struct StdDev

{ } impl_par_dist! { - impl, Item, > ParallelSink for StdDev

{ + impl, Item, > ParallelSink for StdDev

{ folder_par_sink!( SDFolder, SDFolder, @@ -27,7 +27,6 @@ impl_par_dist! { #[derive(Educe, Serialize, Deserialize, new)] #[educe(Clone)] #[serde(bound = "")] - pub struct SDFolder { marker: PhantomData Step>, } @@ -40,7 +39,7 @@ pub struct SDState { #[new(default)] count: u64, #[new(default)] - sum: f64, + mean: f64, #[new(default)] variance: f64, } @@ -56,16 +55,13 @@ impl FolderSync for SDFolder { #[inline(always)] fn push(&mut self, state: &mut Self::State, item: f64) { + // Taken from https://docs.rs/streaming-stats/0.2.3/src/stats/online.rs.html#64-103 + let q_prev = state.variance * u64_to_f64(state.count); + let mean_prev = state.mean; state.count += 1; - state.sum += item; - if state.count > 1 { - let diff = u64_to_f64(state.count) * item - state.sum; - state.variance += - diff * diff / (u64_to_f64(state.count) * (u64_to_f64(state.count) - 1.0)); - state.variance /= u64_to_f64(state.count) - 1.0; - } else { - state.variance = f64::NAN; - } + let count = u64_to_f64(state.count); + state.mean += (item - state.mean) / count; + state.variance = (q_prev + (item - mean_prev) * (item - state.mean)) / count; } #[inline(always)] @@ -85,11 +81,14 @@ impl FolderSync for SDFolder { #[inline(always)] fn push(&mut self, state: &mut Self::State, item: SDState) { - state.variance = ((u64_to_f64(state.count) - 1.0) * state.variance - + (u64_to_f64(item.count) - 1.0) * item.variance) - / ((u64_to_f64(state.count) + u64_to_f64(item.count)) - 2.0); - state.sum += item.sum; + let (s1, s2) = (u64_to_f64(state.count), u64_to_f64(item.count)); + let meandiffsq = (state.mean - item.mean) * (state.mean - item.mean); + let mean = ((s1 * state.mean) + (s2 * item.mean)) / (s1 + s2); + let var = (((s1 * state.variance) + (s2 * item.variance)) / (s1 + s2)) + + ((s1 * s2 * meandiffsq) / ((s1 + s2) * (s1 + s2))); state.count += item.count; + state.mean = mean; + state.variance = var; } #[inline(always)] diff --git a/amadeus-core/src/par_stream.rs b/amadeus-core/src/par_stream.rs index 2168f270..495399be 100644 --- a/amadeus-core/src/par_stream.rs +++ b/amadeus-core/src/par_stream.rs @@ -215,6 +215,30 @@ macro_rules! stream { .await } + #[inline] + async fn mean

(self, pool: &P) -> f64 + where + P: $pool, + Self::Item: 'static, + Self::Task: 'static, + Self: $stream + Sized, + { + self.pipe(pool, $pipe::::mean(Identity)) + .await + } + + #[inline] + async fn stddev

(self, pool: &P) -> f64 + where + P: $pool, + Self::Item: 'static, + Self::Task: 'static, + Self: $stream + Sized, + { + self.pipe(pool, $pipe::::stddev(Identity)) + .await + } + #[inline] async fn combine(self, pool: &P, f: F) -> Option where @@ -267,18 +291,6 @@ macro_rules! stream { .await } - #[inline] - async fn mean

(self, pool: &P) -> f64 - where - P: $pool, - Self::Item: 'static, - Self::Task: 'static, - Self: $stream + Sized, - { - self.pipe(pool, $pipe::::mean(Identity)) - .await - } - #[inline] async fn min

(self, pool: &P) -> Option where diff --git a/amadeus-core/src/par_stream/identity.rs b/amadeus-core/src/par_stream/identity.rs index f6592694..9a301550 100644 --- a/amadeus-core/src/par_stream/identity.rs +++ b/amadeus-core/src/par_stream/identity.rs @@ -133,6 +133,16 @@ mod workaround { Sum::new(self) } + #[inline] + pub fn mean(self) -> Mean { + Mean::new(self) + } + + #[inline] + pub fn stddev(self) -> StdDev { + StdDev::new(self) + } + #[inline] pub fn combine(self, f: F) -> Combine where @@ -162,11 +172,6 @@ mod workaround { MaxByKey::new(self, f) } - #[inline] - pub fn mean(self) -> Mean { - Mean::new(self) - } - #[inline] pub fn min(self) -> Min { Min::new(self) @@ -207,11 +212,6 @@ mod workaround { SampleUnstable::new(self, samples) } - #[inline] - pub fn stddev(self) -> StdDev { - StdDev::new(self) - } - #[inline] pub fn all(self, f: F) -> All where