diff --git a/benchmark/benchmark_comparison_non_stream_WWR.jl b/benchmark/benchmark_comparison_non_stream_WWR.jl index dbfc7a30..3d454148 100644 --- a/benchmark/benchmark_comparison_non_stream_WWR.jl +++ b/benchmark/benchmark_comparison_non_stream_WWR.jl @@ -125,43 +125,43 @@ end rng = Xoshiro(42); rngs = Tuple(Xoshiro(rand(rng, 1:10000)) for _ in 1:Threads.nthreads()); -a = collect(1:10^7); +a = collect(1:10^8); wsa = Float64.(a); times_other_parallel = Float64[] -for i in 0:6 - b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) +for i in 0:7 + b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20 push!(times_other_parallel, median(b.times)/10^6) println("other $(10^i): $(median(b.times)/10^6) ms") end times_other = Float64[] -for i in 0:6 - b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true) +for i in 0:7 + b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true) seconds=20 push!(times_other, median(b.times)/10^6) println("other $(10^i): $(median(b.times)/10^6) ms") end ## single thread times_single_thread = Float64[] -for i in 0:6 - b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i) +for i in 0:7 + b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i) seconds=20 push!(times_single_thread, median(b.times)/10^6) println("sequential $(10^i): $(median(b.times)/10^6) ms") end # multi thread 1 pass - 6 threads times_multi_thread = Float64[] -for i in 0:6 - b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i) +for i in 0:7 + b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i) seconds=20 push!(times_multi_thread, median(b.times)/10^6) println("parallel $(10^i): $(median(b.times)/10^6) ms") end # multi thread 2 pass - 6 threads times_multi_thread_2 = Float64[] -for i in 0:6 - b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) +for i in 0:7 + b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20 push!(times_multi_thread_2, median(b.times)/10^6) println("parallel $(10^i): $(median(b.times)/10^6) ms") end @@ -170,13 +170,13 @@ py""" import numpy as np import timeit -a = np.arange(1, 10**7+1, dtype=np.int64); -wsa = np.arange(1, 10**7+1, dtype=np.float64) +a = np.arange(1, 10**8+1, dtype=np.int64); +wsa = np.arange(1, 10**8+1, dtype=np.float64) p = wsa/np.sum(wsa); def sample_times_numpy(): times_numpy = [] - for i in range(7): + for i in range(8): ts = [] for j in range(11): t = timeit.timeit("np.random.choice(a, size=10**i, replace=True, p=p)", @@ -196,20 +196,20 @@ ax1 = Axis(f[1, 1], yscale=log10, xscale=log10, yminorticksvisible = true, yminorgridvisible = true, yminorticks = IntervalsBetween(10)) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_numpy[2:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other[2:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other_parallel[2:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_single_thread[2:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread[2:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot) -scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread_2[2:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot) -Legend(f[1,2], ax1, labelsize=10, framevisible = false) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_numpy[3:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other[3:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other_parallel[3:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_single_thread[3:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread[3:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot) +scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread_2[3:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot) +Legend(f[2,1], ax1, labelsize=10, framevisible = false, orientation = :horizontal) ax1.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%" ax1.title = "Comparison between weighted sampling algorithms in a non-streaming context" -ax1.xticks = [10^(i)/10^7 for i in 1:6] +ax1.xticks = [10^(i)/10^8 for i in 2:7] ax1.xlabel = "sample ratio" ax1.ylabel = "time (ms)" f -save("comparison_WRSWR_SKIP_alg.png", f) +save("comparison_WRSWR_SKIP_alg_no_stream.png", f) diff --git a/benchmark/benchmark_comparison_stream.jl b/benchmark/benchmark_comparison_stream.jl index 29f8c203..6c3c5093 100644 --- a/benchmark/benchmark_comparison_stream.jl +++ b/benchmark/benchmark_comparison_stream.jl @@ -3,7 +3,7 @@ using Random, Printf, BenchmarkTools using CairoMakie rng = Xoshiro(42); -stream = Iterators.filter(x -> x != 10, 1:10^7); +stream = Iterators.filter(x -> x != 1, 1:10^8); pop = collect(stream); w(el) = Float64(el); weights = Weights(w.(stream)); @@ -11,7 +11,7 @@ weights = Weights(w.(stream)); algs = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP()); algsweighted = (AlgAExpJ(), AlgWRSWRSKIP()); algsreplace = (AlgRSWRSKIP(), AlgWRSWRSKIP()); -sizes = (10^3, 10^4, 10^5, 10^6) +sizes = (10^4, 10^5, 10^6, 10^7) p = Dict((0, 0) => 1, (0, 1) => 2, (1, 0) => 3, (1, 1) => 4); m_times = Matrix{Vector{Float64}}(undef, (3, 4)); @@ -24,13 +24,13 @@ for m in algs replace = m in algsreplace weighted = m in algsweighted if weighted - b1 = @benchmark itsample($rng, $stream, $w, $size, $m) evals=1 - b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) evals=1 - b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) evals=1 + b1 = @benchmark itsample($rng, $stream, $w, $size, $m) seconds=20 + b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) seconds=20 + b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) seconds=20 else - b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 - b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) evals=1 - b3 = @benchmark sample($rng, $pop, $size; replace = $replace) evals=1 + b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 seconds=20 + b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) seconds=20 + b3 = @benchmark sample($rng, $pop, $size; replace = $replace) seconds=20 end ts = [median(b1.times), median(b2.times), median(b3.times)] .* 1e-6 ms = [b1.memory, b2.memory, b3.memory] .* 1e-6 @@ -39,6 +39,7 @@ for m in algs push!(m_times[r, c], ts[r]) push!(m_mems[r, c], ms[r]) end + println("c") end end diff --git a/benchmark/benchmark_comparison_stream_WWR.jl b/benchmark/benchmark_comparison_stream_WWR.jl index ea291f26..675a063f 100644 --- a/benchmark/benchmark_comparison_stream_WWR.jl +++ b/benchmark/benchmark_comparison_stream_WWR.jl @@ -65,14 +65,61 @@ a = Iterators.filter(x -> x != 1, 1:10^8) wv_const(x) = 1.0 wv_incr(x) = Float64(x) wv_decr(x) = 1/x -wvs = (wv_decr, wv_const, wv_incr) +wvs = ((:wv_decr, wv_decr), + (:wv_const, wv_const), + (:wv_incr, wv_incr)) -for wv in wvs - for m in (AlgWRSWRSKIP(), AlgAExpJWR()) - for sz in [10^i for i in 0:7] - b = @benchmark itsample($a, $wv, $sz, $m) seconds=10 - println(wv, " ", m, " ", sz, " ", median(b.times)) +benchs = [] +for (wvn, wv) in wvs + for m in (AlgAExpJWR(), AlgWRSWRSKIP()) + bs = [] + for sz in [10^i for i in 3:7] + b = @benchmark itsample($a, $wv, $sz, $m) seconds=20 + push!(bs, median(b.times)) + println(median(b.times)) end + push!(benchs, (wvn, m, bs)) + println(benchs) end end +using CairoMakie + +f = Figure(backgroundcolor = RGBf(0.98, 0.98, 0.98), size = (1100, 700)); + +f.title = "Comparison between AExpJ-WR and WRSWR-SKIP Algorithms" + +ax1 = Axis(f[1, 1], yscale=log10, xscale=log10, + yminorticksvisible = true, yminorgridvisible = true, + yminorticks = IntervalsBetween(10)) +ax2 = Axis(f[1, 2], yscale=log10, xscale=log10, + yminorticksvisible = true, yminorgridvisible = true, + yminorticks = IntervalsBetween(10)) +ax3 = Axis(f[1, 3], yscale=log10, xscale=log10, + yminorticksvisible = true, yminorgridvisible = true, + yminorticks = IntervalsBetween(10)) + +#ax4 = Axis(f[2, 1]) + +for x in benchs + label = x[1] == :wv_const ? (x[2] == AlgAExpJWR() ? "ExpJ-WR" : "WRSWR-SKIP") : "" + ax = x[1] == :wv_decr ? ax1 : (x[1] == :wv_const ? ax2 : ax3) + marker = x[2] == AlgAExpJWR() ? :circle : (:xcross) + scatterlines!(ax, [10^i/10^8 for i in 3:7], x[3] ./ 10^6, marker = marker, + label = label, markersize = 12, linestyle = :dot) +end + +Legend(ax4, labelsize=10, framevisible = false, orientation = :horizontal) + +for ax in [ax1, ax2, ax3] + ax.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%" + #ax.ytickformat = y -> y .* "^" + ax.title = ax == ax1 ? "decreasing weights" : (ax == ax2 ? "constant weights" : "increasing weights") + ax.xticks = [10^(i)/10^8 for i in 3:7] + ax.yticks = [10^i for i in 2:4] + ax.xlabel = "sample ratio" + ax == ax1 && (ax.ylabel = "time (ms)") +end + +save("comparison_WRSWR_SKIP_alg_stream.png", f) +f \ No newline at end of file diff --git a/benchmark/comparison_WRSWR_SKIP_alg.png b/benchmark/comparison_WRSWR_SKIP_alg.png deleted file mode 100644 index c406ddd0..00000000 Binary files a/benchmark/comparison_WRSWR_SKIP_alg.png and /dev/null differ diff --git a/benchmark/comparison_WRSWR_SKIP_alg_no_stream.png b/benchmark/comparison_WRSWR_SKIP_alg_no_stream.png new file mode 100644 index 00000000..49f1db84 Binary files /dev/null and b/benchmark/comparison_WRSWR_SKIP_alg_no_stream.png differ diff --git a/benchmark/comparison_WRSWR_SKIP_alg_stream.png b/benchmark/comparison_WRSWR_SKIP_alg_stream.png new file mode 100644 index 00000000..03a68971 Binary files /dev/null and b/benchmark/comparison_WRSWR_SKIP_alg_stream.png differ diff --git a/benchmark/comparison_stream_algs.png b/benchmark/comparison_stream_algs.png index 4e4c4845..2b555e8d 100644 Binary files a/benchmark/comparison_stream_algs.png and b/benchmark/comparison_stream_algs.png differ diff --git a/src/SamplingUtils.jl b/src/SamplingUtils.jl index 10db6315..5c8c729f 100644 --- a/src/SamplingUtils.jl +++ b/src/SamplingUtils.jl @@ -34,6 +34,7 @@ Base.eltype(::SeqSampleIterWR) = Int Base.IteratorSize(::SeqSampleIterWR) = Base.HasLength() Base.length(s::SeqSampleIterWR) = s.n +# courtesy of StatsBase.jl for part of the implementation struct SeqSampleIter{R} rng::R N::Int diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 70c52d25..947e0747 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -1,5 +1,6 @@ @hybrid struct SampleMultiAlgR{O,T,R} <: AbstractReservoirSample + const n::Int seen_k::Int const rng::R const value::Vector{T} @@ -8,6 +9,7 @@ end const SampleMultiOrdAlgR = SampleMultiAlgR{<:Vector} @hybrid struct SampleMultiAlgL{O,T,R} <: AbstractReservoirSample + const n::Int state::Float64 skip_k::Int seen_k::Int @@ -18,6 +20,7 @@ end const SampleMultiOrdAlgL = SampleMultiAlgL{<:Vector} @hybrid struct SampleMultiAlgRSWRSKIP{O,T,R} <: AbstractReservoirSample + const n::Int skip_k::Int seen_k::Int const rng::R @@ -27,65 +30,67 @@ end const SampleMultiOrdAlgRSWRSKIP = SampleMultiAlgRSWRSKIP{<:Vector} function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSample, ::Ord) where T - return SampleMultiAlgL_Mut(0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgL_Mut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::MutSample, ::Unord) where T - return SampleMultiAlgL_Mut(0.0, 0, 0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgL_Mut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSample, ::Ord) where T - return SampleMultiAlgL_Immut(0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgL_Immut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgL, ::ImmutSample, ::Unord) where T - return SampleMultiAlgL_Immut(0.0, 0, 0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgL_Immut(n, 0.0, 0, 0, rng, Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSample, ::Ord) where T - return SampleMultiAlgR_Mut(0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgR_Mut(n, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::MutSample, ::Unord) where T - return SampleMultiAlgR_Mut(0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgR_Mut(n, 0, rng, Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSample, ::Ord) where T - return SampleMultiAlgR_Immut(0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgR_Immut(n, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgR, ::ImmutSample, ::Unord) where T - return SampleMultiAlgR_Immut(0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgR_Immut(n, 0, rng, Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSample, ::Ord) where T - return SampleMultiAlgRSWRSKIP_Mut(0, 0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgRSWRSKIP_Mut(n, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::MutSample, ::Unord) where T - return SampleMultiAlgRSWRSKIP_Mut(0, 0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgRSWRSKIP_Mut(n, 0, 0, rng, Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSample, ::Ord) where T - return SampleMultiAlgRSWRSKIP_Immut(0, 0, rng, Vector{T}(undef, n), collect(1:n)) + return SampleMultiAlgRSWRSKIP_Immut(n, 0, 0, rng, Vector{T}(undef, n), collect(1:n)) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgRSWRSKIP, ::ImmutSample, ::Unord) where T - return SampleMultiAlgRSWRSKIP_Immut(0, 0, rng, Vector{T}(undef, n), nothing) + return SampleMultiAlgRSWRSKIP_Immut(n, 0, 0, rng, Vector{T}(undef, n), nothing) end @inline function OnlineStatsBase._fit!(s::SampleMultiAlgR, el) - n = length(s.value) + n = s.n s = @inline update_state!(s) if s.seen_k <= n @inbounds s.value[s.seen_k] = el - else - j = rand(s.rng, 1:s.seen_k) - if j <= n - @inbounds s.value[j] = el - update_order!(s, j) - end + return s + end + j = rand(s.rng, 1:s.seen_k) + if j <= n + @inbounds s.value[j] = el + update_order!(s, j) end return s end @inline function OnlineStatsBase._fit!(s::SampleMultiAlgL, el) - n = length(s.value) + n = s.n s = @inline update_state!(s) if s.seen_k <= n @inbounds s.value[s.seen_k] = el if s.seen_k === n s = @inline recompute_skip!(s, n) end - elseif s.skip_k < s.seen_k + return s + end + if s.skip_k < s.seen_k j = rand(s.rng, 1:n) @inbounds s.value[j] = el update_order!(s, j) @@ -94,7 +99,7 @@ end return s end @inline function OnlineStatsBase._fit!(s::SampleMultiAlgRSWRSKIP, el) - n = length(s.value) + n = s.n s = @inline update_state!(s) if s.seen_k <= n @inbounds s.value[s.seen_k] = el @@ -105,11 +110,13 @@ end s.value[i] = new_values[i] end end - elseif s.skip_k < s.seen_k + return s + end + if s.skip_k < s.seen_k p = 1/s.seen_k z = exp((n-4)*log1p(-p)) - q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) - k = @inline choose(n, p, q, z) + c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) + k = @inline choose(n, p, c, z) @inbounds for j in 1:k r = rand(s.rng, j:n) s.value[r], s.value[j] = s.value[j], el @@ -160,19 +167,18 @@ function recompute_skip!(s::SampleMultiAlgRSWRSKIP, n) return s end -function choose(n, p, q, z) - m = 1-p - s = z - z = s*m*m*m*(m + n*p) - z > q && return 1 - z += n*p*(n-1)*p*s*m*m/2 - z > q && return 2 - z += n*p*(n-1)*p*(n-2)*p*s*m/6 - z > q && return 3 - z += n*p*(n-1)*p*(n-2)*p*(n-3)*p*s/24 - z > q && return 4 +function choose(n, p, c, z) + q = 1-p + k = z*q*q*q*(q + n*p) + k > c && return 1 + k += n*p*(n-1)*p*z*q*q/2 + k > c && return 2 + k += n*p*(n-1)*p*(n-2)*p*z*q/6 + k > c && return 3 + k += n*p*(n-1)*p*(n-2)*p*(n-3)*p*z/24 + k > c && return 4 b = Binomial(n, p) - return quantile(b, q) + return quantile(b, c) end update_order!(s::Union{SampleMultiAlgR, SampleMultiAlgL}, j) = nothing @@ -197,7 +203,7 @@ function Base.merge(ss::SampleMultiAlgRSWRSKIP...) newvalue = reduce_samples(TypeUnion(), ss...) skip_k = sum(getfield(s, :skip_k) for s in ss) seen_k = sum(getfield(s, :seen_k) for s in ss) - return SampleMultiAlgRSWRSKIP_Mut(skip_k, seen_k, ss[1].rng, newvalue, nothing) + return SampleMultiAlgRSWRSKIP_Mut(ss[1].n, skip_k, seen_k, ss[1].rng, newvalue, nothing) end function Base.merge!(s1::SampleMultiAlgRSWRSKIP{<:Nothing}, ss::SampleMultiAlgRSWRSKIP...) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index e9051ba0..9495a2d6 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -20,6 +20,7 @@ end const SampleMultiOrdAlgAExpJ = Union{SampleMultiAlgAExpJ_Immut{<:OrdWeighted}, SampleMultiAlgAExpJ_Mut{<:OrdWeighted}} @hybrid struct SampleMultiAlgWRSWRSKIP{O,T,R} <: AbstractReservoirSample + const n::Int state::Float64 skip_w::Float64 seen_k::Int @@ -72,17 +73,17 @@ function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgARes, ::ImmutSamp end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Ord) where T ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) + return SampleMultiAlgWRSWRSKIP_Mut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::MutSample, ::Unord) where T - return SampleMultiAlgWRSWRSKIP_Mut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) + return SampleMultiAlgWRSWRSKIP_Mut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Ord) where T ord = collect(1:n) - return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) + return SampleMultiAlgWRSWRSKIP_Immut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), ord) end function ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgWRSWRSKIP, ::ImmutSample, ::Unord) where T - return SampleMultiAlgWRSWRSKIP_Immut(0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) + return SampleMultiAlgWRSWRSKIP_Immut(n, 0.0, 0.0, 0, rng, Vector{Float64}(undef, n), Vector{T}(undef, n), nothing) end @inline function OnlineStatsBase._fit!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w) @@ -90,13 +91,13 @@ end s = @inline update_state!(s, w) priority = -randexp(s.rng)/w if s.seen_k <= n - push_value!(s, el, priority) - else - min_priority = last(first(s.value)) - if priority > min_priority - pop!(s.value) - push_value!(s, el, priority) - end + @inline push_value!(s, el, priority) + return s + end + min_priority = last(first(s.value)) + if priority > min_priority + pop!(s.value) + @inline push_value!(s, el, priority) end return s end @@ -105,37 +106,42 @@ end s = @inline update_state!(s, w) if s.seen_k <= n priority = exp(-randexp(s.rng)/w) - push_value!(s, el, priority) + @inline push_value!(s, el, priority) if s.seen_k == n s = @inline recompute_skip!(s) end - elseif s.state <= 0.0 + return s + end + if s.state <= 0.0 priority = @inline compute_skip_priority(s, w) pop!(s.value) - push_value!(s, el, priority) + @inline push_value!(s, el, priority) s = @inline recompute_skip!(s) end return s end @inline function OnlineStatsBase._fit!(s::SampleMultiAlgWRSWRSKIP, el, w) - n = length(s.value) + n = s.n s = @inline update_state!(s, w) if s.seen_k <= n @inbounds s.value[s.seen_k] = el @inbounds s.weights[s.seen_k] = w if s.seen_k == n - new_values = sample(s.rng, s.value, Weights(s.weights, s.state), n; ordered = is_ordered(s)) + new_values = sample(s.rng, s.value, Weights(s.weights, s.state), n; + ordered = is_ordered(s)) @inbounds for i in 1:n s.value[i] = new_values[i] end s = @inline recompute_skip!(s, n) empty!(s.weights) end - elseif s.skip_w <= s.state + return s + end + if s.skip_w <= s.state p = w/s.state z = exp((n-4)*log1p(-p)) - q = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p),1.0)) - k = @inline choose(n, p, q, z) + c = rand(s.rng, Uniform(z*(1-p)*(1-p)*(1-p)*(1-p), 1.0)) + k = @inline choose(n, p, c, z) @inbounds for j in 1:k r = rand(s.rng, j:n) s.value[r], s.value[j] = s.value[j], el @@ -182,7 +188,7 @@ function Base.merge(ss::SampleMultiAlgWRSWRSKIP...) skip_w = sum(getfield(s, :skip_w) for s in ss) state = sum(getfield(s, :state) for s in ss) seen_k = sum(getfield(s, :seen_k) for s in ss) - s = SampleMultiAlgWRSWRSKIP_Mut(state, skip_w, seen_k, ss[1].rng, Float64[], newvalue, nothing) + s = SampleMultiAlgWRSWRSKIP_Mut(ss[1].n, state, skip_w, seen_k, ss[1].rng, Float64[], newvalue, nothing) return s end