Skip to content

Commit 01e6140

Browse files
committed
tests
1 parent 31372ea commit 01e6140

2 files changed

Lines changed: 31 additions & 22 deletions

File tree

src/stats/stats.jl

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ end
1515

1616
#-----------------------------------------------------------------------# Variance
1717
"""
18-
Variance(; weight=EqualWeight())
18+
Variance(T = Float64; weight=EqualWeight())
1919
20-
Univariate variance.
20+
Univariate variance, tracked as type `T`.
2121
2222
# Example
2323
@@ -26,19 +26,21 @@ Univariate variance.
2626
var(o)
2727
std(o)
2828
"""
29-
mutable struct Variance{W} <: OnlineStat{Number}
30-
σ2::Float64
31-
μ::Float64
29+
mutable struct Variance{T, W} <: OnlineStat{Number}
30+
σ2::T
31+
μ::T
3232
weight::W
3333
n::Int
3434
end
35-
Variance(;weight = EqualWeight()) = Variance(0.0, 0.0, weight, 0)
35+
function Variance(T::Type{<:Number} = Float64; weight = EqualWeight())
36+
Variance(zero(T), zero(T), weight, 0)
37+
end
3638
Base.copy(o::Variance) = Variance(o.σ2, o.μ, o.weight, o.n)
37-
function _fit!(o::Variance, x)
39+
function _fit!(o::Variance{T}, x) where {T}
3840
μ = o.μ
39-
γ = o.weight(o.n += 1)
40-
o.μ = smooth(o.μ, x, γ)
41-
o.σ2 = smooth(o.σ2, (x - o.μ) * (x - μ), γ)
41+
γ = T(o.weight(o.n += 1))
42+
o.μ = smooth(o.μ, T(x), γ)
43+
o.σ2 = smooth(o.σ2, (T(x) - o.μ) * (T(x) - μ), γ)
4244
end
4345
function _merge!(o::Variance, o2::Variance)
4446
γ = o2.n / (o.n += o2.n)
@@ -713,13 +715,9 @@ end
713715

714716
#-----------------------------------------------------------------------# Mean
715717
"""
716-
Mean(; weight=EqualWeight())
717-
718-
Track a univariate mean.
719-
720-
# Update
718+
Mean(T = Float64; weight=EqualWeight())
721719
722-
``μ = (1 - γ) * μ + γ * x``
720+
Track a univariate mean, stored as type `T`.
723721
724722
# Example
725723
@@ -730,7 +728,7 @@ mutable struct Mean{T,W} <: OnlineStat{Number}
730728
weight::W
731729
n::Int
732730
end
733-
Mean(T::Type = Float64; weight = EqualWeight()) = Mean(zero(T), weight, 0)
731+
Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight, 0)
734732
_fit!(o::Mean{T}, x) where {T} = (o.μ = smooth(o.μ, x, T(o.weight(o.n += 1))))
735733
function _merge!(o::Mean, o2::Mean)
736734
o.n += o2.n

test/runtests.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
module OnlineStatsTests
2-
31
using OnlineStats, Test, Statistics, Random, LinearAlgebra, Dates
42
O = OnlineStats
5-
import StatsBase: countmap, fit, Histogram
3+
import StatsBase: countmap, fit, Histogram, sample
64
import DataStructures: OrderedDict, SortedDict
75

86
const y = randn(1000)
@@ -49,6 +47,7 @@ function test_merge(o, y1, y2, compare = ≈; kw...)
4947
@test nobs(o) == nobs(o2) == nrows(y1) + nrows(y2)
5048
end
5149

50+
# @test compare(fo(fit!(o,y)), fy(y))
5251
function test_exact(o, y, fo, fy::Function, compare = ; kw...)
5352
fit!(o, y)
5453
for (v1, v2) in zip(fo(o), fy(y))
@@ -58,6 +57,8 @@ function test_exact(o, y, fo, fy::Function, compare = ≈; kw...)
5857
end
5958
@test nobs(o) == nrows(y)
6059
end
60+
61+
# @test compare(fo(fit!(o,y )), fy)
6162
function test_exact(o, y, fo, fy, compare = ; kw...)
6263
fit!(o, y)
6364
for (v1, v2) in zip(fo(o), fy)
@@ -73,6 +74,17 @@ nrows(m::AbstractMatrix) = size(m, 1)
7374
nrows(t::Tuple) = length(t[2])
7475
nrows(y::Base.Iterators.Zip2) = length(y)
7576

77+
function testfit(o::OnlineStat, y, val, compare = )
78+
@test nobs(o) == nobs(o)
79+
@test compare(value(fit!(o, y)), val)
80+
end
81+
82+
function testmerge(o::OnlineStat, y, compare = ; n=10)
83+
for i in 1:n
84+
a = sample(y, floor(Int, length(y)/2); replace=false)
85+
b = sample(y, floor(Int, length(y)/2); replace=false)
86+
end
87+
end
7688

7789
#-----------------------------------------------------------------------# utils
7890
println("\n\n")
@@ -395,6 +407,7 @@ end
395407
@inferred Mean()
396408
@inferred Mean(Complex{Float64})
397409
test_exact(Mean(), y, mean, mean)
410+
test_exact(Mean(BigFloat), big.(y), mean, mean, , atol=1e-16)
398411
test_exact(Mean(Complex{Float64}), y + y2*im, mean, mean)
399412
test_merge(Mean(), y, y2)
400413
end
@@ -609,5 +622,3 @@ end
609622
end
610623

611624
include("test_kahan.jl")
612-
613-
end #module

0 commit comments

Comments
 (0)