Skip to content

Commit

Permalink
Benchmark improvements (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar authored Oct 28, 2024
1 parent 9898f87 commit 30605ae
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 95 deletions.
46 changes: 23 additions & 23 deletions benchmark/benchmark_comparison_non_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,43 +125,43 @@ end
rng = Xoshiro(42);
rngs = Tuple(Xoshiro(rand(rng, 1:10000)) for _ in 1:Threads.nthreads());

a = collect(1:10^7);
a = collect(1:10^8);
wsa = Float64.(a);

times_other_parallel = Float64[]
for i in 0:6
b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i)
for i in 0:7
b = @benchmark sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20
push!(times_other_parallel, median(b.times)/10^6)
println("other $(10^i): $(median(b.times)/10^6) ms")
end

times_other = Float64[]
for i in 0:6
b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true)
for i in 0:7
b = @benchmark sample($rng, $a, Weights($wsa), 10^$i; replace = true) seconds=20
push!(times_other, median(b.times)/10^6)
println("other $(10^i): $(median(b.times)/10^6) ms")
end

## single thread
times_single_thread = Float64[]
for i in 0:6
b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i)
for i in 0:7
b = @benchmark weighted_reservoir_sample($rng, $a, $wsa, 10^$i) seconds=20
push!(times_single_thread, median(b.times)/10^6)
println("sequential $(10^i): $(median(b.times)/10^6) ms")
end

# multi thread 1 pass - 6 threads
times_multi_thread = Float64[]
for i in 0:6
b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i)
for i in 0:7
b = @benchmark weighted_reservoir_sample_parallel_1_pass($rngs, $a, $wsa, 10^$i) seconds=20
push!(times_multi_thread, median(b.times)/10^6)
println("parallel $(10^i): $(median(b.times)/10^6) ms")
end

# multi thread 2 pass - 6 threads
times_multi_thread_2 = Float64[]
for i in 0:6
b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i)
for i in 0:7
b = @benchmark weighted_reservoir_sample_parallel_2_pass($rngs, $a, $wsa, 10^$i) seconds=20
push!(times_multi_thread_2, median(b.times)/10^6)
println("parallel $(10^i): $(median(b.times)/10^6) ms")
end
Expand All @@ -170,13 +170,13 @@ py"""
import numpy as np
import timeit
a = np.arange(1, 10**7+1, dtype=np.int64);
wsa = np.arange(1, 10**7+1, dtype=np.float64)
a = np.arange(1, 10**8+1, dtype=np.int64);
wsa = np.arange(1, 10**8+1, dtype=np.float64)
p = wsa/np.sum(wsa);
def sample_times_numpy():
times_numpy = []
for i in range(7):
for i in range(8):
ts = []
for j in range(11):
t = timeit.timeit("np.random.choice(a, size=10**i, replace=True, p=p)",
Expand All @@ -196,20 +196,20 @@ ax1 = Axis(f[1, 1], yscale=log10, xscale=log10,
yminorticksvisible = true, yminorgridvisible = true,
yminorticks = IntervalsBetween(10))

scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_numpy[2:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other[2:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_other_parallel[2:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_single_thread[2:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread[2:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^7 for i in 1:6], times_multi_thread_2[2:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot)
Legend(f[1,2], ax1, labelsize=10, framevisible = false)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_numpy[3:end], label = "numpy.choice sequential", marker = :circle, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other[3:end], label = "StatsBase.sample sequential", marker = :rect, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_other_parallel[3:end], label = "StatsBase.sample parallel (2 passes)", marker = :diamond, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_single_thread[3:end], label = "WRSWR-SKIP sequential", marker = :hexagon, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread[3:end], label = "WRSWR-SKIP parallel (1 pass)", marker = :cross, markersize = 12, linestyle = :dot)
scatterlines!(ax1, [10^i/10^8 for i in 2:7], times_multi_thread_2[3:end], label = "WRSWR-SKIP parallel (2 passes)", marker = :xcross, markersize = 12, linestyle = :dot)
Legend(f[2,1], ax1, labelsize=10, framevisible = false, orientation = :horizontal)

ax1.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%"
ax1.title = "Comparison between weighted sampling algorithms in a non-streaming context"
ax1.xticks = [10^(i)/10^7 for i in 1:6]
ax1.xticks = [10^(i)/10^8 for i in 2:7]

ax1.xlabel = "sample ratio"
ax1.ylabel = "time (ms)"

f
save("comparison_WRSWR_SKIP_alg.png", f)
save("comparison_WRSWR_SKIP_alg_no_stream.png", f)
17 changes: 9 additions & 8 deletions benchmark/benchmark_comparison_stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ using Random, Printf, BenchmarkTools
using CairoMakie

rng = Xoshiro(42);
stream = Iterators.filter(x -> x != 10, 1:10^7);
stream = Iterators.filter(x -> x != 1, 1:10^8);
pop = collect(stream);
w(el) = Float64(el);
weights = Weights(w.(stream));

algs = (AlgL(), AlgRSWRSKIP(), AlgAExpJ(), AlgWRSWRSKIP());
algsweighted = (AlgAExpJ(), AlgWRSWRSKIP());
algsreplace = (AlgRSWRSKIP(), AlgWRSWRSKIP());
sizes = (10^3, 10^4, 10^5, 10^6)
sizes = (10^4, 10^5, 10^6, 10^7)

p = Dict((0, 0) => 1, (0, 1) => 2, (1, 0) => 3, (1, 1) => 4);
m_times = Matrix{Vector{Float64}}(undef, (3, 4));
Expand All @@ -24,13 +24,13 @@ for m in algs
replace = m in algsreplace
weighted = m in algsweighted
if weighted
b1 = @benchmark itsample($rng, $stream, $w, $size, $m) evals=1
b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) evals=1
b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) evals=1
b1 = @benchmark itsample($rng, $stream, $w, $size, $m) seconds=20
b2 = @benchmark sample($rng, collect($stream), Weights($w.($stream)), $size; replace = $replace) seconds=20
b3 = @benchmark sample($rng, $pop, $weights, $size; replace = $replace) seconds=20
else
b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1
b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) evals=1
b3 = @benchmark sample($rng, $pop, $size; replace = $replace) evals=1
b1 = @benchmark itsample($rng, $stream, $size, $m) evals=1 seconds=20
b2 = @benchmark sample($rng, collect($stream), $size; replace = $replace) seconds=20
b3 = @benchmark sample($rng, $pop, $size; replace = $replace) seconds=20
end
ts = [median(b1.times), median(b2.times), median(b3.times)] .* 1e-6
ms = [b1.memory, b2.memory, b3.memory] .* 1e-6
Expand All @@ -39,6 +39,7 @@ for m in algs
push!(m_times[r, c], ts[r])
push!(m_mems[r, c], ms[r])
end
println("c")
end
end

Expand Down
59 changes: 53 additions & 6 deletions benchmark/benchmark_comparison_stream_WWR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,61 @@ a = Iterators.filter(x -> x != 1, 1:10^8)
wv_const(x) = 1.0
wv_incr(x) = Float64(x)
wv_decr(x) = 1/x
wvs = (wv_decr, wv_const, wv_incr)
wvs = ((:wv_decr, wv_decr),
(:wv_const, wv_const),
(:wv_incr, wv_incr))

for wv in wvs
for m in (AlgWRSWRSKIP(), AlgAExpJWR())
for sz in [10^i for i in 0:7]
b = @benchmark itsample($a, $wv, $sz, $m) seconds=10
println(wv, " ", m, " ", sz, " ", median(b.times))
benchs = []
for (wvn, wv) in wvs
for m in (AlgAExpJWR(), AlgWRSWRSKIP())
bs = []
for sz in [10^i for i in 3:7]
b = @benchmark itsample($a, $wv, $sz, $m) seconds=20
push!(bs, median(b.times))
println(median(b.times))
end
push!(benchs, (wvn, m, bs))
println(benchs)
end
end

using CairoMakie

f = Figure(backgroundcolor = RGBf(0.98, 0.98, 0.98), size = (1100, 700));

f.title = "Comparison between AExpJ-WR and WRSWR-SKIP Algorithms"

ax1 = Axis(f[1, 1], yscale=log10, xscale=log10,
yminorticksvisible = true, yminorgridvisible = true,
yminorticks = IntervalsBetween(10))
ax2 = Axis(f[1, 2], yscale=log10, xscale=log10,
yminorticksvisible = true, yminorgridvisible = true,
yminorticks = IntervalsBetween(10))
ax3 = Axis(f[1, 3], yscale=log10, xscale=log10,
yminorticksvisible = true, yminorgridvisible = true,
yminorticks = IntervalsBetween(10))

#ax4 = Axis(f[2, 1])

for x in benchs
label = x[1] == :wv_const ? (x[2] == AlgAExpJWR() ? "ExpJ-WR" : "WRSWR-SKIP") : ""
ax = x[1] == :wv_decr ? ax1 : (x[1] == :wv_const ? ax2 : ax3)
marker = x[2] == AlgAExpJWR() ? :circle : (:xcross)
scatterlines!(ax, [10^i/10^8 for i in 3:7], x[3] ./ 10^6, marker = marker,
label = label, markersize = 12, linestyle = :dot)
end

Legend(ax4, labelsize=10, framevisible = false, orientation = :horizontal)

for ax in [ax1, ax2, ax3]
ax.xtickformat = x -> string.(round.(x.*100, digits=10)) .* "%"
#ax.ytickformat = y -> y .* "^"
ax.title = ax == ax1 ? "decreasing weights" : (ax == ax2 ? "constant weights" : "increasing weights")
ax.xticks = [10^(i)/10^8 for i in 3:7]
ax.yticks = [10^i for i in 2:4]
ax.xlabel = "sample ratio"
ax == ax1 && (ax.ylabel = "time (ms)")
end

save("comparison_WRSWR_SKIP_alg_stream.png", f)
f
Binary file removed benchmark/comparison_WRSWR_SKIP_alg.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added benchmark/comparison_WRSWR_SKIP_alg_stream.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified benchmark/comparison_stream_algs.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions src/SamplingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Base.eltype(::SeqSampleIterWR) = Int
Base.IteratorSize(::SeqSampleIterWR) = Base.HasLength()
Base.length(s::SeqSampleIterWR) = s.n

# courtesy of StatsBase.jl for part of the implementation
struct SeqSampleIter{R}
rng::R
N::Int
Expand Down
Loading

0 comments on commit 30605ae

Please sign in to comment.