|
| 1 | + |
| 2 | +using StreamSampling |
| 3 | +using StatsBase |
| 4 | +using OnlineStatsBase |
| 5 | +using HybridStructs |
| 6 | +using DataStructures |
| 7 | +using Random |
| 8 | + |
| 9 | +struct AlgAExpJWR end |
| 10 | + |
| 11 | +struct SampleMultiAlgAExpJWR{B, R, T} <: AbstractReservoirSample |
| 12 | + n::Int |
| 13 | + seen_k::Int |
| 14 | + w_sum::Float64 |
| 15 | + rng::R |
| 16 | + value::B |
| 17 | + value_prev::Vector{T} |
| 18 | + weights::Vector{Float64} |
| 19 | +end |
| 20 | + |
| 21 | +function StreamSampling.ReservoirSample{T}(rng::AbstractRNG, n::Integer, ::AlgAExpJWR, |
| 22 | + ::StreamSampling.ImmutSample, ::StreamSampling.Unord) where T |
| 23 | + value = BinaryHeap(Base.By(first, DataStructures.FasterForward()), Tuple{Float64,T}[]) |
| 24 | + sizehint!(value, n) |
| 25 | + v = Vector{T}(undef, n) |
| 26 | + w = Vector{Float64}(undef, n) |
| 27 | + return SampleMultiAlgAExpJWR(n, 0, 0.0, rng, value, v, w) |
| 28 | +end |
| 29 | + |
| 30 | +@inline function OnlineStatsBase._fit!(s::SampleMultiAlgAExpJWR, el, w) |
| 31 | + n = s.n |
| 32 | + s = @inline update_state!(s, w) |
| 33 | + if s.seen_k <= n |
| 34 | + @inbounds s.value_prev[s.seen_k] = el |
| 35 | + @inbounds s.weights[s.seen_k] = w |
| 36 | + if s.seen_k === n |
| 37 | + for x in sample(s.rng, s.value_prev, Weights(s.weights, s.w_sum), n) |
| 38 | + push!(s.value, (skip_single(s.rng, s.w_sum), x)) |
| 39 | + end |
| 40 | + empty!(s.value_prev) |
| 41 | + empty!(s.weights) |
| 42 | + end |
| 43 | + else |
| 44 | + while first(s.value)[1] <= s.w_sum |
| 45 | + pop!(s.value) |
| 46 | + push!(s.value, (skip_single(s.rng, s.w_sum), el)) |
| 47 | + end |
| 48 | + end |
| 49 | + return s |
| 50 | +end |
| 51 | + |
| 52 | +skip_single(rng, n) = n/rand(rng) |
| 53 | + |
| 54 | +function update_state!(s::SampleMultiAlgAExpJWR, w) |
| 55 | + @update s.seen_k += 1 |
| 56 | + @update s.w_sum += w |
| 57 | + return s |
| 58 | +end |
| 59 | + |
| 60 | +function OnlineStatsBase.value(s::SampleMultiAlgAExpJWR) |
| 61 | + return shuffle!(s.rng, last.(s.value.valtree)) |
| 62 | +end |
| 63 | + |
| 64 | +a = Iterators.filter(x -> x != 1, 1:10^8) |
| 65 | +wv_const(x) = 1.0 |
| 66 | +wv_incr(x) = Float64(x) |
| 67 | +wv_decr(x) = 1/x |
| 68 | +wvs = (wv_decr, wv_const, wv_incr) |
| 69 | + |
| 70 | +for wv in wvs |
| 71 | + for m in (AlgWRSWRSKIP(), AlgAExpJWR()) |
| 72 | + for sz in [10^i for i in 0:7] |
| 73 | + b = @benchmark itsample($a, $wv, $sz, $m) seconds=10 |
| 74 | + println(wv, " ", m, " ", sz, " ", median(b.times)) |
| 75 | + end |
| 76 | + end |
| 77 | +end |
| 78 | + |
0 commit comments