Skip to content

Commit 8055fa2

Browse files
authored
Create benchmark_comparison_stream_WWR.jl (#108)
1 parent 5656842 commit 8055fa2

File tree

2 files changed

+85
-7
lines changed

2 files changed

+85
-7
lines changed

benchmark/benchmark_comparison_non_stream.jl renamed to benchmark/benchmark_comparison_non_stream_WWR.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,28 @@ function weighted_reservoir_sample_seq(rng, a, ws, n)
1919
w_sum = sum(view_w_f_n)
2020
reservoir = sample(rng, @view(a[1:m]), Weights(view_w_f_n, w_sum), n)
2121
length(a) <= n && return reservoir, w_sum
22-
w_skip = skip(rng, w_sum, n)
22+
w_skip = @inline skip(rng, w_sum, n)
2323
@inbounds for i in n+1:length(a)
2424
w_el = ws[i]
2525
w_sum += w_el
2626
if w_sum > w_skip
2727
p = w_el/w_sum
2828
q = 1-p
29-
z = q^(n-4)
29+
z = exp((n-4)*log1p(-p))
3030
t = rand(rng, Uniform(z*q*q*q*q,1.0))
31-
k = choose(n, p, q, t, z)
32-
for j in 1:k
31+
k = @inline choose(n, p, q, t, z)
32+
@inbounds for j in 1:k
3333
r = rand(rng, j:n)
34-
@inbounds reservoir[r], reservoir[j] = reservoir[j], a[i]
34+
reservoir[r], reservoir[j] = reservoir[j], a[i]
3535
end
36-
w_skip = skip(rng, w_sum, n)
36+
w_skip = @inline skip(rng, w_sum, n)
3737
end
3838
end
3939
return reservoir, w_sum
4040
end
4141

4242
function skip(rng, w_sum::AbstractFloat, n)
43-
k = rand(rng)^(1/n)
43+
k = exp(-randexp(rng)/n)
4444
return w_sum/k
4545
end
4646

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
using StreamSampling
3+
using StatsBase
4+
using OnlineStatsBase
5+
using HybridStructs
6+
using DataStructures
7+
using Random
8+
9+
struct AlgAExpJWR end
10+
11+
struct SampleMultiAlgAExpJWR{B, R, T} <: AbstractReservoirSample
12+
n::Int
13+
seen_k::Int
14+
w_sum::Float64
15+
rng::R
16+
value::B
17+
value_prev::Vector{T}
18+
weights::Vector{Float64}
19+
end
20+
21+
function StreamSampling.ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJWR,
22+
::StreamSampling.ImmutSample, ::StreamSampling.Unord) where T
23+
value = BinaryHeap(Base.By(first, DataStructures.FasterForward()), Tuple{Float64,T}[])
24+
sizehint!(value, n)
25+
v = Vector{T}(undef, n)
26+
w = Vector{Float64}(undef, n)
27+
return SampleMultiAlgAExpJWR(n, 0, 0.0, rng, value, v, w)
28+
end
29+
30+
@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJWR, el, w)
31+
n = s.n
32+
s = @inline update_state!(s, w)
33+
if s.seen_k <= n
34+
@inbounds s.value_prev[s.seen_k] = el
35+
@inbounds s.weights[s.seen_k] = w
36+
if s.seen_k === n
37+
for x in sample(s.rng, s.value_prev, Weights(s.weights, s.w_sum), n)
38+
push!(s.value, (skip_single(s.rng, s.w_sum), x))
39+
end
40+
empty!(s.value_prev)
41+
empty!(s.weights)
42+
end
43+
else
44+
while first(s.value)[1] <= s.w_sum
45+
pop!(s.value)
46+
push!(s.value, (skip_single(s.rng, s.w_sum), el))
47+
end
48+
end
49+
return s
50+
end
51+
52+
skip_single(rng, n) = n/rand(rng)
53+
54+
function update_state!(s::SampleMultiAlgAExpJWR, w)
55+
@update s.seen_k += 1
56+
@update s.w_sum += w
57+
return s
58+
end
59+
60+
function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR)
61+
return shuffle!(s.rng, last.(s.value.valtree))
62+
end
63+
64+
a = Iterators.filter(x -> x != 1, 1:10^8)
65+
wv_const(x) = 1.0
66+
wv_incr(x) = Float64(x)
67+
wv_decr(x) = 1/x
68+
wvs = (wv_decr, wv_const, wv_incr)
69+
70+
for wv in wvs
71+
for m in (AlgWRSWRSKIP(), AlgAExpJWR())
72+
for sz in [10^i for i in 0:7]
73+
b = @benchmark itsample($a, $wv, $sz, $m) seconds=10
74+
println(wv, " ", m, " ", sz, " ", median(b.times))
75+
end
76+
end
77+
end
78+

0 commit comments

Comments
 (0)