Skip to content

Commit d9cdda2

Browse files
committed
compliance with new preprint
1 parent 2a8ce4f commit d9cdda2

3 files changed

Lines changed: 126 additions & 65 deletions

File tree

examples/CO24/functions_classification.jl

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function classiferrortable(
3939
tvec;
4040
multi_thread::Bool=false,
4141
methods::Vector{Symbol}=[:ag, :sp],
42-
clusterings::Vector{Symbol}=[:kmeans, :hccomp, :hcsing, :hcavg, :hcward],
42+
clusterings::Vector{Symbol}=[:kmeans, :threshold, :single, :average, :complete, :ward],
4343
)::DataFrame
4444
## Set default
4545
N, r₊, β, λ, p, Nsimu = default_values
@@ -142,10 +142,11 @@ function metadatacomplete!(
142142
methods_fullstr = Dict("ag" => "aggregated", "sp" => "spectral")
143143
clusterings_fullstr = Dict(
144144
"kmeans" => "kmeans",
145-
"hccomp" => "hierarchical (complete)",
146-
"hcsing" => "hierarchical (single)",
147-
"hcavg" => "hierarchical (average)",
148-
"hcward" => "hierarchical (ward)",
145+
"threshold" => "mean threshold",
146+
"single" => "hierarchical (single)",
147+
"average" => "hierarchical (average)",
148+
"complete" => "hierarchical (complete)",
149+
"ward" => "hierarchical (ward)",
149150
)
150151
for (i, col_str) in enumerate(names(df))
151152
col_str in ("parameter", "T") && continue
@@ -172,25 +173,27 @@ function misclassificationrates(
172173
data::DiscreteTimeData,
173174
excitatory::Vector{Bool};
174175
methods::Vector{Symbol}=[:ag, :sp],
175-
clusterings::Vector{Symbol}=[:kmeans, :hccomp, :hcsing, :hcavg, :hcward],
176+
clusterings::Vector{Symbol}=[:kmeans, :threshold, :single, :average, :complete, :ward],
176177
)::Vector{Float64}
177178
output = Float64[]
178179

179180
σ̂ag = MeanFieldGraph.covariance_vector(data)
180181

181182
for method in methods
182-
if method == :ag
183-
# aggregated classifications
183+
if method == :ag # aggregated estimation
184184
σ̂ = σ̂ag
185-
else
186-
# spectral classifications
185+
else # spectral estimation
187186
Σ̂ = MeanFieldGraph.covariance_matrix(data)
188-
_, vecs = eigsolve(transpose(Σ̂) * Σ̂)
189-
σ̂ = vecs[1]
190-
# FIXME : how to chose the sign of σ̂sp ?
191-
if mapreduce(abs, +, σ̂ag - σ̂) > mapreduce(abs, +, σ̂ag + σ̂)
192-
σ̂ *= -1
193-
end
187+
_, vecs = eigsolve(transpose(Σ̂) * Σ̂) # faster than full SVD
188+
= vecs[1]
189+
190+
# sign disambiguation
191+
m_v̌ = mean(v̌)
192+
=.>= m_v̌
193+
σ̌₊ = sum(σ̂ag[P̌]) / sum(P̌)
194+
σ̌₋ = sum(σ̂ag[.!P̌]) / sum(.!P̌)
195+
196+
σ̂ = σ̌₊ >= σ̌₋ ?: -
194197
end
195198
for c in clusterings
196199
push!(output, misclassificationrate(σ̂, excitatory, c))
@@ -208,12 +211,12 @@ function misclassificationrate(
208211
classif = MeanFieldGraph.cluster2bool(
209212
kmeans(transpose(σ), 2; init=[argmin(σ), argmax(σ)])
210213
)
214+
elseif clustering == :threshold
215+
threshold = mean(σ)
216+
classif = σ .>= threshold
211217
else
212-
linkages = Dict(
213-
:hccomp => :complete, :hcsing => :single, :hcavg => :average, :hcward => :ward
214-
)
215218
distances = [abs(σ[i] - σ[j]) for i in eachindex(σ), j in eachindex(σ)]
216-
ct = cutree(hclust(distances; linkage=linkages[clustering]); k=2)
219+
ct = cutree(hclust(distances; linkage=clustering); k=2)
217220
id_excitatory = ct[argmax(σ)]
218221
classif = ct .== id_excitatory
219222
end
@@ -377,13 +380,7 @@ function simulationandsave(
377380
tvec,
378381
)
379382
df = classiferrortable(
380-
Paramsymbol,
381-
Paramvec,
382-
default_values,
383-
tvec;
384-
multi_thread=true,
385-
methods=[:ag],
386-
clusterings=[:kmeans],
383+
Paramsymbol, Paramvec, default_values, tvec; multi_thread=true, methods=[:ag]
387384
)
388385
paramstring = string(Paramsymbol)
389386

examples/CO24/table.jl

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,43 @@
11
include("functions_classification.jl")
22
using Latexify
33

4+
# Helper functions
5+
"""
6+
non_optimal_couples(df, tup, metric; aside_tuples=Tuple{Symbol,Symbol}[])
7+
8+
Give the couples (N,T) for which the method and clustering combination given in `tup` does not have the optimal value of the metric given in `metric` (either :er or :mr), excluding the combinations given in `aside_tuples`.
9+
"""
10+
function non_optimal_couples(
11+
df::DataFrame,
12+
tup::Tuple{Symbol,Symbol},
13+
metric::Symbol;
14+
aside_tuples::Vector{Tuple{Symbol,Symbol}}=Tuple{Symbol,Symbol}[],
15+
)
16+
df_copy = deepcopy(df) # to avoid modifying the original dataframe with @rsubset!
17+
f = metric == :er ? identity : (x -> -x) # to maximize ER and minimize MR
18+
19+
for a_tup in aside_tuples
20+
@rsubset!(df_copy, !((:method, :clustering) == string.(a_tup))) # exclude the aside methods and clusterings to compare to
21+
end
22+
23+
output = @chain df_copy begin
24+
transform(Symbol(metric, :_mean) => f => :target_metric) # create a column with the target metric to maximize (either ER or minus MR)
25+
@aside target = @rsubset(_, (:method, :clustering) == string.(tup))
26+
@groupby([:parameter, :T])
27+
@combine(:optimal_target_metric = maximum(:target_metric))
28+
_[target[!, :target_metric] .< _[!, :optimal_target_metric], [:parameter, :T]]
29+
@rename(:N = :parameter)
30+
end
31+
32+
return output
33+
end
34+
35+
round_percent(x) = round(Int, 100 * x)
36+
437
## Load table
538
df_wide = estimatorsload("data/CO24/data_for_color_plot")
639
df_mean_bands = mmr_per(df_wide)
740

8-
### helper function
9-
round_percent(x) = round(Int, 100 * x)
10-
1141
## Misclassification rate - table
1242
df_mr = @chain df_mean_bands begin
1343
# value of interest is MMR ± standard error
@@ -19,14 +49,14 @@ df_mr = @chain df_mean_bands begin
1949
:mr = string(round_percent(:mr_mean)) * " ± " * string(round_percent(:mr_std)),
2050
)
2151
# create a column with the combination of clustering and method to be able to unstack both at the same time
22-
@rtransform!(:col_name = string(:clustering, "_", :method))
52+
@rtransform!(:col_name = string(:method, "_", :clustering))
2353
# unstack the table to have one column per method and clustering combination
2454
unstack([:N, :T], :col_name, :mr)
2555
end
2656

2757
# couples are selected so that the lowest MMR (usually ag_kmeans) is around 0.02
2858
function isselected(n, t)
29-
return (n, t) in [(34, 2641), (94, 7903), (142, 13165), (190, 21058), (250, 31582)]
59+
return (n, t) in [(34, 2641), (94, 7903), (142, 10534), (190, 13165), (250, 18427)]
3060
end
3161
latexify(@rsubset(df_mr, isselected(:N, :T)))
3262

@@ -37,20 +67,31 @@ latexify(@rsubset(df_mr, isselected(:N, :T)))
3767
end
3868

3969
## Misclassification rate - best method ?
40-
### Most of the time, the lowest Mean Misclassification Rate (MMR) is achieved for ag_kmeans
41-
### 25 couples (N, T) out of 420 where the lowest MMR is not ag_kmeans
42-
@chain df_mean_bands begin
43-
@groupby([:parameter, :T])
44-
@combine(:id_lowest_mmr = argmin(:mr_mean))
45-
@rsubset(:id_lowest_mmr != 1)
46-
end
70+
### Most of the time, ag_threshold is amongst the lowest Mean Misclassification Rate (MMR)
71+
### 25 couples (N, T) out of 420 where ag_threshold is not amongst the lowest MMR
72+
size(non_optimal_couples(df_mean_bands, (:ag, :threshold), :mr))[1]
73+
#### If we remove threshold clustering,
74+
#### 27 couples where ag_kmeans is not amongst the lowest MMR
75+
size(
76+
non_optimal_couples(
77+
df_mean_bands,
78+
(:ag, :kmeans),
79+
:mr;
80+
aside_tuples=[(:ag, :threshold), (:sp, :threshold)],
81+
),
82+
)[1]
4783

48-
### When MMR is rounded up to 2 decimals, only 11 couples (N, T) out of 420
49-
@chain df_mean_bands begin
50-
@groupby([:parameter, :T])
51-
@combine(:id_lowest_mmr = argmin(round.(100 * :mr_mean)))
52-
@rsubset(:id_lowest_mmr != 1)
53-
end
84+
### When MMR is rounded up to 2 decimals,
85+
df_rounded = @rtransform(df_mean_bands, :mr_mean = round_percent(:mr_mean))
86+
### only 13 couples where ag_threshold is not amongst the lowest MMR
87+
size(non_optimal_couples(df_rounded, (:ag, :threshold), :mr))[1]
88+
#### If we remove threshold clustering,
89+
#### 14 couples where ag_kmeans is not amongst the lowest MMR
90+
size(
91+
non_optimal_couples(
92+
df_rounded, (:ag, :kmeans), :mr; aside_tuples=[(:ag, :threshold), (:sp, :threshold)]
93+
),
94+
)[1]
5495

5596
## Exact recovery
5697
factor = quantile(Normal(), 0.975) / sqrt(metadata(df_mean_bands, "Number of simulations"))
@@ -67,7 +108,7 @@ df_er = @chain df_mean_bands begin
67108
string(round_percent(factor * :er_std)),
68109
)
69110
# create a column with the combination of clustering and method to be able to unstack both at the same time
70-
@rtransform!(:col_name = string(:clustering, "_", :method))
111+
@rtransform!(:col_name = string(:method, "_", :clustering))
71112
# unstack the table to have one column per method and clustering combination
72113
unstack([:N, :T], :col_name, :er)
73114
end
@@ -85,17 +126,28 @@ latexify(@rsubset(df_er, isselected(:N, :T)))
85126
end
86127

87128
## Exact recovery - best method ?
88-
### Most of the time, the highest Probability of Exact Recovery (PER) is achieved for ag_kmeans
89-
### 11 couples (N, T) out of 420 where the highest PER is not ag_kmeans
90-
@chain df_mean_bands begin
91-
@groupby([:parameter, :T])
92-
@combine(:id_highest_er = argmax(:er_mean))
93-
@rsubset(:id_highest_er != 1)
94-
end
129+
### Most of the time, ag_threshold is amongst the highest Probability of Exact Recovery (PER)
130+
### 7 couples (N, T) out of 420 where ag_threshold is not amongst the highest PER
131+
size(non_optimal_couples(df_mean_bands, (:ag, :threshold), :er))[1]
132+
#### If we remove threshold clustering,
133+
#### 11 couples where ag_kmeans is not amongst the highest PER
134+
size(
135+
non_optimal_couples(
136+
df_mean_bands,
137+
(:ag, :kmeans),
138+
:er;
139+
aside_tuples=[(:ag, :threshold), (:sp, :threshold)],
140+
),
141+
)[1]
95142

96-
### When PER is rounded up to 2 decimals, ag_kmeans is always among the best
97-
@chain df_mean_bands begin
98-
@groupby([:parameter, :T])
99-
@combine(:id_highest_er = argmax(round.(100 * :er_mean)))
100-
@rsubset(:id_highest_er != 1)
101-
end
143+
### When PER is rounded up to 2 decimals,
144+
df_rounded = @rtransform(df_mean_bands, :er_mean = round_percent(:er_mean))
145+
### only 1 couple where ag_threshold is not amongst the highest PER
146+
size(non_optimal_couples(df_rounded, (:ag, :threshold), :er))[1]
147+
#### If we remove threshold clustering,
148+
#### 2 couples where ag_kmeans is not amongst the highest PER
149+
size(
150+
non_optimal_couples(
151+
df_rounded, (:ag, :kmeans), :er; aside_tuples=[(:ag, :threshold), (:sp, :threshold)]
152+
),
153+
)[1]

src/classification.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Estimates the two underlying communities (one excitatory and one inhibitory) fro
55
66
# Keyword arguments
77
- `method::Symbol`: the method applied to estimate the covariance vector σ. Valid choices are: `:aggregated` (the default) and `:spectral`.
8-
- `clustering::Symbol`: the clustering method applied to the estimated covariance vector. Valid choices are: `:kmeans` (the default) and `:hclust`.
8+
- `clustering::Symbol`: the clustering method applied to the estimated covariance vector. Valid choices are: `:kmeans` (the default), `:threshold`, and the *linkage* choices for the `hclust` function (`:single`, `:average`, `:complete`, `:ward`).
99
"""
1010
function classification(
1111
data::DiscreteTimeData; method::Symbol=:aggregated, clustering::Symbol=:kmeans
@@ -15,21 +15,33 @@ function classification(
1515
if method == :aggregated
1616
σ̂ = covariance_vector(data)
1717
elseif method == :spectral
18+
# compute the leading singular vector of the covariance matrix
1819
Σ̂ = covariance_matrix(data)
1920
_, vecs = eigsolve(transpose(Σ̂) * Σ̂) # faster than full SVD
20-
σ̂ = vecs[1]
21-
# FIXME : how to chose the sign of σ̂sp ?
21+
= vecs[1]
22+
23+
# sign disambiguation
24+
σ̂_ag = sum(Σ̂; dims=1)[1, :]
25+
m_v̌ = mean(v̌)
26+
=.>= m_v̌
27+
σ̌₊ = sum(σ̂_ag[P̌]) / sum(P̌)
28+
σ̌₋ = sum(σ̂_ag[.!P̌]) / sum(.!P̌)
29+
30+
σ̂ = σ̌₊ >= σ̌₋ ?: -
2231
else
2332
throw(ArgumentError("Unsupported method $method"))
2433
end
2534

26-
# Clustering based on the estimated σ̂
35+
# Clustering based on the estimator σ̂
2736
if clustering == :kmeans
2837
initialisation = [argmin(σ̂), argmax(σ̂)]
2938
output = cluster2bool(kmeans(transpose(σ̂), 2; init=initialisation))
30-
elseif clustering == :hclust
39+
elseif clustering == :threshold
40+
threshold = mean(σ̂)
41+
output = σ̂ .>= threshold
42+
elseif clustering in (:single, :average, :complete, :ward)
3143
distances = [abs(σ̂[i] - σ̂[j]) for i in eachindex(σ̂), j in eachindex(σ̂)]
32-
ct = cutree(hclust(distances); k=2)
44+
ct = cutree(hclust(distances; linkage=clustering); k=2)
3345
id_excitatory = ct[argmax(σ̂)]
3446
output = ct .== id_excitatory
3547
else

0 commit comments

Comments
 (0)