Skip to content

Commit 510c0e2

Browse files
committed
move some stuff to OnlineStatsBase
1 parent e751293 commit 510c0e2

4 files changed

Lines changed: 9 additions & 233 deletions

File tree

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
julia 0.7
2-
OnlineStatsBase 0.9 0.10
2+
OnlineStatsBase 0.10 0.11
33
LearnBase
44
StatsBase
55
DataStructures

src/OnlineStats.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ module OnlineStats
33
using RecipesBase, Reexport, Statistics, LinearAlgebra, Dates
44
@reexport using OnlineStatsBase, LossFunctions, PenaltyFunctions, LearnBase
55

6-
import OnlineStatsBase: OnlineStat, name, _fit!, _merge!, eachrow, Mean, Variance,
7-
smooth, smooth!, smooth_syr!, bessel
6+
import OnlineStatsBase: OnlineStat, name, _fit!, _merge!, eachrow, smooth, smooth!,
7+
smooth_syr!, bessel, StatCollection, Mean, Variance, Series, FTSeries
88
import LearnBase: fit!, nobs, value, predict, transform, transform!
99
import StatsBase: autocov, autocor, confint, skewness, kurtosis, entropy, midpoints,
1010
fweights, sample, coef, Histogram

src/stats/stats.jl

Lines changed: 0 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,3 @@
1-
#-----------------------------------------------------# StatCollection (Series and Group)
2-
abstract type StatCollection{T} <: OnlineStat{T} end
3-
4-
function Base.show(io::IO, o::StatCollection)
5-
print(io, name(o, false, false))
6-
print_stat_tree(io, o.stats)
7-
end
8-
9-
function print_stat_tree(io::IO, stats)
10-
for (i, stat) in enumerate(stats)
11-
char = i == length(stats) ? '' : ''
12-
print(io, "\n $(char)── $stat")
13-
end
14-
end
15-
161
#-----------------------------------------------------------------------# AutoCov and Lag
172
# Lag
183
"""
@@ -206,12 +191,6 @@ nkeys(o::CountMap) = length(o.value)
206191
Base.values(o::CountMap) = values(o.value)
207192
Base.getindex(o::CountMap, i) = o.value[i]
208193

209-
#-----------------------------------------------------------------------# Counter
210-
mutable struct Counter <: OnlineStat{Any}
211-
n::Int
212-
end
213-
_fit!(o::Counter) = (o.n += 1)
214-
215194
#-----------------------------------------------------------------------# CovMatrix
216195
"""
217196
CovMatrix(p=0; weight=EqualWeight())
@@ -307,109 +286,6 @@ end
307286
Base.last(o::Diff) = o.lastval
308287
Base.diff(o::Diff) = o.diff
309288

310-
311-
#-----------------------------------------------------------------------# Extrema
312-
"""
313-
Extrema(T::Type = Float64)
314-
315-
Maximum and minimum.
316-
317-
# Example
318-
319-
o = fit!(Extrema(), rand(10^5))
320-
extrema(o)
321-
maximum(o)
322-
minimum(o)
323-
"""
324-
# T is type to store data, S is type of single observation.
325-
# E.g. you may want to accept any Number even if you are storing values as Float64
326-
mutable struct Extrema{T,S} <: OnlineStat{S}
327-
min::T
328-
max::T
329-
n::Int
330-
function Extrema(T::Type = Float64)
331-
a, b, S = extrema_init(T)
332-
new{T,S}(a, b, 0)
333-
end
334-
end
335-
extrema_init(T::Type{<:Number}) = typemax(T), typemin(T), Number
336-
extrema_init(T::Type{String}) = "", "", String
337-
extrema_init(T::Type{Date}) = typemax(Date), typemin(Date), Date
338-
extrema_init(T::Type) = rand(T), rand(T), T
339-
function _fit!(o::Extrema, y)
340-
(o.n += 1) == 1 && (o.min = o.max = y)
341-
o.min = min(o.min, y)
342-
o.max = max(o.max, y)
343-
end
344-
function _merge!(o::Extrema, o2::Extrema)
345-
o.min = min(o.min, o2.min)
346-
o.max = max(o.max, o2.max)
347-
o.n += o2.n
348-
o
349-
end
350-
value(o::Extrema) = (o.min, o.max)
351-
Base.extrema(o::Extrema) = value(o)
352-
Base.maximum(o::Extrema) = o.max
353-
Base.minimum(o::Extrema) = o.min
354-
355-
#-----------------------------------------------------------------------# FTSeries
356-
"""
357-
FTSeries(stats...; filter=x->true, transform=identity)
358-
359-
Track multiple stats for one data stream that is filtered and transformed before being
360-
fitted.
361-
362-
FTSeries(T, stats...; filter, transform)
363-
364-
Create an FTSeries and specify the type `T` of the transformed values.
365-
366-
# Example
367-
368-
o = FTSeries(Mean(), Variance(); transform=abs)
369-
fit!(o, -rand(1000))
370-
371-
# Remove missing values represented as DataValues
372-
using DataValues
373-
y = DataValueArray(randn(100), rand(Bool, 100))
374-
o = FTSeries(DataValue, Mean(); transform=get, filter=!isna)
375-
fit!(o, y)
376-
"""
377-
mutable struct FTSeries{N, OS<:Tup, F, T} <: StatCollection{Union{N,Missing}}
378-
stats::OS
379-
filter::F
380-
transform::T
381-
nfiltered::Int
382-
end
383-
function FTSeries(stats::OnlineStat...; filter=x->true, transform=identity)
384-
Ts = input.(stats)
385-
FTSeries{Union{Ts...}, typeof(stats), typeof(filter), typeof(transform)}(
386-
stats, filter, transform, 0
387-
)
388-
end
389-
function FTSeries(T::Type, stats::OnlineStat...; filter=x->true, transform=identity)
390-
FTSeries{T, typeof(stats), typeof(filter), typeof(transform)}(stats, filter, transform, 0)
391-
end
392-
value(o::FTSeries) = value.(o.stats)
393-
nobs(o::FTSeries) = nobs(o.stats[1])
394-
@generated function _fit!(o::FTSeries{N, OS}, y) where {N, OS}
395-
n = length(fieldnames(OS))
396-
quote
397-
if o.filter(y)
398-
yt = o.transform(y)
399-
Base.Cartesian.@nexprs $n i -> @inbounds begin
400-
_fit!(o.stats[i], yt)
401-
end
402-
else
403-
o.nfiltered += 1
404-
end
405-
end
406-
end
407-
function _merge!(o::FTSeries, o2::FTSeries)
408-
o.nfiltered += o2.nfiltered
409-
_merge!.(o.stats, o2.stats)
410-
o
411-
end
412-
413289
#-----------------------------------------------------------------------# Group
414290
"""
415291
Group(stats::OnlineStat...)
@@ -1118,57 +994,6 @@ function _merge!(o::T, o2::T) where {T<:ReservoirSample}
1118994
end
1119995
end
1120996

1121-
#-----------------------------------------------------------------------# Series
1122-
"""
1123-
Series(stats)
1124-
Series(stats...)
1125-
Series(; stats...)
1126-
1127-
Track a collection stats for one data stream.
1128-
1129-
# Example
1130-
1131-
s = Series(Mean(), Variance())
1132-
fit!(s, randn(1000))
1133-
"""
1134-
struct Series{IN, T} <: StatCollection{IN}
1135-
stats::T
1136-
function Series(stats::T) where T
1137-
IN = Union{map(input, stats)...}
1138-
new{IN, T}(stats)
1139-
end
1140-
end
1141-
Series(t::OnlineStat...) = Series(t)
1142-
Series(; t...) = Series(t.data)
1143-
1144-
value(o::Series) = map(value, o.stats)
1145-
nobs(o::Series) = nobs(o.stats[1])
1146-
@generated function _fit!(o::Series{IN, T}, y) where {IN, T}
1147-
n = length(fieldnames(T))
1148-
:(Base.Cartesian.@nexprs $n i -> _fit!(o.stats[i], y))
1149-
end
1150-
_merge!(o::Series, o2::Series) = map(_merge!, o.stats, o2.stats)
1151-
1152-
#-----------------------------------------------------------------------# Sum
1153-
"""
1154-
Sum(T::Type = Float64)
1155-
1156-
Track the overall sum.
1157-
1158-
# Example
1159-
1160-
fit!(Sum(Int), fill(1, 100))
1161-
"""
1162-
mutable struct Sum{T} <: OnlineStat{Number}
1163-
sum::T
1164-
n::Int
1165-
end
1166-
Sum(T::Type = Float64) = Sum(T(0), 0)
1167-
Base.sum(o::Sum) = o.sum
1168-
_fit!(o::Sum{T}, x::Real) where {T<:AbstractFloat} = (o.sum += convert(T, x); o.n += 1)
1169-
_fit!(o::Sum{T}, x::Real) where {T<:Integer} = (o.sum += round(T, x); o.n += 1)
1170-
_merge!(o::T, o2::T) where {T <: Sum} = (o.sum += o2.sum; o.n += o2.n; o)
1171-
1172997
# #-----------------------------------------------------------------------# Summarizer
1173998
# mutable struct Summarizer{T} <: OnlineStat{T}
1174999
# group::Group

test/runtests.jl

Lines changed: 6 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -101,26 +101,6 @@ end
101101
@test diff(o) == 1
102102
@test last(o) == 10
103103
end
104-
#-----------------------------------------------------------------------# Extrema
105-
@testset "Extrema" begin
106-
o = fit!(Extrema(), y)
107-
@test extrema(o) == extrema(y)
108-
@test minimum(o) == minimum(y)
109-
@test maximum(o) == maximum(y)
110-
111-
@test value(fit!(Extrema(Bool), x)) == extrema(x)
112-
@test value(fit!(Extrema(Int), z)) == extrema(z)
113-
114-
@test ==(mergevals(Extrema(), y, y2)...)
115-
116-
o = fit!(Extrema(Date), Date(2010):Day(1):Date(2011))
117-
@test minimum(o) == Date(2010)
118-
@test maximum(o) == Date(2011)
119-
120-
@test value(fit!(Extrema(Char), 'a':'z')) == ('a', 'z')
121-
@test value(fit!(Extrema(Char), "abc")) == ('a', 'c')
122-
@test value(fit!(Extrema(String), ["a", "b"])) == ("a", "b")
123-
end
124104
#-----------------------------------------------------------------------# Fit[Dist]
125105
@testset "Fit[Dist]" begin
126106
@testset "FitBeta" begin
@@ -226,16 +206,6 @@ end
226206
@test classify(o, randn(10)) in 1:2
227207
@test mean(classify(o, X) .== Y) > .5
228208
end
229-
#-----------------------------------------------------------------------# FTSeries
230-
@testset "FTSeries" begin
231-
o = fit!(FTSeries(Mean(); transform=abs), y)
232-
@test value(o)[1] mean(abs, y)
233-
234-
data = vcat(y, fill(missing, 20))
235-
o = fit!(FTSeries(Mean(); transform=abs, filter=!ismissing), data)
236-
@test value(o)[1] mean(abs, y)
237-
@test o.nfiltered == 20
238-
end
239209
#-----------------------------------------------------------------------# Group
240210
@testset "Group" begin
241211
o = fit!(5Mean(), OnlineStatsBase.eachrow(ymat))
@@ -266,7 +236,7 @@ end
266236
end
267237
end
268238
#-----------------------------------------------------------------------# HeatMap
269-
@testset "HeatMap" begin
239+
@testset "HeatMap" begin
270240
data1 = OnlineStatsBase.eachrow(ymat[:, 1:2])
271241
data2 = OnlineStatsBase.eachrow(ymat2[:, 1:2])
272242
@test ==(mergevals(HeatMap(-5:.1:5, -5:.1:5), data1, data2)...)
@@ -283,7 +253,7 @@ end
283253
@test fit!(Hist(edges, Number; closed=false, left=false), data).counts == w2
284254
end
285255
end
286-
end
256+
end
287257
o = fit!(Hist(-5:.1:5), y)
288258
for (v1, v2) in zip(extrema(o), extrema(y))
289259
@test (v1, v2; atol=.1)
@@ -370,7 +340,7 @@ end
370340
(mergevals(LinReg(), OnlineStatsBase.eachrow(ymat, y), OnlineStatsBase.eachrow(ymat2, y2))...)
371341

372342
o = fit!(LinReg(), (ymat, y))
373-
@test coef(o) ymat \ y
343+
@test coef(o) ymat \ y
374344
@test coef(o, .1) (ymat'ymat ./ n + .1I) \ ymat'y ./ n
375345
@test coef(o, .1:.1:.5) (ymat'ymat ./ n + Diagonal(.1:.1:.5)) \ ymat'y ./ n
376346
@test predict(o, ymat) == ymat * o.β
@@ -511,16 +481,6 @@ end
511481
@test (yi y) || (yi y2)
512482
end
513483
end
514-
#-----------------------------------------------------------------------# Series
515-
@testset "Series" begin
516-
a, b = mergevals(Series(Mean(), Variance()), y, y2)
517-
@test a[1] b[1]
518-
@test a[2] b[2]
519-
520-
a, b = mergevals(Series(m=Mean(), v=Variance()), y, y2)
521-
@test a.m b.m
522-
@test a.v b.v
523-
end
524484
#-----------------------------------------------------------------------# StatHistory
525485
@testset "StatHistory" begin
526486
o = fit!(StatHistory(Mean(), 10), 1:20)
@@ -560,23 +520,14 @@ end
560520
println()
561521
end
562522
end
563-
#-----------------------------------------------------------------------# Sum
564-
@testset "Sum" begin
565-
@test value(fit!(Sum(Int), x)) == sum(x)
566-
@test value(fit!(Sum(), y)) sum(y)
567-
@test value(fit!(Sum(Int), z)) == sum(z)
568-
569-
@test ==(mergevals(Sum(Int), x, x2)...)
570-
@test (mergevals(Sum(), y, y2)...)
571-
@test ==(mergevals(Sum(Int), z, z2)...)
572-
end
523+
#-----------------------------------------------------------------------# Kahan
573524

574525
include("test_kahan.jl")
575526

576527
#-----------------------------------------------------------------------# Show methods
577528
@testset "Show methods" begin
578-
for stat in [BiasVec([1,2,3]), Bootstrap(Mean()), CallFun(Mean(), println), FastNode(5),
579-
FastTree(5), FastForest(5), FTSeries(Variance()), Group(Mean(), Mean()),
529+
for stat in [BiasVec([1,2,3]), Bootstrap(Mean()), CallFun(Mean(), println), FastNode(5),
530+
FastTree(5), FastForest(5), FTSeries(Variance()), Group(Mean(), Mean()),
580531
HyperLogLog{10}(), LinRegBuilder(4), NBClassifier(5, Float64), ProbMap(Int),
581532
P2Quantile(.5), Series(Mean()), StatLearn(5)]
582533
println(" > ", stat)

0 commit comments

Comments
 (0)