diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 453925c..d42d738 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1 +1,4 @@ -style = "sciml" \ No newline at end of file +indent = 4 +margin = 92 +normalize_line_endings = "unix" +style = "sciml" diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..8400d25 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + open-pull-requests-limit: 99 + labels: + - "dependencies" + - "github-actions" \ No newline at end of file diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 3d4b576..4c7327d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,7 +18,7 @@ jobs: fail-fast: false matrix: version: - - '1.8' + - '1.12' # - 'nightly' os: - ubuntu-latest @@ -26,30 +26,42 @@ jobs: - x64 # - x86 steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@v3 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v3 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v5 with: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} docs: name: Documentation - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.12' + # - 'nightly' + os: + - ubuntu-latest + arch: + - x64 + # - x86 permissions: contents: write steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@v3 with: version: '1' + - uses: julia-actions/cache@v3 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-docdeploy@v1 env: diff --git a/.github/workflows/Documentation.yml b/.github/workflows/Documentation.yml deleted file mode 100644 index 4538118..0000000 --- a/.github/workflows/Documentation.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Documentation - -on: - push: - branches: - - master - tags: '*' - pull_request: - -jobs: - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@latest - with: - version: '1.8' - - name: Install dependencies - run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()' - - name: Build and deploy - env: - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key - run: julia --project=docs/ docs/make.jl diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml new file mode 100644 index 0000000..dfad667 --- /dev/null +++ b/.github/workflows/benchmark_pr.yml @@ -0,0 +1,78 @@ +name: Benchmark a pull request + +on: + pull_request_target: + branches: + - master + +permissions: + pull-requests: write + +jobs: + generate_plots: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: "1.11" + - uses: julia-actions/cache@v2 + - name: Extract Package Name from Project.toml + id: extract-package-name + run: | + PACKAGE_NAME=$(grep "^name" Project.toml | sed 's/^name = "\(.*\)"$/\1/') + echo "::set-output name=package_name::$PACKAGE_NAME" + - name: Build AirspeedVelocity + env: + JULIA_NUM_THREADS: 2 + run: | + # Lightweight build step, as sometimes the runner runs out of memory: + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.add(;url="https://github.com/MilesCranmer/AirspeedVelocity.jl.git")' + julia -e 'ENV["JULIA_PKG_PRECOMPILE_AUTO"]=0; import Pkg; Pkg.build("AirspeedVelocity")' + - name: Add ~/.julia/bin to PATH + run: | + echo "$HOME/.julia/bin" >> $GITHUB_PATH + - name: Run benchmarks + run: | + echo $PATH + ls -l ~/.julia/bin + mkdir results + benchpkg ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.repository.default_branch}}" --output-dir=results/ --tune + - name: Create plots from benchmarks + run: | + mkdir -p plots + benchpkgplot ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --npart=10 --format=png --input-dir=results/ --output-dir=plots/ + - name: Upload plot as artifact + uses: actions/upload-artifact@v4 + with: + name: plots + path: plots + - name: Create markdown table from benchmarks + run: | + benchpkgtable ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --input-dir=results/ --ratio > table.md + echo '### Benchmark Results' > body.md + echo '' >> body.md + echo '' >> body.md + cat table.md >> body.md + echo '' >> body.md + echo '' >> body.md + echo '### Benchmark Plots' >> body.md + echo 'A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.' >> body.md + echo 'Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).' >> body.md + + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fcbenchmark + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: 'github-actions[bot]' + body-includes: Benchmark Results + + - name: Comment on PR + uses: peter-evans/create-or-update-comment@v4 + with: + comment-id: ${{ steps.fcbenchmark.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body-path: body.md + edit-mode: replace \ No newline at end of file diff --git a/.gitignore b/.gitignore index 13c44d4..4b2ed35 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,34 @@ .vscode *dev/ +.CondaPkg/ + +# Files generated by invoking Julia with --code-coverage +*.jl.cov +*.jl.*.cov + +# Files generated by invoking Julia with --track-allocation +*.jl.mem + +# System-specific files and directories generated by the BinaryProvider and BinDeps packages +# They contain absolute paths specific to the host computer, and so should not be committed +deps/deps.jl +deps/build.log +deps/downloads/ +deps/usr/ +deps/src/ + +# Build artifacts for creating documentation generated by the Documenter package docs/build/ +docs/site/ docs/Manifest.toml -Manifest.toml \ No newline at end of file +# File generated by Pkg, the package manager, based on a corresponding Project.toml +# It records a fixed state of all packages used by the project. As such, it should not be +# committed for packages, but should be committed for applications that require a static +# environment. +Manifest.toml + + +## python +__pycache__/ \ No newline at end of file diff --git a/CondaPkg.toml b/CondaPkg.toml new file mode 100644 index 0000000..9446691 --- /dev/null +++ b/CondaPkg.toml @@ -0,0 +1,2 @@ +[deps] +pot = "" diff --git a/Project.toml b/Project.toml index a7719ad..91a1031 100644 --- a/Project.toml +++ b/Project.toml @@ -1,36 +1,49 @@ name = "NetworkHistogram" uuid = "7806f430-7229-459c-b2e6-df35e8e4eb5d" +version = "0.6.0" authors = ["Charles Dufour", "Jake Grainger"] -version = "0.5.2" [deps] -ArnoldiMethod = "ec485272-7323-5ecc-a04f-4719b315124d" -Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2" -HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" -JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8" -Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" +Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Graphons = "e0c12bfd-47d7-434e-afb7-632611640ca5" +Hungarian = "e91730f6-4275-51fb-a7a0-7064cfbd3b39" +IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" + +[weakdeps] +Bootstrap = "e28b5b4c-05e8-5b66-bc03-6f0c0a0a06e0" +LightMC = "b58f5c6e-c887-41d6-b553-02118416cd5d" +Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" + +[extensions] +BootstrapExt = "Bootstrap" +LightMCExt = "LightMC" +MakieExt = "Makie" +PythonOptimalTransport = "PythonCall" [compat] -ArnoldiMethod = "0.2.0" -Arpack = "0.5.4" -BenchmarkTools = "1.3.2" -CodecZstd = "0.7.2" -HTTP = "1.7.4" -JLD = "0.13.3" -Kronecker = "0.5" -ProgressMeter = "1.7.2" -StatsBase = "0.33.21" -TranscodingStreams = "0.9.11" -ValueHistories = "0.5.4" -julia = "1.8" +Accessors = "0.1.42" +ArgCheck = "2.5.0" +BenchmarkTools = "1.6.3" +Clustering = "0.15.8" +Hungarian = "0.7.0" +IntervalSets = "0.7.11" +LinearAlgebra = "1.12.0" +LogExpFunctions = "0.3.29" +Reexport = "1.2.2" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 8d7bf2e..90614ae 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,31 @@ [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://sds-epfl.github.io/NetworkHistogram.jl/stable/) [![DOI](https://zenodo.org/badge/572018079.svg)](https://zenodo.org/doi/10.5281/zenodo.10212851) - -Implementation of the network histogram for graphon estimation from the paper [Network histograms and universality of blockmodel approximation](https://doi.org/10.1073/pnas.1400374111) by Sofia C. Olhede and Patrick J. Wolfe. - +Implementation of the network histogram for graphon estimation from the paper +[Network histograms and universality of blockmodel approximation (2014)](https://doi.org/10.1073/pnas.1400374111) +by Sofia C. Olhede and Patrick J. Wolfe and its extension to decorated graphs +by Charles Dufour and Sofia C. Olhede +[Inference for decorated graphs and application to multiplex networks (2024)](https://arxiv.org/abs/2408.12339). + +The network histogram is a nonparametric estimator for the generating mechanism +of an exchangeable random graph (see graphons, decorated graphons and +probability graphons). We assume our observed graph is +$A \in \mathcal{K}^{n \times n}$, where $\mathcal{K}$ is a set of edge +decorations (e.g. $\{0,1\}$ for unweighted graphs, $\mathbb{N}$ for count +edges, $\mathbb{R}$ for real-valued edges, etc.). Using the Aldous-Hoover +theorem, we know that $A$ is generated from a graphon +$W: [0,1]^2 \to \mathcal{P}\left(\mathcal{K}\right)$, where +$\mathcal{P}\left(\mathcal{K}\right)$ is the set of probability measures on +$\mathcal{K}$ in the following way: + +1. Sample $U_1, \ldots, U_n \sim \text{iid } \text{Uniform}[0,1]$. +2. For each pair of nodes $i,j$, sample the edge $A_{ij} \sim W(U_i, U_j)$ + independently. + +The network histogram approximates the generating graphon +$W: [0,1]^2 \to \mathcal{P}\left(\mathcal{K}\right)$ by a piecewise constant +function, i.e. a stochastic block model with $k$ blocks. For details, see the +papers mentioned above. ## Installation @@ -24,7 +46,10 @@ Pkg.add("NetworkHistogram") ## Usage -We fit the estimator and then extract the estimated graphon matrix and node labels. +### Basic Usage + +We fit the estimator and then extract the estimated graphon matrix and node +labels. ```julia using NetworkHistogram, LinearAlgebra @@ -35,7 +60,7 @@ A[diagind(A)] .= 0 # approximate the graphon with a network histogram hist = graphhist(A) -# get the graphist structure +# get the graphhist structure estimate = hist.graphhist # get the estimated graphon matrix @@ -45,4 +70,105 @@ sbm_matrix = estimate.θ node_labels = estimate.node_labels ``` -You can control the optimization process by modifying the rules used in the optimization. Check out the docs for more information. +### Advanced Usage with Custom Parameters + +You can control the optimization process by modifying the rules used in the +optimization: + +```julia +using NetworkHistogram + +# Binary network +A = Symmetric(rand(0:1, 100, 100)) +A[diagind(A)] .= 0 + +# Initial partition into k groups +k = 3 +initial_labels = rand(1:k, 100) + +# Configure optimization parameters +params = GreedyParams( + 50_000, # Maximum iterations + RandomNodeSwap(), # How to select nodes to swap + Strict(), # Only accept improvements + PreviousBestValue(5000), # Stop after 5000 iterations without improvement + true # Show progress bar +) + +# Fit the network histogram +result = nethist(A, Bernoulli(0.5), initial_labels, params) + +# Extract results +ll = loglikelihood(result) +block_params = result.θ +node_groups = result.node_labels +``` + +### Working with Different Edge Types + +The package supports various edge types through custom distributions: + +```julia +using NetworkHistogram +using Distributions # For standard distributions + +# Example 1: Weighted networks with continuous edges +W = Symmetric(rand(100, 100)) +W[diagind(W)] .= 0 +# You can use any distribution that implements the required interface + +# Example 2: Count data (e.g., number of interactions) +C = Symmetric(rand(Poisson(2), 100, 100)) +C[diagind(C)] .= 0 +# Use appropriate count distribution + +# Example 3: Sparse networks with missing edges +A_sparse = Symmetric(rand([0, 1, missing], 100, 100)) +A_sparse[diagind(A_sparse)] .= 0 +# Missing values are treated as absent edges +``` + +### Visualizing Results (with Makie.jl) + +```julia +using NetworkHistogram +using CairoMakie # or GLMakie + +# Fit model +A = Symmetric(rand(0:1, 100, 100)) +A[diagind(A)] .= 0 +result = nethist(A, Bernoulli(0.5), rand(1:3, 100), GreedyParams()) + +# Create heatmap of estimated parameters +fig = heatmap_params(result, ordering=true, colormap=:viridis) +save("network_histogram.png", fig) +``` + +### Sampling from a Block Model + +```julia +using NetworkHistogram + +# Define a 3-block model +k = 3 +bm = BlockModel(k, Bernoulli(0.5)) + +# Set custom edge probabilities between blocks +bm[1, 1] = Bernoulli(0.8) # High within-group connectivity +bm[2, 2] = Bernoulli(0.7) +bm[3, 3] = Bernoulli(0.6) +bm[1, 2] = Bernoulli(0.1) # Low between-group connectivity +bm[1, 3] = Bernoulli(0.05) +bm[2, 3] = Bernoulli(0.05) + +# Sample a network +n_nodes = 150 +latents, A = sample(bm, n_nodes) + +# latents contains the true block assignments +# A is the sampled adjacency matrix +``` + +Check out the +[documentation](https://sds-epfl.github.io/NetworkHistogram.jl/dev/) for more +examples and detailed API information. diff --git a/docs/Project.toml b/docs/Project.toml index 56dc1c7..0270131 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,17 @@ [deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +DiscretizeDistributions = "1dbf0e27-43cd-4e03-8ecf-3f7be9d12b15" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656" +Graphons = "e0c12bfd-47d7-434e-afb7-632611640ca5" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" +Kneedle = "4ef9287f-f14a-4b13-b4c1-9bb5ae54399a" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +LiveServer = "16fef848-5104-11e9-1b77-fb7a48bbb589" NetworkHistogram = "7806f430-7229-459c-b2e6-df35e8e4eb5d" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" diff --git a/docs/examples/custom_suffstats.jl b/docs/examples/custom_suffstats.jl new file mode 100644 index 0000000..bf7bfb2 --- /dev/null +++ b/docs/examples/custom_suffstats.jl @@ -0,0 +1,88 @@ +import NetworkHistogram: SuffStats, add_sample, remove_sample, make_k_block, loss, + to_params, CategoricalConvertor, num_bins, to_distribution +using StaticArrays +using Accessors + +struct MyCustomSuffStats{M, T} <: SuffStats + h::SVector{M, T} +end + +function MyCustomSuffStats(num_categories::Int) + h = SVector{num_categories, Int}(zeros(Int, num_categories)) + return MyCustomSuffStats{num_categories, Int}(h) +end + +@inline function add_sample(ss::MyCustomSuffStats, sample::Int) + ss = @set ss.h[sample] += 1 + return ss +end + +@inline function remove_sample(ss::MyCustomSuffStats, sample::Int) + ss = @set ss.h[sample] -= 1 + return ss +end + +function make_k_block(k, ::Val{:custom}; num_categories, kwargs...) + k_block = SymArray{MyCustomSuffStats{num_categories, Int}}(undef, k, k) + fill!(k_block, MyCustomSuffStats(num_categories)) + return k_block +end + +@inline function loss(ss::MyCustomSuffStats) + n = sum(ss.h) + return n - sum(abs2, ss.h) / max(n, 1) +end + +function to_params(ss::MyCustomSuffStats) + n = max(sum(ss.h), 1) + return ss.h ./ n +end +## + +using Distributions +using NetworkHistogram +using Random +using StatsBase + +function W_multiplex(x, y) + ps = zeros(4) + ps[2] = sqrt(abs(x - y)) / 2 # layer 1 only + ps[3] = abs(sin(2π * x) * sin(2π * y)) / 2 # layer 2 only + ps[4] = min(x, y) / 2 # both layers + ps[1] = 1 - sum(ps[2:4]) # no edge + return DiscreteNonParametric(0:3, SVector{4}(ps)) +end + +convertor = CategoricalConvertor(4, Dict(0 => 1, 1 => 2, 2 => 3, 3 => 4)) + +graphon = DecoratedGraphon(W_multiplex) + +n = 2000 +true_latents = range(0, 1; length = n) +A = sample_graph(graphon, true_latents); + +k = 20 +oracle_labels = ordered_start_labels(n, k); +initial_labels = shuffle(oracle_labels); + +max_iter = 1_000_000 +stalled_iters = 5_000 + +data = convertor.(A) +es_new = NetworkHistogram.make_greedy_suffstats_estimator( + data, initial_labels, num_categories = num_bins(convertor), + type_suff_stats = Val(:custom), + max_iter = max_iter, + swap_rule = NetworkHistogram.RandomGroupSwap(), + stop_rule = NetworkHistogram.PreviousBestValue(stalled_iters, Inf, :min), + progress = true +); +node_labels_es_new, +parameters = NetworkHistogram.estimate!( + es_new, data, initial_labels; iter_progress = 10_000) + +model_es_new = NetworkHistogram.DecoratedSBM(to_distribution.(convertor, parameters), + counts(node_labels_es_new) ./ length(node_labels_es_new)); + +res_new = NetworkHistogram.NethistResult(node_labels_es_new, model_es_new); +NetworkHistogram.align_res_true_latents!(res_new, oracle_labels); diff --git a/docs/literate/tutorials/multiplex_network.jl b/docs/literate/tutorials/multiplex_network.jl new file mode 100644 index 0000000..655c5a4 --- /dev/null +++ b/docs/literate/tutorials/multiplex_network.jl @@ -0,0 +1,133 @@ +#= +# Decorated Graphon Tutorial for Multiplex Networks +=# +using NetworkHistogram +using Distributions +using StaticArrays +import CairoMakie as Mke + +using Random +Random.seed!(1234); +h = 300; + +function W_multiplex(x, y) + ps = zeros(4) + ps[2] = sqrt(abs(x - y)) / 2 # layer 1 only + ps[3] = abs(sin(2π * x) * sin(2π * y)) / 2 # layer 2 only + ps[4] = min(x, y) / 2 # both layers + ps[1] = 1 - sum(ps[2:4]) # no edge + return DiscreteNonParametric(0:3, SVector{4}(ps)) +end + +function W3(x, y) + ps = zeros(4) + ps[1] = 3 * x * y + ps[2] = 3 * sin(2 * π * x) * sin(2 * π * y) + ps[3] = exp(-3 * (x - 0.5)^2 - 3 * (y - 0.5)^2) + ps[4] = 2 - 3 * (x + y) + e_ps = exp.(ps) + return DiscreteNonParametric(0:3, SVector{4}(e_ps ./ sum(e_ps))) +end + +graphon = DecoratedGraphon(W3) + +let + fig = Mke.Figure(size = (4 * h, h)) + for m in 1:4 + ax = Mke.Axis(fig[1, m], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, graphon, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig + display(fig) #src +end + +n = 1000 +true_latents = range(0, 1; length = n) +A = sample_graph(graphon, true_latents); + +k = 14 +oracle_labels = ordered_start_labels(n, k); +initial_labels = shuffle(oracle_labels); + +oracle_res = NetworkHistogram.oracle_estimator( + A, oracle_labels, NetworkHistogram.CategoricalConvertor(A)); + +res = NetworkHistogram.nethist_categorical(A, k, initial_labels) + +# Visualize the fitted models for different numbers of groups after aligning with true latents + +NetworkHistogram.align_res_true_latents!(res, oracle_res.labels); +let + fig = Mke.Figure(size = (4 * h, h)) + for m in 1:4 + ax = Mke.Axis(fig[1, m], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, res.model, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig + display(fig) #src +end + +# We can also align the fitted model to the true one using optimal transport. We need to load the `PythonCall.jl` +# package for that, as we will use the `POT` Python library. + +ENV["JULIA_CONDAPKG_VERBOSITY"] = "-1" # hide conda messages #hide +using PythonCall +θ_oracle = probs.(oracle_res.model.θ); +θ_hat = probs.(res.model.θ); + +perm = NetworkHistogram.get_perm_alignment(θ_hat, θ_oracle); + +θ_hat_aligned = θ_hat[perm, perm]; +estimator_aligned = DecoratedSBM(DiscreteNonParametric.(Ref(0:3), θ_hat_aligned), + res.model.size[perm]); + +let + fig = Mke.Figure(size = (2 * h, h)) + for m in 1:4 + ax = Mke.Axis( + fig[1, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? "Estimated" : "") + Mke.heatmap!(ax, estimator_aligned, k = m, colormap = :binary, colorrange = (0, 1)) + ax2 = Mke.Axis( + fig[2, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? "Oracle" : "") + Mke.heatmap!( + ax2, oracle_res.model, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig + display(fig) #src +end + +# The fitted network histogram can be further processed to obtain a smoother estimate of the underlying graphon. + +using Clustering +shape_range = 1:30 +ssm_estimated, +criterion_values = Graphons.estimate_ssm( + res.model, A, true_latents, shape_range); + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, + kneedle_scan_algorithm = ScanSmoothing(; S = 1.0)); +# Let's extract the optimal number of shapes using the Kneedle algorithm: + +k_knee = knees(kr)[1] +ssm = SSM(res.model, k_knee) + +models_to_plot = [graphon, res.model, ssm_estimated, ssm] +model_names = ["True graphon", "Block model", + "SSM argmin k=$(length(ssm_estimated.θ))", "SSM knee k=$k_knee"] + +let + fig = Mke.Figure(size = (4 * h, length(models_to_plot) * h)) + for (i, model) in enumerate(models_to_plot) + for m in 1:4 + ax = Mke.Axis( + fig[i, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? model_names[i] : "") + Mke.hidedecorations!(ax, label = false) + Mke.heatmap!(ax, model, k = m, colormap = :lipari, colorrange = (0, 1)) + end + end + Mke.Colorbar(fig[2:3, end + 1], colormap = :lipari, + limits = (0, 1), width = 0.05 * h) + fig + display(fig) #src +end diff --git a/docs/literate/tutorials/simple_graph.jl b/docs/literate/tutorials/simple_graph.jl new file mode 100644 index 0000000..680d74c --- /dev/null +++ b/docs/literate/tutorials/simple_graph.jl @@ -0,0 +1,235 @@ +#= +# A Simple Graphon Tutorial with NetworkHistogram.jl +=# + +# This tutorial introduces the concept of a graphon, demonstrates how to sample a graph from one, and then shows how to estimate the graphon from the sampled graph using the Network Histogram method provided by `NetworkHistogram.jl`. + +# ## What is a Graphon? + +# A graphon (or graph function) is a symmetric, measurable function $$W: [0, 1]^2 \to [0, 1]$$. + +# It serves as a generative model for random graphs. Think of it as a continuous and more general version of a stochastic block model. + +# In simple terms, each node `i` in a graph is assigned a latent (unobserved) position $u_i \in [0, 1]$. The probability of an edge existing between two nodes `i` and `j` is then given by the graphon function evaluated at their latent positions: + +# Let's define a simple graphon in Julia. For this example, we'll use a step-function-like graphon that resembles a stochastic block model. + +import CairoMakie as Mke +using LinearAlgebra +using Random +import StatsBase: inverse_rle +using Statistics +using NetworkHistogram +using Distributions + +h = 300; # hide +Random.seed!(1234); + +# Define a simple step-function graphon +w = SimpleContinuousGraphon((x, y) -> x * y) + +# We can visualize this graphon as a heatmap. +let + fig = Mke.Figure(size = (h + 20, h)) + ax = Mke.Axis(fig[1, 1], title = "True Graphon W(u,v)") + hm = Mke.heatmap!(ax, w, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 2], hm) + fig +end + +#md +# ## Sampling a Graph from a Graphon + +# To generate a random graph from a graphon, we follow these steps: +# 1. **Assign latent positions:** For a graph with `n` nodes, we sample `n` independent and identically distributed random variables $u_1, u_2, \dots, u_n$ from a Uniform(0, 1) distribution. These are the latent positions of our nodes. +# 2. **Generate edges:** For each pair of nodes `(i, j)` with `i < j`, we generate a random number from a Bernoulli distribution with probability $W(u_i, u_j)$. This determines whether an edge exists between them. The resulting adjacency matrix `A` will be symmetric. +# Let's sample a graph with 2000 nodes from our graphon `W`. +n = 3000 +u_true = rand(n); # Latent positions +A = sample_graph(w, u_true); + +# We can visualize the adjacency matrix of the sampled graph. +# To make the block structure visible, we sort the nodes by their latent positions. +perm = sortperm(u_true) +A = A[perm, perm] +let + fig = Mke.Figure(size = (h, h)) + ax = Mke.Axis( + fig[1, 1], title = "Sampled Adjacency Matrix (Sorted)", aspect = Mke.DataAspect()) + Mke.heatmap!(ax, A, colormap = :binary) + fig +end + +#md +# ## The Network Histogram Method + +# The Network Histogram method is a non-parametric approach to estimate a graphon from a single observed network. The core idea is to approximate the (unknown) graphon `W` with a piecewise constant function. + +# This is achieved by: +# 1. **Partitioning the nodes:** The nodes of the graph are partitioned into `k` groups. +# 2. **Estimating block probabilities:** The probability of an edge between any two groups is estimated by the density of edges between them. +# 3. **Constructing the histogram:** These estimated probabilities form a `k x k` matrix, which is a step-function approximation of the true graphon. + +# The main challenge is to find the optimal partition of nodes. `NetworkHistogram.jl` provides tools to find a good partition by optimizing an objective function, such as the log-likelihood of the observed graph under the model. + +# ## Fitting a Network Histogram with NetworkHistogram.jl + +# Now, let's use `NetworkHistogram.jl` to fit a network histogram to the graph `A` we sampled earlier. We will try to recover the underlying 2-block structure. + +# We start with a random initial assignment of nodes to `k=5` groups. +k = 10 +oracle_labels = ordered_start_labels(n, k); + +initial_assignment = shuffle(oracle_labels); + +## We can compute the "oracle" estimator, which uses the true latent positions to assign nodes to groups. This serves as a benchmark for our estimation. +oracle_res = NetworkHistogram.oracle_estimator( + A, oracle_labels, NetworkHistogram.BinaryConvertor(); type_suff_stats = Val(:binary)); + +let + fig = Mke.Figure(size = (400, 300)) + ax = Mke.Axis(fig[1, 1], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, oracle_res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 2], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +## +# `NetworkHistogram.jl` provides optimization algorithms to improve the initial assignment. +# Let's use the `nethist` function with `GreedyParams`, which iteratively moves nodes between +# groups to maximize the log-likelihood. + +# params_opti = NetworkHistogram.GreedyParams( +# 100_000, NetworkHistogram.RandomNodeSwap(), NetworkHistogram.Strict(), +# NetworkHistogram.PreviousBestValue(2_000), false); + +# a = nethist(A, dist, initial_assignment, params_opti, false); + +res = NetworkHistogram.nethist_binary(A, k, initial_assignment); + +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end + +# the block labels found by the optimization are not necessarily aligned with the true latent positions, hence the need to align them for better visualization. + +NetworkHistogram.align_res_true_latents!(res, oracle_res.labels); + +# and display the true function, the oracle estimator, and the fitted model +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end + +# We can even fit a Stochastic Shape Model quite easily from the fitted SBM. + +using Clustering + +# ξ = NetworkHistogram.node_labels_to_latents(res.labels, res.model); +shape_range = 1:(k * (k + 1) ÷ 2 - 1) +ssm_estimated, +criterion_values = Graphons.estimate_ssm( + res.model, A, res.labels, shape_range) + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, + kneedle_scan_algorithm = ScanSmoothing(; S = 1.0)) +# Let's extract the optimal number of shapes using the Kneedle algorithm: + +k_knee = knees(kr)[1] +ssm_knee = SSM(res.model, k_knee) + +println("Number of shapes in SSM argmin: ", length(ssm_estimated.θ)) +println("Number of shapes in SSM knee: ", length(ssm_knee.θ)) +println("Number of shapes in SBM: ", length(res.model.θ)) + +# We greatly reduced the number of parameters from the original SBM estimate while preserving much of the structure of the estimated graphon as seen below: + +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["SBM", "SSM argmin", "SSM knee"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], ssm_estimated, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], ssm_knee, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end + +## + +k_kmeans = 10; +clustering_res = kmeans(A, k_kmeans); + +res_kmeans = NetworkHistogram.oracle_estimator( + A, assignments(clustering_res), NetworkHistogram.BinaryConvertor(); + type_suff_stats = Val(:binary), + name = "k-means"); + +NetworkHistogram.align_res_true_latents!(res_kmeans, oracle_res.labels); + +# and display the true function, the oracle estimator, and the fitted model +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res_kmeans.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end + +# ξ = NetworkHistogram.node_labels_to_latents(res.labels, res.model); +shape_range = 1:(k_kmeans * (k_kmeans + 1) ÷ 2 - 1) +ssm_estimated, +criterion_values = Graphons.estimate_ssm( + res_kmeans.model, A, res_kmeans.labels, shape_range) + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, + kneedle_scan_algorithm = ScanSmoothing(; S = 1.0)) +# Let's extract the optimal number of shapes using the Kneedle algorithm: + +k_knee = knees(kr)[1] +ssm_knee = SSM(res_kmeans.model, k_knee) + +println("Number of shapes in SSM argmin: ", length(ssm_estimated.θ)) +println("Number of shapes in SSM knee: ", length(ssm_knee.θ)) +println("Number of shapes in SBM: ", length(res_kmeans.model.θ)) + +# We greatly reduced the number of parameters from the original SBM estimate while preserving much of the structure of the estimated graphon as seen below: + +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["SBM", "SSM argmin", "SSM knee"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], res_kmeans.model, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], ssm_estimated, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], ssm_knee, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end diff --git a/docs/literate/tutorials/temporal_networks.jl b/docs/literate/tutorials/temporal_networks.jl new file mode 100644 index 0000000..27b9860 --- /dev/null +++ b/docs/literate/tutorials/temporal_networks.jl @@ -0,0 +1,5 @@ +#= +# Decorated Graphon Tutorial for Temporal Networks +=# + +# # How to use NetworkHistogram.jl for Temporal Networks diff --git a/docs/literate/tutorials/weighted_network.jl b/docs/literate/tutorials/weighted_network.jl new file mode 100644 index 0000000..4d938c1 --- /dev/null +++ b/docs/literate/tutorials/weighted_network.jl @@ -0,0 +1,100 @@ +#= +# Decorated Graphon Tutorial for Weighted Networks +=# +using Clustering +using NetworkHistogram +using Distributions +using LinearAlgebra +using Random +using Graphons + +import Distributions: pdf + +pdf_kuma(α, β, x, p = 1.0) = @. p * (α * β * x^(α - 1) .* (1 - x^α)^(β - 1)) + +graphon_params = (x, y) -> (3 * abs(sin(2 * π * x) * sin(2 * π * y)) + 0.8, max(x, y) * 8) + +graphon = DecoratedGraphon((x, y) -> Kumaraswamy(graphon_params(x, y)...)) + +import CairoMakie as Mke +let + fig = Mke.Figure(size = (510, 200)) + ax = Mke.Axis(fig[1, 1], aspect = Mke.DataAspect(), title = "α") + hm = Mke.heatmap!(ax, graphon, k = 1, colormap = :viridis) + Mke.Colorbar(fig[1, 2], hm) + ax2 = Mke.Axis(fig[1, 3], aspect = Mke.DataAspect(), title = "β") + hm2 = Mke.heatmap!(ax2, graphon, k = 2, colormap = :viridis) + Mke.Colorbar(fig[1, 4], hm2) + fig +end + +# We sample a weighted network from the graphon + +Random.seed!(1234); +n = 2000 +k = 15 +n_bins = 10 +p = 0.8 + +A = sample_graph(graphon, n) .* Symmetric(rand(Bernoulli(p), n, n)); +ξs = range(0, 1; length = n) +oracle_latents = ordered_start_labels(n, k); + +res_oracle = NetworkHistogram.oracle_estimator( + A, oracle_latents, NetworkHistogram.UnitIntervalConvertor(n_bins)); + +starting_labels = shuffle(oracle_latents); + +max_iter = 1_000_000 +stalled_iters = 10_000 + +res_new = NetworkHistogram.nethist_continuous( + A, k, + starting_labels; + bins = n_bins +); + +ENV["JULIA_CONDAPKG_VERBOSITY"] = "-1" # hide conda messages #hide +using PythonCall + +θ_oracle = Graphons._extract_param.(res_oracle.model.θ); +θ_hat = Graphons._extract_param.(res_new.model.θ); +perm = NetworkHistogram.get_perm_alignment(θ_oracle, θ_hat); + +fitted_labels = map(x -> perm[x], res_new.labels); +res_ot_aligned = NetworkHistogram.oracle_estimator( + A, fitted_labels, NetworkHistogram.UnitIntervalConvertor(n_bins), + name = "aligned with OT perm"); + +xs = range(0, 1; length = 20) + +function viz_one_group!(axis, g1, g2, A, ξs, res_oracle, res_new, xs; n_viz = 20, p = p) + nodes_1 = findall(res_oracle.labels .== g1) + nodes_2 = findall(res_oracle.labels .== g2) + edge_values = [A[x, y] for y in nodes_2 for x in nodes_1] + Mke.vlines!(axis, edge_values, ymax = 0.025, color = :lightgray) + x1 = sample(ξs[nodes_1], n_viz, replace = false) + x2 = sample(ξs[nodes_2], n_viz, replace = false) + for x_ in x1 + for y_ in x2 + Mke.lines!(axis, xs, pdf_kuma(graphon_params(x_, y_)..., xs, p), + color = :gray, alpha = 0.1) + end + end + Mke.lines!(axis, xs, map(Base.Fix1(pdf, res_oracle.model.θ[g1, g2]), xs), + color = :blue, label = "True") + Mke.lines!(axis, xs, map(Base.Fix1(pdf, res_new.model.θ[g1, g2]), xs), + color = :black, linestyle = :dash, label = "Estimated") +end + +fig = Mke.Figure(size = (1000, 1000)) +for g in 1:k + for g2 in 1:g + ax = Mke.Axis(fig[g, g2]) + Mke.hidedecorations!(ax) + viz_one_group!(ax, g, g2, A, ξs, res_oracle, + res_ot_aligned, xs, p = p, n_viz = 5) + end +end +fig +Mke.display(fig) #src diff --git a/docs/make.jl b/docs/make.jl index e4219ec..a2d9cb3 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,26 +1,73 @@ using NetworkHistogram using Documenter -DocMeta.setdocmeta!(NetworkHistogram, :DocTestSetup, :(using NetworkHistogram); +## Literate preprocessing, maybe move to a separate script later for faster builds + +# to run with LiveServer and avoid infinite loops, use +# servedocs(literate_dir=joinpath("docs","literate","tutorials"),skip_dir = joinpath("docs","src","tutorials")) +# adapting `tutorials` to whatever subdir you are working on + +using Literate + +LITERATE_INPUT = joinpath(@__DIR__, "literate") +LITERATE_OUTPUT = joinpath(@__DIR__, "src") + +for dir_path in filter(isdir, readdir(joinpath(@__DIR__, "literate"), join = true)) + dirname = basename(dir_path) + + for (root, _, files) in walkdir(dir_path), file in files + # ignore non julia files + splitext(file)[2] == ".jl" || continue + # full path to a literate script + ipath = joinpath(root, file) + # generated output path + opath = splitdir(replace(ipath, LITERATE_INPUT => LITERATE_OUTPUT))[1] + # generate the markdown file calling Literate + Literate.markdown(ipath, opath) + end +end + +DocMeta.setdocmeta!( + NetworkHistogram, :DocTestSetup, :(using NetworkHistogram); recursive = true) +# based on available extensions, include them in the documentation +modules_all = [ + NetworkHistogram, + Base.get_extension(NetworkHistogram, :BootstrapExt), + Base.get_extension(NetworkHistogram, :LightMCExt), + Base.get_extension(NetworkHistogram, :MakieExt), + Base.get_extension(NetworkHistogram, :PythonOptimalTransport) +] + +#TODO: safety check, should probably throw an error instead +modules = [filter(!isnothing, modules_all)...] + +using DocumenterInterLinks + +links = InterLinks( + "ot" => "https://pythonot.github.io/", +) + makedocs(; - modules = [NetworkHistogram], + modules = modules, authors = "Jake Grainger, Charles Dufour", - #repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git", + repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git", sitename = "NetworkHistogram.jl", - #format = Documenter.HTML(; - # prettyurls = get(ENV, "CI", "false") == "true", - # canonical = "https://SDS-EPFL.github.io/NetworkHistogram.jl", - # edit_link = "main", - # assets = String[]), + format = Documenter.HTML(; + prettyurls = get(ENV, "CI", "false") == "true", + canonical = "https://SDS-EPFL.github.io/NetworkHistogram.jl", + edit_link = "main", + assets = String[]), pages = [ "Home" => "index.md", "API Reference" => "api.md", - "Optimization hyperparameters" => "rules.md", - "Development" => "internals.md", - ], - checkdocs = :none) + "Tutorials" => ["First steps" => "tutorials/simple_graph.md", + "Multiplex networks" => "tutorials/multiplex_network.md", + "Weighted networks" => "tutorials/weighted_network.md", + "Temporal networks" => "tutorials/temporal_networks.md"]], + checkdocs = :none, + plugins = [links]) deploydocs(; repo = "github.com/SDS-EPFL/NetworkHistogram.jl.git") diff --git a/docs/src/api.md b/docs/src/api.md index 26fd83e..8ecacaf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -1,26 +1,10 @@ ```@contents Pages = ["api.md"] -Depth = 1 +Depth = 2 ``` -# NetworkHistogram - -```@autodocs -Modules = [NetworkHistogram] -Pages = ["histogram.jl","optimize.jl"] -``` - -# Assignment  - -```@autodocs + diff --git a/docs/src/index.md b/docs/src/index.md index f496b95..0bd81de 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,7 +1,8 @@ # NetworkHistogram.jl -Implementation of the network histogram for graphon estimation from the paper [Network histograms and universality of blockmodel approximation](https://doi.org/10.1073/pnas.1400374111) by Sofia C. Olhede and Patrick J. Wolfe. - +Implementation of the network histogram for graphon estimation from the paper +[Network histograms and universality of blockmodel approximation](https://doi.org/10.1073/pnas.1400374111) +by Sofia C. Olhede and Patrick J. Wolfe. ## Installation @@ -11,7 +12,8 @@ Pkg.add("NetworkHistogram") ## Usage -We fit the estimator using [`graphhist`](@ref graphhist) and then extract the estimated graphon matrix and node labels. +We fit the estimator and then extract the estimated graphon matrix and node +labels. ```julia using NetworkHistogram, LinearAlgebra @@ -20,10 +22,7 @@ A = Symmetric(rand(0:1, 100, 100)) A[diagind(A)] .= 0 # approximate the graphon with a network histogram -hist = graphhist(A) - -# get the graphist structure -estimate = hist.graphhist +estimate = graph_hist(A) # get the estimated graphon matrix sbm_matrix = estimate.θ @@ -32,4 +31,5 @@ sbm_matrix = estimate.θ node_labels = estimate.node_labels ``` -You can control the optimization process by modifying the rules used in the optimization. Check out [Optimization hyper-parameters](@ref) for more information. \ No newline at end of file +You can control the optimization process by modifying the rules used in the +optimization. diff --git a/docs/src/internals.md b/docs/src/internals.md deleted file mode 100644 index 04ce12a..0000000 --- a/docs/src/internals.md +++ /dev/null @@ -1,8 +0,0 @@ -# Notes on optimisation - -- We have three different assignment variables: - - current - - best - - proposal -- We need best because we might accept a proposal which is worse than the curent value. -- This is dealt within `accept_reject_update!()`. \ No newline at end of file diff --git a/docs/src/rules.md b/docs/src/rules.md deleted file mode 100644 index 510db50..0000000 --- a/docs/src/rules.md +++ /dev/null @@ -1,38 +0,0 @@ -# Optimization hyper-parameters - -Here we discuss the different parameters that can be used to control the optimization process. The optimization greedily tries to find a good partition of the network. You can control the optimization process by setting the following parameters: - -## Starting node labels - -```@autodocs -Modules = [NetworkHistogram] -Pages = ["starting_assignment_rule.jl"] -``` - -!!! note - The groups will be of size `floor(h * n)` where `n` is the number of nodes if `h` is a - float. If `h` is an integer, the groups will be of size `h`. The last group may be - smaller if `n` is not exactly divisible by the group size. - - -## Swapping rule - -```@autodocs -Modules = [NetworkHistogram] -Pages = ["swap_rule.jl"] -``` - - -## Acceptance rule - -```@autodocs -Modules = [NetworkHistogram] -Pages = ["accept_rule.jl"] -``` - -## Stopping rule - -```@autodocs -Modules = [NetworkHistogram] -Pages = ["stop_rule.jl"] -``` \ No newline at end of file diff --git a/docs/src/tutorials/multiplex_network.md b/docs/src/tutorials/multiplex_network.md new file mode 100644 index 0000000..9fc4dc4 --- /dev/null +++ b/docs/src/tutorials/multiplex_network.md @@ -0,0 +1,147 @@ +```@meta +EditURL = "../../literate/tutorials/multiplex_network.jl" +``` + +# Decorated Graphon Tutorial for Multiplex Networks + +````@example multiplex_network +using NetworkHistogram +using Distributions +using StaticArrays +import CairoMakie as Mke + +using Random +Random.seed!(1234); +h = 300; + +function W_multiplex(x, y) + ps = zeros(4) + ps[2] = sqrt(abs(x - y)) / 2 # layer 1 only + ps[3] = abs(sin(2π * x) * sin(2π * y)) / 2 # layer 2 only + ps[4] = min(x, y) / 2 # both layers + ps[1] = 1 - sum(ps[2:4]) # no edge + return DiscreteNonParametric(0:3, SVector{4}(ps)) +end + +function W3(x, y) + ps = zeros(4) + ps[1] = 3 * x * y + ps[2] = 3 * sin(2 * π * x) * sin(2 * π * y) + ps[3] = exp(-3 * (x - 0.5)^2 - 3 * (y - 0.5)^2) + ps[4] = 2 - 3 * (x + y) + e_ps = exp.(ps) + return DiscreteNonParametric(0:3, SVector{4}(e_ps ./ sum(e_ps))) +end + +graphon = DecoratedGraphon(W3) + +let + fig = Mke.Figure(size = (4 * h, h)) + for m in 1:4 + ax = Mke.Axis(fig[1, m], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, graphon, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig +end + +n = 1000 +true_latents = range(0, 1; length = n) +A = sample_graph(graphon, true_latents); + +k = 14 +oracle_labels = ordered_start_labels(n, k); +initial_labels = shuffle(oracle_labels); + +oracle_res = NetworkHistogram.oracle_estimator( + A, oracle_labels, NetworkHistogram.CategoricalConvertor(A)); + +res = NetworkHistogram.nethist_categorical(A, k, initial_labels) +```` + +Visualize the fitted models for different numbers of groups after aligning with true latents + +````@example multiplex_network +NetworkHistogram.align_res_true_latents!(res, oracle_res.labels); +let + fig = Mke.Figure(size = (4 * h, h)) + for m in 1:4 + ax = Mke.Axis(fig[1, m], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, res.model, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig +end +```` + +We can also align the fitted model to the true one using optimal transport. We need to load the `PythonCall.jl` +package for that, as we will use the `POT` Python library. + +````@example multiplex_network +ENV["JULIA_CONDAPKG_VERBOSITY"] = "-1" # hide conda messages #hide +using PythonCall +θ_oracle = probs.(oracle_res.model.θ); +θ_hat = probs.(res.model.θ); + +perm = NetworkHistogram.get_perm_alignment(θ_hat, θ_oracle); + +θ_hat_aligned = θ_hat[perm, perm]; +estimator_aligned = DecoratedSBM(DiscreteNonParametric.(Ref(0:3), θ_hat_aligned), + res.model.size[perm]); + +let + fig = Mke.Figure(size = (2 * h, h)) + for m in 1:4 + ax = Mke.Axis( + fig[1, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? "Estimated" : "") + Mke.heatmap!(ax, estimator_aligned, k = m, colormap = :binary, colorrange = (0, 1)) + ax2 = Mke.Axis( + fig[2, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? "Oracle" : "") + Mke.heatmap!( + ax2, oracle_res.model, k = m, colormap = :binary, colorrange = (0, 1)) + end + fig +end +```` + +The fitted network histogram can be further processed to obtain a smoother estimate of the underlying graphon. + +````@example multiplex_network +using Clustering +shape_range = 1:30 +ssm_estimated, criterion_values = Graphons.estimate_ssm( + res.model, A, true_latents, shape_range); + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, scan_type = :smoothing); +nothing #hide +```` + + Let's extract the optimal number of shapes using the Kneedle algorithm: + +````@example multiplex_network +k_knee = knees(kr)[1] +ssm = SSM(res.model, k_knee) + +models_to_plot = [graphon, res.model, ssm_estimated, ssm] +model_names = ["True graphon", "Block model", + "SSM argmin k=$(length(ssm_estimated.θ))", "SSM knee k=$k_knee"] + +let + fig = Mke.Figure(size = (4 * h, length(models_to_plot) * h)) + for (i, model) in enumerate(models_to_plot) + for m in 1:4 + ax = Mke.Axis( + fig[i, m], aspect = Mke.DataAspect(), ylabel = m == 1 ? model_names[i] : "") + Mke.hidedecorations!(ax, label = false) + Mke.heatmap!(ax, model, k = m, colormap = :lipari, colorrange = (0, 1)) + end + end + Mke.Colorbar(fig[2:3, end + 1], colormap = :lipari, + limits = (0, 1), width = 0.05 * h) + fig +end +```` + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/docs/src/tutorials/simple_graph.md b/docs/src/tutorials/simple_graph.md new file mode 100644 index 0000000..2360201 --- /dev/null +++ b/docs/src/tutorials/simple_graph.md @@ -0,0 +1,290 @@ +```@meta +EditURL = "../../literate/tutorials/simple_graph.jl" +``` + +# A Simple Graphon Tutorial with NetworkHistogram.jl + +This tutorial introduces the concept of a graphon, demonstrates how to sample a graph from one, and then shows how to estimate the graphon from the sampled graph using the Network Histogram method provided by `NetworkHistogram.jl`. + +## What is a Graphon? + +A graphon (or graph function) is a symmetric, measurable function $$W: [0, 1]^2 \to [0, 1]$$. + +It serves as a generative model for random graphs. Think of it as a continuous and more general version of a stochastic block model. + +In simple terms, each node `i` in a graph is assigned a latent (unobserved) position $u_i \in [0, 1]$. The probability of an edge existing between two nodes `i` and `j` is then given by the graphon function evaluated at their latent positions: + +Let's define a simple graphon in Julia. For this example, we'll use a step-function-like graphon that resembles a stochastic block model. + +````@example simple_graph +import CairoMakie as Mke +using LinearAlgebra +using Random +import StatsBase: inverse_rle +using Statistics +using NetworkHistogram +using Distributions + +h = 300; # hide +Random.seed!(1234); +nothing #hide +```` + +Define a simple step-function graphon + +````@example simple_graph +w = SimpleContinuousGraphon((x, y) -> x * y) +```` + +We can visualize this graphon as a heatmap. + +````@example simple_graph +let + fig = Mke.Figure(size = (h + 20, h)) + ax = Mke.Axis(fig[1, 1], title = "True Graphon W(u,v)") + hm = Mke.heatmap!(ax, w, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 2], hm) + fig +end + +#md +```` + +## Sampling a Graph from a Graphon + +To generate a random graph from a graphon, we follow these steps: +1. **Assign latent positions:** For a graph with `n` nodes, we sample `n` independent and identically distributed random variables $u_1, u_2, \dots, u_n$ from a Uniform(0, 1) distribution. These are the latent positions of our nodes. +2. **Generate edges:** For each pair of nodes `(i, j)` with `i < j`, we generate a random number from a Bernoulli distribution with probability $W(u_i, u_j)$. This determines whether an edge exists between them. The resulting adjacency matrix `A` will be symmetric. +Let's sample a graph with 2000 nodes from our graphon `W`. + +````@example simple_graph +n = 3000 +u_true = rand(n); # Latent positions +A = sample_graph(w, u_true); +nothing #hide +```` + +We can visualize the adjacency matrix of the sampled graph. +To make the block structure visible, we sort the nodes by their latent positions. + +````@example simple_graph +perm = sortperm(u_true) +A = A[perm, perm] +let + fig = Mke.Figure(size = (h, h)) + ax = Mke.Axis( + fig[1, 1], title = "Sampled Adjacency Matrix (Sorted)", aspect = Mke.DataAspect()) + Mke.heatmap!(ax, A, colormap = :binary) + fig +end + +#md +```` + +## The Network Histogram Method + +The Network Histogram method is a non-parametric approach to estimate a graphon from a single observed network. The core idea is to approximate the (unknown) graphon `W` with a piecewise constant function. + +This is achieved by: +1. **Partitioning the nodes:** The nodes of the graph are partitioned into `k` groups. +2. **Estimating block probabilities:** The probability of an edge between any two groups is estimated by the density of edges between them. +3. **Constructing the histogram:** These estimated probabilities form a `k x k` matrix, which is a step-function approximation of the true graphon. + +The main challenge is to find the optimal partition of nodes. `NetworkHistogram.jl` provides tools to find a good partition by optimizing an objective function, such as the log-likelihood of the observed graph under the model. + +## Fitting a Network Histogram with NetworkHistogram.jl + +Now, let's use `NetworkHistogram.jl` to fit a network histogram to the graph `A` we sampled earlier. We will try to recover the underlying 2-block structure. + +We start with a random initial assignment of nodes to `k=5` groups. + +````@example simple_graph +k = 10 +oracle_labels = ordered_start_labels(n, k); + +initial_assignment = shuffle(oracle_labels); + +# We can compute the "oracle" estimator, which uses the true latent positions to assign nodes to groups. This serves as a benchmark for our estimation. +oracle_res = NetworkHistogram.oracle_estimator( + A, oracle_labels, NetworkHistogram.BinaryConvertor(); type_suff_stats = Val(:binary)); + +let + fig = Mke.Figure(size = (400, 300)) + ax = Mke.Axis(fig[1, 1], aspect = Mke.DataAspect()) + Mke.heatmap!(ax, oracle_res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 2], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +# +```` + +`NetworkHistogram.jl` provides optimization algorithms to improve the initial assignment. +Let's use the `nethist` function with `GreedyParams`, which iteratively moves nodes between +groups to maximize the log-likelihood. + +params_opti = NetworkHistogram.GreedyParams( + 100_000, NetworkHistogram.RandomNodeSwap(), NetworkHistogram.Strict(), + NetworkHistogram.PreviousBestValue(2_000), false); + +a = nethist(A, dist, initial_assignment, params_opti, false); + +````@example simple_graph +res = NetworkHistogram.nethist_binary(A, k, initial_assignment); + +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +```` + +the block labels found by the optimization are not necessarily aligned with the true latent positions, hence the need to align them for better visualization. + +````@example simple_graph +NetworkHistogram.align_res_true_latents!(res, oracle_res.labels); +nothing #hide +```` + +and display the true function, the oracle estimator, and the fitted model + +````@example simple_graph +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +```` + +We can even fit a Stochastic Shape Model quite easily from the fitted SBM. + +````@example simple_graph +using Clustering +```` + +ξ = NetworkHistogram.node_labels_to_latents(res.labels, res.model); + +````@example simple_graph +shape_range = 1:(k * (k + 1) ÷ 2 - 1) +ssm_estimated, criterion_values = Graphons.estimate_ssm( + res.model, A, res.labels, shape_range) + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, scan_type = :smoothing) +```` + + Let's extract the optimal number of shapes using the Kneedle algorithm: + +````@example simple_graph +k_knee = knees(kr)[1] +ssm_knee = SSM(res.model, k_knee) + +println("Number of shapes in SSM argmin: ", length(ssm_estimated.θ)) +println("Number of shapes in SSM knee: ", length(ssm_knee.θ)) +println("Number of shapes in SBM: ", length(res.model.θ)) +```` + +We greatly reduced the number of parameters from the original SBM estimate while preserving much of the structure of the estimated graphon as seen below: + +````@example simple_graph +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["SBM", "SSM argmin", "SSM knee"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], res.model, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], ssm_estimated, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], ssm_knee, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end + +# + +k_kmeans = 10; +clustering_res = kmeans(A, k_kmeans); + +res_kmeans = NetworkHistogram.oracle_estimator( + A, assignments(clustering_res), NetworkHistogram.BinaryConvertor(); + type_suff_stats = Val(:binary), + name = "k-means"); + +NetworkHistogram.align_res_true_latents!(res_kmeans, oracle_res.labels); +nothing #hide +```` + +and display the true function, the oracle estimator, and the fitted model + +````@example simple_graph +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["True Graphon W(u,v)", "Oracle Estimator", "Fitted Network Histogram"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], w, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], oracle_res.model, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], res_kmeans.model, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +```` + +ξ = NetworkHistogram.node_labels_to_latents(res.labels, res.model); + +````@example simple_graph +shape_range = 1:(k_kmeans * (k_kmeans + 1) ÷ 2 - 1) +ssm_estimated, criterion_values = Graphons.estimate_ssm( + res_kmeans.model, A, res_kmeans.labels, shape_range) + +using Kneedle +kr = kneedle(shape_range, criterion_values, "convex_dec", 1, scan_type = :smoothing) +```` + + Let's extract the optimal number of shapes using the Kneedle algorithm: + +````@example simple_graph +k_knee = knees(kr)[1] +ssm_knee = SSM(res_kmeans.model, k_knee) + +println("Number of shapes in SSM argmin: ", length(ssm_estimated.θ)) +println("Number of shapes in SSM knee: ", length(ssm_knee.θ)) +println("Number of shapes in SBM: ", length(res_kmeans.model.θ)) +```` + +We greatly reduced the number of parameters from the original SBM estimate while preserving much of the structure of the estimated graphon as seen below: + +````@example simple_graph +let + fig = Mke.Figure(size = (1220, 400)) + titles = ["SBM", "SSM argmin", "SSM knee"] + axes = [Mke.Axis(fig[1, i], aspect = Mke.DataAspect(), title = titles[i]) for i in 1:3] + Mke.heatmap!(axes[1], res_kmeans.model, colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[2], ssm_estimated, + colormap = :binary, colorrange = (0, 1)) + Mke.heatmap!(axes[3], ssm_knee, colormap = :binary, colorrange = (0, 1)) + Mke.Colorbar(fig[1, 4], colormap = :binary, + limits = (0, 1), label = "Edge Probability", width = 20) + fig +end +```` + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/docs/src/tutorials/temporal_networks.md b/docs/src/tutorials/temporal_networks.md new file mode 100644 index 0000000..7174672 --- /dev/null +++ b/docs/src/tutorials/temporal_networks.md @@ -0,0 +1,12 @@ +```@meta +EditURL = "../../literate/tutorials/temporal_networks.jl" +``` + +# Decorated Graphon Tutorial for Temporal Networks + +# How to use NetworkHistogram.jl for Temporal Networks + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/docs/src/tutorials/weighted_network.md b/docs/src/tutorials/weighted_network.md new file mode 100644 index 0000000..1b4e24b --- /dev/null +++ b/docs/src/tutorials/weighted_network.md @@ -0,0 +1,111 @@ +```@meta +EditURL = "../../literate/tutorials/weighted_network.jl" +``` + +# Decorated Graphon Tutorial for Weighted Networks + +````@example weighted_network +using Clustering +using NetworkHistogram +using Distributions +using LinearAlgebra +using Random +using Graphons + +import Distributions: pdf + +pdf_kuma(α, β, x, p = 1.0) = @. p * (α * β * x^(α - 1) .* (1 - x^α)^(β - 1)) + +graphon_params = (x, y) -> (3 * abs(sin(2 * π * x) * sin(2 * π * y)) + 0.8, max(x, y) * 8) + +graphon = DecoratedGraphon((x, y) -> Kumaraswamy(graphon_params(x, y)...)) + +import CairoMakie as Mke +let + fig = Mke.Figure(size = (510, 200)) + ax = Mke.Axis(fig[1, 1], aspect = Mke.DataAspect(), title = "α") + hm = Mke.heatmap!(ax, graphon, k = 1, colormap = :viridis) + Mke.Colorbar(fig[1, 2], hm) + ax2 = Mke.Axis(fig[1, 3], aspect = Mke.DataAspect(), title = "β") + hm2 = Mke.heatmap!(ax2, graphon, k = 2, colormap = :viridis) + Mke.Colorbar(fig[1, 4], hm2) + fig +end +```` + +We sample a weighted network from the graphon + +````@example weighted_network +Random.seed!(1234); +n = 2000 +k = 15 +n_bins = 10 +p = 0.8 + +A = sample_graph(graphon, n) .* Symmetric(rand(Bernoulli(p), n, n)); +ξs = range(0, 1; length = n) +oracle_latents = ordered_start_labels(n, k); + +res_oracle = NetworkHistogram.oracle_estimator( + A, oracle_latents, NetworkHistogram.UnitIntervalConvertor(n_bins)); + +starting_labels = shuffle(oracle_latents); + +max_iter = 1_000_000 +stalled_iters = 10_000 + +res_new = NetworkHistogram.nethist_continuous( + A, k, + starting_labels; + bins = n_bins +); + +ENV["JULIA_CONDAPKG_VERBOSITY"] = "-1" # hide conda messages #hide +using PythonCall + +θ_oracle = Graphons._extract_param.(res_oracle.model.θ); +θ_hat = Graphons._extract_param.(res_new.model.θ); +perm = NetworkHistogram.get_perm_alignment(θ_oracle, θ_hat); + +fitted_labels = map(x -> perm[x], res_new.labels); +res_ot_aligned = NetworkHistogram.oracle_estimator( + A, fitted_labels, NetworkHistogram.UnitIntervalConvertor(n_bins), + name = "aligned with OT perm"); + +xs = range(0, 1; length = 20) + +function viz_one_group!(axis, g1, g2, A, ξs, res_oracle, res_new, xs; n_viz = 20, p = p) + nodes_1 = findall(res_oracle.labels .== g1) + nodes_2 = findall(res_oracle.labels .== g2) + edge_values = [A[x, y] for y in nodes_2 for x in nodes_1] + Mke.vlines!(axis, edge_values, ymax = 0.025, color = :lightgray) + x1 = sample(ξs[nodes_1], n_viz, replace = false) + x2 = sample(ξs[nodes_2], n_viz, replace = false) + for x_ in x1 + for y_ in x2 + Mke.lines!(axis, xs, pdf_kuma(graphon_params(x_, y_)..., xs, p), + color = :gray, alpha = 0.1) + end + end + Mke.lines!(axis, xs, map(Base.Fix1(pdf, res_oracle.model.θ[g1, g2]), xs), + color = :blue, label = "True") + Mke.lines!(axis, xs, map(Base.Fix1(pdf, res_new.model.θ[g1, g2]), xs), + color = :black, linestyle = :dash, label = "Estimated") +end + +fig = Mke.Figure(size = (1000, 1000)) +for g in 1:k + for g2 in 1:g + ax = Mke.Axis(fig[g, g2]) + Mke.hidedecorations!(ax) + viz_one_group!(ax, g, g2, A, ξs, res_oracle, + res_ot_aligned, xs, p = p, n_viz = 5) + end +end +fig +```` + +--- + +*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* + diff --git a/ext/BootstrapExt.jl b/ext/BootstrapExt.jl new file mode 100644 index 0000000..5584274 --- /dev/null +++ b/ext/BootstrapExt.jl @@ -0,0 +1,7 @@ +module BootstrapExt + +using NetworkHistogram + +using Bootstrap + +end diff --git a/ext/LightMCExt.jl b/ext/LightMCExt.jl new file mode 100644 index 0000000..4942c5d --- /dev/null +++ b/ext/LightMCExt.jl @@ -0,0 +1,62 @@ +module LightMCExt + +using StaticArrays +using Accessors +using NetworkHistogram +import NetworkHistogram: SuffStats, add_sample, remove_sample, make_k_block, loss, + to_params, AbstractConvertor, to_distribution, get_convertor + +using LightMC: DiscreteMarkovChain, SampleChain, transition_matrix, ConvertBinaryMC + +# need to define a convertor that only look at the possible transitions and not all of them +struct McConvertor <: AbstractConvertor end + +get_convertor(::Val{:mc}; kwargs...) = McConvertor(kwargs...) + +function (c::McConvertor)(chain::SampleChain) + return SVector([SVector(c...) for c in eachcol(chain.transitions)]...) +end + +function to_distribution(::McConvertor, transition_matrix; kwargs...) + return DiscreteMarkovChain(transition_matrix, sum(transition_matrix; dims = 2)) +end +struct McSuffStats{M, T} <: SuffStats + h::SVector{M, T} +end + +# this will also need to be modified to take into account the structure of the markov chain +# as above (e.g. only count the transitions that are possible) +function McSuffStats(num_states::Int) + inter = @SVector zeros(SVector{num_states, Int}, num_states) + return McSuffStats(inter) +end + +function add_sample(ss::McSuffStats, sample) + @inbounds for (i, s) in enumerate(sample) + ss = @set ss.h[i] = ss.h[i] + s + end + return ss +end + +function remove_sample(ss::McSuffStats, sample) + @inbounds for (i, s) in enumerate(sample) + ss = @set ss.h[i] = ss.h[i] - s + end + return ss +end + +function _loss(counts::SVector) + n = sum(counts) + norm_ = max(n, 1) + return (n - sum(abs2, counts) / norm_) / norm_ +end + +function loss(ss::McSuffStats) + return sum(_loss, ss.h) +end + +function to_params(ss::McSuffStats) + return reduce(hcat, ss.h) +end + +end diff --git a/ext/MakieExt.jl b/ext/MakieExt.jl new file mode 100644 index 0000000..3a7cc28 --- /dev/null +++ b/ext/MakieExt.jl @@ -0,0 +1,8 @@ +module MakieExt + +using NetworkHistogram +using Makie + +Makie.convert_single_argument(A::SymArray) = Matrix(A) + +end diff --git a/ext/PythonOptimalTransport/PythonOptimalTransport.jl b/ext/PythonOptimalTransport/PythonOptimalTransport.jl new file mode 100644 index 0000000..969429b --- /dev/null +++ b/ext/PythonOptimalTransport/PythonOptimalTransport.jl @@ -0,0 +1,27 @@ +module PythonOptimalTransport +using PythonCall +using NetworkHistogram + +import NetworkHistogram: align_matrices, get_perm_alignment + +const ot = Ref{Py}() +const fngw = Ref{Py}() + +function __init__() + ot[] = pyimport("ot") + pyimport("sys").path.append(@__DIR__) + # TODO: find why I need to use fngw.x to access the functions later... + fngw[] = pyimport("fngw") +end + +## helpers to convert Julia arrays to numpy arrays +jl_to_np(mat::AbstractArray{<:Real}) = Py(mat).to_numpy() +function jl_to_np(mat::AbstractMatrix{<:AbstractVector}) + Py(permutedims(stack(mat), (3, 2, 1))).to_numpy() +end + +include("alignment.jl") + +# look at https://pythonot.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html#sphx-glr-auto-examples-backends-plot-optim-gromov-pytorch-py +# to implement semi-relaxed gromov-wasserstein ? +end diff --git a/ext/PythonOptimalTransport/alignment.jl b/ext/PythonOptimalTransport/alignment.jl new file mode 100644 index 0000000..0c68ebf --- /dev/null +++ b/ext/PythonOptimalTransport/alignment.jl @@ -0,0 +1,67 @@ + +# helpers for optimal transport alignment + +function plan_to_permutation(plan) + ordering = argmax(plan, dims = 1) .|> Tuple |> vec + perm = sort(ordering, by = x -> x[1]) .|> last + return perm +end + +""" +Get the permutation aligning source and target matrices using optimal transport. + +This function converts a gromov-wasserstein plan into a permutation by taking the argmax +along the rows. + +This function uses [`gromov_wasserstein`](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.BAPG_gromov_wasserstein) + +# See also +- [`align_matrices`](@ref) +- [`ot.gromov.gromov_wasserstein`](@extref) +""" +function get_perm_alignment( + src::AbstractMatrix{<:Real}, + target::AbstractMatrix{<:Real}; + kwargs... +) + plan = ot[].gromov.gromov_wasserstein( + C2 = jl_to_np(src), C1 = jl_to_np(target), kwargs...) + plan = pyconvert(typeof(target), plan) + return plan_to_permutation(plan) +end + +function get_perm_alignment( + src::AbstractMatrix{T1}, + target::AbstractMatrix{T2}; + kwargs... +) where {T1 <: AbstractVector, T2 <: AbstractVector} + C1 = jl_to_np(target) + C2 = jl_to_np(src) + dist, + log_ = fngw.x.fused_network_gromov_wasserstein2( + M = jl_to_np(zeros(size(target, 1), size(src, 1))), + C1 = C1, + C2 = C2, + A1 = jl_to_np(ones(size(target, 1), size(target, 1))), + A2 = jl_to_np(ones(size(src, 1), size(src, 1))), + p = jl_to_np(fill(1.0 / size(target, 1), size(target, 1))), + q = jl_to_np(fill(1.0 / size(src, 1), size(src, 1))), + alpha = 1.0, + beta = 0.0, + log = true, + kwargs... + ) + plan = pyconvert(Matrix{Float64}, log_["T"]) + return plan_to_permutation(plan) +end + +""" +Align the source and target matrices using optimal transport. + +# See also +- [`get_perm_alignment`](@ref). +""" +function align_matrices(src, target) + perm = get_perm_alignment(src, target) + return src[perm, perm] +end diff --git a/ext/PythonOptimalTransport/fngw.py b/ext/PythonOptimalTransport/fngw.py new file mode 100644 index 0000000..56f575d --- /dev/null +++ b/ext/PythonOptimalTransport/fngw.py @@ -0,0 +1,1002 @@ +import numpy as np + +from ot.utils import dist, UndefinedParameter, list_to_array +from ot.utils import check_random_state +from ot.backend import get_backend +from ot.optim import line_search_armijo, solve_1d_linesearch_quad +from ot.lp import emd + +from ot.gromov import init_matrix as init_matrix_A +from ot.gromov import gwloss, gwggrad + + +def fngw_barycenters( + N, + Fs, + As, + Cs, + ps, + lambdas, + alpha, + beta, + fixed_structure=False, + fixed_node_features=False, + fixed_edge_features=False, + p=None, + dist_fun_C="l2_norm", + dist_fun_A="square_loss", + max_iter=100, + tol=1e-9, + verbose=False, + log=False, + init_C=None, + init_F=None, + init_A=None, + random_state=None, +): + r"""Compute the FNGW barycenter as presented eq (12) in our paper + + Parameters + ---------- + N : int + Desired number of samples of the target barycenter + Fs: list of array-like, each element has shape (ns,d) + Node features of all samples + As : list of array-like, each element has shape (ns,ns) + Structure matrices of all samples + Cs : list of array-like, each element has shape (ns,ns,d') + Edge feature tensors of all samples + ps : list of array-like, each element has shape (ns,) + Masses of all samples. + lambdas : list of float + List of the `S` spaces' weights + alpha : float + Alpha parameter for the FNGW distance + beta : float + Alpha parameter for the FNGW distance + fixed_structure : bool + Whether to fix the structure of the barycenter during the updates + fixed_node_features : bool + Whether to fix the node feature of the barycenter during the updates + fixed_edge_features : bool + Whether to fix the edge feature of the barycenter during the updates + dist_fun_A : str + Loss function used for the solver either 'square_loss' or 'kl_loss' + dist_fun_C : str + Inner distance function used for the solver. Now only 'l2_norm' + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on error (>0). + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : array-like, shape (N,N,d'), optional + Initialization for the barycenters' edge feature tensor. If not set + a random init is used. + init_F : array-like, shape (N,d), optional + Initialization for the barycenters' node features. If not set a + random init is used. + init_A : array-like, shape (N,N), optional + Initialization for the barycenters' structure matrix. If not set + a random init is used. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + F : array-like, shape (`N`, `d`) + Barycenters' features + A : array-like, shape (`N`, `N`) + Barycenters' structure matrix + C : array-like, shape (`N`, `N`, `d'`) + Barycenters' edge feature tensor + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature + of the barycenter and the other features + :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`) + + """ + Cs = list_to_array(*Cs) + As = list_to_array(*As) + ps = list_to_array( + *ps + ) # list to array bug when only one list has length one + Fs = list_to_array(*Fs) + if not isinstance(Cs, list): + Cs = [Cs] + if not isinstance(As, list): + As = [As] + if not isinstance(ps, list): + ps = [ps] + if not isinstance(Fs, list): + Fs = [Fs] + + p = list_to_array(p) + nx = get_backend(*Cs, *Fs, *As, *ps) + + S = len(Cs) + d = Fs[0].shape[1] # dimension on the node features + d_edge = Cs[0].shape[2] + if p is None: + p = nx.ones(N, type_as=Cs[0]) / N + + if fixed_edge_features: + if init_C is None: + raise UndefinedParameter("If C is fixed it must be initialized") + else: + C = init_C + else: + if init_C is None: + generator = check_random_state(random_state) + C = generator.randn(N, N, d_edge) + C = nx.from_numpy(C, type_as=ps[0]) + else: + C = init_C + + if fixed_structure: + if init_A is None: + raise UndefinedParameter("If A is fixed it must be initialized") + else: + A = init_A + else: + if init_A is None: + generator = check_random_state(random_state) + xalea = generator.randn(N, 2) + A = dist(xalea, xalea) + A = nx.from_numpy(A, type_as=ps[0]) + else: + A = init_A + + if fixed_node_features: + if init_F is None: + raise UndefinedParameter("If F is fixed it must be initialized") + else: + F = init_F + else: + if init_F is None: + F = nx.zeros((N, d), type_as=ps[0]) + else: + F = init_F + + T = [nx.outer(p, q) for q in ps] + + Ms = [dist(F, Fs[s]) for s in range(len(Fs))] + + cpt = 0 + err_node_feature = 1 + err_structure = 1 + err_edge_feature = 1 + + if log: + log_ = {} + log_["err_node_feature"] = [] + log_["err_edge_feature"] = [] + log_["err_structure"] = [] + log_["Ts_iter"] = [] + + while ( + err_node_feature > tol or err_structure > tol or err_edge_feature > tol + ) and cpt < max_iter: + Cprev = C + Aprev = A + Xprev = F + + if not fixed_node_features: + Fs_temp = [y.T for y in Fs] + F = update_node_feature_matrix(lambdas, Fs_temp, T, p).T + + Ms = [dist(F, Fs[s]) for s in range(len(Fs))] + + if not fixed_structure: + if dist_fun_A == "square_loss": + T_temp = [t.T for t in T] + A = update_structure_matrix(p, lambdas, T_temp, As) + + if not fixed_edge_features: + if dist_fun_C == "l2_norm": + T_temp = [t.T for t in T] + C = update_edge_feature_tensor(p, lambdas, T_temp, Cs) + + T = [ + fused_network_gromov_wasserstein2( + Ms[s], + C, + Cs[s], + A, + As[s], + p, + ps[s], + dist_fun_C, + dist_fun_A, + alpha, + beta, + numItermax=max_iter, + stopThr=1e-5, + verbose=verbose > 2, + log=True, + )[1]["T"] + for s in range(S) + ] + + # T is N,ns + err_node_feature = nx.norm(F - nx.reshape(Xprev, (N, d))) + err_structure = nx.norm(A - Aprev) + err_edge_feature = nx.norm(C - Cprev) + if log: + log_["err_node_feature"].append(err_node_feature) + log_["err_edge_feature"].append(err_edge_feature) + log_["err_structure"].append(err_structure) + log_["Ts_iter"].append(T) + + if verbose: + if cpt % 200 == 0: + print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19) + print("{:5d}|{:8e}|".format(cpt, err_structure)) + print("{:5d}|{:8e}|".format(cpt, err_node_feature)) + print("{:5d}|{:8e}|".format(cpt, err_edge_feature)) + print("\n") + + cpt += 1 + + if log: + log_["T"] = T # from target to Fs + log_["p"] = p + log_["Ms"] = Ms + + if log: + return F, A, C, log_ + else: + return F, A, C + + +def update_structure_matrix(p, lambdas, T, As): + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the + `S` :math:`\mathbf{T}_s` couplings. + It is calculated at each iteration + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns, N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + As : list of S array-like, shape (ns, ns) + Metric cost matrices. + Returns + ------- + A : array-like, shape (`nt`, `nt`) + Updated :math:`\mathbf{A}` matrix. + """ + p = list_to_array(p) + T = list_to_array(*T) + As = list_to_array(*As) + nx = get_backend(*As, *T, p) + + tmpsum = sum( + [ + lambdas[s] * nx.dot(nx.dot(T[s].T, As[s]), T[s]) + for s in range(len(T)) + ] + ) + ppt = nx.outer(p, p) + return tmpsum / ppt + + +def update_node_feature_matrix(lambdas, Fs, Ts, p): + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` + couplings. + Parameters + ---------- + p : array-like, shape (N,) + masses in the targeted barycenter + lambdas : list of float + List of the `S` spaces' weights + Ts : list of S array-like, shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Fs : list of S array-like, shape (d,ns) + The features. + + Returns + ------- + F : array-like, shape (`d`, `N`) + """ + p = list_to_array(p) + Ts = list_to_array(*Ts) + Fs = list_to_array(*Fs) + if not isinstance(Ts, list): + Ts = [Ts] + if not isinstance(Fs, list): + Fs = [Fs] + nx = get_backend(*Fs, *Ts, p) + + p = 1.0 / p + tmpsum = sum( + [ + lambdas[s] * nx.dot(Fs[s], Ts[s].T) * p[None, :] + for s in range(len(Ts)) + ] + ) + return tmpsum + + +def update_edge_feature_tensor(p, lambdas, T, Cs): + r"""Updates :math:`\mathbf{C}` according to the l2 norm inner distance with + the `S` :math:`\mathbf{T}_s` couplings. + + It is calculated at each iteration + + Parameters + ---------- + p : array-like, shape (N,) + Masses in the targeted barycenter. + lambdas : list of float + List of the `S` spaces' weights. + T : list of S array-like of shape (ns,N) + The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. + Cs : list of S array-like, shape (ns,ns,d') + Edge features tensors + + Returns + ------- + C : array-like, shape (nt,nt,d') + Updated :math:`\mathbf{C}` tensor. + """ + p = list_to_array(p) + T = list_to_array(*T) + Cs = list_to_array(*Cs) + if not isinstance(T, list): + T = [T] + if not isinstance(Cs, list): + Cs = [Cs] + nx = get_backend(*Cs, *T, p) + + # Proposition 2.10 in our paper + tmpsum = sum( + [ + lambdas[s] + * nx.einsum( + "ijd,jk...->ikd", + nx.einsum("ij...,jkd->ikd", T[s].T, Cs[s]), + T[s], + ) + for s in range(len(T)) + ] + ) + ppt = nx.reshape(nx.outer(p, p), shape=(len(p), len(p), 1)) + return tmpsum / ppt + + +def fused_network_gromov_wasserstein2( + M, + C1, + C2, + A1, + A2, + p, + q, + dist_fun_C="l2_norm", + dist_fun_A="square_loss", + alpha=0.33, + beta=0.33, + armijo=False, + G0=None, + log=False, + **kwargs, +): + r""" + Computes the FNGW transport between two graphs + + See Algorithm 1 in our paper. + + Parameters + ---------- + M : array-like, shape (ns, nt) + Metric cost matrix between node features of source and target graphs + C1 : array-like, shape (ns, ns, d') + Edge feature tensor of the source graph + C2 : array-like, shape (nt, nt, d') + Edge feature tensor of the target graph + A1 : array-like, shape (ns, ns) + Structure matrix of the source graph + A2 : array-like, shape (nt, nt) + Structure matrix of the target graph + p : array-like, shape (ns,) + Distribution in the source space + q : array-like, shape (nt,) + Distribution in the target space + dist_fun_C : str, optional + Inner distance used for the edge feature tensor + dist_fun_A : str, optional + Loss function used for the structure matrix + alpha : float, optional + Trade-off parameter (0 < alpha < 1) + beta : float, optional + Trade-off parameter (0 < beta < 1) + armijo : bool, optional + If True the step of the line-search is found via an armijo research. + Else closed form is used. + If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as + initial transport of the solver. + log : bool, optional + record log if True + **kwargs : dict + parameters can be directly passed to the ot.optim.cg solver + + Returns + ------- + fngw_dist : float + FNGW distance for the given parameters. + log : dict + Log dictionary return only if log==True in parameters. + """ + assert alpha + beta <= 1 + p, q = list_to_array(p, q) + p0, q0, C10, C20, A10, A20, M0 = p, q, C1, C2, A1, A2, M + if G0 is None: + nx = get_backend(p0, q0, C10, C20, A10, A20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, A10, A20, M0, G0_) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + A1 = nx.to_numpy(A10) + A2 = nx.to_numpy(A20) + M = nx.to_numpy(M0) + + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) + + constA, hA1, hA2 = init_matrix_A(A1, A2, p, q, dist_fun_A) + constC = init_matrix_C(C1, C2, p, q, dist_fun_C) + + def f(G): + return ngwloss(constC, C1, C2, G) + + def df(G): + return ngwgrad(constC, C1, C2, G) + + def g(G): + return gwloss(constA, hA1, hA2, G) + + def dg(G): + return gwggrad(constA, hA1, hA2, G) + + T, cg_log = cg( + p, + q, + (1 - alpha - beta) * M, + reg_f=alpha, + reg_g=beta, + f=f, + df=df, + g=g, + dg=dg, + G0=G0, + armijo=armijo, + C1=C1, + C2=C2, + A1=A1, + A2=A2, + constC=constC, + constA=constA, + log=True, + **kwargs, + ) + + fngw_dist = nx.from_numpy(cg_log["loss"][-1], type_as=C10) + T0 = nx.from_numpy(T, type_as=C10) + cg_log["fngw_dist"] = fngw_dist + cg_log["u"] = nx.from_numpy(cg_log["u"], type_as=C10) + cg_log["v"] = nx.from_numpy(cg_log["v"], type_as=C10) + cg_log["T"] = T0 + + # TODO: implement the gradient for p0, q0 + if dist_fun_C == "l2_norm" and dist_fun_A == "square_loss": + gC1 = 2 * C1 * (p[:, None] * p[None, :])[:, :, None] - 2 * np.einsum( + "ilt, kl->ikt", np.einsum("ij,jlt->ilt", T, C2), T + ) + gC2 = 2 * C2 * (q[:, None] * q[None, :])[:, :, None] - 2 * np.einsum( + "jkt, kl->jlt", np.einsum("ij,ikt->jkt", T, C1), T + ) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) + + gA1 = 2 * A1 * (p[:, None] * p[None, :]) - 2 * T.dot(A2).dot(T.T) + gA2 = 2 * A2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(A1).dot(T) + gA1 = nx.from_numpy(gA1, type_as=A10) + gA2 = nx.from_numpy(gA2, type_as=A10) + + fngw_dist = nx.set_gradients( + fngw_dist, + (p0, q0, C10, C20, A10, A20, M0), + ( + cg_log["u"] + - nx.mean( + cg_log["u"] + ), # No need for p0, q0 since they will not be updated, + # keeps it right now + cg_log["v"] - nx.mean(cg_log["v"]), + alpha * gC1, + alpha * gC2, + beta * gA1, + beta * gA2, + (1 - alpha - beta) * T0, + ), + ) + if log: + return fngw_dist, cg_log + else: + return fngw_dist + + +def init_matrix_C(C1, C2, p, q, dist="l2_norm"): + r"""Computation of the sum of the first two terms of Equation (6) in our + paper. + + Parameters + ---------- + C1 : array-like, shape (ns, ns, d') + Edge feature tensor of the source graph + C2 : array-like, shape (nt, nt, d') + Edge feature tensor of the target graph + T : array-like, shape (ns, nt) + Coupling between source and target spaces + p : array-like, shape (ns,) + + Returns + ------- + constC : array-like, shape (ns, nt) + + """ + C1, C2, p, q = list_to_array(C1, C2, p, q) + nx = get_backend(C1, C2, p, q) + + if dist == "l2_norm": + + def f1(a): + return nx.sum(nx.power(a, 2), axis=-1) + + def f2(b): + return nx.sum(nx.power(b, 2), axis=-1) + + else: + raise ValueError + + constC1 = nx.dot( + nx.dot(f1(C1), nx.reshape(p, (-1, 1))), nx.ones((1, len(q)), type_as=q) + ) + constC2 = nx.dot( + nx.ones((len(p), 1), type_as=p), + nx.dot(nx.reshape(q, (1, -1)), f2(C2).T), + ) + constC = constC1 + constC2 + + return constC + + +def tensor_product(constC, C1, C2, T): + r"""Implementation of the Prop. 2.5 in our paper. + + Parameters + ---------- + constC : array-like, shape (ns, nt) + the sum of the first two terms of Eq. (6) + C1 : array-like, shape (ns, ns, d') + Edge feature tensor of the source graph + C2 : array-like, shape (nt, nt, d') + Edge feature tensor of the target graph + + T : array-like, shape (ns, nt) + + Returns + ------- + tens : array-like, shape (ns, nt) + + """ + constC, C1, C2, T = list_to_array(constC, C1, C2, T) + nx = get_backend(constC, C1, C2, T) + + A = -2 * nx.einsum( + "ijd, kjd->ikd", nx.einsum("ijd,jk...->ikd", C1, T), C2 + ) # (ns, nt, d) + + A = nx.sum(A, axis=-1) # (ns, nt) + tens = constC + A + # tens -= tens.min() + return tens + + +def ngwloss(constC, C1, C2, T): + r"""Compute the third term of Eq.5 in our paper + + Parameters + ---------- + constC : array-like, shape (ns, nt) + the sum of the first two terms of Eq. (6) + C1 : array-like, shape (ns, ns, d') + Edge feature tensor of the source graph + C2 : array-like, shape (nt, nt, d') + Edge feature tensor of the target graph + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + Current value of transport matrix :math:`\mathbf{T}` + + Returns + ------- + loss : float + + """ + + tens = tensor_product(constC, C1, C2, T) + + tens, T = list_to_array(tens, T) + nx = get_backend(tens, T) + + return nx.sum(tens * T) + + +def ngwgrad(constC, C1, C2, T): + r"""Compute the third term of Eq.7 in our paper + + Parameters + ---------- + constC : array-like, shape (ns, nt) + the sum of the first two terms of Eq. (6) + C1 : array-like, shape (ns, ns, d') + Edge feature tensor of the source graph + C2 : array-like, shape (nt, nt, d') + Edge feature tensor of the target graph + T : array-like, shape (ns, nt) + Current value of transport matrix :math:`\mathbf{T}` + + Returns + ------- + grad : array-like, shape (`ns`, `nt`) + + """ + return 2 * tensor_product(constC, C1, C2, T) + + +def cg( + a, + b, + M, + reg_f, + reg_g, + f, + df, + g, + dg, + G0=None, + numItermax=200, + numItermaxEmd=100000, + stopThr=1e-9, + stopThr2=1e-9, + verbose=False, + log=False, + **kwargs, +): + r""" + Solve the general regularized OT problem with conditional gradient + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} + \rangle_F + \mathrm{reg} \cdot f(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + where : + + - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix + - :math:`f` is the regularization term (and `df` is its gradient) + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + (sum to 1) + + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[1] ` + + + Parameters + ---------- + a : array-like, shape (ns,) + samples weights in the source domain + b : array-like, shape (nt,) + samples in the target domain + M : array-like, shape (ns, nt) + loss matrix + reg_f : float + Regularization term >0 + reg_g : float + Regularization term >0 + G0 : array-like, shape (ns,nt), optional + initial guess (default is indep joint density) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + **kwargs : dict + Parameters for linesearch + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-cg: + References + ---------- + + .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). + Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, + 7(3), 1853-1882. + + See Also + -------- + ot.lp.emd : Unregularized optimal ransport + ot.bregman.sinkhorn : Entropic regularized optimal transport + + """ + a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(a, b) + else: + nx = get_backend(a, b, M) + + loop = 1 + + if log: + log = {"loss": []} + + if G0 is None: + G = nx.outer(a, b) + else: + G = G0 + + def cost(G): + return nx.sum(M * G) + reg_f * f(G) + reg_g * g(G) + + f_val = cost(G) + if log: + log["loss"].append(f_val) + + it = 0 + + if verbose: + print( + "{:5s}|{:12s}|{:8s}|{:8s}".format( + "It.", "Loss", "Relative loss", "Absolute loss" + ) + + "\n" + + "-" * 48 + ) + print("{:5d}|{:8e}|{:8e}|{:8e}".format(it, f_val, 0, 0)) + + while loop: + + it += 1 + old_fval = f_val + + # problem linearization + Mi = M + reg_f * df(G) + reg_g * dg(G) + # set M positive + Mi += nx.min(Mi) + + # solve linear program + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) + + deltaG = Gc - G + + # line search + alpha, fc, f_val = solve_linesearch( + cost, + G, + deltaG, + Mi, + f_val, + reg_f=reg_f, + reg_g=reg_g, + M=M, + Gc=Gc, + alpha_min=0.0, + alpha_max=1.0, + **kwargs, + ) + + G = G + alpha * deltaG + + # test convergence + if it >= numItermax: + loop = 0 + + abs_delta_fval = abs(f_val - old_fval) + relative_delta_fval = abs_delta_fval / abs(f_val) + if relative_delta_fval < stopThr or abs_delta_fval < stopThr2: + loop = 0 + + if log: + log["loss"].append(f_val) + + if verbose: + if it % 20 == 0: + print( + "{:5s}|{:12s}|{:8s}|{:8s}".format( + "It.", "Loss", "Relative loss", "Absolute loss" + ) + + "\n" + + "-" * 48 + ) + print( + "{:5d}|{:8e}|{:8e}|{:8e}".format( + it, f_val, relative_delta_fval, abs_delta_fval + ) + ) + + if log: + log.update(logemd) + return G, log + else: + return G + + +def solve_linesearch( + cost, + G, + deltaG, + Mi, + val, + armijo=True, + C1=None, + C2=None, + reg_f=None, + A1=None, + A2=None, + reg_g=None, + Gc=None, + constC=None, + constA=None, + M=None, + alpha_min=None, + alpha_max=None, +): + """ + Solve the linesearch in the FNGW iterations + + Parameters + ---------- + cost : method + Cost in the FNGW for the linesearch + G : array-like, shape(ns,nt) + The transport map at a given iteration of the FNGW + deltaG : array-like (ns,nt) + Difference between the optimal map found by linearization in the FW + algorithm and the value at a given iteration + Mi : array-like (ns,nt) + Cost matrix of the linearized transport problem. Corresponds to the + gradient of the cost + val : float + Value of the cost at `G` + armijo : bool, optional + If True the steps of the line-search is found via an armijo research. + Else closed form is used. + If there is convergence issues use False. + C1 : array-like (ns,ns,d'), optional + Edge feature tensor of the source graph. Only used and necessary when + armijo=False + C2 : array-like (nt,nt,d'), optional + Edge feature tensor in the target graph. Only used and necessary when + armijo=False + reg_f : float, optional + Regularization parameter. Only used and necessary when armijo=False + A1 : array-like (ns,ns), optional + Structure matrix of the source graph. Only used and necessary when + armijo=False + A2 : array-like (nt,nt), optional + Structure matrix of the target graph. Only used and necessary when + armijo=False + reg_g : float, optional + Regularization parameter. Only used and necessary when armijo=False + Gc : array-like (ns,nt) + Optimal map found by linearization in the FW algorithm. Only used and + necessary when armijo=False + constC : array-like (ns,nt) + Constant for the gromov cost. Only used and necessary when armijo=False + See :ref:`[24] `. + + M : array-like (ns,nt), optional + Cost matrix between the features. Only used and necessary when + armijo=False + alpha_min : float, optional + Minimum value for alpha + alpha_max : float, optional + Maximum value for alpha + + Returns + ------- + alpha : float + The optimal step size of the FW + fc : int + nb of function call. Useless here + f_val : float + The value of the cost for the next iteration + + """ + if armijo: + # TODO: Update for armijo + alpha, fc, f_val = line_search_armijo( + cost, G, deltaG, Mi, val, alpha_min=alpha_min, alpha_max=alpha_max + ) + else: + G, deltaG, C1, C2, constC, A1, A2, constA, M = list_to_array( + G, deltaG, C1, C2, constC, A1, A2, constA, M + ) + if isinstance(M, int) or isinstance(M, float): + nx = get_backend(G, deltaG, C1, C2, constC, A1, A2, constA) + else: + nx = get_backend(G, deltaG, C1, C2, constC, A1, A2, constA, M) + + dotC_1 = nx.sum( + nx.einsum( + "ijd, kjd->ikd", nx.einsum("ijd,jk...->ikd", C1, deltaG), C2 + ), + axis=-1, + ) + dotC_2 = nx.sum( + nx.einsum("ijd, kjd->ikd", nx.einsum("ijd,jk...->ikd", C1, G), C2), + axis=-1, + ) + + dotA_1 = nx.dot(nx.dot(A1, deltaG), A2.T) + dotA_2 = nx.dot(nx.dot(A1, G), A2.T) + + a = -2 * reg_f * nx.sum(dotC_1 * deltaG) - 2 * reg_g * nx.sum( + dotA_1 * deltaG + ) + + b = ( + nx.sum((M + reg_f * constC + reg_g * constA) * deltaG) + - 2 * reg_f * (nx.sum(dotC_1 * G) + nx.sum(dotC_2 * deltaG)) + - 2 * reg_g * (nx.sum(dotA_1 * G) + nx.sum(dotA_2 * deltaG)) + ) + + # c = cost(G) + # c was pased to solve_linesearch as c, which does not exist + alpha = solve_1d_linesearch_quad(a, b) + if alpha_min is not None or alpha_max is not None: + alpha = np.clip(alpha, alpha_min, alpha_max) + fc = None + f_val = cost(G + alpha * deltaG) + + return alpha, fc, f_val diff --git a/ext/PythonOptimalTransport/srGW.jl b/ext/PythonOptimalTransport/srGW.jl new file mode 100644 index 0000000..1523329 --- /dev/null +++ b/ext/PythonOptimalTransport/srGW.jl @@ -0,0 +1,510 @@ +# Julia implementation of semi-relaxed Gromov-Wasserstein algorithms +# Converted from Python implementation by cvincentcuaz +# This module provides conditional gradient, mirror descent, and MM algorithms +# for semi-relaxed (fused) Gromov-Wasserstein optimal transport + +using LinearAlgebra +using Random + +# ============================================================================= +# Utility Functions +# ============================================================================= + +# Initialize transport plan for semi-relaxed GW +# Arguments: +# init_mode: "product", "random", or "random_product" +# p: source distribution (N1,) +# N1, N2: dimensions +# seed: random seed (nothing for no seeding) +# Returns: T - initial transport plan (N1, N2) +function initializer_semirelaxed_GW( + init_mode::String, p::AbstractVector{T}, N1::Int, N2::Int; + seed::Union{Int, Nothing} = 0) where {T <: Real} + if init_mode == "product" + q = ones(T, N2) / N2 + T_plan = p * q' + elseif init_mode == "random" + if !isnothing(seed) + Random.seed!(seed) + end + T_plan = rand(T, N1, N2) + # Scale to satisfy first marginal constraint + scale = p ./ sum(T_plan, dims = 2) + T_plan .*= scale + elseif init_mode == "random_product" + if !isnothing(seed) + Random.seed!(seed) + end + q = rand(T, N2) + q ./= sum(q) + T_plan = p * q' + else + error("Unknown init mode: $init_mode") + end + return T_plan +end + +# Initialize matrices for symmetric GW computation +function init_matrix_GW2(C1::AbstractMatrix{T}, C2::AbstractMatrix{T}, + p::AbstractVector{T}, q::AbstractVector{T}, + ones_p::AbstractVector{T}, ones_q::AbstractVector{T}) where {T <: Real} + f1_ = C1 .^ 2 + f2_ = C2 .^ 2 + constC1 = f1_ * (p * ones_q') + constC2 = (ones_p * q') * f2_ + constC = constC1 + constC2 + hC1 = C1 + hC2 = 2 * C2 + return constC, hC1, hC2 +end + +# Initialize matrices for asymmetric GW computation +function init_matrix_asymGW2(C1::AbstractMatrix{T}, C2::AbstractMatrix{T}, + p::AbstractVector{T}, q::AbstractVector{T}, + ones_p::AbstractVector{T}, ones_q::AbstractVector{T}) where {T <: Real} + f1_ = (C1 .^ 2) / 2.0 + f2_ = (C2 .^ 2) / 2.0 + constC1 = f1_ * (p * ones_q') + constC2 = (ones_p * q') * f2_' + constC = constC1 + constC2 + hC1 = C1 + hC2 = C2 + return constC, hC1, hC2 +end + +# Compute tensor product for GW distance +function tensor_product(constC::AbstractMatrix{T}, hC1::AbstractMatrix{T}, + hC2::AbstractMatrix{T}, T_plan::AbstractMatrix{T}) where {T <: Real} + A = -hC1 * T_plan * hC2' + return constC + A +end + +# ============================================================================= +# Conditional Gradient Descent Algorithms +# ============================================================================= + +# Conditional gradient algorithm for semi-relaxed (fused) Gromov-Wasserstein +# Solves: min_T α * ⟨L(C₁, C₂) ⊗ T, T⟩ + ⟨M, T⟩ +function cg_semirelaxed(C1::AbstractMatrix{T}, p::AbstractVector{T}, C2::AbstractMatrix{T}; + alpha::Real = 1.0, linear_cost::Union{Nothing, AbstractMatrix{T}} = nothing, + init_mode::String = "product", T_init::Union{Nothing, AbstractMatrix{T}} = nothing, + symmetry::Bool = true, use_log::Bool = false, eps::Real = 1e-5, + max_iter::Int = 1000, seed::Int = 0, verbose::Bool = false) where {T <: Real} + N1, N2 = size(C1, 1), size(C2, 1) + + # Initialize transport plan + if isnothing(T_init) + T_plan = initializer_semirelaxed_GW(init_mode, p, N1, N2; seed = seed) + else + @assert size(T_init) == (N1, N2) + T_plan = copy(T_init) + end + + # Check symmetry + if isnothing(symmetry) + symmetry = (C1 == C1') && (C2 == C2') + end + + # Initialize + q = vec(sum(T_plan, dims = 1)) + ones_p = ones(T, N1) + ones_q = ones(T, N2) + + # Compute initial gradient + if symmetry + constC, hC1, hC2 = init_matrix_GW2(C1, C2, p, q, ones_p, ones_q) + G = 2 * tensor_product(constC, hC1, hC2, T_plan) + else + constC, hC1, hC2 = init_matrix_asymGW2(C1, C2, p, q, ones_p, ones_q) + constCt, hC1t, hC2t = init_matrix_asymGW2(C1', C2', p, q, ones_p, ones_q) + subG = tensor_product(constC, hC1, hC2, T_plan) + subGt = tensor_product(constCt, hC1t, hC2t, T_plan) + G = subG + subGt + end + G .*= alpha + + srgw_loss = 0.5 * sum(G .* T_plan) + add_linear_cost = !isnothing(linear_cost) + + if add_linear_cost + linear_loss = sum(linear_cost .* T_plan) + current_loss = srgw_loss + linear_loss + G .+= linear_cost + else + current_loss = srgw_loss + end + + log = use_log ? Dict("loss" => [current_loss]) : nothing + convergence_criterion = Inf + outer_count = 0 + + while convergence_criterion > eps && outer_count < max_iter + previous_loss = current_loss + + # Direction finding by solving subproblem on rows + min_vals = minimum(G, dims = 2) + X = (G .== min_vals) .* T(1.0) + row_sums = vec(sum(X, dims = 2)) + scale = p ./ row_sums + X .*= scale + + # Exact line search + qX = vec(sum(X, dims = 1)) + + if symmetry + constCX, hC1X, hC2X = init_matrix_GW2(C1, C2, p, qX, ones_p, ones_q) + GX = 2 * alpha * tensor_product(constCX, hC1X, hC2X, X) + GXX = 0.5 * sum(GX .* X) + GXT = 0.5 * sum(GX .* T_plan) + + a = srgw_loss + GXX - 2 * GXT + b = 2 * (GXT - srgw_loss) + else + constCX, hC1X, hC2X = init_matrix_asymGW2(C1, C2, p, qX, ones_p, ones_q) + constCXt, hC1Xt, hC2Xt = init_matrix_asymGW2(C1', C2', p, qX, ones_p, ones_q) + subGX = tensor_product(constCX, hC1X, hC2X, X) + subGXt = tensor_product(constCXt, hC1Xt, hC2Xt, X) + GX = alpha * (subGX + subGXt) + GXX = 0.5 * sum(GX .* X) + subGXt_dotT = sum(subGXt .* T_plan) + subGTt_dotX = sum(subGt .* X) + + a = srgw_loss + GXX - subGXt_dotT - subGTt_dotX + b = -2 * srgw_loss + subGXt_dotT + subGTt_dotX + end + + if add_linear_cost + linear_loss_X = sum(linear_cost .* X) + b += linear_loss_X - linear_loss + end + + # Compute step size + if a > 0 + gamma = min(1, max(0, -b / (2 * a))) + elseif a + b < 0 + gamma = 1 + else + gamma = 0 + end + + # Update + T_plan .= (1 - gamma) * T_plan + gamma * X + current_loss += a * gamma^2 + b * gamma + + if add_linear_cost + linear_loss = (1 - gamma) * linear_loss + gamma * linear_loss_X + srgw_loss = current_loss - linear_loss + G .= (1 - gamma) * G + gamma * (GX + linear_cost) + else + srgw_loss = current_loss + G .= (1 - gamma) * G + gamma * GX + end + + outer_count += 1 + use_log && push!(log["loss"], current_loss) + + convergence_criterion = abs(previous_loss - current_loss) / + (abs(previous_loss) + 1e-15) + end + + return use_log ? (T_plan, current_loss, log) : (T_plan, current_loss) +end + +# Conditional gradient for semi-relaxed Gromov-Wasserstein +# Wrapper for cg_semirelaxed with α=1 and no linear cost +function cg_semirelaxed_gromov_wasserstein(C1::AbstractMatrix{T}, p::AbstractVector{T}, + C2::AbstractMatrix{T}; kwargs...) where {T <: Real} + return cg_semirelaxed(C1, p, C2; alpha = 1.0, linear_cost = nothing, kwargs...) +end + +# Conditional gradient for semi-relaxed fused Gromov-Wasserstein +# A1, A2: Feature matrices (N1×d), (N2×d) +# alpha: Trade-off parameter (0 for pure OT, 1 for pure GW) +function cg_semirelaxed_fused_gromov_wasserstein( + C1::AbstractMatrix{T}, A1::AbstractMatrix{T}, + p::AbstractVector{T}, C2::AbstractMatrix{T}, + A2::AbstractMatrix{T}, alpha::Real; + kwargs...) where {T <: Real} + N1, N2 = size(A1, 1), size(A2, 1) + d = size(A1, 2) + + # Compute Euclidean distance matrix between features + A1_sq = sum(A1 .^ 2, dims = 2) * ones(T, 1, N2) + A2_sq = ones(T, N1, 1) * sum(A2 .^ 2, dims = 2)' + D = A1_sq + A2_sq - 2 * A1 * A2' + + return cg_semirelaxed( + C1, p, C2; alpha = alpha, linear_cost = (1 - alpha) * D, kwargs...) +end + +# ============================================================================= +# Mirror Descent Algorithms (Entropic Regularization) +# ============================================================================= + +# Mirror descent algorithm using KL geometry for semi-relaxed (fused) GW +# gamma_entropy: Entropy regularization parameter (must be > 0) +function md_semirelaxed(C1::AbstractMatrix{T}, p::AbstractVector{T}, C2::AbstractMatrix{T}, + gamma_entropy::Real; alpha::Real = 1.0, + linear_cost::Union{Nothing, AbstractMatrix{T}} = nothing, + init_mode::String = "product", T_init::Union{Nothing, AbstractMatrix{T}} = nothing, + symmetry::Bool = true, use_log::Bool = false, eps::Real = 1e-5, + max_iter::Int = 1000, seed::Int = 0, verbose::Bool = false) where {T <: Real} + @assert gamma_entropy>0 "gamma_entropy must be positive" + + N1, N2 = size(C1, 1), size(C2, 1) + + # Initialize transport plan + if isnothing(T_init) + T_plan = initializer_semirelaxed_GW(init_mode, p, N1, N2; seed = seed) + else + @assert size(T_init) == (N1, N2) + T_plan = copy(T_init) + end + + # Check symmetry + if isnothing(symmetry) + symmetry = (C1 == C1') && (C2 == C2') + end + + # Initialize + q = vec(sum(T_plan, dims = 1)) + ones_p = ones(T, N1) + ones_q = ones(T, N2) + + # Compute initial gradient + if symmetry + constC, hC1, hC2 = init_matrix_GW2(C1, C2, p, q, ones_p, ones_q) + G = 2 * alpha * tensor_product(constC, hC1, hC2, T_plan) + else + constC, hC1, hC2 = init_matrix_asymGW2(C1, C2, p, q, ones_p, ones_q) + constCt, hC1t, hC2t = init_matrix_asymGW2(C1', C2', p, q, ones_p, ones_q) + subG = tensor_product(constC, hC1, hC2, T_plan) + subGt = tensor_product(constCt, hC1t, hC2t, T_plan) + G = alpha * (subG + subGt) + end + + current_loss = 0.5 * sum(G .* T_plan) + add_linear_cost = !isnothing(linear_cost) + + if add_linear_cost + linear_loss = sum(linear_cost .* T_plan) + current_loss += linear_loss + G .+= linear_cost + end + + log = use_log ? Dict("loss" => [current_loss]) : nothing + convergence_criterion = Inf + outer_count = 0 + + while convergence_criterion > eps && outer_count < max_iter + previous_loss = current_loss + + # Compute Bregman projection + M = G - gamma_entropy * Base.log.(T_plan) + K = Base.exp.(-M / gamma_entropy) + scaling = p ./ vec(sum(K, dims = 2)) + T_plan .= (scaling .* ones(T, 1, N2)) .* K + + q = vec(sum(T_plan, dims = 1)) + + # Update gradient + if symmetry + constC, hC1, hC2 = init_matrix_GW2(C1, C2, p, q, ones_p, ones_q) + G = 2 * alpha * tensor_product(constC, hC1, hC2, T_plan) + else + constC, hC1, hC2 = init_matrix_asymGW2(C1, C2, p, q, ones_p, ones_q) + constCt, hC1t, hC2t = init_matrix_asymGW2(C1', C2', p, q, ones_p, ones_q) + subG = tensor_product(constC, hC1, hC2, T_plan) + subGt = tensor_product(constCt, hC1t, hC2t, T_plan) + G = alpha * (subG + subGt) + end + + current_loss = 0.5 * sum(G .* T_plan) + + if add_linear_cost + linear_loss = sum(linear_cost .* T_plan) + current_loss += linear_loss + G .+= linear_cost + end + + outer_count += 1 + use_log && push!(log["loss"], current_loss) + + convergence_criterion = abs(previous_loss - current_loss) / + (abs(previous_loss) + 1e-15) + end + + return use_log ? (T_plan, current_loss, log) : (T_plan, current_loss) +end + +# Mirror descent for semi-relaxed Gromov-Wasserstein with entropic regularization +function md_semirelaxed_gromov_wasserstein(C1::AbstractMatrix{T}, p::AbstractVector{T}, + C2::AbstractMatrix{T}, gamma_entropy::Real; + kwargs...) where {T <: Real} + return md_semirelaxed( + C1, p, C2, gamma_entropy; alpha = 1.0, linear_cost = nothing, kwargs...) +end + +# Mirror descent for semi-relaxed fused Gromov-Wasserstein with entropic regularization +function md_semirelaxed_fused_gromov_wasserstein( + C1::AbstractMatrix{T}, A1::AbstractMatrix{T}, + p::AbstractVector{T}, C2::AbstractMatrix{T}, + A2::AbstractMatrix{T}, gamma_entropy::Real, + alpha::Real; kwargs...) where {T <: Real} + N1, N2 = size(A1, 1), size(A2, 1) + d = size(A1, 2) + + # Compute Euclidean distance matrix + A1_sq = sum(A1 .^ 2, dims = 2) * ones(T, 1, N2) + A2_sq = ones(T, N1, 1) * sum(A2 .^ 2, dims = 2)' + D = A1_sq + A2_sq - 2 * A1 * A2' + + return md_semirelaxed(C1, p, C2, gamma_entropy; alpha = alpha, + linear_cost = (1 - alpha) * D, kwargs...) +end + +# ============================================================================= +# Majorization-Minimization Algorithms with Sparsity Regularization +# ============================================================================= + +# MM algorithm with ℓₚ-ℓ₁ sparsity regularization for semi-relaxed (fused) GW +# Solves: min_T α⟨L(C₁,C₂)⊗T,T⟩ + ⟨M,T⟩ + λ∑ⱼ(∑ᵢTᵢⱼ)^p +function mm_lpl1_semirelaxed( + C1::AbstractMatrix{T}, p::AbstractVector{T}, C2::AbstractMatrix{T}, + gamma_entropy::Real; alpha::Real = 1.0, + linear_cost::Union{Nothing, AbstractMatrix{T}} = nothing, + T_init::Union{Nothing, AbstractMatrix{T}} = nothing, + init_mode::String = "product", symmetry::Bool = true, + p_reg::Real = 0.5, lambda_reg::Real = 0.001, + use_log::Bool = false, use_warmstart::Bool = false, + eps_inner::Real = 1e-6, eps_outer::Real = 1e-6, + max_iter_inner::Int = 1000, max_iter_outer::Int = 50, + seed::Int = 0, verbose::Bool = false, + inner_log::Bool = false) where {T <: Real} + @assert 0=0 "gamma_entropy must be non-negative" + + N1, N2 = size(C1, 1), size(C2, 1) + + # Initialize + if isnothing(T_init) + T_plan = initializer_semirelaxed_GW(init_mode, p, N1, N2; seed = seed) + T_init_warm = use_warmstart ? copy(T_plan) : nothing + else + @assert size(T_init) == (N1, N2) + T_plan = copy(T_init) + T_init_warm = nothing + end + + # Inner solver selection + if gamma_entropy == 0 + inner_solver = (total_linear_cost, + T_init_local) -> cg_semirelaxed( + C1, p, C2; alpha = alpha, linear_cost = total_linear_cost, + init_mode = init_mode, T_init = T_init_local, symmetry = symmetry, + use_log = inner_log, eps = eps_inner, max_iter = max_iter_inner, + seed = seed, verbose = verbose + ) + else + inner_solver = (total_linear_cost, + T_init_local) -> md_semirelaxed( + C1, p, C2, gamma_entropy; alpha = alpha, linear_cost = total_linear_cost, + init_mode = init_mode, T_init = T_init_local, symmetry = symmetry, + use_log = inner_log, eps = eps_inner, max_iter = max_iter_inner, + seed = seed, verbose = verbose + ) + end + + reg_linear_cost = zeros(T, N1, N2) + total_linear_cost = isnothing(linear_cost) ? nothing : copy(linear_cost) + + best_T = copy(T_plan) + ones_p = ones(T, N1, 1) + + log = use_log ? Dict("loss" => T[], "inner_loss" => []) : nothing + best_loss = T(Inf) + current_loss = T(1e15) + convergence_criterion = Inf + outer_count = 0 + + while convergence_criterion > eps_outer && outer_count < max_iter_outer + previous_loss = current_loss + + # Solve generalized problem + result = inner_solver(total_linear_cost, use_warmstart ? T_init_warm : nothing) + + if inner_log + T_plan, majorization_loss, inner_log_data = result + else + T_plan, majorization_loss = result + end + + # Compute linearized reg loss + linearized_reg_loss = sum(reg_linear_cost .* T_plan) + + if use_warmstart + T_init_warm = copy(T_plan) + end + + # Update regularization + q = vec(sum(T_plan, dims = 1)) + reg_loss = lambda_reg * sum((q .+ 1e-15) .^ p_reg) + current_loss = majorization_loss - linearized_reg_loss + reg_loss + + reg_linear_cost .= lambda_reg * p_reg * ((ones_p * q') .+ 1e-15) .^ (p_reg - 1.0) + + if isnothing(linear_cost) + total_linear_cost = reg_linear_cost + else + total_linear_cost = reg_linear_cost + linear_cost + end + + if verbose + println("Outer iter $outer_count: loss = $current_loss, q = $q") + end + + outer_count += 1 + + if use_log + push!(log["loss"], current_loss) + inner_log && push!(log["inner_loss"], inner_log_data) + end + + convergence_criterion = abs(previous_loss - current_loss) / + (abs(previous_loss) + 1e-15) + + if current_loss < best_loss + best_loss = current_loss + best_T = copy(T_plan) + end + end + + return use_log ? (best_T, best_loss, log) : (best_T, best_loss) +end + +# MM algorithm with sparsity for semi-relaxed Gromov-Wasserstein +function mm_lpl1_semirelaxed_gromov_wasserstein( + C1::AbstractMatrix{T}, p::AbstractVector{T}, + C2::AbstractMatrix{T}, gamma_entropy::Real; + kwargs...) where {T <: Real} + return mm_lpl1_semirelaxed(C1, p, C2, gamma_entropy; alpha = 1.0, + linear_cost = nothing, kwargs...) +end + +# MM algorithm with sparsity for semi-relaxed fused Gromov-Wasserstein +function mm_lpl1_semirelaxed_fused_gromov_wasserstein( + C1::AbstractMatrix{T}, A1::AbstractMatrix{T}, + p::AbstractVector{T}, C2::AbstractMatrix{T}, + A2::AbstractMatrix{T}, alpha::Real, + gamma_entropy::Real; kwargs...) where {T <: Real} + N1, N2 = size(A1, 1), size(A2, 1) + d = size(A1, 2) + + # Compute Euclidean distance matrix + A1_sq = sum(A1 .^ 2, dims = 2) * ones(T, 1, N2) + A2_sq = ones(T, N1, 1) * sum(A2 .^ 2, dims = 2)' + D = A1_sq + A2_sq - 2 * A1 * A2' + + return mm_lpl1_semirelaxed(C1, p, C2, gamma_entropy; alpha = alpha, + linear_cost = (1 - alpha) * D, kwargs...) +end diff --git a/src/GreedySuffStats.jl b/src/GreedySuffStats.jl new file mode 100644 index 0000000..d164e84 --- /dev/null +++ b/src/GreedySuffStats.jl @@ -0,0 +1,163 @@ +abstract type SBMEstimator end + +abstract type Result end + +struct NethistResult{L, M} <: Result + labels::L + model::M +end + +function permute!(res::NethistResult, perm::AbstractVector{<:Integer}) + permute!(res.model, perm) + res.labels .= map(x -> perm[x], res.labels) + return res +end + +struct GreedySuffStats{M, NodeR <: NodeSwapRule, StopR <: StopRule} <: SBMEstimator + block_ss::M + block_ss_swap::M + node_swap_rule::NodeR + stop_rule::StopR + max_iter::Int +end + +function init!(es::GreedySuffStats, data, node_labels) + # Initialize the sufficient statistics for each block + for j in axes(data, 2) + gj = node_labels[j] + for i in 1:(j - 1) # More efficient than i < j check inside loop + edge_value = data[i, j] + gi = node_labels[i] + es.block_ss[gi, gj] = add_sample(es.block_ss[gi, gj], edge_value, i, j) + es.block_ss_swap[gi, gj] = add_sample( + es.block_ss_swap[gi, gj], edge_value, i, j) + end + end +end + +loss(es::GreedySuffStats; norm = 1.0) = loss(es.block_ss; norm = norm) + +# TODO: allow for non-symmetric data +@inline function loss(matrix_ss::SymArray{<:SuffStats}; norm = 1.0) + total_loss = 0.0 + for m in matrix_ss.uppertrian.nzval + total_loss += loss(m) + end + return total_loss / norm +end + +@inline function loss(matrix_ss::AbstractMatrix{<:SuffStats}; norm = 1.0) + total_loss = 0.0 + @inbounds for j in axes(matrix_ss, 2) + for i in 1:j + inter = loss(matrix_ss[i, j]) + total_loss += inter + end + end + return total_loss / norm +end + +function make_greedy_suffstats_estimator( + data, + node_labels; + type_suff_stats = Val(:categorical), + max_iter = 10_000, + node_swap_rule = RandomGroupSwap(), + stop_rule = PreviousBestValue(5_000, Inf, :min), + kwargs... +) + k = length(unique(node_labels)) + block_ss = make_k_block(k, type_suff_stats; data = data, kwargs...) + block_ss_swap = make_k_block(k, type_suff_stats; data = data, kwargs...) + return GreedySuffStats{typeof(block_ss), typeof(node_swap_rule), typeof(stop_rule)}( + block_ss, block_ss_swap, node_swap_rule, stop_rule, max_iter) +end + +function estimate!( + es::GreedySuffStats, + data, + node_labels_init; + progress = true, + iter_progress = 5000 +) + # Initialize node labels + node_labels = copy(node_labels_init) + n = length(node_labels) + k = length(unique(node_labels)) + n_edges = n * (n - 1) / 2 + init!(es, data, node_labels) + + # Progress tracking + pbar = ProgressUnknown( + enabled = progress, + showspeed = true, + desc = "Greedy search: " + ) + + # Update progress bar only every N iterations to reduce overhead + progress_update_interval = max(1, es.max_iter ÷ iter_progress) + # Initial log-likelihood + + current_loss = loss(es.block_ss, norm = n_edges) + reset!(es.stop_rule, current_loss) + # Main optimization loop + for iter in 1:(es.max_iter) + # Select two nodes to potentially swap + index1, index2 = select_indices_swap(node_labels, es.node_swap_rule, k) + + group1 = node_labels[index1] + group2 = node_labels[index2] + + @inbounds for j in axes(data, 2) + if j != index1 && j != index2 + # extract data + groupj = node_labels[j] + edge_value_1 = data[j, index1] + edge_value_2 = data[j, index2] + + es.block_ss_swap[group1, groupj] = remove_sample( + es.block_ss_swap[group1, groupj], edge_value_1, j, index1) + es.block_ss_swap[group2, groupj] = add_sample( + es.block_ss_swap[group2, groupj], edge_value_1, j, index1) + + es.block_ss_swap[group2, groupj] = remove_sample( + es.block_ss_swap[group2, groupj], edge_value_2, j, index2) + es.block_ss_swap[group1, groupj] = add_sample( + es.block_ss_swap[group1, groupj], edge_value_2, j, index2) + end + end + + # tentative swap + @inbounds node_labels[index1], node_labels[index2] = group2, group1 + new_loss = loss(es.block_ss_swap, norm = n_edges) + + if compare_to_best(new_loss, current_loss, es.stop_rule) + # apply swap + copy!(es.block_ss, es.block_ss_swap) + current_loss = new_loss + else + # revert swap + node_labels[index1], node_labels[index2] = group1, group2 + # revert sufficient statistics + copy!(es.block_ss_swap, es.block_ss) + end + + if progress && (iter % progress_update_interval == 0 || iter == es.max_iter) + update!( + pbar, iter; + showvalues = [ + ("loss", current_loss), + info_to_print(es.stop_rule) + ]) + end + + # Check stopping criterion + if stopping_rule(current_loss, es.stop_rule) + break + end + end + finish!(pbar) + @info "Optimization finished. Final loss: $current_loss" + + return node_labels, to_params.(es.block_ss) +end diff --git a/src/NetworkHistogram.jl b/src/NetworkHistogram.jl index 1272972..d489b7f 100644 --- a/src/NetworkHistogram.jl +++ b/src/NetworkHistogram.jl @@ -1,29 +1,35 @@ module NetworkHistogram +using Accessors +using StatsBase +using StaticArrays +using ProgressMeter +import StatsAPI: loglikelihood, fit, params +import Base: convert, eltype, zero +using Distributions +using LinearAlgebra +using ArgCheck +import Random: randperm, AbstractRNG, rand, shuffle +import Distributions: logpdf, pdf +import LogExpFunctions: xlogx +using IntervalSets +using Hungarian +using Reexport +@reexport using Graphons + +import Graphons: _extract_param, convert_to_params, node_labels_to_latents + +include("SymArray.jl") +@reexport using .FastSymArray + +include("distributions/hist_dist.jl") +include("preprocessor/abstractConvertor.jl") +include("config_rules/include.jl") +include("pseudo_suff_stats/abstract_suffstat.jl") +include("GreedySuffStats.jl") +include("utils/utils_node_labels.jl") +include("api.jl") + +export GreedyParams, nethist, nethist_discrete_edges, ordered_start_labels, RandomGroupSwap, + Strict, PreviousBestValue, nethist_binary_edges -using ValueHistories, StatsBase, Random, LinearAlgebra, Kronecker - -using Arpack: eigs -using ArnoldiMethod: partialschur, partialeigen, SR, LR - -export graphhist, PreviousBestValue, Strict, RandomNodeSwap -export OrderedStart, RandomStart, EigenStart, DistStart - -include("group_numbering.jl") -include("assignment.jl") -include("history.jl") - -include("config_rules/starting_assignment_rule.jl") -include("config_rules/swap_rule.jl") -include("config_rules/accept_rule.jl") -include("config_rules/stop_rule.jl") -include("config_rules/bandwidth_selection_rule.jl") - -include("optimize.jl") -include("histogram.jl") -include("proposal.jl") - -include("utils.jl") - -include("data/gt.jl") -include("data/datasets.jl") end diff --git a/src/SymArray.jl b/src/SymArray.jl new file mode 100644 index 0000000..f886556 --- /dev/null +++ b/src/SymArray.jl @@ -0,0 +1,281 @@ +""" +FastSymArray - Efficient symmetric matrix storage + +This module provides `SymArray`, a memory-efficient storage for symmetric matrices +that only stores the upper triangle (including diagonal) of the matrix using a sparse matrix. +""" +module FastSymArray + +using SparseArrays +using LinearAlgebra +import Base: eltype, convert, size, getindex, setindex!, copy!, similar, + IndexStyle, axes, length, iterate, copyto!, fill! +import SparseArrays: getcolptr, nonzeros, FixedSparseCSC + +export SymArray, eltype, deepcopy! + +""" + SymArray{F} <: AbstractSparseMatrix{F, 2} + +A symmetric matrix that stores only the upper triangle using a sparse matrix. + +For a k×k symmetric matrix, only k(k+1)/2 elements are stored instead of k². +This implementation uses Julia's SparseMatrixCSC for efficient storage and access. + +# Fields +- `uppertrian::SparseMatrixCSC{F, Int}`: Sparse matrix storing the upper triangle (i ≤ j) + +# Examples +```julia +# Create a 3×3 symmetric matrix +sym = SymArray{Float64}(undef, 3, 3) +sym .= 0.0 + +# Access elements (symmetric) +sym[1, 2] = 5.0 +sym[2, 1] # Returns 5.0 + +# Convert from regular matrix +A = [1 2 3; 2 4 5; 3 5 6] +sym = SymArray(A) +``` +""" +mutable struct SymArray{F} <: AbstractSparseMatrix{F, Int} + uppertrian::SparseMatrixCSC{F, Int} +end + +function check_size(a) + if length(a.uppertrian.nzval) == 0 + println("Warning: SymArray has zero stored elements.") + end +end + +SymArray(::Type{F}, dims::Int...) where {F} = SymArray(F, dims) +function SymArray(::Type{F}, dims::NTuple{2, Int}) where {F} + if dims[1] != dims[2] + throw(ArgumentError("SymArray must be square, got dims=$(dims)")) + end + a = SymArray{F}(SparseMatrixCSC{F, Int}(make_csc_format(dims[1], F)...)) + check_size(a) + return a +end + +SymArray{F}(::UndefInitializer, dims::Int...) where {F} = SymArray{F}(undef, dims) +function SymArray{F}(::UndefInitializer, dims::NTuple{2, Int}) where {F} + return SymArray(F, dims) +end + +function make_csc_format(k::Int, ::Type{F}) where {F} + k > 0 || throw(ArgumentError("Matrix dimension k=$k must be positive")) + + n_elements = div(k * (k + 1), 2) # Number of non-zeros in upper triangle + + colptr = Vector{Int}(undef, k + 1) + rowval = Vector{Int}(undef, n_elements) + nzval = Vector{F}(undef, n_elements) + + @inbounds for j in 1:(k + 1) + colptr[j] = div((j - 1) * j, 2) + 1 + end + + idx = 1 + @inbounds for j in 1:k + for i in 1:j + rowval[idx] = i + idx += 1 + end + end + return k, k, colptr, rowval, nzval +end + +""" + SymArray(d::AbstractMatrix{F}) + +Create a SymArray from an existing matrix.The matrix must be square and is assumed to be symmetric. +""" +function SymArray(d::AbstractMatrix{F}) where {F} + m, n = size(d) + m == n || throw(ArgumentError("Input matrix must be square, got size $(size(d))")) + return convert(SymArray{F}, d) +end + +function size(a::SymArray) + return size(a.uppertrian) +end + +# axes function +function axes(a::SymArray) + return axes(a.uppertrian) +end + +# length function +function length(a::SymArray) + return length(a.uppertrian) +end + +# faster indexing by avoiding search, modified from SparseArrays +Base.@propagate_inbounds function getindex(A::SymArray, i0::Integer, i1::Integer) + i0, i1 = minmax(i0, i1) + @boundscheck checkbounds(A, i0, i1) + r1 = Int(@inbounds getcolptr(A.uppertrian)[i1]) + A.uppertrian.nzval[r1 + i0 - 1] +end + +# faster indexing by avoiding search, modified from SparseArrays +Base.@propagate_inbounds function setindex!(A::SymArray, v, i::Int, j::Int) + i, j = minmax(i, j) + @boundscheck checkbounds(A, i, j) + r1 = Int(@inbounds getcolptr(A.uppertrian)[j]) + A.uppertrian.nzval[r1 + i - 1] = v +end + +function similar(a::SymArray, ::Type{T} = eltype(a), dims::Dims{2} = size(a)) where {T} + return SymArray{T}(undef, dims) +end + +function copy!(dest::SymArray{F}, src::SymArray{F}) where {F} + size(dest) == size(src) || throw(DimensionMismatch("arrays must have the same size")) + copy!(dest.uppertrian.nzval, src.uppertrian.nzval) + return nothing +end + +function convert(::Type{SymArray{F}}, a::AbstractMatrix{F}) where {F} + k, n = size(a) + @assert k==n "Input matrix must be square, got size $(size(a))" + + # Directly build upper triangle sparse matrix + # Pre-allocate with exact size needed + m, n, colptr, rowval, nzval = make_csc_format(k, F) + idx = 1 + @inbounds for j in 1:k + for i in 1:j + nzval[idx] = a[i, j] + idx += 1 + end + end + return SymArray(SparseMatrixCSC{F, Int}(m, n, colptr, rowval, nzval)) +end + +function deepcopy!(dest::SymArray{F}, src::SymArray{F}) where {F <: AbstractArray} + dest_ = dest.uppertrian.nzval + src_ = src.uppertrian.nzval + @inbounds for index in eachindex(src_) + if isassigned(dest_, index) + copy!(dest_[index], src_[index]) + else + dest_[index] = copy(src_[index]) + end + end + return dest +end + +deepcopy!(dest::SymArray{F}, src::SymArray{F}) where {F <: Real} = copy!(dest, src) + +# Broadcasting support - custom style to maintain symmetric structure +# struct SymArrayStyle <: Broadcast.AbstractArrayStyle{2} end +# SymArrayStyle(::Val{2}) = SymArrayStyle() + +const SymArrayStyle = Broadcast.ArrayStyle{SymArray} + +Base.BroadcastStyle(::Type{<:SymArray}) = Broadcast.ArrayStyle{SymArray}() # SymArrayStyle() + +# When broadcasting with scalars or other styles, keep SymArrayStyle +Base.BroadcastStyle(::SymArrayStyle, ::Broadcast.DefaultArrayStyle{0}) = SymArrayStyle() +Base.BroadcastStyle(::Broadcast.DefaultArrayStyle{0}, ::SymArrayStyle) = SymArrayStyle() + +# When broadcasting with regular arrays (not scalars), defer to the array's style +# This ensures SymArray .+ Matrix returns Matrix, not SymArray +Base.BroadcastStyle(::SymArrayStyle, s::Broadcast.DefaultArrayStyle) = s +Base.BroadcastStyle(s::Broadcast.DefaultArrayStyle, ::SymArrayStyle) = s + +# When broadcasting between SymArrays, keep SymArrayStyle +Base.BroadcastStyle(::SymArrayStyle, ::SymArrayStyle) = SymArrayStyle() + +# Custom similar for broadcasted SymArrays +function Base.similar( + bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{SymArray}}, ::Type{ElType}) where {ElType} + return SymArray{ElType}(undef, size(bc)...) +end + +# Helper function to find a SymArray in the broadcast tree +find_symarray(bc::Broadcast.Broadcasted) = find_symarray(bc.args) +find_symarray(args::Tuple) = find_symarray(args[1], Base.tail(args)) +find_symarray(x) = x +find_symarray(args::Tuple{}) = nothing +find_symarray(a::SymArray, rest) = a +find_symarray(::Any, rest) = find_symarray(rest) + +# Override broadcasted to eagerly evaluate when SymArrayStyle is involved +# This prevents issues with nested broadcasts losing the SymArray type +# hack, needs to be fixed later +function Broadcast.broadcasted(::SymArrayStyle, f, args...) + # Eagerly materialize any nested Broadcasted{SymArrayStyle} to maintain type stability + materialized_args = map(args) do arg + if arg isa Broadcast.Broadcasted{SymArrayStyle} + # Materialize nested SymArray broadcasts immediately + return copy(arg) + else + return arg + end + end + # Now create the broadcast with materialized args + return Broadcast.Broadcasted{SymArrayStyle}(f, materialized_args) +end + +# Specialized copyto! for efficient in-place broadcasting into SymArray +# This maintains the symmetric structure during broadcast operations +function Base.copyto!(dest::SymArray, bc::Broadcast.Broadcasted{SymArrayStyle}) + axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) + _copyto_nzval!(dest, bc) + return dest + # # Try to use optimized nzval path for simple operations + # if _can_use_nzval_broadcast(bc) + # return _copyto_nzval!(dest, bc) + # end + + # # Fallback: iterate using CartesianIndices but only over upper triangle + # bc′ = Broadcast.preprocess(dest, bc) + # @inbounds for j in 1:size(dest, 2) + # for i in 1:j + # dest[i, j] = bc′[i, j] + # end + # end + # return dest +end + +# Optimized copyto! that works directly on nzval arrays +function _copyto_nzval!( + dest::SymArray{T}, bc::Broadcast.Broadcasted{SymArrayStyle}) where {T} + # Replace SymArrays in the broadcast tree with their nzval arrays + bc_nzval = _replace_with_nzval(bc) + + # Broadcast directly on the nzval array + dest_nzval = dest.uppertrian.nzval + copyto!(dest_nzval, bc_nzval) + return dest +end + +# Replace SymArrays in broadcast tree with their nzval arrays +function _replace_with_nzval(bc::Broadcast.Broadcasted{SymArrayStyle}) + # Create new broadcasted with transformed arguments + new_args = map(_replace_with_nzval, bc.args) + # Don't specify style - let it be inferred + return Broadcast.Broadcasted(bc.f, new_args) +end + +function _replace_with_nzval(sa::SymArray) + return sa.uppertrian.nzval +end + +function _replace_with_nzval(bc::Broadcast.Broadcasted) + # Recursively process nested broadcasts + new_args = map(_replace_with_nzval, bc.args) + return Broadcast.Broadcasted(bc.f, new_args) +end + +function _replace_with_nzval(x) + # For scalars and other types, return as-is + return x +end + +end diff --git a/src/api.jl b/src/api.jl new file mode 100644 index 0000000..329a645 --- /dev/null +++ b/src/api.jl @@ -0,0 +1,101 @@ +function nethist_categorical( + A, k, + labels_start = ordered_start_labels(size(A, 1), k); + params::GreedyParams = GreedyParams(stalled_iters = 50_000, max_iter = 2_000_000)) + convertor = CategoricalConvertor(A) + @info "Using $(num_bins(convertor)) discrete categories for edge values" + return _nethist( + A, labels_start, + convertor, + Val(:categorical), + params; + num_categories = num_bins(convertor) + ) +end + +function nethist_continuous( + A, k, + labels_start = ordered_start_labels(size(A, 1), k); + bins::Int = 10, + params::GreedyParams = GreedyParams()) + convertor = UnitIntervalConvertor(bins) + @info "Using $(num_bins(convertor)) discrete categories for edge values" + return _nethist( + A, labels_start, + convertor, + Val(:categorical), + params; + num_categories = num_bins(convertor) + ) +end + +function nethist_binary( + A, k, + labels_start = ordered_start_labels(size(A, 1), k); + params::GreedyParams = GreedyParams()) + return _nethist( + A, labels_start, + BinaryConvertor(), + Val(:binary), + params + ) +end + +function _nethist( + A, labels_start, convertor, type_suff_stats, + params::GreedyParams = GreedyParams(); kwargs...) + if !params.warm_start + reset!(params) + end + data = convertor.(A) + es = make_greedy_suffstats_estimator( + data, + labels_start; + type_suff_stats = type_suff_stats, + max_iter = params.max_iter, + node_swap_rule = params.node_swap_rule, + stop_rule = params.stop_rule, + kwargs... + ) + node_labels, + parameters = estimate!( + es, data, labels_start; + progress = params.display_progress, + iter_progress = params.progress_freq) + + return convert_to_result(node_labels, convertor, parameters) +end + +function oracle_estimator( + A, oracle_labels, convertor; type_suff_stats = Val(:categorical), + name = "oracle", kwargs...) + + # prepare data + k = length(unique(oracle_labels)) + data = convertor.(A) + # prepare suff stats + block_ss = make_k_block( + k, type_suff_stats; data = data, num_categories = num_bins(convertor), kwargs...) + + # compute oracle suff stats + es_dummy = GreedySuffStats(block_ss, copy(block_ss), RandomGroupSwap(), + PreviousBestValue(1_000, Inf, :min), 1) + init!(es_dummy, data, oracle_labels) + + # retrieve parameters + @info "$name estimator loss: $(loss(es_dummy, norm = get_num_obs(data)))" + parameters = to_params.(es_dummy.block_ss) + return convert_to_result(oracle_labels, convertor, parameters) +end + +function convert_to_result(node_labels, convertor, parameters) + model = DecoratedSBM(to_distribution.(convertor, parameters), + counts(node_labels) ./ length(node_labels)) + return NethistResult(node_labels, model) +end + +function convert_to_result(node_labels, convertor::BinaryConvertor, parameters) + model = SBM(to_distribution.(convertor, parameters), + counts(node_labels) ./ length(node_labels)) + return NethistResult(node_labels, model) +end diff --git a/src/assignment.jl b/src/assignment.jl deleted file mode 100644 index ea06117..0000000 --- a/src/assignment.jl +++ /dev/null @@ -1,195 +0,0 @@ -mutable struct Assignment{T, M} - const group_size::GroupSize{T} - - const node_labels::Vector{Int} - const counts::Matrix{Int} - const realized::Array{Float64, M} - const estimated_theta::Array{Float64, M} - const number_layers::Int - - likelihood::Float64 - - function Assignment(A, node_labels, group_size::GroupSize{T}) where {T} - M = ndims(A) - number_groups = length(group_size) - - counts = zeros(Int64, number_groups, number_groups) - realized = zeros(Int64, number_groups, number_groups) - - @inbounds @simd for k in 1:number_groups - for l in k:number_groups - realized[k, l] = sum(A[node_labels .== k, node_labels .== l]) - realized[l, k] = realized[k, l] - counts[k, l] = group_size[k] * group_size[l] - counts[l, k] = counts[k, l] - end - end - - @inbounds @simd for k in 1:number_groups - counts[k, k] = group_size[k] * (group_size[k] - 1) ÷ 2 - realized[k, k] = sum(A[node_labels .== k, node_labels .== k]) ÷ 2 - end - - estimated_theta = realized ./ counts - likelihood = compute_log_likelihood(number_groups, estimated_theta, counts, - size(A, 1)) - - new{T, M}(group_size, - node_labels, - counts, - realized, - estimated_theta, - 1, - likelihood) - end - - function Assignment(A::Array{I, 3}, - node_labels, - group_size::GroupSize{T}) where {I, T} - M = ndims(A) - number_groups = length(group_size) - - counts = zeros(Int64, number_groups, number_groups) - realized = zeros(Int64, number_groups, number_groups, 2^size(A, 3)) - - A_updated = zeros(Int64, size(A, 1), size(A, 2)) - for i in 1:size(A, 1) - for j in (i + 1):size(A, 2) - A_updated[i, j] = _binary_to_index(A[i, j, :]) - A_updated[j, i] = A_updated[i, j] - end - end - - @inbounds @simd for m in 1:size(realized, 3) - for k in 1:number_groups - for l in k:number_groups - realized[k, l, m] = sum(A_updated[node_labels .== k, - node_labels .== l] .== m) - realized[l, k, m] = realized[k, l, m] - counts[k, l] = group_size[k] * group_size[l] - counts[l, k] = counts[k, l] - end - end - end - - @inbounds @simd for m in 1:size(realized, 3) - for k in 1:number_groups - counts[k, k] = group_size[k] * (group_size[k] - 1) ÷ 2 - realized[k, k, m] = sum(A_updated[node_labels .== k, node_labels .== k] .== - m) ÷ 2 - end - end - estimated_theta = realized ./ counts - likelihood = compute_multivariate_log_likelihood(number_groups, - estimated_theta, - realized) - - new{T, M}(group_size, - node_labels, - counts, - realized, - estimated_theta, - size(A, 3), - likelihood) - end -end - -function _binary_to_index(binary_vector::Vector{Int}) - total = 1 - for i in 1:length(binary_vector) - total += binary_vector[i] * 2^(i - 1) - end - return total -end - -function _index_to_binary(index::Int, M) - binary_vector = zeros(Int, M) - index -= 1 - for i in 1:M - binary_vector[i] = index % 2 - index = index ÷ 2 - end - return binary_vector -end - -""" - compute_log_likelihood(number_groups, estimated_theta, counts, number_nodes) - -Compute the scaled log-likelihood in terms of communities: -```math -l(z;A) = \\frac{1}{n} \\sum_{g_1 = 1}^{G} \\sum_{g_2 \\geq g_1}^{G} \\left[ \\theta_{g_1g_2} \\log(\\theta_{g_1g_2}) + (1 - \\theta_{g_1g_2}) \\log(1 - \\theta_{g_1g_2}) \\right] \\cdot c_{g_1g_2}, -``` - -where ``c_{g_1g_2}`` (``\\theta_{g_1g_2}``) is the number of possible edges (estimated -probability) between communities ``g_1`` and ``g_2``, ``n`` is the number of nodes, and -``z_i ∈ \\{1, \\dots, G\\}`` is the community assignment of node ``i``. -""" -function compute_log_likelihood(number_groups, estimated_theta, counts, number_nodes) - loglik = 0.0 - @inbounds @simd for i in 1:number_groups - for j in i:number_groups - θ = estimated_theta[i, j] - θ_c = θ <= 0 ? 1e-14 : (θ >= 1 ? 1 - 1e-14 : θ) - loglik += (θ_c * log(θ_c) + (1 - θ_c) * log(1 - θ_c)) * counts[i, j] - end - end - return loglik -end - -function compute_multivariate_log_likelihood(number_groups, estimated_theta, realized) - loglik = 0.0 - @inbounds @simd for i in 1:number_groups - for j in i:number_groups - for m in 1:size(realized, 3) - if realized[i, j, m] != 0 - θ = estimated_theta[i, j, m] - θ_c = θ <= 0 ? 1e-14 : (θ >= 1 ? 1 - 1e-14 : θ) - loglik += log(θ_c) * realized[i, j, m] - end - end - end - end - return loglik -end - -""" - compute_log_likelihood(assignment::Assignment) - -Compute the scaled log-likelihood of the assignment. - -```math - l(z;A) = \\frac{1}{n}\\sum\\limits_{i=1}^n \\sum\\limits_{j>i}^n \\left[ A_{ij} \\log(\\hat{\\theta}_{z_i z_j}) + (1 - A_{ij}) \\log(1 - \\hat{\\theta}_{z_i z_j}) \\right], -``` - -where ``\\hat{\\theta}_{ab}`` is the estimated probability of an edge between communities -``a`` and ``b`` - -```math - \\hat{\\theta}_{ab} = \\frac{\\sum\\limits_{i current.likelihood - deepcopy!(current, proposal) - end - - update_current!(history, iteration, current.likelihood) - return current -end diff --git a/src/config_rules/bandwidth_selection_rule.jl b/src/config_rules/bandwidth_selection_rule.jl deleted file mode 100644 index fc0eb40..0000000 --- a/src/config_rules/bandwidth_selection_rule.jl +++ /dev/null @@ -1,61 +0,0 @@ - -function select_bandwidth(A::Array{T, 2}; type = "degs", alpha = 1, c = 1)::Int where {T} - h = oracle_bandwidth(A, type, alpha, c) - return max(2, min(size(A)[1], round(Int, h))) -end - -function select_bandwidth(A::Array{T, 3}; type = "degs", alpha = 1, c = 1)::Int where {T} - hs = [select_bandwidth(A[:, :, i]; type, alpha, c) for i in 1:size(A, 3)] - @warn "Naive bandwidth selection for multilayer graph histogram: using minimum over layers" - h = max(2, min(size(A, 1), round(Int, minimum(hs)))) - return h -end - -""" - oracle_bandwidth(A, type = "degs", alpha = 1, c = min(4, sqrt(size(A, 1)) / 8)) - -Oracle bandwidth selection for graph histogram, using - -```math -\\widehat{h^*}=\\left(2\\left(\\left(d^T d\\right)^{+}\\right)^2 d^T A d \\cdot \\hat{m} \\hat{b}\\right)^{-\\frac{1}{2}} \\hat{\\rho}_n^{\\frac{1}{4}}, -``` - -where ``d`` is the vector of degree sorted in increasing order,``\\hat{\\rho}_n`` is the empirical edge density, and ``m``, ``b`` are the slope and intercept fitted on ``d[n/2-c\\sqrt{n}:n/2+c\\sqrt{n}]`` for some ``c``. -""" -function oracle_bandwidth(A, type = "degs", alpha = 1, c = min(4, sqrt(size(A, 1)) / 8)) - if type ∉ ["eigs", "degs"] - error("Invalid input type $(type)") - end - - if alpha != 1 - error("Currently only supports alpha = 1") - end - - n = size(A, 1) - midPt = collect(max(1, round(Int, (n ÷ 2 - c * sqrt(n)))):round(Int, - (n ÷ 2 + c * sqrt(n)))) - rhoHat_inv = inv(sum(A) / (n * (n - 1))) - - # Rank-1 graphon estimate via fhat(x,y) = mult*u(x)*u(y)*pinv(rhoHat); - if type == "eigs" - eig_res = eigs(A, nev = 1, which = :LM) - u = eig_res.vectors - mult = eig_res.values[1] - elseif type == "degs" - u = sum(A, dims = 2) - mult = (u' * A * u) / (sum(u .^ 2))^2 - else - error("Invalid input type $(type)") - end - - # Calculation bandwidth - u = sort(u, dims = 1) - uMid = u[midPt] - lmfit_coef = hcat(ones(length(uMid)), 1:length(uMid)) \ uMid - - h = (2^(alpha + 1) * alpha * mult^2 * - (lmfit_coef[2] * length(uMid) / 2 + lmfit_coef[1])^2 * lmfit_coef[2]^2 * - rhoHat_inv)^(-1 / (2 * (alpha + 1))) - #estMSqrd = 2*mult^2*(lmfit_coef[2]*length(uMid)/2+lmfit_coef[1])^2*lmfit_coef[2]^2*rhoHat_inv^2*(n+1)^2 - return h[1] -end diff --git a/src/config_rules/include.jl b/src/config_rules/include.jl new file mode 100644 index 0000000..d4a03db --- /dev/null +++ b/src/config_rules/include.jl @@ -0,0 +1,19 @@ +include("swap_rule.jl") +include("stop_rule.jl") + +abstract type ParamsType end + +@kwdef struct GreedyParams{N <: NodeSwapRule, S <: StopRule} <: ParamsType + max_iter::Int = 1_000_000 + stalled_iters::Int = 5_000 + node_swap_rule::N = RandomGroupSwap() + stop_rule::S = PreviousBestValue(stalled_iters, Inf, :min) + display_progress::Bool = true + progress_freq::Int = 10_000 + warm_start::Bool = false +end + +function reset!(params::GreedyParams) + reset!(params.stop_rule) + return params +end diff --git a/src/config_rules/starting_assignment_rule.jl b/src/config_rules/starting_assignment_rule.jl deleted file mode 100644 index 1ee0261..0000000 --- a/src/config_rules/starting_assignment_rule.jl +++ /dev/null @@ -1,57 +0,0 @@ -abstract type StartingAssignment end -struct OrderedStart <: StartingAssignment end -struct RandomStart <: StartingAssignment end -struct EigenStart <: StartingAssignment end -struct DistStart <: StartingAssignment end - -""" - initialize_node_labels(A, h, starting_assignment_rule::StartingAssignment) - -initialize node labels based on the `starting_assignment_rule`, and return a vector of -node labels and a `GroupSize` object. - -# Implemenented rules -- `OrderedStart()`: Sequentially assign nodes to groups based on the ordering of `A`. -- `RandomStart()`: Randomly assign nodes to groups. -- `EigenStart()`: Assign nodes to groups based on the second eigenvector of the normalized Laplacian. -- `DistStart()`: Assign nodes to groups based on the Hamming distance between rows of `A`. -""" -initialize_node_labels - -function initialize_node_labels(A, h, ::OrderedStart) - group_size = GroupSize(size(A, 1), h) - node_labels = inverse_rle(1:length(group_size), group_size) - return node_labels, group_size -end - -function initialize_node_labels(A, h, ::RandomStart) - node_labels, group_size = initialize_node_labels(A, h, OrderedStart()) - node_labels = shuffle!(node_labels) - return node_labels, group_size -end - -function initialize_node_labels(A, h, ::EigenStart) - group_size = GroupSize(size(A, 1), h) - node_labels = zeros(Int, size(A, 1)) - - laplacian = normalized_laplacian(A) - _, eigenvectors = eigs(laplacian, nev = 2, which = :LR, tol = 1e-2) - #_, eigenvectors = eigen(Symmetric(laplacian), (size(A, 1) - 1):(size(A, 1) - 1)) - - # get 2nd eigenvector, sort its components - indices = sortperm(eigenvectors[:, 1]) - # bin them into groups of correct size - start = 1 - for (i, group) in enumerate(group_size) - stop = start + group - 1 - node_labels[indices[start:stop]] .= i - start = stop + 1 - end - return node_labels, group_size -end - -function initialize_node_labels(A, h, ::DistStart) - group_size = GroupSize(size(A, 1), h) - node_labels = spectral_clustering(A, h) - return node_labels, group_size -end diff --git a/src/config_rules/stop_rule.jl b/src/config_rules/stop_rule.jl index d2ae852..7d7fa13 100644 --- a/src/config_rules/stop_rule.jl +++ b/src/config_rules/stop_rule.jl @@ -1,14 +1,41 @@ abstract type StopRule end -struct PreviousBestValue <: StopRule - k::Int - function PreviousBestValue(k::Int) - @assert k > 0 - new(k) - end + +function info_to_print(::StopRule) + return nothing +end + +mutable struct PreviousBestValue{T, S} <: StopRule + const k::Int + previous_best_value::T + iterations_since_best::Int +end + +function PreviousBestValue(k::Int, x::T = -Inf, best = :max) where {T <: Real} + @argcheck k > 0 + PreviousBestValue{T, Val(best)}(k, x, 0) +end + +const PreviousMaxValue{T} = PreviousBestValue{T, Val(:max)} +const PreviousMinValue{T} = PreviousBestValue{T, Val(:min)} + +function reset!(stop_rule::PreviousBestValue{T}, loss_value::T) where {T} + stop_rule.previous_best_value = loss_value + stop_rule.iterations_since_best = 0 +end + +reset!(stop_rule::PreviousMaxValue) = reset!(stop_rule, -Inf) +reset!(stop_rule::PreviousMinValue) = reset!(stop_rule, Inf) + +function compare_to_best(current, past, ::PreviousMaxValue) + return current > past +end + +function compare_to_best(current, past, ::PreviousMinValue) + return current < past end """ - stopping_rule(history, stop_rule::StopRule) + stopping_rule(assignment::Assignment,g, stop_rule::StopRule) Returns a Bool with true if we should stop the optimization based on the `stop_rule`. @@ -18,8 +45,15 @@ Returns a Bool with true if we should stop the optimization based on the `stop_r """ stopping_rule -function stopping_rule(history::GraphOptimizationHistory, stop_rule::PreviousBestValue) - current_itr = get_currentitr(history) - prev_best_itr = get_bestitr(history) - return current_itr[1] - prev_best_itr[1] > stop_rule.k +function stopping_rule(loss::T, stop_rule::PreviousBestValue{T}) where {T <: Real} + if compare_to_best(loss, stop_rule.previous_best_value, stop_rule) + reset!(stop_rule, loss) + else + stop_rule.iterations_since_best += 1 + end + return stop_rule.iterations_since_best >= stop_rule.k +end + +function info_to_print(stop_rule::PreviousBestValue) + ("stalled iter: ", stop_rule.iterations_since_best) end diff --git a/src/config_rules/swap_rule.jl b/src/config_rules/swap_rule.jl index dcd21a7..cc22198 100644 --- a/src/config_rules/swap_rule.jl +++ b/src/config_rules/swap_rule.jl @@ -1,27 +1,27 @@ abstract type NodeSwapRule end struct RandomNodeSwap <: NodeSwapRule end - +struct RandomGroupSwap <: NodeSwapRule end """ - select_swap(node_assignment::Assignment, A, ::NodeSwapRule) + select_indices_swap(node_assignment::Assignment, ::NodeSwapRule) Selects two nodes to swap based on the `NodeSwapRule`, the adjacency matrix `A` and the current assignment `node_assignment`. # Implemented rules - `RandomNodeSwap()`: Select two nodes at random. +- `RandomGroupSwap()`: Select two nodes from two different groups at random. """ select_swap -function select_swap(node_assignment::Assignment, A, ::RandomNodeSwap) - index1 = rand(1:size(A, 1)) - label1 = node_assignment.node_labels[index1] - index2 = index1 - for _ in 1:10 - index2 = rand(1:size(A, 1)) - if node_assignment.node_labels[index2] != label1 - break - end - end - return (index1, index2) +function select_indices_swap(node_labels::AbstractVector{Int}, ::RandomNodeSwap) + return Tuple(StatsBase.samplepair(1:length(node_labels))) +end + +function select_indices_swap(node_labels::AbstractVector{Int}, ::RandomGroupSwap, + k::Int = length(unique(node_labels))) + groups = StatsBase.sample(1:k, 2; replace = false) + index1 = rand(findall(x -> x == groups[1], node_labels)) + index2 = rand(findall(x -> x == groups[2], node_labels)) + return index1, index2 end diff --git a/src/data/datasets.jl b/src/data/datasets.jl deleted file mode 100644 index 8c7e071..0000000 --- a/src/data/datasets.jl +++ /dev/null @@ -1,17 +0,0 @@ -using HTTP, CodecZstd, TranscodingStreams - -const url_ref = "https://networks.skewed.de" - -include("utils.jl") - -function get_netzschleuder_network(name::String) - url = joinpath(url_ref, "net", name, "files", name * ".gt.zst") - res = HTTP.get(url) - - if res.status != 200 - error("Error downloading network" * res.status) - end - - decompressed = Base.IOBuffer(transcode(ZstdDecompressor, res.body)) - return readgt(decompressed) -end diff --git a/src/data/gt.jl b/src/data/gt.jl deleted file mode 100644 index 77e283f..0000000 --- a/src/data/gt.jl +++ /dev/null @@ -1,40 +0,0 @@ -"""Utils to read .gt files - -Inspired from the library Erdos.jl from CarloLucibello (precisely the file -`src/persistence.jl`) -""" - -const start_gt_format = "⛾ gt" - -function readgt_simple_network!(io::IO, adj, ::Type{T}) where {T} - n = size(adj, 1) - for i in 1:n - k = read(io, UInt64) - for _ in 1:k - j = read(io, T) + 1 - if i != j - adj[i, j] = 1 - adj[j, i] = 1 - end - end - end -end - -function readgt(io::IO) - @assert String(read(io, 6))==start_gt_format "gt file not correctly formatted" - ver = read(io, UInt8) ## version - indian = read(io, Bool) - @assert indian == false - lencomment = read(io, UInt64) - read(io, lencomment) - isdir = read(io, Bool) - n = read(io, UInt64) - T = minutype(n) - if isdir - @warn "Directed graphs are not supported, automatically converting to undirected." - end - adj = zeros(Int, n, n) - - readgt_simple_network!(io, adj, T) - return adj -end diff --git a/src/data/utils.jl b/src/data/utils.jl deleted file mode 100644 index 083fa67..0000000 --- a/src/data/utils.jl +++ /dev/null @@ -1,20 +0,0 @@ -function drop_isolated_vertices(A) - degrees = vec(sum(A, dims = 2)) - return A[degrees .> 0, degrees .> 0] -end - -####### From Erdos.jl ####### - -function minutype(n::Integer) - @assert n ≥ 0 - if n < 2^8 - return UInt8 - elseif n < 2^16 - return UInt16 - elseif n < 2^32 - return UInt32 - elseif n < 2^64 - return UInt64 - end - error("No type big enough") -end diff --git a/src/distributions/hist_dist.jl b/src/distributions/hist_dist.jl new file mode 100644 index 0000000..a500b37 --- /dev/null +++ b/src/distributions/hist_dist.jl @@ -0,0 +1,46 @@ +struct HistDistribution{B, P, P2} <: ContinuousUnivariateDistribution + bins::B + probs::P + cum_probs::P2 +end + +Base.broadcastable(d::HistDistribution) = Ref(d) + +params(d::HistDistribution) = (d.bins, d.probs) + +function rand(rng::AbstractRNG, d::HistDistribution) + u = rand(rng) + bin_idx = searchsortedfirst(d.cum_probs, u) + return rand(rng, d.bins[bin_idx]) +end + +function HistDistribution(bins, ps) + cum_ps = SVector(cumsum(ps)...) + return HistDistribution{typeof(bins), typeof(ps), typeof(cum_ps)}(bins, ps, cum_ps) +end + +function logpdf(d::HistDistribution, x::Real) + # potentially slow + bin_idx = findfirst(b -> x ∈ b, d.bins) + p = d.probs[bin_idx] + bin_idx == 1 && return log(p) + return log(p) - log(width(d.bins[bin_idx])) +end + +function pdf(d::HistDistribution, x::Real) + # potentially slow + bin_idx = findfirst(b -> x ∈ b, d.bins) + p = d.probs[bin_idx] + bin_idx == 1 && return p + return p / width(d.bins[bin_idx]) +end + +### For Graphons compatibility +support(d::HistDistribution) = d.bins +_extract_param(d::HistDistribution, k) = d.probs[k] + +function convert_to_params(centers, + sbm::DecoratedSBM{HistDistribution{B, P, P2}}) where {B, P, P2} + s = sbm.θ[1, 1].bins + return [HistDistribution(s, convert(P, centers[:, i])) for i in axes(centers, 2)] +end diff --git a/src/group_numbering.jl b/src/group_numbering.jl deleted file mode 100644 index 12d9691..0000000 --- a/src/group_numbering.jl +++ /dev/null @@ -1,43 +0,0 @@ -""" -Array-like storage for the number of nodes in each group. -""" -struct GroupSize{T} <: AbstractVector{Int} - group_number::T - number_groups::Int - - function GroupSize(number_nodes, h::Real) - @assert 0 < h < 1 - standard_group = floor(Int, number_nodes * h) - GroupSize(number_nodes, standard_group) - end - - function GroupSize(number_nodes, standard_group::Integer) - @assert 1 < standard_group < number_nodes - number_groups = number_nodes ÷ standard_group # number of standard groups! - if number_groups * standard_group == number_nodes - new{Int}(standard_group, number_groups) - else - remainder_group = number_nodes - number_groups * standard_group - if remainder_group == 1 - @warn "h has to be changed, only one node in remainder group" - standard_group -= 1 - remainder_group = number_groups + 1 # because equal to 1+number_groups because we take 1 from each standard group, and there are number_groups of them - if standard_group == 1 - error("Standard group size now 1, please choose a new value for h.") - end - end - new{Tuple{Int, Int}}((standard_group, remainder_group), number_groups + 1) - end - end -end - -Base.size(g::GroupSize) = (g.number_groups,) -Base.@propagate_inbounds function Base.getindex(g::GroupSize{Int}, i::Int) - @boundscheck checkbounds(g, i) - return g.group_number -end - -Base.@propagate_inbounds function Base.getindex(g::GroupSize{Tuple{Int, Int}}, i::Int) - @boundscheck checkbounds(g, i) - return i < length(g) ? g.group_number[1] : g.group_number[2] -end diff --git a/src/histogram.jl b/src/histogram.jl deleted file mode 100644 index 126d1ad..0000000 --- a/src/histogram.jl +++ /dev/null @@ -1,43 +0,0 @@ -struct GraphHist{T, M} - θ::Array{T, M} - node_labels::Vector{Int} - num_layers::Int - function GraphHist(a::Assignment{T, M}) where {T, M} - θ = a.estimated_theta - node_labels = a.node_labels - new{typeof(θ[1]), M}(θ, node_labels, a.number_layers) - end -end - -""" -Network Histogram approximation [1]. - -Contains the estimated network histogram and the node labels. - -# Fields -- `θ::Matrix{T}`: Estimated stochastic block model parameters. -- `node_labels::Vector{Int}`: Node labels for each node in the adjacency matrix used - to estimate the network histogram. - -# References -[1] - Olhede, Sofia C., and Patrick J. Wolfe. "Network histograms and universality of -blockmodel approximation." Proceedings of the National Academy of Sciences 111.41 (2014): 14722-14727. -""" -GraphHist - -function get_moment_representation(g::GraphHist{T, 2}) where {T} - return g.θ -end - -function get_moment_representation(g::GraphHist{T, 3}) where {T} - moments = zeros(size(g.θ, 1), size(g.θ, 2), 2^g.num_layers - 1) - transition = collect(kronecker([1 1; 0 1], g.num_layers)) - for i in 1:size(g.θ, 1) - for j in 1:size(g.θ, 2) - moments[i, j, :] .= (transition * g.θ[i, j, :])[2:end] - end - end - indices_for_moments = [findall(x -> x == 1, _index_to_binary(e, g.num_layers)) - for e in 2:size(g.θ, 3)] - return moments, indices_for_moments -end diff --git a/src/history.jl b/src/history.jl deleted file mode 100644 index 381f175..0000000 --- a/src/history.jl +++ /dev/null @@ -1,88 +0,0 @@ -abstract type GraphOptimizationHistory end -struct TraceHistory{M <: MVHistory} <: GraphOptimizationHistory - history::M -end -mutable struct NoTraceHistory <: GraphOptimizationHistory - current_iteration::Int - best_iteration::Int - best_likelihood::Float64 -end - -""" - initialize_history(best, current, proposal, ::Val{true}) - - initialize the history when `record_trace=true` is passed to `graphhist`. -""" -function initialize_history(best, current, proposal, ::Val{true}) - history = MVHistory(Dict([ - :proposal_likelihood => QHistory(Float64), - :current_likelihood => QHistory(Float64), - :best_likelihood => QHistory(Float64), - ])) - push!(history, :proposal_likelihood, 0, proposal.likelihood) - push!(history, :current_likelihood, 0, current.likelihood) - push!(history, :best_likelihood, 0, best.likelihood) - return TraceHistory(history) -end - -""" - initialize_history(best, current, proposal, ::Val{false}) - -initialize the history when `record_trace=false` is passed to `graphhist`. -""" -function initialize_history(best, current, proposal, ::Val{false}) - return NoTraceHistory(0, 0, best.likelihood) -end - -""" - get_currentitr(history::GraphOptimizationHistory) - -Return the current iteration of the optimization from the history. -""" -get_currentitr(history::TraceHistory) = last(history.history, :current_likelihood) -get_currentitr(history::NoTraceHistory) = history.current_iteration - -""" - get_bestitr(history::GraphOptimizationHistory) - -Return the best iteration of the optimization from the history. -""" -get_bestitr(history::TraceHistory) = last(history.history, :best_likelihood) -get_bestitr(history::NoTraceHistory) = history.best_iteration - -""" - update_current!(history::GraphOptimizationHistory, iteration, likelihood) - -Updates the current value and iteration in history. -""" -function update_current!(history::TraceHistory, iteration, likelihood) - push!(history.history, :current_likelihood, iteration, likelihood) -end -function update_current!(history::NoTraceHistory, iteration, likelihood) - history.current_iteration = iteration -end - -""" - update_best!(history::GraphOptimizationHistory, iteration, likelihood) - -Updates the best value and iteration in history. -""" -function update_best!(history::TraceHistory, iteration, likelihood) - push!(history.history, :best_likelihood, iteration, likelihood) -end -function update_best!(history::NoTraceHistory, iteration, likelihood) - history.best_iteration = iteration - history.best_likelihood = likelihood -end - -""" - update_previous!(history::GraphOptimizationHistory, iteration, likelihood) - -Updates the previous value and iteration in history. - -Note this does not apply is `history` is a `NoTraceHistory`, so nothign happens. -""" -function update_proposal!(history::TraceHistory, iteration, likelihood) - push!(history.history, :proposal_likelihood, iteration, likelihood) -end -update_proposal!(history::NoTraceHistory, iteration, likelihood) = nothing diff --git a/src/optimize.jl b/src/optimize.jl deleted file mode 100644 index 751ffcf..0000000 --- a/src/optimize.jl +++ /dev/null @@ -1,154 +0,0 @@ -function checkadjacency(A) - @assert eltype(A) <: Real - if !(eltype(A) === Bool) - @assert all(a ∈ [zero(eltype(A)), one(eltype(A))] for a in A) "All elements of the ajacency matrix should be zero or one." - end - check_symmetry_and_diag(A) - return nothing -end - -function check_symmetry_and_diag(A) - @assert issymmetric(A) - @assert all(A[i, i] == zero(eltype(A)) for i in 1:size(A, 1)) "The diagonal of the adjacency matrix should all be zeros." -end - -function check_symmetry_and_diag(A::Array{T, 3}) where {T} - for layer in eachslice(A, dims = 3) - check_symmetry_and_diag(layer) - @assert all(layer[i, i] == zero(eltype(layer)) for i in 1:size(layer, 1)) "The diagonal of the adjacency matrix should all be zeros." - end -end - -function update_adj(A::Array{T, 2}) where {T} - return A -end - -function update_adj(A::Array{T, 3}) where {T} - A_updated = zeros(Int64, size(A, 1), size(A, 2)) - for i in 1:size(A, 1) - for j in (i + 1):size(A, 2) - A_updated[i, j] = _binary_to_index(A[i, j, :]) - A_updated[j, i] = A_updated[i, j] - end - end - return A_updated -end - -""" - graphhist(A; h = select_bandwidth(A), maxitr = 1000, swap_rule = RandomNodeSwap(), - starting_assignment_rule = RandomStart(), accept_rule = Strict(), - stop_rule = PreviousBestValue(3), record_trace=true) - -Computes the graph histogram approximation. - -# Arguments -- `A`: adjacency matrix of a simple graph - -- `h`: bandwidth of the graph histogram (number of nodes in a group or percentage (in [0,1]) of - nodes in a group) - -- `record_trace` (optional): whether to record the trace of the optimization process and return - it as part of the output. Default is `true`. - -# Returns -named tuple with the following fields: -- `graphhist`: the graph histogram approximation -- `trace`: the trace of the optimization process (if `record_trace` is `true`) -- `likelihood`: the loglikelihood of the graph histogram approximation - -# Examples -```julia -julia> A = [0 0 1 0 1 0 1 1 0 1 - 0 0 1 1 1 1 1 1 0 0 - 1 1 0 1 0 0 0 0 1 0 - 0 1 1 0 1 0 1 0 0 0 - 1 1 0 1 0 0 1 0 0 1 - 0 1 0 0 0 0 0 1 0 0 - 1 1 0 1 1 0 0 1 0 1 - 1 1 0 0 0 1 1 0 0 1 - 0 0 1 0 0 0 0 0 0 1 - 1 0 0 0 1 0 1 1 1 0] -julia> out = graphhist(A); -julia> graphist_approx = out.graphist -... -julia> trace = out.trace -NetworkHistogram.TraceHistory{...} - :best_likelihood => 1 elements {Int64,Float64} - :proposal_likelihood => 5 elements {Int64,Float64} - :current_likelihood => 5 elements {Int64,Float64}) -julia> loglikelihood = out.likelihood --22.337057781338277 -``` -""" -function graphhist(A; h = select_bandwidth(A), maxitr = 10000, - swap_rule::NodeSwapRule = RandomNodeSwap(), - starting_assignment_rule::StartingAssignment = EigenStart(), - accept_rule::AcceptRule = Strict(), - stop_rule::StopRule = PreviousBestValue(100), record_trace = true) - checkadjacency(A) - @assert maxitr > 0 - A = drop_disconnected_components(A) - - return _graphhist(A, Val{record_trace}(), h = h, maxitr = maxitr, swap_rule = swap_rule, - starting_assignment_rule = starting_assignment_rule, - accept_rule = accept_rule, - stop_rule = stop_rule) -end - -""" - _graphhist(A, record_trace=Val{true}(); h, maxitr, swap_rule, starting_assignment_rule, accept_rule, stop_rule) - -Internal version of `graphhist` which is type stable. -""" -function _graphhist(A, record_trace = Val{true}(); h, maxitr, swap_rule, - starting_assignment_rule, accept_rule, stop_rule) - best, current, proposal, history, A = initialize(A, h, starting_assignment_rule, - record_trace) - - for i in 1:maxitr - proposal = create_proposal!(history, i, proposal, current, A, swap_rule) - current = accept_reject_update!(history, i, proposal, current, accept_rule) - best = update_best!(history, i, current, best) - if stopping_rule(history, stop_rule) - break - end - end - - return graphhist_format_output(best, history) -end - -""" - graphhist_format_output(best, history) - -Formates the `graphhist` output depending on the type of `history` requested by the user. -""" -function graphhist_format_output(best, history::TraceHistory) - return (graphhist = GraphHist(best), trace = history, likelihood = best.likelihood) -end -function graphhist_format_output(best, history::NoTraceHistory) - return (graphhist = GraphHist(best), likelihood = history.best_likelihood) -end - -function update_best!(history::GraphOptimizationHistory, iteration::Int, - current::Assignment, - best::Assignment) - if current.likelihood > best.likelihood - update_best!(history, iteration, current.likelihood) - deepcopy!(best, current) - end - return best -end - -""" - initialize(A, h, starting_assignment_rule, record_trace) - -Initialize the memory required for finding optimal graph histogram. -""" -function initialize(A, h, starting_assignment_rule, record_trace) - node_labels, group_size = initialize_node_labels(A, h, starting_assignment_rule) - proposal = Assignment(A, node_labels, group_size) - current = deepcopy(proposal) - best = deepcopy(proposal) - history = initialize_history(best, current, proposal, record_trace) - return best, current, proposal, history, update_adj(A) -end diff --git a/src/preprocessor/abstractConvertor.jl b/src/preprocessor/abstractConvertor.jl new file mode 100644 index 0000000..f035903 --- /dev/null +++ b/src/preprocessor/abstractConvertor.jl @@ -0,0 +1,34 @@ +abstract type AbstractConvertor end + +Base.broadcastable(o::AbstractConvertor) = Ref(o) + +""" + Convert data from its original form to a processed form suitable for SBM estimation. +""" +function (c::AbstractConvertor)(A; kwargs...) + @error "to be implemented" +end + +function to_distribution(c::AbstractConvertor, ps; kwargs...) + @error "to be implemented" +end + +get_convertor(s::String, ; kwargs...) = get_convertor(Symbol(s); kwargs...) +get_convertor(s::Symbol; kwargs...) = get_convertor(Val(s); kwargs...) +get_convertor(::T; kwargs...) where {T} = @error "No convertor found for type $T" + +include("binary.jl") +include("categorical.jl") +include("continuous.jl") + +function get_convertor(::Val{:categorical}; kwargs...) + return CategoricalConvertor(kwargs[:num_categories]) +end + +function get_convertor(::Val{:continuous}; kwargs...) + return UnitIntervalConvertor(kwargs[:num_bins]) +end + +function get_convertor(::Val{:binary}; kwargs...) + return BinaryConvertor() +end diff --git a/src/preprocessor/binary.jl b/src/preprocessor/binary.jl new file mode 100644 index 0000000..b1f5f27 --- /dev/null +++ b/src/preprocessor/binary.jl @@ -0,0 +1,12 @@ +struct BinaryConvertor <: AbstractConvertor end + +function (c::BinaryConvertor)(obs::T) where {T <: Union{Real, Bool}} + return obs == 1 ? true : false +end + +function to_distribution( + c::BinaryConvertor, p::T; kwargs...) where {T <: Real} + return p +end + +num_bins(::BinaryConvertor) = 2 diff --git a/src/preprocessor/categorical.jl b/src/preprocessor/categorical.jl new file mode 100644 index 0000000..7cf3fdd --- /dev/null +++ b/src/preprocessor/categorical.jl @@ -0,0 +1,36 @@ +### ======================================================================================= +### Categorical Convertor +### ======================================================================================= + +struct CategoricalConvertor{T} <: AbstractConvertor + m::Int # number of categories + map::Dict{T, Int} +end + +function CategoricalConvertor(data::AbstractArray{T}) where {T} + categories = sort(unique(data)) + m = length(categories) + map = Dict{T, Int}(categories[i] => i for i in 1:m) + return CategoricalConvertor{T}(m, map) +end + +function CategoricalConvertor(num_categories::Int) + map = Dict{Int, Int}(i => i for i in 1:num_categories) + return CategoricalConvertor{Int}(num_categories, map) +end + +function num_bins(c::CategoricalConvertor) + return c.m +end + +function (c::CategoricalConvertor)(obs::T) where {T} + return c.map[obs] +end + +function to_distribution( + c::CategoricalConvertor{T}, ps::AbstractVector{T2}; kwargs...) where {T, T2} + @argcheck length(ps)==c.m "Length of probabilities must match number of categories" + support = sort(collect(keys(c.map))) + probabilities = SVector{c.m, T2}(ps[c.map[s]] for s in support) + return DiscreteNonParametric(support, probabilities) +end diff --git a/src/preprocessor/continuous.jl b/src/preprocessor/continuous.jl new file mode 100644 index 0000000..4aceb22 --- /dev/null +++ b/src/preprocessor/continuous.jl @@ -0,0 +1,71 @@ + +### ======================================================================================= +### [0,1] Continuous Convertor +### ======================================================================================= + +abstract type UnitIntervalConvertorType <: AbstractConvertor end + +struct UnitIntervalConvertor{B <: AbstractVector} <: UnitIntervalConvertorType + bins::B +end + +function UnitIntervalConvertor(n::Int) + zero_interval = Interval{:closed, :closed}(0.0, 0.0) + edges = range(0.0, stop = 1.0, length = n + 1) + bins = [Interval{:closed, :closed}(edges[i], edges[i + 1]) for i in 1:n] + bins = vcat(zero_interval, bins) + return UnitIntervalConvertor{typeof(bins)}(bins) +end + +function num_bins(c::UnitIntervalConvertor) + return length(c.bins) +end + +function (c::UnitIntervalConvertor)(x::Real) + return findfirst(b -> x ∈ b, c.bins) +end + +function to_distribution( + c::UnitIntervalConvertor, ps::AbstractVector{T}; kwargs...) where {T} + @argcheck length(ps)==length(c.bins) "Length of probabilities must match number of bins" + return HistDistribution(c.bins, SVector{length(ps), T}(ps)) +end + +# struct RegularUnitIntervalConvertor{N} <: UnitIntervalConvertorType +# num_bins::Int +# end + +### ======================================================================================= +### Continuous Convertor +### ======================================================================================= +# struct ContinuousConvertor{B, N, V <: AbstractVector{B}} <: AbstractConvertor +# zero_index::Int +# bins::V +# end + +# function num_bins(c::ContinuousConvertor{B, N}) where {B, N} +# return N +# end + +# ## assume no singleton bins +# function ContinuousConvertor(bins::AbstractVector{B}) where {B <: +# Union{Interval, BareInterval}} +# bins = sort(bins, lt = lt = strictprecedes) +# N = length(bins) + 1 +# zero_index = 1 +# ContinuousConvertor{B, N, typeof(bins)}(zero_index, bins) +# end + +# # assume bins are sorted and correctly cover the whole support +# function (c::ContinuousConvertor{<:Union{Interval, BareInterval}})(x) +# iszero(x) && return c.zero_index +# x >= sup(c.bins[end]) && return length(c.bins) + 1 +# x <= inf(c.bins[1]) && return c.zero_index + 1 +# return findfirst(b -> in_interval(x, b), c.bins) + 1 +# end + +# function ContinuousConvertor(l, u, num_bins::Int) +# edges = collect(range(l, stop = u, length = num_bins + 1)) +# bins = [bareinterval(edges[i], edges[i + 1]) for i in 1:num_bins] +# ContinuousConvertor(bins) +# end diff --git a/src/proposal.jl b/src/proposal.jl deleted file mode 100644 index 848a2fb..0000000 --- a/src/proposal.jl +++ /dev/null @@ -1,135 +0,0 @@ -"""Functions to create and evaluate possible labels update.""" - -""" - create_proposal!(history::GraphOptimizationHistory, iteration::Int, proposal::Assignment, - current::Assignment, A, swap_rule) - -Create a new proposal by swapping the labels of two nodes. The new assignment is stored in -`proposal`. The swap is selected using the `swap_rule` function. The likelihood of the new -proposal is stored in the history. - -!!! warning - The `proposal` assignment is modified in place to avoid unnecessary memory allocation. -""" -function create_proposal!(history::GraphOptimizationHistory, iteration::Int, - proposal::Assignment, - current::Assignment, A, swap_rule) - swap = select_swap(current, A, swap_rule) - make_proposal!(proposal, current, swap, A) - update_proposal!(history, iteration, proposal.likelihood) - return proposal -end - -""" - make_proposal!(proposal::Assignment, current::Assignment, swap::Tuple{Int, Int}, A) - -From the current assignment, create a new assignment by swapping the labels of the nodes -specified in `swap`. The new assignment is stored in `proposal`. -""" -function make_proposal!(proposal::Assignment, current::Assignment, swap::Tuple{Int, Int}, A) - # copy current in proposal - deepcopy!(proposal, current) - # update realized, estimated_theta - update_observed!(proposal, swap, A) - # update node labels (has to happen after!!!) - update_labels!(proposal, swap, current) - # update ll - updateLL!(proposal) -end - -""" - update_labels!(proposal::Assignment, swap::Tuple{Int, Int}, current::Assignment) - -Update the labels of the nodes specified in `swap` in the `proposal` assignment. -""" -function update_labels!(proposal::Assignment, swap::Tuple{Int, Int}, current::Assignment) - proposal.node_labels[swap[1]] = current.node_labels[swap[2]] - proposal.node_labels[swap[2]] = current.node_labels[swap[1]] -end - -""" - updateLL!(proposal::Assignment) - -Update the likelihood of the `proposal` assignment based on its observed and estimated -attributes. -""" -function updateLL!(proposal::Assignment) - # O(G^2) where G is the number of groups - proposal.likelihood = NetworkHistogram.compute_log_likelihood(proposal) -end - -""" - update_observed!(proposal::Assignment, swap::Tuple{Int, Int}, A) - -Update the observed and estimated attributes of the `proposal` assignment based on the -swap of the nodes specified in `swap`. - -NOTE labels of the nodes before the swap -""" - -function update_observed!(proposal::Assignment{T, 2}, swap::Tuple{Int, Int}, A) where {T} - group_node_1 = proposal.node_labels[swap[1]] - group_node_2 = proposal.node_labels[swap[2]] - - for i in axes(A, 1) - if i == swap[1] || i == swap[2] || A[swap[1], i] == A[swap[2], i] - continue - end - group_i = proposal.node_labels[i] - if A[i, swap[1]] == 1 - proposal.realized[group_node_1, group_i] -= 1 - proposal.realized[group_i, group_node_1] = proposal.realized[group_node_1, - group_i] - - proposal.realized[group_node_2, group_i] += 1 - proposal.realized[group_i, group_node_2] = proposal.realized[group_node_2, - group_i] - end - if A[i, swap[2]] == 1 - proposal.realized[group_node_2, group_i] -= 1 - proposal.realized[group_i, group_node_2] = proposal.realized[group_node_2, - group_i] - - proposal.realized[group_node_1, group_i] += 1 - proposal.realized[group_i, group_node_1] = proposal.realized[group_node_1, - group_i] - end - end - - @. proposal.estimated_theta = proposal.realized / proposal.counts - - return nothing -end - -function update_observed!(proposal::Assignment{T, 3}, swap::Tuple{Int, Int}, A) where {T} - group_node_1 = proposal.node_labels[swap[1]] - group_node_2 = proposal.node_labels[swap[2]] - if group_node_1 == group_node_2 - return nothing - end - - for i in axes(A, 1) - if i == swap[1] || i == swap[2] || A[swap[1], i] == A[swap[2], i] - continue - end - group_i = proposal.node_labels[i] - - proposal.realized[group_node_1, group_i, A[i, swap[1]]] -= 1 - proposal.realized[group_i, group_node_1, A[i, swap[1]]] = proposal.realized[group_node_1, - group_i, A[i, swap[1]]] - proposal.realized[group_node_2, group_i, A[i, swap[1]]] += 1 - proposal.realized[group_i, group_node_2, A[i, swap[1]]] = proposal.realized[group_node_2, - group_i, A[i, swap[1]]] - - proposal.realized[group_node_1, group_i, A[i, swap[2]]] += 1 - proposal.realized[group_i, group_node_1, A[i, swap[2]]] = proposal.realized[group_node_1, - group_i, A[i, swap[2]]] - proposal.realized[group_node_2, group_i, A[i, swap[2]]] -= 1 - proposal.realized[group_i, group_node_2, A[i, swap[2]]] = proposal.realized[group_node_2, - group_i, A[i, swap[2]]] - end - - @. proposal.estimated_theta = proposal.realized / proposal.counts - - return nothing -end diff --git a/src/pseudo_suff_stats/abstract_suffstat.jl b/src/pseudo_suff_stats/abstract_suffstat.jl new file mode 100644 index 0000000..883ce97 --- /dev/null +++ b/src/pseudo_suff_stats/abstract_suffstat.jl @@ -0,0 +1,17 @@ +abstract type SuffStats end + +function add_sample end +function remove_sample end +function make_k_block end + +# loss will be minimized +function loss end +function to_params end + +# some suffstat may need the edge index (i,j) to update properly +add_sample(suffstats::SuffStats, sample, i, j) = add_sample(suffstats, sample) +remove_sample(suffstats::SuffStats, sample, i, j) = remove_sample(suffstats, sample) + +include("categorical.jl") +include("bernoulli.jl") +include("generic.jl") diff --git a/src/pseudo_suff_stats/bernoulli.jl b/src/pseudo_suff_stats/bernoulli.jl new file mode 100644 index 0000000..c0c0d6e --- /dev/null +++ b/src/pseudo_suff_stats/bernoulli.jl @@ -0,0 +1,47 @@ + +struct BernoulliSuffStats{T} <: SuffStats + h::T + n::T +end + +function BernoulliSuffStats() + return BernoulliSuffStats{Int}(0, 0) +end + +function add_sample(ss::BernoulliSuffStats, sample::Bool) + sample && (@reset ss.h += 1) + @reset ss.n += 1 + return ss +end + +function add_sample(ss::BernoulliSuffStats, ::Nothing) + @reset ss.n += 1 + return ss +end + +function remove_sample(ss::BernoulliSuffStats, sample::Bool) + sample && (@reset ss.h -= 1) + @reset ss.n -= 1 + return ss +end + +function remove_sample(ss::BernoulliSuffStats, ::Nothing) + @reset ss.n -= 1 + return ss +end + +function make_k_block(k, ::Val{:binary}; kwargs...) + k_block = SymArray{BernoulliSuffStats{Int}}(undef, k, k) + fill!(k_block, BernoulliSuffStats()) + return k_block +end + +function loss(ss::BernoulliSuffStats) + n = max(ss.n, 1) + p = ss.h / n + return -n * (xlogx(1 - p) + xlogx(p)) +end + +function to_params(ss::BernoulliSuffStats) + return ss.h / max(ss.n, 1) +end diff --git a/src/pseudo_suff_stats/categorical.jl b/src/pseudo_suff_stats/categorical.jl new file mode 100644 index 0000000..4931e3b --- /dev/null +++ b/src/pseudo_suff_stats/categorical.jl @@ -0,0 +1,39 @@ +struct CategoricalSuffStats{M, T} <: SuffStats + h::SVector{M, T} +end + +function CategoricalSuffStats(num_categories::Int) + h = SVector{num_categories, Int}(zeros(Int, num_categories)) + return CategoricalSuffStats{num_categories, Int}(h) +end + +function add_sample(ss::CategoricalSuffStats, sample::Int) + ss = @set ss.h[sample] += 1 + return ss +end + +function remove_sample(ss::CategoricalSuffStats, sample::Int) + ss = @set ss.h[sample] -= 1 + return ss +end + +function make_k_block(k, ::Val{:categorical}; num_categories, kwargs...) + k_block = SymArray{CategoricalSuffStats{num_categories, Int}}(undef, k, k) + fill!(k_block, CategoricalSuffStats(num_categories)) + return k_block +end + +function loss(ss::CategoricalSuffStats) + n = sum(ss.h) + return n - sum(abs2, ss.h) / max(n, 1) +end + +function to_params(ss::CategoricalSuffStats) + return custom_normalize(ss.h) +end + +function custom_normalize(ps::SVector{M, T}) where {M, T} + n = sum(ps) + n == 0 && return zero(SVector{M, T}) + return ps / n +end diff --git a/src/pseudo_suff_stats/generic.jl b/src/pseudo_suff_stats/generic.jl new file mode 100644 index 0000000..0444c6e --- /dev/null +++ b/src/pseudo_suff_stats/generic.jl @@ -0,0 +1,66 @@ +struct GenericSuffStats{T, D} <: SuffStats + samples::Vector{T} + dist::D +end + +function GenericSuffStats(::AbstractArray{T}, dist::D) where {T, D} + return GenericSuffStats{T, D}(Vector{T}(), dist) +end + +function get_samples(ss::GenericSuffStats) + return ss.samples +end + +function add_sample(ss::GenericSuffStats, sample) + append!(ss.samples, sample) + return ss +end + +function remove_sample(ss::GenericSuffStats, sample) + index = findfirst(==(sample), ss.samples) + if index !== nothing + deleteat!(ss.samples, index) + end + return ss +end + +function make_k_block(k, generic; data::AbstractArray, dist::D, kwargs...) where {D} + @warn "Using GenericSuffStats may be very slow even for small graphs. + Consider using more specialized sufficient statistics types when possible." + k_block = SymArray{GenericSuffStats{eltype(data), D}}(undef, k, k) + for j in 1:k, i in 1:k + + k_block[i, j] = GenericSuffStats(data, dist) + end + return k_block +end + +# use indices rather than pushing and deleting samples for better performance ? +# struct GenericSuffStatsIndex{T} <: GenericSuffStatsType +# indices::Vector{Tuple{Int, Int}} +# data::T +# end + +# function get_samples(ss::GenericSuffStatsIndex) +# return [ss.data[i, j] for (i, j) in ss.indices] +# end + +# function GenericSuffStatsIndex{T}(data::T) where {T} +# return GenericSuffStatsIndex{T}(Vector{Tuple{Int, Int}}(), data) +# end + +# function add_sample(ss::GenericSuffStatsIndex, sample, i, j) +# push!(ss.indices, (i, j)) +# return ss +# end + +function loss(ss::GenericSuffStats{T, D}) where {T, D} + samples = get_samples(ss) + d = fit(D, samples) + return -mapreduce(BaseFix1(logpdf, d), +, samples) +end + +function to_params(ss::GenericSuffStats) + d = fit(typeof(ss.dist), get_samples(ss)) + return params(d) +end diff --git a/src/utils.jl b/src/utils.jl deleted file mode 100644 index 3b85a6f..0000000 --- a/src/utils.jl +++ /dev/null @@ -1,109 +0,0 @@ -function laplacian(A) - s = sum(A; dims = 1) - return diagm(vec(s)) - A -end - -function normalized_laplacian(A) - L = zeros(size(A)) - degrees = vec(sum(A, dims = 1)) - for j in 1:size(A, 1) - for i in 1:size(A, 2) - if i == j - L[i, j] = 1 - else - L[i, j] = A[i, j] / sqrt(degrees[i] * degrees[j]) - end - end - end - return L -end - -function normalized_laplacian(A::AbstractArray{T, 3}) where {T} - L = zeros(size(A, 1), size(A, 2)) - for layer in eachslice(A, dims = 3) - L .+= normalized_laplacian(layer) - end - return L ./ size(A, 3) -end - -function drop_disconnected_components(A::AbstractArray{T, 2}) where {T} - indices = findall(x -> x != 0, vec(sum(A, dims = 1))) - return A[indices, indices] -end - -function drop_disconnected_components(A::AbstractArray{T, 3}) where {T} - indices = findall(x -> x != 0, vec(sum(A, dims = (1, 3)))) - return A[indices, indices, :] -end - -""" - hamming_distance(x, y) - -Compute the normalized Hamming distance between two vectors `x` and `y`. -""" -function hamming_distance(x::Vector{T}, y::Vector{T}) where {T} - return sum(x .!= y) / length(x) -end - -""" - pairwise_hamming_distance(A) - -Compute the pairwise Hamming distance between all rows of `A`. If `A` is a 3D -array, then the average Hamming distance for each layer of the array is returned. -""" -pairwise_hamming_distance - -function pairwise_hamming_distance(matrix::AbstractArray{T, 2}) where {T} - n = size(matrix, 1) - dist_matrix = zeros(n, n) - for i in 1:n, j in (i + 1):n - dist_matrix[i, j] = hamming_distance(matrix[i, :], matrix[j, :]) - dist_matrix[j, i] = dist_matrix[i, j] # Symmetric matrix - end - return dist_matrix -end - -function pairwise_hamming_distance(matrix::AbstractArray{T, 3}) where {T} - n = size(matrix, 1) - dist_matrix = zeros(n, n) - for layer in eachslice(matrix, dims = 3) - dist_matrix .+= pairwise_hamming_distance(layer) - end - return dist_matrix ./ size(matrix, 3) -end - -function spectral_clustering(A, h) - n = size(A, 1) - - L = 1 .- pairwise_hamming_distance(A) ./ n - - # Compute the degree matrix - d = sum(L, dims = 2) - - # Compute the normalized Laplacian - normalized_L = sum(1.0 ./ d) .* L .- sum(d) / sqrt(sum(d .^ 2)) - - # Compute eigenvalues and eigenvectors of the normalized Laplacian - - decomp, history = partialschur(normalized_L, nev = 2, which = LR()) - _, eigen_vecs = partialeigen(decomp) - - # Extract the second eigenvector - u = real.(eigen_vecs[:, 1]) - u = u .* sign(u[1]) # Set the first coordinate >= 0 wlog - - # Sort based on the embedding - ind = sortperm(u, alg = QuickSort, rev = false) - - # Determine the number of clusters - k = ceil(Int, n / h) - - # Initialize cluster assignments - idxInit = zeros(Int, n) - for i in 1:k - for j in ((i - 1) * h + 1):min(n, i * h) - idxInit[ind[j]] = i - end - end - return idxInit -end diff --git a/src/utils/utils_node_labels.jl b/src/utils/utils_node_labels.jl new file mode 100644 index 0000000..d5b50f0 --- /dev/null +++ b/src/utils/utils_node_labels.jl @@ -0,0 +1,61 @@ +function ordered_start_labels(n::Int, k::Int) + labels = Vector{Int}(undef, n) + base_size = n ÷ k + remainder = n % k + for group in 1:k + fill!(view(labels, ((group - 1) * base_size + 1):(group * base_size)), group) + end + if remainder > 0 + fill!(view(labels, (k * base_size + 1):(k * base_size + remainder)), k) + end + return labels +end + +function align_res_true_latents!(res::NethistResult, latents) + new_labels, mapping = order_groups(res.labels, latents) + res.labels .= new_labels + perm = [key for (key, val) in sort(collect(mapping), by = last)] + permute!(res.model, perm) +end + +#TODO: move to Graphons.jl see https://github.com/SDS-EPFL/Graphons.jl/pull/17 +function permute!(sbm, perm) + permuted_theta = copy(sbm.θ) + sbm.θ .= permuted_theta[perm, perm] + sbm.size .= sbm.size[perm] + sbm.cumsize .= cumsum(sbm.size) +end + +function order_groups(node_labels, latents::AbstractVector) + n = length(node_labels) + k = length(unique(node_labels)) + sort_perm = sortperm(latents) + sorted_group_labels = node_labels[sort_perm] + dummy_group_labels = repeat(1:k, inner = n ÷ k + 1)[1:n] + counts = Dict( + group => countmap(dummy_group_labels[sorted_group_labels .== group]) + for + group in 1:k + ) + perm = sort(1:k, by = x -> Tuple(get(counts[x], g, 0) for g in 1:k), rev = true) + new_labels = map(x -> findfirst(==(x), perm), node_labels) + mapping = Dict(perm[i] => i for i in 1:k) + return new_labels, mapping +end + +function get_num_obs(A::AbstractMatrix) + n = size(A, 1) + return n * (n - 1) ÷ 2 +end + +""" +Align the source and target matrices using optimal transport. This function requires +the PythonCall.jl package to be loaded +""" +function align_matrices end + +""" +Get the permutation aligning source and target matrices using optimal transport. This function requires +the PythonCall.jl package to be loaded +""" +function get_perm_alignment end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..e35dbe4 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,13 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +Bootstrap = "e28b5b4c-05e8-5b66-bc03-6f0c0a0a06e0" +DiscretizeDistributions = "1dbf0e27-43cd-4e03-8ecf-3f7be9d12b15" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/config_rules/accept_rule_test.jl b/test/config_rules/accept_rule_test.jl deleted file mode 100644 index 56de1f3..0000000 --- a/test/config_rules/accept_rule_test.jl +++ /dev/null @@ -1,33 +0,0 @@ -import NetworkHistogram: accept_reject_update!, initialize_history -@testset "accept rule" begin - iteration = 3 - A, node_labels, group_size, proposal = make_simple_example() - - proposal.likelihood = 0.0 - current = deepcopy(proposal) - best = deepcopy(proposal) - - histories = [ - initialize_history(best, current, proposal, Val{true}()), - initialize_history(best, current, proposal, Val{false}()), - ] - test_likelihoods = [-0.1, 0.1] - for history in histories, lik in test_likelihoods - proposal.likelihood = lik # set proposal - accept_reject_update!(history, iteration, proposal, current, Strict()) - - @testset "Strict with history is $(typeof(history).name.name), likelihood=$lik" begin - @test current.likelihood == max(lik, 0.0) # should have accepted if better - if history isa NetworkHistogram.TraceHistory - @test get(history.history, :current_likelihood)[1][end] == iteration - @test get(history.history, :current_likelihood)[2][end] == - current.likelihood - else - @test history.current_iteration == iteration - end - end - - current.likelihood = 0.0 # reset for next example - iteration += 1 # otherwise will get an error from history - end -end diff --git a/test/config_rules/config_rule_test.jl b/test/config_rules/config_rule_test.jl deleted file mode 100644 index 7c7c95c..0000000 --- a/test/config_rules/config_rule_test.jl +++ /dev/null @@ -1,6 +0,0 @@ -@testset "config rules" begin - include("accept_rule_test.jl") - include("starting_assigment_rule_test.jl") - include("stop_rule_test.jl") - include("swap_rule_test.jl") -end diff --git a/test/config_rules/starting_assigment_rule_test.jl b/test/config_rules/starting_assigment_rule_test.jl deleted file mode 100644 index 5ff2617..0000000 --- a/test/config_rules/starting_assigment_rule_test.jl +++ /dev/null @@ -1,33 +0,0 @@ -import NetworkHistogram: initialize_node_labels - -@testset "starting assignment rules" begin - @testset "starting assigment rule simple graphs" begin - A, _, _, _ = make_simple_example() - for method in (OrderedStart(), - RandomStart(), - EigenStart(), - DistStart()) - node_labels, group_size = initialize_node_labels(A, 4, method) - if method isa OrderedStart - @test sort(node_labels) == node_labels - end - @test all(sum(n -> n == j, node_labels) == group_size[j] - for j in unique(node_labels)) - end - end - - @testset "starting assigment rule multilayer graphs" begin - A, _, _, _ = make_multivariate_example() - for method in (OrderedStart(), - RandomStart(), - EigenStart(), - DistStart()) - node_labels, group_size = initialize_node_labels(A, 4, method) - if method isa OrderedStart - @test sort(node_labels) == node_labels - end - @test all(sum(n -> n == j, node_labels) == group_size[j] - for j in unique(node_labels)) - end - end -end diff --git a/test/config_rules/stop_rule_test.jl b/test/config_rules/stop_rule_test.jl deleted file mode 100644 index e814c70..0000000 --- a/test/config_rules/stop_rule_test.jl +++ /dev/null @@ -1,20 +0,0 @@ -import NetworkHistogram: stopping_rule -@testset "stop rule" begin - A, node_labels, group_size, proposal = make_simple_example() - - proposal.likelihood = -0.1 - current = deepcopy(proposal) - best = deepcopy(proposal) - - histories = [ - initialize_history(best, current, proposal, Val{true}()), - initialize_history(best, current, proposal, Val{false}()), - ] - for history in histories - for i in 1:4 - NetworkHistogram.update_current!(history, i, 0.0) - end - @test stopping_rule(history, PreviousBestValue(3)) == true - @test stopping_rule(history, PreviousBestValue(4)) == false - end -end diff --git a/test/config_rules/swap_rule_test.jl b/test/config_rules/swap_rule_test.jl deleted file mode 100644 index bcf4075..0000000 --- a/test/config_rules/swap_rule_test.jl +++ /dev/null @@ -1,7 +0,0 @@ -import NetworkHistogram: select_swap -@testset "swap rule" begin - A, node_labels, group_size, assignment = make_simple_example() - x = select_swap(assignment, A, RandomNodeSwap()) - @test x isa Tuple{Int, Int} - @test all(1 .≤ x .≤ size(A, 1)) -end diff --git a/test/data_tests/utils.jl b/test/data_tests/utils.jl deleted file mode 100644 index d37db8a..0000000 --- a/test/data_tests/utils.jl +++ /dev/null @@ -1,13 +0,0 @@ -@testset "Data utils" begin - A = [0 0 0 1 - 0 0 0 0 - 1 0 0 1 - 0 0 1 0] - - @testset "drop isolated vertices" begin - B = NetworkHistogram.drop_isolated_vertices(A) - @test B == [0 1 1 - 1 0 1 - 0 1 0] - end -end diff --git a/test/error_handling_tests.jl b/test/error_handling_tests.jl deleted file mode 100644 index 76fae2b..0000000 --- a/test/error_handling_tests.jl +++ /dev/null @@ -1,27 +0,0 @@ -@testset "Error handling" begin - @testset "Adjacency matrix" begin - As = [ - [0 1 - 0 0], [1 1 - 1 0], [0 2 - 2 0], [0 1 - 1 0 - 0 1], - ] - for A in As - @test_throws AssertionError graphhist(A, h = 2) - @test_throws AssertionError graphhist(Bool.(min.(A, 1)), h = 2) - @test_throws AssertionError graphhist(Float64.(A), h = 2) - end - @test_throws AssertionError graphhist(["0" "1"; "1" "0"], h = 2) - end - @testset "maxitr" begin - @test_throws AssertionError graphhist([0 1; 1 0], h = 2, - maxitr = -1) - end - @testset "h" begin - for h in (3, -1, 1.1, -0.1) - @test_throws AssertionError graphhist([0 1; 1 0], h = h) - end - end -end diff --git a/test/oracle_bandwidth_test.jl b/test/oracle_bandwidth_test.jl deleted file mode 100644 index 27e1f26..0000000 --- a/test/oracle_bandwidth_test.jl +++ /dev/null @@ -1,18 +0,0 @@ -@testset "oracle bandwidth test" begin - A = [0 0 1 0 1 0 1 1 0 1 - 0 0 1 1 1 1 1 1 0 0 - 1 1 0 1 0 0 0 0 1 0 - 0 1 1 0 1 0 1 0 0 0 - 1 1 0 1 0 0 1 0 0 1 - 0 1 0 0 0 0 0 1 0 0 - 1 1 0 1 1 0 0 1 0 1 - 1 1 0 0 0 1 1 0 0 1 - 0 0 1 0 0 0 0 0 0 1 - 1 0 0 0 1 0 1 1 1 0] - h = NetworkHistogram.oracle_bandwidth(A) - rho = sum(A) / (size(A, 1) * (size(A, 1) - 1)) - h_true_nethist = 2.643731 # version 0.2.3 from nethist package - h_clean = 3 - @test h≈h_true_nethist atol=1e-4 - @test NetworkHistogram.select_bandwidth(A) == h_clean -end diff --git a/test/pipeline_test.jl b/test/pipeline_test.jl deleted file mode 100644 index 7f86664..0000000 --- a/test/pipeline_test.jl +++ /dev/null @@ -1,84 +0,0 @@ -@testset "Pipeline" begin - A = [0 0 1 0 1 0 1 1 0 1 - 0 0 1 1 1 1 1 1 0 0 - 1 1 0 1 0 0 0 0 1 0 - 0 1 1 0 1 0 1 0 0 0 - 1 1 0 1 0 0 1 0 0 1 - 0 1 0 0 0 0 0 1 0 0 - 1 1 0 1 1 0 0 1 0 1 - 1 1 0 0 0 1 1 0 0 1 - 0 0 1 0 0 0 0 0 0 1 - 1 0 0 0 1 0 1 1 1 0] - @testset "dummy run" begin - @testset "run bandwidth float" begin - estimated = graphhist(A; h = 0.5) - @test all(estimated.graphhist.θ .>= 0.0) - @test all(estimated.graphhist.θ .<= 1.0) - @test size(estimated.graphhist.θ) == (2, 2) - end - @testset "run bandwidth int" begin - estimated = graphhist(A; h = 5) - @test all(estimated.graphhist.θ .>= 0.0) - @test all(estimated.graphhist.θ .<= 1.0) - @test size(estimated.graphhist.θ) == (2, 2) - end - @testset "run with automatic bandwidth" begin - estimated = graphhist(A) - @test all(estimated.graphhist.θ .>= 0.0) - @test all(estimated.graphhist.θ .<= 1.0) - end - end - - @testset "associative stochastic block model" begin - adjacencies = load(pwd() * "/test_files/sbm.jld") - - for (name, adjacency) in adjacencies - @testset "$name" begin - estimated, history = graphhist(adjacency; h = 0.3, - stop_rule = PreviousBestValue(100), - starting_assignment_rule = OrderedStart()) - @test all(estimated.θ .>= 0.0) - estimated, history = graphhist(adjacency; h = 0.3, - stop_rule = PreviousBestValue(100), - starting_assignment_rule = OrderedStart(), - record_trace = false) - @test all(estimated.θ .>= 0.0) - end - end - end - - @testset "multilayer run" begin - @testset "2 layers perfectly correlated" begin - A_2 = cat(A, A, dims = 3) - estimated, history = graphhist(A_2; h = 0.5) - @test all(estimated.θ .>= 0.0) - @test all(estimated.θ .<= 1.0) - @test size(estimated.θ) == (2, 2, 4) - end - @testset "run with automatic bandwidth" begin - A_2 = cat(A, A, dims = 3) - estimated, history = graphhist(A_2) - @test all(estimated.θ .>= 0.0) - @test all(estimated.θ .<= 1.0) - end - - @testset "2 layers perfectly anti-correlated" begin - A_2 = cat(A, abs.(A .- 1), dims = 3) - for i in 1:size(A, 1) - A_2[i, i, 2] = 0 - end - estimated, history = graphhist(A_2; h = 0.5) - @test all(estimated.θ .>= 0.0) - @test all(estimated.θ .<= 1.0) - @test size(estimated.θ) == (2, 2, 4) - end - - @testset "3 layers" begin - A_3 = cat(A, A, A, dims = 3) - estimated, history = graphhist(A_3; h = 0.5) - @test all(estimated.θ .>= 0.0) - @test all(estimated.θ .<= 1.0) - @test size(estimated.θ) == (2, 2, 8) - end - end -end diff --git a/test/proposal_test.jl b/test/proposal_test.jl deleted file mode 100644 index 0328b8e..0000000 --- a/test/proposal_test.jl +++ /dev/null @@ -1,30 +0,0 @@ -@testset "Proposal" begin - A, node_labels, group_size, assignment = make_simple_example() - h = 0.5 - swap = (2, 5) - proposal = deepcopy(assignment) - NetworkHistogram.make_proposal!(proposal, assignment, swap, A) - reference_proposal = NetworkHistogram.Assignment(A, [1, 2, 1, 1, 1, 2, 2, 2], - group_size) - - @testset "update labels" begin - @test proposal.node_labels[swap[1]] == reference_proposal.node_labels[swap[1]] == 2 - @test proposal.node_labels[swap[2]] == reference_proposal.node_labels[swap[2]] == 1 - end - - @testset "update realized edges" begin - @test proposal.realized[1, 2] == reference_proposal.realized[1, 2] == 8 - @test proposal.realized[2, 1] == reference_proposal.realized[2, 1] == 8 - @test proposal.realized[1, 1] == reference_proposal.realized[1, 1] == 2 - @test proposal.realized[2, 2] == reference_proposal.realized[2, 2] == 2 - end - - @testset "fast likelihood update" begin - # inside each group likelihood contribution - theoretical_after_update = 2 * (2 * log(2 / 6) + log(4 / 6) * 4) - # between group likelihood contribution - theoretical_after_update += 8 * log(8 / 16) * 2 - @test proposal.likelihood == theoretical_after_update == - reference_proposal.likelihood - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 643bfc7..bdbb9d6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,15 +1,9 @@ -using NetworkHistogram using Test +using LinearAlgebra, SparseArrays +using NetworkHistogram -using JLD -include("simple_test_example.jl") - -@testset "NetworkHistogram.jl" begin - include("pipeline_test.jl") - include("test_multilayer.jl") - include("proposal_test.jl") - include("starting_labels_test.jl") - include("oracle_bandwidth_test.jl") - include("error_handling_tests.jl") - include("config_rules/config_rule_test.jl") +@testset "Tests" begin + include("test_symarray.jl") + include("test_pseudo_suff_stats.jl") + include("test_hist_dist.jl") end diff --git a/test/simple_test_example.jl b/test/simple_test_example.jl deleted file mode 100644 index 010f511..0000000 --- a/test/simple_test_example.jl +++ /dev/null @@ -1,39 +0,0 @@ -""" - make_simple_example() - -Makes the simple example used in many tests. -Returns A, node_labels, group_size, assignment -""" -function make_simple_example() - A = [0 1 1 1 0 0 1 0 - 1 0 1 1 0 0 0 0 - 1 1 0 0 0 0 0 0 - 1 1 0 0 0 0 0 1 - 0 0 0 0 0 1 1 1 - 0 0 0 0 1 0 1 1 - 1 0 0 0 1 1 0 0 - 0 0 0 1 1 1 0 0] - node_labels = [1, 1, 1, 1, 2, 2, 2, 2] - group_size = NetworkHistogram.GroupSize(8, 4) - assignment = NetworkHistogram.Assignment(A, node_labels, group_size) - return A, node_labels, group_size, assignment -end - -function make_multivariate_example() - A = [0 1 1 1 0 0 1 0 - 1 0 1 1 0 0 0 0 - 1 1 0 0 0 0 0 0 - 1 1 0 0 0 0 0 1 - 0 0 0 0 0 1 1 1 - 0 0 0 0 1 0 1 1 - 1 0 0 0 1 1 0 0 - 0 0 0 1 1 1 0 0] - A = cat(A, abs.(A .- 1), dims = 3) - for i in size(A, 1) - A[i, i, :] .= 0 - end - node_labels = [1, 1, 1, 1, 2, 2, 2, 2] - group_size = NetworkHistogram.GroupSize(8, 4) - assignment = NetworkHistogram.Assignment(A, node_labels, group_size) - return A, node_labels, group_size, assignment -end diff --git a/test/starting_labels_test.jl b/test/starting_labels_test.jl deleted file mode 100644 index e94c3e4..0000000 --- a/test/starting_labels_test.jl +++ /dev/null @@ -1,41 +0,0 @@ - -""" - test_basic_node_labels(node_labels, group_size) - -Test that the node labels are valid: - - correct number of labels - - labels are positive - - labels are within the range of number of groups -""" -function test_basic_node_labels(node_labels, group_size) - @test length(node_labels) == sum(group_size) - @test all(node_labels .> 0) - @test all(node_labels .<= length(group_size)) - for (i, group_s) in enumerate(group_size) - @test count(x -> x == i, node_labels) == group_s - end -end - -@testset "Initial node labels" begin - A, _, _, _ = make_simple_example() - h = 0.5 - - @testset "random start" begin - node_labels, group_size = NetworkHistogram.initialize_node_labels(A, h, - RandomStart()) - test_basic_node_labels(node_labels, group_size) - end - - @testset "ordered start" begin - node_labels, group_size = NetworkHistogram.initialize_node_labels(A, h, - OrderedStart()) - test_basic_node_labels(node_labels, group_size) - @test node_labels == sort(node_labels) - end - - @testset "eigenvalue start" begin - node_labels, group_size = NetworkHistogram.initialize_node_labels(A, h, - EigenStart()) - test_basic_node_labels(node_labels, group_size) - end -end diff --git a/test/test_files/sbm.jld b/test/test_files/sbm.jld deleted file mode 100644 index dcd8b9f..0000000 Binary files a/test/test_files/sbm.jld and /dev/null differ diff --git a/test/test_hist_dist.jl b/test/test_hist_dist.jl new file mode 100644 index 0000000..8860223 --- /dev/null +++ b/test/test_hist_dist.jl @@ -0,0 +1,7 @@ +using Test +using NetworkHistogram +using StaticArrays +using Distributions +import NetworkHistogram as NH + +@testset "Histogram-based Distribution" begin end diff --git a/test/test_multilayer.jl b/test/test_multilayer.jl deleted file mode 100644 index aba8e04..0000000 --- a/test/test_multilayer.jl +++ /dev/null @@ -1,10 +0,0 @@ -@testset "multilayer" begin - @testset "test initialisation of assignments" begin - A, labels, group_size, assignment = make_simple_example() - A2, _, _, assignment2 = make_multivariate_example() - @test all(assignment.estimated_theta .== assignment2.estimated_theta[:, :, 2]) - @test all(assignment.realized .== assignment2.realized[:, :, 2]) - @test assignment.likelihood == assignment2.likelihood - @test sum(assignment2.estimated_theta) ≈ size(assignment2.estimated_theta, 1)^2 - end -end diff --git a/test/test_pseudo_suff_stats.jl b/test/test_pseudo_suff_stats.jl new file mode 100644 index 0000000..a22509c --- /dev/null +++ b/test/test_pseudo_suff_stats.jl @@ -0,0 +1,50 @@ +using Test +using NetworkHistogram +using StaticArrays +using Distributions +import NetworkHistogram as NH + +function _one_hot_vector(sample::Int, num_categories::Int) + v = zeros(Int, num_categories) + v[sample] = 1 + return v +end + +@testset "Bernoulli" begin + @testset "loss" begin + ss = NH.BernoulliSuffStats() + samples = [true, false, true, true, false, true, false, false, true, true] + for s in samples + ss = NH.add_sample(ss, s) + end + d = fit_mle(Bernoulli, samples) + @test NH.loss(ss) ≈ -sum(map(Base.Fix1(logpdf, d), samples)) + @test NH.to_params(ss) == d.p + end +end + +@testset "Categorical" begin + @testset "loss" begin + ss = NH.CategoricalSuffStats(3) + samples = [1, 2, 1, 2, 3, 1, 2, 3, 1, 2] + s_vec = _one_hot_vector.(samples, 3) + for s in samples + ss = NH.add_sample(ss, s) + end + d = fit_mle(Categorical, samples) + p = probs(d) + loss_val = 0.0 + for s in s_vec + loss_val += sum(abs2, s - p) + end + @test NH.loss(ss) ≈ loss_val + @test NH.to_params(ss) == p + + samples = ones(Int, 10) + ss_unique = NH.CategoricalSuffStats(3) + for s in samples + ss_unique = NH.add_sample(ss_unique, s) + end + @assert NH.loss(ss_unique) == 0.0 + end +end diff --git a/test/test_symarray.jl b/test/test_symarray.jl new file mode 100644 index 0000000..ce950ea --- /dev/null +++ b/test/test_symarray.jl @@ -0,0 +1,422 @@ +using Test +using NetworkHistogram +using SparseArrays +using LinearAlgebra +using StaticArrays + +@testset "SymArray Array Interface" begin + @testset "Construction and basic properties" begin + # Test construction with scalar + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 1.0) + @test a isa AbstractArray{Float64, 2} + @test size(a) == (3, 3) + @test length(a) == 9 + @test axes(a) == (1:3, 1:3) + @test eltype(a) == Float64 + + # Test construction with zeros + b = SymArray{Float64}(undef, 5, 5) + fill!(b, 0.0) + @test size(b) == (5, 5) + @test all(b[i, j] == 0.0 for i in 1:5 for j in 1:5) + + # Test dimension validation + @test_throws ArgumentError SymArray{Float64}(undef, 3, 4) + end + + @testset "Indexing - getindex and setindex!" begin + a = SymArray{Float64}(undef, 4, 4) + fill!(a, 0.0) + + # Test setindex! in upper triangle + a[1, 2] = 5.0 + @test a[1, 2] == 5.0 + @test a[2, 1] == 5.0 # Symmetry + + # Test setindex! in lower triangle (should set upper) + a[3, 2] = 7.0 + @test a[2, 3] == 7.0 + @test a[3, 2] == 7.0 + + # Test diagonal + a[2, 2] = 3.0 + @test a[2, 2] == 3.0 + + # Test bounds checking + @test_throws BoundsError a[0, 1] + @test_throws BoundsError a[5, 1] + @test_throws BoundsError a[1, 5] + end + + @testset "Symmetry property" begin + a = SymArray{Float64}(undef, 5, 5) + fill!(a, 0.0) + + # Set values and verify symmetry + for i in 1:5 + for j in 1:5 + val = i * 10 + j + a[i, j] = val + @test a[i, j] == a[j, i] + end + end + end + + @testset "Construction from matrix" begin + # Test from symmetric matrix + M = [1.0 2.0 3.0; + 2.0 4.0 5.0; + 3.0 5.0 6.0] + a = SymArray(M) + + @test size(a) == (3, 3) + for i in 1:3, j in 1:3 + + @test a[i, j] == M[i, j] + end + + # Test non-square matrix throws error + @test_throws ArgumentError SymArray([1.0 2.0; 3.0 4.0; 5.0 6.0]) + end + + @testset "convert functions" begin + # Test conversion to SymArray + M = [1.0 2.0; 2.0 4.0] + a = convert(SymArray{Float64}, M) + @test a isa SymArray{Float64} + @test a[1, 1] == 1.0 + @test a[1, 2] == 2.0 + @test a[2, 2] == 4.0 + + # Test conversion to AbstractMatrix + b = convert(Matrix{Float64}, a) + @test b isa Matrix{Float64} + @test b == M + @test b[1, 2] == b[2, 1] # Verify symmetry + end + + @testset "similar function" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 5.0) + + # Test similar without type + b = similar(a) + @test size(b) == size(a) + @test eltype(b) == eltype(a) + @test b isa SymArray{Float64} + + # Test similar with type + c = similar(a, Int) + @test size(c) == size(a) + @test eltype(c) == Int + @test c isa SymArray{Int} + + # Test similar with type and dimensions + d = similar(a, Float32, (4, 4)) + @test size(d) == (4, 4) + @test eltype(d) == Float32 + + # Test non-square dimensions throw error + @test_throws ArgumentError similar(a, Float64, (3, 4)) + end + + @testset "copy! and deepcopy!" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 0.0) + a[1, 1] = 1.0 + a[1, 2] = 2.0 + a[2, 3] = 5.0 + + b = similar(a) + copy!(b, a) + + @test b[1, 1] == 1.0 + @test b[1, 2] == 2.0 + @test b[2, 1] == 2.0 + @test b[2, 3] == 5.0 + @test b[3, 2] == 5.0 + + # Test dimension mismatch + d = SymArray{Float64}(undef, 4, 4) + fill!(d, 0.0) + @test_throws DimensionMismatch copy!(d, a) + + # Test deepcopy! + src = SymArray{Vector{Int}}(undef, 4, 4) + for j in 1:4, i in j:4 + + src[i, j] = [i, j] + end + + # on unassigned dest + dest = similar(src) + deepcopy!(dest, src) + for j in 1:4, i in j:4 + + @test dest[i, j] == src[i, j] + @test !(dest[i, j] === src[i, j]) # Ensure deep copy + end + + # on assigned dest + dest2 = similar(src) + for j in 1:4, i in j:4 + + dest2[i, j] = [-1, -1] + end + deepcopy!(dest2, src) + for j in 1:4, i in j:4 + + @test dest2[i, j] == src[i, j] + @test !(dest2[i, j] === src[i, j]) # Ensure deep copy + end + end + + @testset "Array operations" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 2.0) + + # Test iteration + count = 0 + for val in a + @test val == 2.0 + count += 1 + end + @test count == 9 + + # Test sum + @test sum(a) == 18.0 + + # Test all/any + @test all(x -> x == 2.0, a) + @test any(x -> x == 2.0, a) + + # Test maximum/minimum + b = SymArray{Float64}(undef, 3, 3) + fill!(b, 0.0) + b[1, 1] = 5.0 + b[2, 3] = -3.0 + @test maximum(b) == 5.0 + @test minimum(b) == -3.0 + end + + @testset "Mathematical operations" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 2.0) + b = SymArray{Float64}(undef, 3, 3) + fill!(b, 3.0) + + # Element-wise operations (using broadcasting) + c = a .+ b + @test c isa SymArray + @test all(c[i, j] == 5.0 for i in 1:3, j in 1:3) + + d = a .* 2 + @test d isa SymArray + @test all(d[i, j] == 4.0 for i in 1:3, j in 1:3) + + # Test subtraction + e = b .- a + @test e isa SymArray + @test all(e[i, j] == 1.0 for i in 1:3, j in 1:3) + + # Test division + f = b ./ 2.0 + @test f isa SymArray + @test all(f[i, j] == 1.5 for i in 1:3, j in 1:3) + + # Test unary operations + g = SymArray{Float64}(undef, 3, 3) + fill!(g, -2.0) + h = abs.(g) + @test h isa SymArray + @test all(h[i, j] == 2.0 for i in 1:3, j in 1:3) + + # Test with mixed values + m = SymArray{Float64}(undef, 3, 3) + fill!(m, 0.0) + m[1, 1] = 1.0 + m[1, 2] = 2.0 + m[2, 2] = 3.0 + m[1, 3] = 4.0 + m[2, 3] = 5.0 + m[3, 3] = 6.0 + + n = m .+ 10.0 + @test n isa SymArray + @test n[1, 1] == 11.0 + @test n[1, 2] == 12.0 + @test n[2, 1] == 12.0 # Symmetry + @test n[2, 2] == 13.0 + @test n[3, 3] == 16.0 + + # Test operations between two SymArrays with different values + p = SymArray{Float64}(undef, 3, 3) + fill!(p, 0.0) + p[1, 1] = 10.0 + p[2, 2] = 20.0 + p[3, 3] = 30.0 + + q = m .+ p + @test q isa SymArray + @test q[1, 1] == 11.0 + @test q[2, 2] == 23.0 + @test q[3, 3] == 36.0 + @test q[1, 2] == 2.0 + @test q[2, 1] == 2.0 + end + + @testset "Type stability" begin + # Float64 + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 1.0) + @test typeof(a[1, 1]) == Float64 + + # Int + b = SymArray{Int}(undef, 3, 3) + fill!(b, 1) + @test typeof(b[1, 1]) == Int + + # Float32 + c = SymArray{Float32}(undef, 3, 3) + fill!(c, 1.0f0) + @test typeof(c[1, 1]) == Float32 + end + + @testset "Sparse matrix properties" begin + a = SymArray{Float64}(undef, 10, 10) + fill!(a, 0.0) + # Initially all elements are stored (including zeros) + # Set only a few elements to non-zero + a[1, 5] = 3.0 + a[3, 7] = 4.0 + a[9, 9] = 5.0 + + # Verify values are correct (symmetry) + @test a[1, 5] == 3.0 + @test a[5, 1] == 3.0 + @test a[3, 7] == 4.0 + @test a[7, 3] == 4.0 + @test a[9, 9] == 5.0 + @test a[2, 2] == 0.0 + end + + @testset "Edge cases" begin + # 1x1 matrix + a = SymArray{Float64}(undef, 1, 1) + fill!(a, 5.0) + @test size(a) == (1, 1) + @test a[1, 1] == 5.0 + a[1, 1] = 10.0 + @test a[1, 1] == 10.0 + + # Large diagonal + b = SymArray{Float64}(undef, 100, 100) + fill!(b, 0.0) + for i in 1:100 + b[i, i] = Float64(i) + end + @test b[50, 50] == 50.0 + @test b[99, 99] == 99.0 + end + + @testset "Broadcasting" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 2.0) + b = @. a + 2.0 + @test b isa SymArray + @test all(b[i, j] == 4.0 for i in 1:3, j in 1:3) + + c = b ./ a + @test c isa SymArray + @test all(c[i, j] == 2.0 for i in 1:3, j in 1:3) + + sin_a = @. sin(a) + sin_a_bis = sin.(a) + for sin_test in (sin_a, sin_a_bis) + @test sin_test isa SymArray + @test all(sin_test[i, j] == sin(2.0) for i in 1:3, j in 1:3) + end + end + + @testset "Broadcasting with regular arrays" begin + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 2.0) + M = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] + + # SymArray + Matrix should return Matrix (follows Matrix type) + result1 = a .+ M + @test result1 isa Matrix{Float64} + @test !(result1 isa SymArray) + + # Matrix + SymArray should also return Matrix + result2 = M .+ a + @test result2 isa Matrix{Float64} + @test !(result2 isa SymArray) + + # Check values are correct + for i in 1:3, j in 1:3 + + @test result1[i, j] ≈ 2.0 + M[i, j] + @test result2[i, j] ≈ M[i, j] + 2.0 + end + + # SymArray + scalar should still return SymArray + result3 = a .+ 5.0 + @test result3 isa SymArray + + # SymArray + SymArray should return SymArray + b = SymArray{Float64}(undef, 3, 3) + fill!(b, 3.0) + result4 = a .+ b + @test result4 isa SymArray + end + + @testset "SymArray broadcast with Matrix returns Matrix" begin + # Create a SymArray and a regular Matrix + a = SymArray{Float64}(undef, 3, 3) + fill!(a, 2.0) + M = [1.0 2.0 3.0; 4.0 5.0 6.0; 7.0 8.0 9.0] + + # SymArray + Matrix should return Matrix + result1 = a .+ M + @test result1 isa Matrix{Float64} + @test !(result1 isa SymArray) + @test size(result1) == (3, 3) + + # Matrix + SymArray should also return Matrix + result2 = M .+ a + @test result2 isa Matrix{Float64} + @test !(result2 isa SymArray) + + # Check values are correct + for i in 1:3, j in 1:3 + + @test result1[i, j] ≈ 2.0 + M[i, j] + @test result2[i, j] ≈ M[i, j] + 2.0 + end + + # SymArray + scalar should still return SymArray + result3 = a .+ 5.0 + @test result3 isa SymArray + @test all(result3[i, j] ≈ 7.0 for i in 1:3, j in 1:3) + + # SymArray + SymArray should return SymArray + b = SymArray{Float64}(undef, 3, 3) + fill!(b, 3.0) + result4 = a .+ b + @test result4 isa SymArray + @test all(result4[i, j] ≈ 5.0 for i in 1:3, j in 1:3) + + # Chained operations with scalars should still work + result5 = (a .+ 1) .* 2 + @test result5 isa SymArray + @test all(result5[i, j] ≈ 6.0 for i in 1:3, j in 1:3) + + a_ones = SymArray{Float64}(undef, 3, 3) + fill!(a_ones, 1.0) + result_sum_two_matrices = a_ones .+ M .+ M + @test result_sum_two_matrices isa Matrix{Float64} + @test all(result_sum_two_matrices[i, j] ≈ 1 + 2 * M[i, j] for i in 1:3, j in 1:3) + end +end