From 1eb085370f72acf98d1153b3e9d4e2916bd4941e Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Fri, 12 Mar 2021 15:29:09 -0500 Subject: [PATCH 1/8] add two-way ZDD --- ZDD/jl_zdd/node_auxillaries.jl | 4 ++ ZDD/jl_zdd/weighted.jl | 1 + ZDD/jl_zdd/weighted_node.jl | 3 +- ZDD/jl_zdd/weightless.jl | 1 + ZDD/jl_zdd/zdd_jl.ipynb | 83 ++++++++++++++++++++++++++++++---- 5 files changed, 82 insertions(+), 10 deletions(-) diff --git a/ZDD/jl_zdd/node_auxillaries.jl b/ZDD/jl_zdd/node_auxillaries.jl index 3e9ee23..a8394f8 100644 --- a/ZDD/jl_zdd/node_auxillaries.jl +++ b/ZDD/jl_zdd/node_auxillaries.jl @@ -38,6 +38,10 @@ function readable(arr::Vector{UInt8})::Array{Int64, 1} Array{Int, 1}([Int64(x) for x in arr]) end +function readable(arr::Vector{UInt32})::Array{Int64, 1} + Array{Int, 1}([Int64(x) for x in arr]) +end + function readable(cc::UInt8)::Int64 Int64(cc) end diff --git a/ZDD/jl_zdd/weighted.jl b/ZDD/jl_zdd/weighted.jl index 045a27c..ac27409 100644 --- a/ZDD/jl_zdd/weighted.jl +++ b/ZDD/jl_zdd/weighted.jl @@ -11,6 +11,7 @@ include("edge_ordering.jl") include("zdd.jl") include("count_enumerate.jl") include("visualization.jl") +include("two_way_zdd.jl") function make_new_node(g::SimpleGraph, diff --git a/ZDD/jl_zdd/weighted_node.jl b/ZDD/jl_zdd/weighted_node.jl index 35b35cf..fae4cff 100644 --- a/ZDD/jl_zdd/weighted_node.jl +++ b/ZDD/jl_zdd/weighted_node.jl @@ -43,7 +43,8 @@ function custom_deepcopy(n::Node, recycler::Stack{Node}, x::Int8)::Node return n end if isempty(recycler) - comp_weights = Vector{UInt32}(undef, length(n.comp_weights)) + # comp_weights = Vector{UInt32}(undef, length(n.comp_weights)) + comp_weights = zeros(UInt32, length(n.comp_weights)) comp_assign = zeros(UInt8, length(n.comp_assign)) fps = Vector{ForbiddenPair}(undef, length(n.fps)) diff --git a/ZDD/jl_zdd/weightless.jl b/ZDD/jl_zdd/weightless.jl index 70f7d32..7566f9f 100644 --- a/ZDD/jl_zdd/weightless.jl +++ b/ZDD/jl_zdd/weightless.jl @@ -10,6 +10,7 @@ include("edge_ordering.jl") include("zdd.jl") include("count_enumerate.jl") include("visualization.jl") +include("two_way_zdd.jl") function make_new_node(g::SimpleGraph, diff --git a/ZDD/jl_zdd/zdd_jl.ipynb b/ZDD/jl_zdd/zdd_jl.ipynb index 2fd0ca8..b9737c0 100644 --- a/ZDD/jl_zdd/zdd_jl.ipynb +++ b/ZDD/jl_zdd/zdd_jl.ipynb @@ -9,11 +9,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m environment at `~/.julia/environments/zdd/Project.toml`\n", - "┌ Info: Precompiling GraphPlot [a2cc645c-3eea-5389-862e-a155d0052231]\n", - "└ @ Base loading.jl:1278\n", - "┌ Info: Precompiling BenchmarkTools [6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf]\n", - "└ @ Base loading.jl:1278\n" + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m new environment at `~/.julia/environments/zdd/Project.toml`\n" ] } ], @@ -29,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 58, "metadata": {}, "outputs": [ { @@ -38,7 +34,7 @@ "adjust_node! (generic function with 1 method)" ] }, - "execution_count": 40, + "execution_count": 58, "metadata": {}, "output_type": "execute_result" } @@ -48,6 +44,75 @@ "include(\"weighted.jl\")" ] }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constructing a half-ZDD...\n", + " 15.644982 seconds (16.00 M allocations: 1.242 GiB, 6.91% gc time)\n", + "Constructing a half-ZDD...\n", + " 16.198733 seconds (16.00 M allocations: 1.242 GiB, 8.26% gc time)\n" + ] + } + ], + "source": [ + "m = 7\n", + "dims = [m,m]\n", + "k = m\n", + "d = 0\n", + "g = grid(dims)\n", + "g_edges = optimal_grid_edge_order_diags(g, dims[1], dims[2])\n", + "forwards = convert_lightgraphs_edges_to_node_edges(g_edges)\n", + "backwards = reverse(forwards)\n", + "\n", + "frontiers = compute_all_frontiers(g, forwards)\n", + "middle_frontier = frontiers[Int(ceil(length(frontiers)/2))]\n", + "\n", + "# @time zdd = construct_zdd(g, k, d, forwards)\n", + "@time fnodes = construct_half_zdd(g, k, d, forwards)\n", + "@time bnodes = construct_half_zdd(g, k, d, backwards); nothing\n", + "# count_paths(zdd)" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "number_compatible (generic function with 2 methods)" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "all_three = [\"weights\", \"fps\", \"cc\"]\n", + "weights = [\"weights\"] # none seem to work\n", + "fps = [\"fps\"] # all seem to work\n", + "cc = [\"cc\"] # 40/81 work\n", + "weights_fps = [\"weights\", \"fps\"]\n", + "weights_cc = [\"weights\", \"cc\"]\n", + "fps_cc = [\"fps\", \"cc\"]\n", + "@time number_compatible(fnodes, bnodes, middle_frontier, 4:4, w, k, all_three)" + ] + }, { "cell_type": "code", "execution_count": 47, @@ -133,7 +198,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Julia 1.5.3", + "display_name": "Julia 1.5.1", "language": "julia", "name": "julia-1.5" }, @@ -141,7 +206,7 @@ "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", - "version": "1.5.3" + "version": "1.5.1" } }, "nbformat": 4, From 33f884d2fd67e3737e403d3987cbf2188363a850 Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Fri, 12 Mar 2021 15:30:47 -0500 Subject: [PATCH 2/8] add two_way_zdd.jl file --- ZDD/jl_zdd/two_way_zdd.jl | 171 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ZDD/jl_zdd/two_way_zdd.jl diff --git a/ZDD/jl_zdd/two_way_zdd.jl b/ZDD/jl_zdd/two_way_zdd.jl new file mode 100644 index 0000000..401b066 --- /dev/null +++ b/ZDD/jl_zdd/two_way_zdd.jl @@ -0,0 +1,171 @@ +include("zdd.jl") + +function construct_half_zdd(g::SimpleGraph, + k::Int64, # can we change this to Int8? should we? + d::Int64, + g_edges::Array{NodeEdge,1}, + weights::Vector{Int64}=Vector{Int64}([1 for i in 1:nv(g)]), + viz::Bool=false, + save_fp::String="zdd_tree.txt")::Set{Node} + + # delete file if it already exists + if isfile(save_fp) + rm(save_fp) + end + + weights = Vector{UInt32}([convert(UInt32,i) for i in weights]) + root = Node(g_edges[1], g, weights) + + lower_bound = Int32(floor(sum(weights)/k - d)) + upper_bound = Int32(floor(sum(weights)/k + d)) + + println("Constructing a half-ZDD...") + + zdd = ZDD(g, root, viz=viz) + halfway = Int(ne(g)/2) + N = Vector{Set{Node}}([Set{Node}([]) for a in 1:halfway+1]) + N[1] = Set([root]) # why not Set(root) + frontiers = compute_all_frontiers(g, g_edges) # only need to do half... + xs = Vector{Int8}([0,1]) + zero_terminal = Node(0) + one_terminal = Node(1) + fp_container = Vector{ForbiddenPair}([]) + rm_container = Vector{ForbiddenPair}([]) + reusable_set = Set{ForbiddenPair}([]) + recycler = Stack{Node}() # what is this? + lower_vs = Vector{UInt8}([]) + + for i = 1:halfway + for n in N[i] + n_idx = zdd.nodes[n.hash] + for x in xs + n′ = make_new_node(g, g_edges, k, n, i, x, d, frontiers, + lower_bound, upper_bound, + zero_terminal, one_terminal, + fp_container, rm_container, lower_vs, recycler) + + if n′ === one_terminal + zdd.paths += n.paths + end + + if !(n′.label == NodeEdge(0,0) || n′.label == NodeEdge(1,1)) # if not a Terminal Node + n′.label = g_edges[i+1] # update the label of n′ + reusable_unique!(n′.fps, reusable_set) + sort!(n′.fps, alg=QuickSort) + n′.hash = hash(n′) + + if n′ in N[i+1] + index = Base.ht_keyindex2!(N[i+1].dict, n′) + N[i+1].dict.keys[index].paths += n.paths + else + add_zdd_node_and_edge!(zdd, n′, n, n_idx, x) + push!(N[i+1], n′) + continue + end + end + add_zdd_edge!(zdd, n, n′, n_idx, x) # the order of n and n′ are switched, but probably ok + end + end + if i == halfway + return N[i+1] + end + zdd.deleted_nodes += length(N[i]) + save_tree_so_far!(zdd, save_fp, length(N[i])) + erase_upper_levels!(zdd, N[i+1], zero_terminal, one_terminal, length(N[i])) # release memory + N[i] = Set{Node}([]) # release memory + # println(i, ": ", Base.summarysize(zdd)) + end + # return zdd +end + +function all_vertices_connected_to(v, node) + representative_vertex = node.comp_assign[v] + return Set(findall(==(representative_vertex), node.comp_assign)) +end + +function setup(fnode, bnode, frontier) + fcomp, bcomp, intcomp = Dict(), Dict(), Dict() # add type annotations + ffrontier, bfrontier, frontier_span = Dict(), Dict(), Dict() # add type annotations + isolated_frontier_vtxs = Set() + + for v ∈ frontier + fcomp[v] = all_vertices_connected_to(v, fnode) + bcomp[v] = all_vertices_connected_to(v, bnode) + intcomp[v] = intersect(fcomp[v],bcomp[v]) + + ffrontier[v] = intersect(fcomp[v], frontier) + bfrontier[v] = intersect(bcomp[v], frontier) + frontier_span[v] = union(ffrontier[v], bfrontier[v]) + + push!(isolated_frontier_vtxs, maximum(frontier_span[v])) + end + return fcomp, bcomp, intcomp, isolated_frontier_vtxs +end + +function check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) + for v ∈ frontier + if !(w(fcomp[v]) + w(bcomp[v]) - w(intcomp[v]) ∈ acceptable) + return false + end + end + return true +end + +function check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) + for v ∈ frontier + for x ∈ fcomp[v] + for y ∈ bcomp[v] + if (x,y) ∈ union(fnode.fps, bnode.fps) + return false + end + end + end + end + return true +end + +function check_cc(fnode, bnode, k, isolated_frontier_vtxs) + if fnode.cc + bnode.cc + length(isolated_frontier_vtxs) == k + return true + end + return false +end + +function is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) + fcomp, bcomp, intcomp, isolated_frontier_vtxs = setup(fnode, bnode, frontier) + weights = check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) + fps = check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) + cc = check_cc(fnode, bnode, k, isolated_frontier_vtxs) + + checklist = [] + if "weights" ∈ checking + push!(checklist, weights) + end + if "fps" ∈ checking + push!(checklist, fps) + end + if "cc" ∈ checking + push!(checklist, cc) + end + if all(checklist) + return true + end + return false +end + +function w(v) + return 1 +end + +function number_compatible(fnodes, bnodes, frontier, acceptable, w, k, checking) + println("Total # of node pairs: $(length(fnodes)) * $(length(bnodes)) = $(length(fnodes) * length(bnodes))") + num_partitions = 0 + for fnode ∈ fnodes + for bnode ∈ bnodes + if is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) + num_partitions += 1 + end + end + end + return num_partitions +end \ No newline at end of file From 62364c46ed75a0fd23d8da7c0f130e16d015b942 Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Fri, 12 Mar 2021 15:36:59 -0500 Subject: [PATCH 3/8] remove comment --- ZDD/jl_zdd/weighted_node.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ZDD/jl_zdd/weighted_node.jl b/ZDD/jl_zdd/weighted_node.jl index fae4cff..02f1547 100644 --- a/ZDD/jl_zdd/weighted_node.jl +++ b/ZDD/jl_zdd/weighted_node.jl @@ -43,7 +43,6 @@ function custom_deepcopy(n::Node, recycler::Stack{Node}, x::Int8)::Node return n end if isempty(recycler) - # comp_weights = Vector{UInt32}(undef, length(n.comp_weights)) comp_weights = zeros(UInt32, length(n.comp_weights)) comp_assign = zeros(UInt8, length(n.comp_assign)) fps = Vector{ForbiddenPair}(undef, length(n.fps)) From 9c6c62662f67bbbd174146dc85fa02edd07a57e7 Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Thu, 1 Apr 2021 14:40:56 -0700 Subject: [PATCH 4/8] add first pass of two-way zdd --- ZDD/jl_zdd/old_two_way_zdd.jl | 93 +++++++++++++++++ ZDD/jl_zdd/two_way_algo_notes.txt | 55 ++++++++++ ZDD/jl_zdd/two_way_zdd.jl | 168 ++++++++++++++++++------------ 3 files changed, 248 insertions(+), 68 deletions(-) create mode 100644 ZDD/jl_zdd/old_two_way_zdd.jl create mode 100644 ZDD/jl_zdd/two_way_algo_notes.txt diff --git a/ZDD/jl_zdd/old_two_way_zdd.jl b/ZDD/jl_zdd/old_two_way_zdd.jl new file mode 100644 index 0000000..9ea96c9 --- /dev/null +++ b/ZDD/jl_zdd/old_two_way_zdd.jl @@ -0,0 +1,93 @@ +include("zdd.jl") + +function all_vertices_connected_to(v, node) + representative_vertex = node.comp_assign[v] + return Set(findall(==(representative_vertex), node.comp_assign)) +end + +function setup(fnode, bnode, frontier) + fcomp, bcomp, intcomp = Dict(), Dict(), Dict() # add type annotations + ffrontier, bfrontier, frontier_span = Dict(), Dict(), Dict() # add type annotations + isolated_frontier_vtxs = Set() + + for v ∈ frontier + fcomp[v] = all_vertices_connected_to(v, fnode) + bcomp[v] = all_vertices_connected_to(v, bnode) + intcomp[v] = intersect(fcomp[v],bcomp[v]) + + ffrontier[v] = intersect(fcomp[v], frontier) + bfrontier[v] = intersect(bcomp[v], frontier) + frontier_span[v] = union(ffrontier[v], bfrontier[v]) + + push!(isolated_frontier_vtxs, maximum(frontier_span[v])) + end + return fcomp, bcomp, intcomp, isolated_frontier_vtxs +end + +function check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) + for v ∈ frontier + if !(w(fcomp[v]) + w(bcomp[v]) - w(intcomp[v]) ∈ acceptable) + return false + end + end + return true +end + +function check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) + for v ∈ frontier + for x ∈ fcomp[v] + for y ∈ bcomp[v] + if (x,y) ∈ union(fnode.fps, bnode.fps) + return false + end + end + end + end + return true +end + +function check_cc(fnode, bnode, k, isolated_frontier_vtxs) + if fnode.cc + bnode.cc + length(isolated_frontier_vtxs) == k + return true + end + return false +end + +function is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) + fcomp, bcomp, intcomp, isolated_frontier_vtxs = setup(fnode, bnode, frontier) + weights = check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) + fps = check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) + cc = check_cc(fnode, bnode, k, isolated_frontier_vtxs) + + checklist = [] + if "weights" ∈ checking + push!(checklist, weights) + end + if "fps" ∈ checking + push!(checklist, fps) + end + if "cc" ∈ checking + push!(checklist, cc) + end + if all(checklist) + return true + end + return false +end + +function w(v) + return 1 +end + +function number_compatible(fnodes, bnodes, frontier, acceptable, w, k, checking) + println("Total # of node pairs: $(length(fnodes)) * $(length(bnodes)) = $(length(fnodes) * length(bnodes))") + num_partitions = 0 + for fnode ∈ fnodes + for bnode ∈ bnodes + if is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) + num_partitions += 1 + end + end + end + return num_partitions +end \ No newline at end of file diff --git a/ZDD/jl_zdd/two_way_algo_notes.txt b/ZDD/jl_zdd/two_way_algo_notes.txt new file mode 100644 index 0000000..bf2eb4f --- /dev/null +++ b/ZDD/jl_zdd/two_way_algo_notes.txt @@ -0,0 +1,55 @@ +1-2-3 + +4-5 6 +| +7 8 9 + +1 2 3 + | +4 5 6 + +7-8-9 + +f: {3}, {5,7} +b: {3}, {5}, {7} +--> {3}, {5,7} + +---- +{3,5,7} + + +f: {1,2} {3} {4,5,6} +b: {1,4} {2,3}, {5,6} + +d = {1:2, 2:2, 3:3, 4:6, 5:6, 6:6} # start with the forwards dictionary +weights = {} +for each bset in b: + find: find the 1 group, find the 4 group. + union those groups into U + add weights of 1group to 4group + weight of bset - intersection (which is the sum of weights of vertices in bset) + and find the max: max(d[u] for u in U) + relabel: d[u] = max for u in U + d = {1:6, 2:6, 3:3, 4:6, 5:6, 6:6} # after one iteration + d = {1:6, 2:6, 3:6, 4:6, 5:6, 6:6} # after second iteration + +{{1,2,3,4,5,6}} + +Union-Find: +{1,2} {1,4} -> 4 +{3} -> 3 +{2,3} -> 4 which changes 3 +{4,5,6} -> 6 which changes 1-4 +{5,6} = 6 + + + +--- (answer is {1,2,3,4,5,6}) +{1,2,4,3} +{3,2} +{4,5,6,1} +--> {1,2,3,4,5,6} + + + + + diff --git a/ZDD/jl_zdd/two_way_zdd.jl b/ZDD/jl_zdd/two_way_zdd.jl index 401b066..3905bc1 100644 --- a/ZDD/jl_zdd/two_way_zdd.jl +++ b/ZDD/jl_zdd/two_way_zdd.jl @@ -77,95 +77,127 @@ function construct_half_zdd(g::SimpleGraph, end # return zdd end - -function all_vertices_connected_to(v, node) - representative_vertex = node.comp_assign[v] - return Set(findall(==(representative_vertex), node.comp_assign)) -end -function setup(fnode, bnode, frontier) - fcomp, bcomp, intcomp = Dict(), Dict(), Dict() # add type annotations - ffrontier, bfrontier, frontier_span = Dict(), Dict(), Dict() # add type annotations - isolated_frontier_vtxs = Set() - - for v ∈ frontier - fcomp[v] = all_vertices_connected_to(v, fnode) - bcomp[v] = all_vertices_connected_to(v, bnode) - intcomp[v] = intersect(fcomp[v],bcomp[v]) - - ffrontier[v] = intersect(fcomp[v], frontier) - bfrontier[v] = intersect(bcomp[v], frontier) - frontier_span[v] = union(ffrontier[v], bfrontier[v]) - - push!(isolated_frontier_vtxs, maximum(frontier_span[v])) +function frontier_sets(node, frontier) + """ + Returns the sets of vertices in the frontier that are connected to each other. + """ + sets = Set() + seen_vertices = Set() + frontier_list = reverse(collect(frontier)) # reverse probably unnecessary + for v ∈ frontier_list + if v ∈ seen_vertices + continue + end + set = Set(findall(n -> n == v, node.comp_assign)) + if length(set) > 0 + push!(sets, set) + for seen_vertex ∈ set + push!(seen_vertices, seen_vertex) + end + end end - return fcomp, bcomp, intcomp, isolated_frontier_vtxs + return sets end -function check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) - for v ∈ frontier - if !(w(fcomp[v]) + w(bcomp[v]) - w(intcomp[v]) ∈ acceptable) - return false +function merge_nodes(fnode, bnode, frontier) + """ + Use Union-Find type of thing to merge nodes.... + """ + local_fnode_comp_weights = deepcopy(fnode.comp_weights) + + ffrontier_sets = frontier_sets(fnode, frontier) + bfrontier_sets = frontier_sets(bnode, frontier) + + labels = Dict() + weights = Dict(v => -1 for v in frontier) + for set ∈ ffrontier_sets + for v ∈ set + labels[v] = maximum(set) end end - return true -end - -function check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) - for v ∈ frontier - for x ∈ fcomp[v] - for y ∈ bcomp[v] - if (x,y) ∈ union(fnode.fps, bnode.fps) - return false - end + for bset ∈ bfrontier_sets + U = Set() + w = 0 + for v ∈ bset + vgroup = findall(x -> (labels[x] == labels[v]), collect(frontier)) + for g ∈ vgroup + push!(U,collect(frontier)[g]) end + w += local_fnode_comp_weights[labels[v]] # labels[v] in frontier so has a weight + end + w_bset = maximum(bnode.comp_weights[v] for v ∈ bset) # if you just pick one, you might hit a weird 0 + w_intersection = length(bset) # TODO: generalize past unit weights + w += (w_bset - w_intersection) + U_frontier = intersect(U, frontier) # need this because `labels` only has frontier keys + max_label = maximum(labels[u] for u ∈ U_frontier) + for u ∈ U_frontier + labels[u] = max_label + weights[u] = w + local_fnode_comp_weights[labels[u]] = w end end - return true + + merged_fps = union(fnode.fps, bnode.fps) + connected_components = Set(values(labels)) + return connected_components, labels, weights, merged_fps end -function check_cc(fnode, bnode, k, isolated_frontier_vtxs) - if fnode.cc + bnode.cc + length(isolated_frontier_vtxs) == k +function check_cc(fnode, bnode, connected_components, k) + if fnode.cc + bnode.cc + length(connected_components) == k return true + else + return false end - return false end -function is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) - fcomp, bcomp, intcomp, isolated_frontier_vtxs = setup(fnode, bnode, frontier) - weights = check_weights(fcomp, bcomp, intcomp, frontier, acceptable, w) - fps = check_fps(fnode, bnode, fcomp, bcomp, intcomp, frontier) - cc = check_cc(fnode, bnode, k, isolated_frontier_vtxs) - - checklist = [] - if "weights" ∈ checking - push!(checklist, weights) - end - if "fps" ∈ checking - push!(checklist, fps) - end - if "cc" ∈ checking - push!(checklist, cc) - end - if all(checklist) - return true +function check_fps(fnode, bnode, connected_components, labels, merged_fps, frontier) + for c ∈ connected_components + idxs = findall(x -> labels[x] == c, collect(frontier)) + vtxs = [collect(frontier)[i] for i ∈ idxs] + for v₁ ∈ vtxs + for v₂ ∈ vtxs + if v₁ != v₂ + maybe_forbidden = ForbiddenPair(v₁, v₂) + if maybe_forbidden ∈ merged_fps + return false + end + else + continue + end + end + end end - return false + return true end -function w(v) - return 1 +function check_weights(fnode, bnode, weights, acceptable) + for w ∈ values(weights) + if w ∉ acceptable + return false + end + end + return true end -function number_compatible(fnodes, bnodes, frontier, acceptable, w, k, checking) - println("Total # of node pairs: $(length(fnodes)) * $(length(bnodes)) = $(length(fnodes) * length(bnodes))") - num_partitions = 0 - for fnode ∈ fnodes - for bnode ∈ bnodes - if is_compatible(fnode, bnode, frontier, acceptable, w, k, checking) - num_partitions += 1 +function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, verbose=false) + num_paths = 0 + if verbose + println("Comparing $(length(fnodes)) fnodes to $(length(bnodes)) bnodes...\n") + end + for (fi,fnode) ∈ enumerate(fnodes) + for (bi,bnode) ∈ enumerate(bnodes) + ccs, labels, weights, merged_fps = merge_nodes(fnode, bnode, middle_frontier) + cc = check_cc(fnode, bnode, ccs, k) + fps = check_fps(fnode, bnode, ccs, labels, merged_fps, middle_frontier) + ws = check_weights(fnode, bnode, weights, acceptable) + if cc & fps & ws + num_paths += fnode.paths * bnode.paths + if verbose + println("(fnode $fi, bnode $bi) contributes $(fnode.paths*bnode.paths) solns") + end end end end - return num_partitions + return num_paths end \ No newline at end of file From 6e12cd75850e1db585644e3977a509ccbd0aa46d Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Thu, 1 Apr 2021 14:42:47 -0700 Subject: [PATCH 5/8] two-way demo notebook --- ZDD/jl_zdd/zdd_jl.ipynb | 112 +++++----------------------------------- 1 file changed, 12 insertions(+), 100 deletions(-) diff --git a/ZDD/jl_zdd/zdd_jl.ipynb b/ZDD/jl_zdd/zdd_jl.ipynb index b9737c0..62a8441 100644 --- a/ZDD/jl_zdd/zdd_jl.ipynb +++ b/ZDD/jl_zdd/zdd_jl.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -34,7 +34,7 @@ "adjust_node! (generic function with 1 method)" ] }, - "execution_count": 58, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -54,17 +54,18 @@ "output_type": "stream", "text": [ "Constructing a half-ZDD...\n", - " 15.644982 seconds (16.00 M allocations: 1.242 GiB, 6.91% gc time)\n", + " 0.001409 seconds (785 allocations: 78.484 KiB)\n", "Constructing a half-ZDD...\n", - " 16.198733 seconds (16.00 M allocations: 1.242 GiB, 8.26% gc time)\n" + " 0.001356 seconds (899 allocations: 87.500 KiB)\n" ] } ], "source": [ - "m = 7\n", + "m = 3\n", "dims = [m,m]\n", "k = m\n", - "d = 0\n", + "d = 1\n", + "acceptable = m-d:m+d\n", "g = grid(dims)\n", "g_edges = optimal_grid_edge_order_diags(g, dims[1], dims[2])\n", "forwards = convert_lightgraphs_edges_to_node_edges(g_edges)\n", @@ -81,111 +82,22 @@ }, { "cell_type": "code", - "execution_count": 76, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "number_compatible (generic function with 2 methods)" - ] - }, - "execution_count": 76, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_three = [\"weights\", \"fps\", \"cc\"]\n", - "weights = [\"weights\"] # none seem to work\n", - "fps = [\"fps\"] # all seem to work\n", - "cc = [\"cc\"] # 40/81 work\n", - "weights_fps = [\"weights\", \"fps\"]\n", - "weights_cc = [\"weights\", \"cc\"]\n", - "fps_cc = [\"fps\", \"cc\"]\n", - "@time number_compatible(fnodes, bnodes, middle_frontier, 4:4, w, k, all_three)" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Accepting districts with population in [3, 3]\n", - " 0.002601 seconds (1.17 k allocations: 105.688 KiB)\n" - ] - } - ], - "source": [ - "# grid graph\n", - "m = 3\n", - "dims = [m, m]\n", - "k = m\n", - "d = 0\n", - "contiguity = \"rook\"\n", - "\n", - "if contiguity == \"queen\"\n", - " g = queen_grid(dims)\n", - " g_edges = optimal_queen_grid_edge_order(g, dims[1], dims[2])\n", - "else\n", - " g = grid(dims)\n", - " g_edges = optimal_grid_edge_order_diags(g, dims[1], dims[2])\n", - "end\n", - "\n", - "g_edges = convert_lightgraphs_edges_to_node_edges(g_edges)\n", - "@time zdd = construct_zdd(g, k, d, g_edges)\n", - "nothing" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "10" - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "count_paths(zdd)" - ] - }, - { - "cell_type": "code", - "execution_count": 49, + "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "79" + "54" ] }, - "execution_count": 49, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "num_nodes(zdd)" + "count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k)" ] }, { From 00ff371106f578d5de6ac6b7acd4d40aac7d1d6a Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Sun, 4 Apr 2021 15:52:35 -0700 Subject: [PATCH 6/8] two-way zdd works! but slow --- ZDD/jl_zdd/two_way_zdd.jl | 21 ++- ZDD/jl_zdd/zdd_jl.ipynb | 288 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 296 insertions(+), 13 deletions(-) diff --git a/ZDD/jl_zdd/two_way_zdd.jl b/ZDD/jl_zdd/two_way_zdd.jl index 3905bc1..8aeafdc 100644 --- a/ZDD/jl_zdd/two_way_zdd.jl +++ b/ZDD/jl_zdd/two_way_zdd.jl @@ -117,25 +117,38 @@ function merge_nodes(fnode, bnode, frontier) end end for bset ∈ bfrontier_sets + # println("bset is: $bset") + # println("Right now weights is: $weights") U = Set() w = 0 - for v ∈ bset + # println("w is: $w") + seen_fcomps = Set() + for (i,v) ∈ enumerate(bset) vgroup = findall(x -> (labels[x] == labels[v]), collect(frontier)) for g ∈ vgroup push!(U,collect(frontier)[g]) end - w += local_fnode_comp_weights[labels[v]] # labels[v] in frontier so has a weight + # println("inside enumerate bset, w is: $w") + if !(labels[v] in seen_fcomps) + w += local_fnode_comp_weights[labels[v]] + push!(seen_fcomps, labels[v]) + end + # println("still inside, w is: $w") end + # println("after local_fnode etc. w is: $w") w_bset = maximum(bnode.comp_weights[v] for v ∈ bset) # if you just pick one, you might hit a weird 0 w_intersection = length(bset) # TODO: generalize past unit weights + # println("w_bset is: $w_bset\nw_intersection is: $w_intersection") w += (w_bset - w_intersection) U_frontier = intersect(U, frontier) # need this because `labels` only has frontier keys + # println("U_frontier is $U_frontier") max_label = maximum(labels[u] for u ∈ U_frontier) for u ∈ U_frontier labels[u] = max_label weights[u] = w local_fnode_comp_weights[labels[u]] = w end + # println("End of loop block weights is: $weights") end merged_fps = union(fnode.fps, bnode.fps) @@ -172,6 +185,7 @@ function check_fps(fnode, bnode, connected_components, labels, merged_fps, front end function check_weights(fnode, bnode, weights, acceptable) + # println(weights) for w ∈ values(weights) if w ∉ acceptable return false @@ -181,6 +195,7 @@ function check_weights(fnode, bnode, weights, acceptable) end function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, verbose=false) + flabels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] num_paths = 0 if verbose println("Comparing $(length(fnodes)) fnodes to $(length(bnodes)) bnodes...\n") @@ -194,7 +209,7 @@ function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k if cc & fps & ws num_paths += fnode.paths * bnode.paths if verbose - println("(fnode $fi, bnode $bi) contributes $(fnode.paths*bnode.paths) solns") + println("($(flabels[fi]),$bi) contributes $(fnode.paths*bnode.paths) solns") end end end diff --git a/ZDD/jl_zdd/zdd_jl.ipynb b/ZDD/jl_zdd/zdd_jl.ipynb index 62a8441..7475132 100644 --- a/ZDD/jl_zdd/zdd_jl.ipynb +++ b/ZDD/jl_zdd/zdd_jl.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 206, "metadata": {}, "outputs": [ { @@ -34,7 +34,7 @@ "adjust_node! (generic function with 1 method)" ] }, - "execution_count": 2, + "execution_count": 206, "metadata": {}, "output_type": "execute_result" } @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 219, "metadata": {}, "outputs": [ { @@ -54,17 +54,17 @@ "output_type": "stream", "text": [ "Constructing a half-ZDD...\n", - " 0.001409 seconds (785 allocations: 78.484 KiB)\n", + " 0.335368 seconds (540.15 k allocations: 44.193 MiB, 2.02% gc time)\n", "Constructing a half-ZDD...\n", - " 0.001356 seconds (899 allocations: 87.500 KiB)\n" + " 0.321185 seconds (540.21 k allocations: 44.195 MiB, 6.78% gc time)\n" ] } ], "source": [ - "m = 3\n", + "m = 6\n", "dims = [m,m]\n", "k = m\n", - "d = 1\n", + "d = 0\n", "acceptable = m-d:m+d\n", "g = grid(dims)\n", "g_edges = optimal_grid_edge_order_diags(g, dims[1], dims[2])\n", @@ -82,22 +82,290 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@time count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, false)" + ] + }, + { + "cell_type": "code", + "execution_count": 202, + "metadata": {}, + "outputs": [], + "source": [ + "_D = collect(fnodes)[4]\n", + "_10 = collect(bnodes)[10]\n", + "\n", + "_B = collect(fnodes)[2]\n", + "_12 = collect(bnodes)[12]\n", + "\n", + "fnode = _B\n", + "bnode = _12; nothing" + ] + }, + { + "cell_type": "code", + "execution_count": 203, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label: NodeEdge(3 -> 6)\n", + "cc: 0\n", + "fps: [\"ForbiddenPair(3 -> 7)\"]\n", + "comp_assign: [0, 0, 3, 0, 7, 6, 7, 8, 9]\n", + "comp_weights: [0, 0, 3, 0, 0, 1, 3, 1, 1]\n", + "\n" + ] + } + ], + "source": [ + "node_summary(fnode)" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label: NodeEdge(4 -> 7)\n", + "cc: 1\n", + "fps: String[]\n", + "comp_assign: [1, 2, 3, 4, 7, 0, 7, 0, 0]\n", + "comp_weights: [1, 1, 1, 1, 0, 0, 3, 0, 0]\n", + "\n" + ] + } + ], + "source": [ + "node_summary(bnode)" + ] + }, + { + "cell_type": "code", + "execution_count": 205, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bset is: Set([3])\n", + "Right now weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => -1,0x05 => -1)\n", + "inside enumerate bset, w is: 0\n", + "still inside, w is: 3\n", + "after local_fnode etc. w is: 3\n", + "w_bset is: 1\n", + "w_intersection is: 1\n", + "U_frontier is Set(Any[0x03])\n", + "End of loop block weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => 3,0x05 => -1)\n", + "bset is: Set([7, 5])\n", + "Right now weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => 3,0x05 => -1)\n", + "inside enumerate bset, w is: 0\n", + "still inside, w is: 3\n", + "inside enumerate bset, w is: 3\n", + "still inside, w is: 3\n", + "after local_fnode etc. w is: 3\n", + "w_bset is: 3\n", + "w_intersection is: 2\n", + "U_frontier is Set(Any[0x07, 0x05])\n", + "End of loop block weights is: Dict{UInt8,Int64}(0x07 => 4,0x03 => 3,0x05 => 4)\n", + "Dict{UInt8,Int64}(0x07 => 4,0x03 => 3,0x05 => 4)\n" + ] + }, + { + "data": { + "text/plain": [ + "true" + ] + }, + "execution_count": 205, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ccs, labels, weights, merged_fps = merge_nodes(fnode, bnode, middle_frontier)\n", + "cc = check_cc(fnode, bnode, ccs, k) \n", + "fps = check_fps(fnode, bnode, ccs, labels, merged_fps, middle_frontier)\n", + "ws = check_weights(fnode, bnode, weights, acceptable)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Set{Any} with 1 element:\n", + " 7" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ccs" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "true" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fps" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "false" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ws" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n", + "3\n", + "4\n" + ] + } + ], + "source": [ + "for w in acceptable\n", + " println(w)\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Comparing 20 fnodes to 20 bnodes...\n", + "\n", + "(A,1) contributes 1 solns\n", + "(A,6) contributes 1 solns\n", + "(A,14) contributes 1 solns\n", + "(A,20) contributes 1 solns\n", + "(B,4) contributes 1 solns\n", + "(B,11) contributes 1 solns\n", + "(B,17) contributes 1 solns\n", + "(C,12) contributes 1 solns\n", + "(C,14) contributes 1 solns\n", + "(C,20) contributes 1 solns\n", + "(D,8) contributes 1 solns\n", + "(E,5) contributes 1 solns\n", + "(E,8) contributes 1 solns\n", + "(E,15) contributes 1 solns\n", + "(E,18) contributes 1 solns\n", + "(F,4) contributes 1 solns\n", + "(F,11) contributes 1 solns\n", + "(G,5) contributes 1 solns\n", + "(G,6) contributes 1 solns\n", + "(G,8) contributes 1 solns\n", + "(G,15) contributes 1 solns\n", + "(H,10) contributes 1 solns\n", + "(H,15) contributes 1 solns\n", + "(H,18) contributes 1 solns\n", + "(I,7) contributes 1 solns\n", + "(I,14) contributes 1 solns\n", + "(I,18) contributes 1 solns\n", + "(I,20) contributes 1 solns\n", + "(J,3) contributes 1 solns\n", + "(J,4) contributes 1 solns\n", + "(K,13) contributes 1 solns\n", + "(K,19) contributes 1 solns\n", + "(L,13) contributes 1 solns\n", + "(M,5) contributes 1 solns\n", + "(M,6) contributes 1 solns\n", + "(N,7) contributes 1 solns\n", + "(N,14) contributes 1 solns\n", + "(N,15) contributes 1 solns\n", + "(N,18) contributes 1 solns\n", + "(O,9) contributes 1 solns\n", + "(O,13) contributes 1 solns\n", + "(O,16) contributes 1 solns\n", + "(O,19) contributes 1 solns\n", + "(P,1) contributes 1 solns\n", + "(P,5) contributes 1 solns\n", + "(P,6) contributes 1 solns\n", + "(P,20) contributes 1 solns\n", + "(Q,1) contributes 1 solns\n", + "(R,9) contributes 1 solns\n", + "(R,16) contributes 1 solns\n", + "(S,2) contributes 1 solns\n", + "(S,3) contributes 1 solns\n", + "(S,4) contributes 1 solns\n", + "(T,16) contributes 1 solns\n" + ] + }, { "data": { "text/plain": [ "54" ] }, - "execution_count": 15, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k)" + "count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, true)" ] }, { From 86aa9fdc88da7bc607c58feacf6b0858b1d86763 Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Mon, 5 Apr 2021 08:18:39 -0700 Subject: [PATCH 7/8] annotate two-way functions --- ZDD/jl_zdd/two_way_zdd.jl | 100 ++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/ZDD/jl_zdd/two_way_zdd.jl b/ZDD/jl_zdd/two_way_zdd.jl index 8aeafdc..7cd03b9 100644 --- a/ZDD/jl_zdd/two_way_zdd.jl +++ b/ZDD/jl_zdd/two_way_zdd.jl @@ -80,75 +80,91 @@ end function frontier_sets(node, frontier) """ - Returns the sets of vertices in the frontier that are connected to each other. + Returns Set(C₁, C₂, ..., Cₙ) where every v ∈ Cᵢ is in the same connected component, + and Cᵢ ⊆ F where F is the set of all vertices in the frontier. """ - sets = Set() + frontier_sets = Set() seen_vertices = Set() - frontier_list = reverse(collect(frontier)) # reverse probably unnecessary - for v ∈ frontier_list + for v ∈ collect(frontier) if v ∈ seen_vertices continue end - set = Set(findall(n -> n == v, node.comp_assign)) - if length(set) > 0 - push!(sets, set) - for seen_vertex ∈ set + Cᵥ = Set(findall(n -> n == v, node.comp_assign)) + if length(Cᵥ) > 0 + push!(frontier_sets, Cᵥ) + for seen_vertex ∈ Cᵥ push!(seen_vertices, seen_vertex) end end end - return sets + return frontier_sets end function merge_nodes(fnode, bnode, frontier) """ - Use Union-Find type of thing to merge nodes.... + Merges two nodes and returns the following: + - labels :: Dict({v:l}) where v ∈ F and l ∈ F is v's connected component number + - weights :: Dict({v:n}) where v ∈ F and n ∈ N is the weight of v's component + - connected_components :: Set(C₁, C₂, ..., Cₙ) frontier sets of the merged node + - merged_fps :: Array(ForbiddenPair) union of the FPS from each node + + Outline of algorithm: + 0. Make a copy of fnode.comp_weights since we'll be modifying it later + 1. Initialize `weights` dictionary to -1 for all v ∈ F + 2. Initialize `labels` dictionary from fnode's perspective + 3. For each connected component `bcomp` in bnode: + # We want to figure out the merged component `M`, which is all the vertices in F + # that are connected by the merge of bcomp and fnode. + a. Initialize `M` = Set() and `w` = 0 (the weight of all vertices in `M`) + b. For each vertex v ∈ `bcomp`: + i. Find the set `fvcomp`: {x ∈ F | x touches v from fnode's perspective}, + and add each x ∈ `fvcomp` to M + ii. Once per fcomp, add the weight of fcomp to `w` + c. The weight added to `w` by `bcomp` is the weight of bcomp minus the weight + of the intersection of bcomp and any components it touches in fnode. + d. Now that `M` and `w` are finalized, we store the information by: + i. Relabel `labels` such that every v ∈ M has is connected + ii. In `weights`, assign `w` to every v ∈ M + iii. Update our copy of fnode.comp_weights to `w` for future bcomps + 4. Make `merged_fps` — just the union of fnode.fps and bnode.fps + 5. Make `connected_components` — just the Set(values(labels)) """ local_fnode_comp_weights = deepcopy(fnode.comp_weights) - ffrontier_sets = frontier_sets(fnode, frontier) - bfrontier_sets = frontier_sets(bnode, frontier) - + ### Initialize labels, weights ### labels = Dict() - weights = Dict(v => -1 for v in frontier) - for set ∈ ffrontier_sets + weights = Dict(v => -1 for v ∈ frontier) + for set ∈ frontier_sets(fnode, frontier) for v ∈ set labels[v] = maximum(set) end end - for bset ∈ bfrontier_sets - # println("bset is: $bset") - # println("Right now weights is: $weights") - U = Set() + + for bcomp ∈ frontier_sets(bnode, frontier) + M = Set() w = 0 - # println("w is: $w") seen_fcomps = Set() - for (i,v) ∈ enumerate(bset) - vgroup = findall(x -> (labels[x] == labels[v]), collect(frontier)) - for g ∈ vgroup - push!(U,collect(frontier)[g]) + for v ∈ bcomp + fvcomp = findall(x -> (labels[x] == labels[v]), collect(frontier)) + for x ∈ fvcomp + push!(M,collect(frontier)[x]) end - # println("inside enumerate bset, w is: $w") - if !(labels[v] in seen_fcomps) + if !(labels[v] ∈ seen_fcomps) w += local_fnode_comp_weights[labels[v]] push!(seen_fcomps, labels[v]) end - # println("still inside, w is: $w") end - # println("after local_fnode etc. w is: $w") - w_bset = maximum(bnode.comp_weights[v] for v ∈ bset) # if you just pick one, you might hit a weird 0 - w_intersection = length(bset) # TODO: generalize past unit weights - # println("w_bset is: $w_bset\nw_intersection is: $w_intersection") - w += (w_bset - w_intersection) - U_frontier = intersect(U, frontier) # need this because `labels` only has frontier keys - # println("U_frontier is $U_frontier") - max_label = maximum(labels[u] for u ∈ U_frontier) - for u ∈ U_frontier - labels[u] = max_label - weights[u] = w - local_fnode_comp_weights[labels[u]] = w + + w_bcomp = maximum(bnode.comp_weights[v] for v ∈ bcomp) # if you just pick one, you might hit a weird 0 + w_intersection = length(bcomp) # TODO: generalize past unit weights + w += (w_bcomp - w_intersection) + + max_label = maximum(labels[v] for v ∈ M) + for v ∈ M + labels[v] = max_label + weights[v] = w + local_fnode_comp_weights[labels[v]] = w end - # println("End of loop block weights is: $weights") end merged_fps = union(fnode.fps, bnode.fps) @@ -195,7 +211,7 @@ function check_weights(fnode, bnode, weights, acceptable) end function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, verbose=false) - flabels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] + # flabels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] num_paths = 0 if verbose println("Comparing $(length(fnodes)) fnodes to $(length(bnodes)) bnodes...\n") @@ -209,7 +225,7 @@ function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k if cc & fps & ws num_paths += fnode.paths * bnode.paths if verbose - println("($(flabels[fi]),$bi) contributes $(fnode.paths*bnode.paths) solns") + println("($fi,$bi) contributes $(fnode.paths*bnode.paths) solns") end end end From 2e8d92902ba8ffc552ab9c3087df2b8847edc1f0 Mon Sep 17 00:00:00 2001 From: Gabe Schoenbach Date: Wed, 5 May 2021 14:21:25 -0400 Subject: [PATCH 8/8] add pre-computation speedups to two-way --- ZDD/jl_zdd/two_way_zdd.jl | 190 ++++++++++++++++++++-------- ZDD/jl_zdd/zdd_jl.ipynb | 259 +++++--------------------------------- 2 files changed, 173 insertions(+), 276 deletions(-) diff --git a/ZDD/jl_zdd/two_way_zdd.jl b/ZDD/jl_zdd/two_way_zdd.jl index 7cd03b9..86bbedf 100644 --- a/ZDD/jl_zdd/two_way_zdd.jl +++ b/ZDD/jl_zdd/two_way_zdd.jl @@ -100,8 +100,96 @@ function frontier_sets(node, frontier) return frontier_sets end -function merge_nodes(fnode, bnode, frontier) +function compute_frontier_sets(fnodes, bnodes, frontier) + frontier_sets_dictionary = Dict() + for fnode ∈ fnodes + frontier_sets_dictionary[fnode] = frontier_sets(fnode, frontier) + end + for bnode ∈ bnodes + frontier_sets_dictionary[bnode] = frontier_sets(bnode, frontier) + end + return frontier_sets_dictionary +end + +function generate_dictionaries(fnodes, bnodes, frontier) + """ Pre-compute necessary lookup tables """ + allnodes = union(fnodes, bnodes) + frontier_projection_dict = Dict(node => frontier_sets(node, frontier) for node ∈ allnodes) + fps_dict = Dict(node => node.fps for node ∈ allnodes) + ffps_dict = Dict(fnode => fnode.fps for fnode ∈ fnodes) + + merged_fps_dict = Dict() + for ffps ∈ values(ffps_dict) + for bfps ∈ values(ffps_dict) + merged_fps_dict[(ffps, bfps)] = union(ffps, bfps) + end + end + + merged_frontier_projection_dict = Dict() + fsets = Set([frontier_projection_dict[fnode] for fnode ∈ fnodes]) + for fset ∈ fsets + for bset ∈ fsets + merged_frontier_projection_dict[(fset, bset)] = union_find(fset, bset, frontier) + end + end + + intersection_dict = Dict() + for connected_components ∈ Set(values(merged_frontier_projection_dict)) + for comp ∈ connected_components + for sets_list ∈ Set(values(frontier_projection_dict)) + for set ∈ sets_list + intersection_dict[(set, comp)] = intersect(set, comp) + end + end + end + end + return frontier_projection_dict, merged_frontier_projection_dict, fps_dict, merged_fps_dict, intersection_dict +end + +function union_find(fsets, bsets, frontier) + """ Given the projections for each node, compute the merged projection """ + labels = Dict() + for fset ∈ fsets + for v ∈ fset + labels[v] = maximum(fset) + end + end + for bset ∈ bsets + M = Set() + for v ∈ bset + fvcomp = findall(x -> (labels[x] == labels[v]), collect(frontier)) + for x ∈ fvcomp + push!(M, collect(frontier)[x]) + end + end + max_label = maximum(labels[v] for v ∈ M) + for v ∈ M + labels[v] = max_label + end + end + connected_components = Set() + for v ∈ frontier + push!(connected_components, Set(findall(x -> (labels[x] == labels[v]), labels))) + end + return connected_components +end + + +function initialize(fnode, frontier, frontier_sets_dictionary) + """ DEPRECATED """ + labels = Dict() + weights = Dict(v => -1 for v ∈ frontier) + for set ∈ frontier_sets_dictionary[fnode] + for v ∈ set + labels[v] = maximum(set) + end + end + return labels, weights +end + +function merge_nodes(fnode, bnode, frontier, frontier_sets_dictionary) #, labels, weights, local_fnode_comp_weights) """ + DEPRECATED Merges two nodes and returns the following: - labels :: Dict({v:l}) where v ∈ F and l ∈ F is v's connected component number - weights :: Dict({v:n}) where v ∈ F and n ∈ N is the weight of v's component @@ -129,18 +217,11 @@ function merge_nodes(fnode, bnode, frontier) 4. Make `merged_fps` — just the union of fnode.fps and bnode.fps 5. Make `connected_components` — just the Set(values(labels)) """ - local_fnode_comp_weights = deepcopy(fnode.comp_weights) - - ### Initialize labels, weights ### - labels = Dict() - weights = Dict(v => -1 for v ∈ frontier) - for set ∈ frontier_sets(fnode, frontier) - for v ∈ set - labels[v] = maximum(set) - end - end - for bcomp ∈ frontier_sets(bnode, frontier) + labels, weights = initialize(fnode, frontier, frontier_sets_dictionary) + + local_fnode_comp_weights = deepcopy(fnode.comp_weights) + for bcomp ∈ frontier_sets_dictionary[bnode] M = Set() w = 0 seen_fcomps = Set() @@ -166,10 +247,8 @@ function merge_nodes(fnode, bnode, frontier) local_fnode_comp_weights[labels[v]] = w end end - - merged_fps = union(fnode.fps, bnode.fps) - connected_components = Set(values(labels)) - return connected_components, labels, weights, merged_fps + + return labels, weights end function check_cc(fnode, bnode, connected_components, k) @@ -180,53 +259,64 @@ function check_cc(fnode, bnode, connected_components, k) end end -function check_fps(fnode, bnode, connected_components, labels, merged_fps, frontier) - for c ∈ connected_components - idxs = findall(x -> labels[x] == c, collect(frontier)) - vtxs = [collect(frontier)[i] for i ∈ idxs] - for v₁ ∈ vtxs - for v₂ ∈ vtxs - if v₁ != v₂ - maybe_forbidden = ForbiddenPair(v₁, v₂) - if maybe_forbidden ∈ merged_fps - return false - end - else - continue +function check_weights(fnode, bnode, frontier_projection_dict, intersection_dict, connected_components, acceptable) + for comp ∈ connected_components + w = 0 + if length(comp) > ceil(maximum(acceptable)/2) # needs rook contiguity and unit pop. + return false + end + for node ∈ [fnode, bnode] + for set ∈ frontier_projection_dict[node] + if length(intersection_dict[(set, comp)]) == length(set) # there should be no partial intersections + wset = maximum(node.comp_weights[v] for v ∈ set) + w += wset - length(set) # assumes unit pop end end end + w += length(comp) + if w ∉ acceptable + return false + end end return true end -function check_weights(fnode, bnode, weights, acceptable) - # println(weights) - for w ∈ values(weights) - if w ∉ acceptable - return false +function check_fps(fnode, bnode, fps_dict, merged_fps_dict, connected_components) + merged_fps = merged_fps_dict[(fps_dict[fnode], fps_dict[bnode])] + for comp ∈ connected_components + for v₁ ∈ comp + for v₂ ∈ comp + if v₁ != v₂ + if ForbiddenPair(v₁, v₂) ∈ merged_fps + return false + end + end + end end end return true end -function count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, verbose=false) - # flabels = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] - num_paths = 0 - if verbose - println("Comparing $(length(fnodes)) fnodes to $(length(bnodes)) bnodes...\n") - end - for (fi,fnode) ∈ enumerate(fnodes) - for (bi,bnode) ∈ enumerate(bnodes) - ccs, labels, weights, merged_fps = merge_nodes(fnode, bnode, middle_frontier) - cc = check_cc(fnode, bnode, ccs, k) - fps = check_fps(fnode, bnode, ccs, labels, merged_fps, middle_frontier) - ws = check_weights(fnode, bnode, weights, acceptable) - if cc & fps & ws +function count_paths_from_halfway(fnodes, bnodes, + frontier_projection_dict, + merged_frontier_projection_dict, + fps_dict, merged_fps_dict, + intersection_dict, + acceptable, k, verbose=false) + num_paths = 0 + for (fi, fnode) ∈ enumerate(fnodes) + fset = frontier_projection_dict[fnode] + for (bi, bnode) ∈ enumerate(bnodes) + bset = frontier_projection_dict[bnode] + connected_components = merged_frontier_projection_dict[(fset, bset)] + cc = check_cc(fnode, bnode, connected_components, k) + w = check_weights(fnode, bnode, frontier_projection_dict, intersection_dict, connected_components, acceptable) + fps = check_fps(fnode, bnode, fps_dict, merged_fps_dict, connected_components) + if cc & w & fps num_paths += fnode.paths * bnode.paths - if verbose - println("($fi,$bi) contributes $(fnode.paths*bnode.paths) solns") - end + end + if verbose + println("($fi,$bi) contributes $(fnode.paths*bnode.paths) solns") end end end diff --git a/ZDD/jl_zdd/zdd_jl.ipynb b/ZDD/jl_zdd/zdd_jl.ipynb index 7475132..77788dd 100644 --- a/ZDD/jl_zdd/zdd_jl.ipynb +++ b/ZDD/jl_zdd/zdd_jl.ipynb @@ -9,7 +9,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m new environment at `~/.julia/environments/zdd/Project.toml`\n" + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m environment at `~/.julia/environments/zdd/Project.toml`\n" ] } ], @@ -20,12 +20,13 @@ "using Compose\n", "using Random\n", "using Traceur\n", - "using BenchmarkTools" + "using BenchmarkTools\n", + "using StatProfilerHTML" ] }, { "cell_type": "code", - "execution_count": 206, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -34,7 +35,7 @@ "adjust_node! (generic function with 1 method)" ] }, - "execution_count": 206, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -46,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 219, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -54,14 +55,14 @@ "output_type": "stream", "text": [ "Constructing a half-ZDD...\n", - " 0.335368 seconds (540.15 k allocations: 44.193 MiB, 2.02% gc time)\n", + " 0.017077 seconds (24.74 k allocations: 2.123 MiB)\n", "Constructing a half-ZDD...\n", - " 0.321185 seconds (540.21 k allocations: 44.195 MiB, 6.78% gc time)\n" + " 0.014161 seconds (24.76 k allocations: 2.159 MiB)\n" ] } ], "source": [ - "m = 6\n", + "m = 5\n", "dims = [m,m]\n", "k = m\n", "d = 0\n", @@ -75,297 +76,103 @@ "middle_frontier = frontiers[Int(ceil(length(frontiers)/2))]\n", "\n", "# @time zdd = construct_zdd(g, k, d, forwards)\n", - "@time fnodes = construct_half_zdd(g, k, d, forwards)\n", + "@time fnodes = construct_half_zdd(g, k, d, forwards); nothing\n", "@time bnodes = construct_half_zdd(g, k, d, backwards); nothing\n", "# count_paths(zdd)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "@time count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, false)" - ] - }, - { - "cell_type": "code", - "execution_count": 202, - "metadata": {}, - "outputs": [], - "source": [ - "_D = collect(fnodes)[4]\n", - "_10 = collect(bnodes)[10]\n", - "\n", - "_B = collect(fnodes)[2]\n", - "_12 = collect(bnodes)[12]\n", - "\n", - "fnode = _B\n", - "bnode = _12; nothing" - ] - }, - { - "cell_type": "code", - "execution_count": 203, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Label: NodeEdge(3 -> 6)\n", - "cc: 0\n", - "fps: [\"ForbiddenPair(3 -> 7)\"]\n", - "comp_assign: [0, 0, 3, 0, 7, 6, 7, 8, 9]\n", - "comp_weights: [0, 0, 3, 0, 0, 1, 3, 1, 1]\n", - "\n" + " 1.507757 seconds (16.31 M allocations: 790.128 MiB, 9.94% gc time)\n" ] } ], "source": [ - "node_summary(fnode)" + "@time frontier_projection_dict, merged_frontier_projection_dict, fps_dict, merged_fps_dict, intersection_dict = generate_dictionaries(fnodes, bnodes, middle_frontier); nothing" ] }, { "cell_type": "code", - "execution_count": 204, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Label: NodeEdge(4 -> 7)\n", - "cc: 1\n", - "fps: String[]\n", - "comp_assign: [1, 2, 3, 4, 7, 0, 7, 0, 0]\n", - "comp_weights: [1, 1, 1, 1, 0, 0, 3, 0, 0]\n", - "\n" - ] - } - ], - "source": [ - "node_summary(bnode)" - ] - }, - { - "cell_type": "code", - "execution_count": 205, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "bset is: Set([3])\n", - "Right now weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => -1,0x05 => -1)\n", - "inside enumerate bset, w is: 0\n", - "still inside, w is: 3\n", - "after local_fnode etc. w is: 3\n", - "w_bset is: 1\n", - "w_intersection is: 1\n", - "U_frontier is Set(Any[0x03])\n", - "End of loop block weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => 3,0x05 => -1)\n", - "bset is: Set([7, 5])\n", - "Right now weights is: Dict{UInt8,Int64}(0x07 => -1,0x03 => 3,0x05 => -1)\n", - "inside enumerate bset, w is: 0\n", - "still inside, w is: 3\n", - "inside enumerate bset, w is: 3\n", - "still inside, w is: 3\n", - "after local_fnode etc. w is: 3\n", - "w_bset is: 3\n", - "w_intersection is: 2\n", - "U_frontier is Set(Any[0x07, 0x05])\n", - "End of loop block weights is: Dict{UInt8,Int64}(0x07 => 4,0x03 => 3,0x05 => 4)\n", - "Dict{UInt8,Int64}(0x07 => 4,0x03 => 3,0x05 => 4)\n" + " 10.856397 seconds (110.11 M allocations: 2.480 GiB, 3.63% gc time)\n" ] }, { "data": { "text/plain": [ - "true" + "4006" ] }, - "execution_count": 205, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ccs, labels, weights, merged_fps = merge_nodes(fnode, bnode, middle_frontier)\n", - "cc = check_cc(fnode, bnode, ccs, k) \n", - "fps = check_fps(fnode, bnode, ccs, labels, merged_fps, middle_frontier)\n", - "ws = check_weights(fnode, bnode, weights, acceptable)" + "@time count_paths_from_halfway(fnodes, bnodes, \n", + " frontier_projection_dict, \n", + " merged_frontier_projection_dict,\n", + " fps_dict, merged_fps_dict,\n", + " intersection_dict,\n", + " acceptable, k, false)" ] }, { "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Set{Any} with 1 element:\n", - " 7" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ccs" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "true" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fps" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "false" - ] - }, - "execution_count": 51, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ws" - ] - }, - { - "cell_type": "code", - "execution_count": 54, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "2\n", - "3\n", - "4\n" + "Accepting districts with population in [5, 5]\n", + " 0.068781 seconds (131.75 k allocations: 9.191 MiB)\n" ] } ], "source": [ - "for w in acceptable\n", - " println(w)\n", - "end" + "@time zdd = construct_zdd(g, k, d, forwards); nothing" ] }, { "cell_type": "code", - "execution_count": 43, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 19, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Comparing 20 fnodes to 20 bnodes...\n", - "\n", - "(A,1) contributes 1 solns\n", - "(A,6) contributes 1 solns\n", - "(A,14) contributes 1 solns\n", - "(A,20) contributes 1 solns\n", - "(B,4) contributes 1 solns\n", - "(B,11) contributes 1 solns\n", - "(B,17) contributes 1 solns\n", - "(C,12) contributes 1 solns\n", - "(C,14) contributes 1 solns\n", - "(C,20) contributes 1 solns\n", - "(D,8) contributes 1 solns\n", - "(E,5) contributes 1 solns\n", - "(E,8) contributes 1 solns\n", - "(E,15) contributes 1 solns\n", - "(E,18) contributes 1 solns\n", - "(F,4) contributes 1 solns\n", - "(F,11) contributes 1 solns\n", - "(G,5) contributes 1 solns\n", - "(G,6) contributes 1 solns\n", - "(G,8) contributes 1 solns\n", - "(G,15) contributes 1 solns\n", - "(H,10) contributes 1 solns\n", - "(H,15) contributes 1 solns\n", - "(H,18) contributes 1 solns\n", - "(I,7) contributes 1 solns\n", - "(I,14) contributes 1 solns\n", - "(I,18) contributes 1 solns\n", - "(I,20) contributes 1 solns\n", - "(J,3) contributes 1 solns\n", - "(J,4) contributes 1 solns\n", - "(K,13) contributes 1 solns\n", - "(K,19) contributes 1 solns\n", - "(L,13) contributes 1 solns\n", - "(M,5) contributes 1 solns\n", - "(M,6) contributes 1 solns\n", - "(N,7) contributes 1 solns\n", - "(N,14) contributes 1 solns\n", - "(N,15) contributes 1 solns\n", - "(N,18) contributes 1 solns\n", - "(O,9) contributes 1 solns\n", - "(O,13) contributes 1 solns\n", - "(O,16) contributes 1 solns\n", - "(O,19) contributes 1 solns\n", - "(P,1) contributes 1 solns\n", - "(P,5) contributes 1 solns\n", - "(P,6) contributes 1 solns\n", - "(P,20) contributes 1 solns\n", - "(Q,1) contributes 1 solns\n", - "(R,9) contributes 1 solns\n", - "(R,16) contributes 1 solns\n", - "(S,2) contributes 1 solns\n", - "(S,3) contributes 1 solns\n", - "(S,4) contributes 1 solns\n", - "(T,16) contributes 1 solns\n" + " 0.004653 seconds (1.34 k allocations: 83.282 KiB)\n" ] }, { "data": { "text/plain": [ - "54" + "4006" ] }, - "execution_count": 43, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "count_paths_from_halfway(fnodes, bnodes, middle_frontier, acceptable, k, true)" + "@time count_paths(zdd)" ] }, {