From 8ee4f55caf7fefd147d28929a8876c08e2aaa04c Mon Sep 17 00:00:00 2001 From: Adriano Meligrana <68152031+Tortar@users.noreply.github.com> Date: Sat, 14 Sep 2024 00:30:30 +0200 Subject: [PATCH] Remove algorithm L for single sample (#83) --- Project.toml | 2 +- src/UnweightedSamplingSingle.jl | 50 +++++------------------- src/WeightedSamplingSingle.jl | 39 ++---------------- src/precompile.jl | 4 -- test/benchmark/benchmark_tests.jl | 2 + test/unweighted_sampling_single_tests.jl | 2 +- test/weighted_sampling_single_tests.jl | 2 +- 7 files changed, 18 insertions(+), 83 deletions(-) diff --git a/Project.toml b/Project.toml index 225de10..f91c937 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StreamSampling" uuid = "ff63dad9-3335-55d8-95ec-f8139d39e468" -version = "0.3.7" +version = "0.3.8" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index 1f56434..1d9b68d 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -1,42 +1,26 @@ -mutable struct SampleSingleAlgL{T,R} <: AbstractReservoirSampleSingle - state::Float64 - seen_k::Int - skip_k::Int - const rng::R - value::T - SampleSingleAlgL{T,R}(state, seen_k, skip_k, rng) where {T,R} = new{T,R}(state, seen_k, skip_k, rng) -end - mutable struct SampleSingleAlgR{T,R} <: AbstractReservoirSampleSingle seen_k::Int skip_k::Int const rng::R value::T - SampleSingleAlgR{T,R}(seen_k, rng, value) where {T,R} = new{T,R}(seen_k, rng, value) - SampleSingleAlgR{T,R}(seen_k, rng) where {T,R} = new{T,R}(seen_k, rng) + SampleSingleAlgR{T,R}(rng) where {T,R} = new{T,R}(0, 0, rng) + SampleSingleAlgR{T,R}(seen_k, skip_k, rng, value) where {T,R} = new{T,R}(seen_k, skip_k, rng, value) end -function value(s::SampleSingleAlgL) - s.state === 1.0 && return nothing - return s.value -end function value(s::SampleSingleAlgR) s.seen_k === 0 && return nothing return s.value end -function ReservoirSample(T, method::ReservoirAlgorithm = algR) - return ReservoirSample(Random.default_rng(), T, method, ms) -end -function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = algR) - return ReservoirSample(rng, T, method, ms) +function ReservoirSample(T, method::ReservoirAlgorithm = AlgR()) + return ReservoirSample(Random.default_rng(), T, method, MutSample()) end -function ReservoirSample(rng::AbstractRNG, T, ::AlgL, ::MutSample) - return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, 0, rng) +function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = AlgR()) + return ReservoirSample(rng, T, method, MutSample()) end function ReservoirSample(rng::AbstractRNG, T, ::AlgR, ::MutSample) - return SampleSingleAlgR{T, typeof(rng)}(0, 0, rng) + return SampleSingleAlgR{T, typeof(rng)}(rng) end @inline function update!(s::SampleSingleAlgR, el) @@ -47,24 +31,7 @@ end end return s end -@inline function update!(s::SampleSingleAlgL, el) - s.seen_k += 1 - if s.skip_k > 0 - s.skip_k -= 1 - else - s.value = el - s.state *= rand(s.rng) - s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state)) - end - return s -end -function reset!(s::SampleSingleAlgL) - s.state = 1.0 - s.seen_k = 0 - s.skip_k = 0 - return s -end function reset!(s::SampleSingleAlgR) s.seen_k = 0 s.skip_k = 0 @@ -75,7 +42,7 @@ function Base.merge(s1::AbstractReservoirSampleSingle, s2::AbstractReservoirSamp n1, n2 = n_seen(s1), n_seen(s2) n_tot = n1 + n2 value = rand(s1.rng) < n1/n_tot ? s1.value : s2.value - return SampleSingleAlgR{typeof(value), typeof(s1.rng)}(n_tot, s1.rng, value) + return SampleSingleAlgR{typeof(value), typeof(s1.rng)}(n_tot, s1.skip_k + s2.skip_k, s1.rng, value) end function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle) @@ -87,6 +54,7 @@ function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle) s1.value = s2.value end s1.seen_k = n_tot + s1.skip_k += s2.skip_k return s1 end diff --git a/src/WeightedSamplingSingle.jl b/src/WeightedSamplingSingle.jl index d06f538..4eb3b4a 100644 --- a/src/WeightedSamplingSingle.jl +++ b/src/WeightedSamplingSingle.jl @@ -5,19 +5,6 @@ mutable struct RefVal{T} RefVal(value::T) where T = new{T}(value) end -struct ImmutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle - state::Float64 - rng::R - rvalue::RefVal{T} -end -mutable struct MutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle - state::Float64 - const rng::R - value::T - MutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng) -end -const SampleSingleAlgARes = Union{ImmutSampleSingleAlgARes, MutSampleSingleAlgARes} - struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle state::Float64 skip_w::Float64 @@ -33,12 +20,6 @@ mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSi end const SampleSingleAlgAExpJ = Union{ImmutSampleSingleAlgAExpJ, MutSampleSingleAlgAExpJ} -function ReservoirSample(rng::R, T, ::AlgARes, ::MutSample) where {R<:AbstractRNG} - return MutSampleSingleAlgARes{T,R}(typemax(Float64), rng) -end -function ReservoirSample(rng::R, T, ::AlgARes, ::ImmutSample) where {R<:AbstractRNG} - return ImmutSampleSingleAlgARes(typemax(Float64), rng, RefVal{T}()) -end function ReservoirSample(rng::R, T, ::AlgAExpJ, ::MutSample) where {R<:AbstractRNG} return MutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng) end @@ -51,14 +32,6 @@ function value(s::AbstractWeightedReservoirSampleSingle) return get_val(s) end -@inline function update!(s::SampleSingleAlgARes, el, w) - priority = randexp(s.rng)/w - if priority < s.state - @imm_reset s.state = priority - s = set_val(s, el) - end - return s -end @inline function update!(s::SampleSingleAlgAExpJ, el, weight) @imm_reset s.state += weight if s.skip_w <= s.state @@ -68,23 +41,19 @@ end return s end -function reset!(s::MutSampleSingleAlgARes) - s.state = typemax(Float64) - return s -end function reset!(s::MutSampleSingleAlgAExpJ) s.state = 0.0 s.skip_w = 0.0 return s end -get_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}) = s.rvalue.value -function set_val(s::Union{ImmutSampleSingleAlgARes, ImmutSampleSingleAlgAExpJ}, el) +get_val(s::ImmutSampleSingleAlgAExpJ) = s.rvalue.value +function set_val(s::ImmutSampleSingleAlgAExpJ, el) @reset s.rvalue.value = el return s end -get_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}) = s.value -function set_val(s::Union{MutSampleSingleAlgARes, MutSampleSingleAlgAExpJ}, el) +get_val(s::MutSampleSingleAlgAExpJ) = s.value +function set_val(s::MutSampleSingleAlgAExpJ, el) s.value = el return s end diff --git a/src/precompile.jl b/src/precompile.jl index 6b8c96c..bf1ac67 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -9,10 +9,6 @@ using PrecompileTools @compile_workload let rs = ReservoirSample(Int, algR) update_s_no_weights!(rs, iter) - rs = ReservoirSample(Int, algL) - update_s_no_weights!(rs, iter) - rs = ReservoirSample(Int, algARes) - update_s!(rs, iter) rs = ReservoirSample(Int, algAExpJ) update_s!(rs, iter) rs = ReservoirSample(Int, 2, algR) diff --git a/test/benchmark/benchmark_tests.jl b/test/benchmark/benchmark_tests.jl index e852310..6523ba7 100644 --- a/test/benchmark/benchmark_tests.jl +++ b/test/benchmark/benchmark_tests.jl @@ -4,6 +4,7 @@ wv(el) = 1.0 for m in (algR, algL, algRSWRSKIP) for size in (nothing, 10) + size == nothing && m === algL && continue size == nothing && m === algRSWRSKIP && continue s = size == nothing ? () : (size,) b = @benchmark itsample($rng, $iter, $s..., $m) evals=1 @@ -15,6 +16,7 @@ end for m in (algARes, algAExpJ, algWRSWRSKIP) for size in (nothing, 10) + size == nothing && m === algARes && continue size == nothing && m === algWRSWRSKIP && continue s = size == nothing ? () : (size,) b = @benchmark itsample($rng, $iter, $wv, $s..., $m) evals=1 diff --git a/test/unweighted_sampling_single_tests.jl b/test/unweighted_sampling_single_tests.jl index b6c5aee..afb63cc 100644 --- a/test/unweighted_sampling_single_tests.jl +++ b/test/unweighted_sampling_single_tests.jl @@ -1,6 +1,6 @@ @testset "Unweighted sampling single tests" begin - @testset "method=$method" for method in (algL, algR) + @testset "method=$method" for method in (algR,) a, b = 1, 100 z = itsample(a:b, method) @test a <= z <= b diff --git a/test/weighted_sampling_single_tests.jl b/test/weighted_sampling_single_tests.jl index 36be714..3d42243 100644 --- a/test/weighted_sampling_single_tests.jl +++ b/test/weighted_sampling_single_tests.jl @@ -1,6 +1,6 @@ @testset "Weighted sampling single tests" begin - @testset "method=$method" for method in (algARes, algAExpJ) + @testset "method=$method" for method in (algAExpJ,) wv(el) = 1.0 a, b = 1, 100 z = itsample(a:b, wv, method)