Skip to content

Commit 7175814

Browse files
atomassiAndrea Tomassilli
authored andcommitted
updated tests
1 parent a6535cf commit 7175814

2 files changed

Lines changed: 39 additions & 58 deletions

File tree

networkx/algorithms/approximation/kcutsets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def minimum_k_cut(G, k, weight=None):
196196
if not 1 <= k <= len(G):
197197
raise nx.NetworkXError(f"k should be within 1 and {len(G)}")
198198

199-
# extract edges weight, and set edges weights with no attribute to 1
199+
# extract edges weights, and set edges weights with no attribute to 1
200200
edges_weights = G.edges(data=weight, default=1)
201201
# create a new Graph G2
202202
G2 = nx.Graph()

networkx/algorithms/approximation/tests/test_kcutsets.py

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Unit tests for the :mod:`networkx.algorithms.approximation.kcutsets` module."""
22

3+
import itertools
34
import pytest
45
import networkx as nx
56
from networkx.algorithms.approximation import minimum_multiway_cut, minimum_k_cut
@@ -72,15 +73,15 @@ def test_complete_graph(self):
7273
nx.tutte_graph,
7374
],
7475
)
75-
def test_compare_min_cut(self, graph_class):
76+
@pytest.mark.parametrize("s,t", itertools.combinations(range(5), 2))
77+
def test_compare_min_cut(self, graph_class, s, t):
7678
"""Compare minimum_cut_value and minimum_multiway_cut considering 2 nodes.
7779
7880
For two nodes the minimum_cut_value(G, s, t) should be equivalent to
7981
minimum_multiway_cut(G, {s,t}).
8082
"""
8183
G = graph_class()
8284
nx.set_edge_attributes(G, values=10, name="weight")
83-
s, t = min(G), max(G)
8485
cut_value, cutset = minimum_multiway_cut(G, {s, t}, weight="weight")
8586
assert cut_value == minimum_cut_value(G, s, t, capacity="weight")
8687

@@ -113,64 +114,23 @@ def test_invalid_k(self):
113114
with pytest.raises(nx.NetworkXError, match="k should be within 1 and 10"):
114115
minimum_k_cut(G, 11)
115116

116-
def test_path_graph_unweighted(self):
117-
"""Test min k-cut for a path graph."""
118-
G = nx.path_graph(2)
119-
cut_value, cutset = minimum_k_cut(G, 2)
120-
assert cut_value == 1
121-
G.remove_edges_from(cutset)
122-
assert len(list(nx.connected_components(G))) == 2
123-
124-
def test_path_graph_weighted_k2(self):
125-
"""Test min k-cut for a path graph with weights."""
126-
G = nx.Graph()
127-
G.add_weighted_edges_from(
128-
[(0, 1, 10), (1, 2, 10), (2, 3, 5)], weight="capacity"
129-
)
130-
cut_value, cutset = minimum_k_cut(G, 2, weight="capacity")
131-
assert cut_value == 5
117+
@pytest.mark.parametrize("k,expected", [(1, 0), (2, 1), (3, 2), (4, 3), (5, 4)])
118+
def test_path_graph(self, k, expected):
119+
"""Test various k for a path graph of 5 nodes."""
120+
G = nx.path_graph(n=5)
121+
cut_value, cutset = minimum_k_cut(G, k, weight="capacity")
122+
assert cut_value == expected
132123
G.remove_edges_from(cutset)
133-
assert len(list(nx.connected_components(G))) == 2
124+
assert len(list(nx.connected_components(G))) == k
134125

135-
def test_path_graph_weighted_k3(self):
136-
"""Test min k-cut for a path graph with weights."""
137-
G = nx.Graph()
138-
G.add_weighted_edges_from(
139-
[(0, 1, 10), (1, 2, 10), (2, 3, 5)], weight="capacity"
140-
)
141-
cut_value, cutset = minimum_k_cut(G, 3, weight="capacity")
142-
assert cut_value == 15
143-
G.remove_edges_from(cutset)
144-
assert len(list(nx.connected_components(G))) == 3
145-
146-
def test_complete_graph_k2(self):
147-
"""Test min k-cut for a complete graph for k=2."""
126+
@pytest.mark.parametrize("k,expected", [(1, 0), (2, 4), (3, 7), (4, 9), (5, 10)])
127+
def test_complete_graph(self, k, expected):
128+
"""Test various k for a complete graph of 5 nodes."""
148129
G = nx.complete_graph(5)
149-
cut_value, cutset = minimum_k_cut(G, 2)
150-
# it should contain all the edges incident to a node
151-
assert cut_value == 4
152-
# remove the edges
153-
G.remove_edges_from(cutset)
154-
assert len(list(nx.connected_components(G))) == 2
155-
156-
def test_complete_graph_all(self):
157-
"""Test min k-cut for a complete graph."""
158-
G = nx.complete_graph(5)
159-
cut_value, cutset = minimum_k_cut(G, 5)
160-
assert cut_value == 10
161-
# remove the edges
162-
G.remove_edges_from(cutset)
163-
assert set(G.edges()) == set()
164-
165-
def test_complete_graph_weighted(self):
166-
"""Test min k-cut for a weighted complete graph."""
167-
G = nx.complete_graph(5)
168-
nx.set_edge_attributes(G, values=10, name="weight")
169-
cut_value, cutset = minimum_k_cut(G, 5, weight="weight")
170-
assert cut_value == 100
171-
# remove the edges
130+
cut_value, cutset = minimum_k_cut(G, k)
131+
assert cut_value == expected
172132
G.remove_edges_from(cutset)
173-
assert set(G.edges()) == set()
133+
assert len(list(nx.connected_components(G))) >= k
174134

175135
@pytest.mark.parametrize(
176136
"graph_class",
@@ -184,9 +144,30 @@ def test_complete_graph_weighted(self):
184144
],
185145
)
186146
@pytest.mark.parametrize("k", list(range(1, 11)))
187-
def test_compare_min_cut(self, graph_class, k):
147+
def test_connected_components(self, graph_class, k):
188148
"""Test multiple graph types and k."""
189149
G = graph_class()
190150
cut_value, cutset = minimum_k_cut(G, k)
191151
G.remove_edges_from(cutset)
192152
assert len(list(nx.connected_components(G))) >= k
153+
154+
def test_complete_graph_weighted(self):
155+
"""Test min k-cut for a weighted complete graph."""
156+
G = nx.complete_graph(5)
157+
nx.set_edge_attributes(G, values=10, name="weight")
158+
cut_value, cutset = minimum_k_cut(G, 5, weight="weight")
159+
assert cut_value == 100
160+
# remove the edges
161+
G.remove_edges_from(cutset)
162+
assert set(G.edges()) == set()
163+
164+
def test_path_graph_weighted(self):
165+
"""Test min k-cut for a weighted path graph."""
166+
G = nx.Graph()
167+
G.add_weighted_edges_from(
168+
[(0, 1, 10), (1, 2, 10), (2, 3, 5)], weight="capacity"
169+
)
170+
cut_value, cutset = minimum_k_cut(G, 3, weight="capacity")
171+
assert cut_value == 15
172+
G.remove_edges_from(cutset)
173+
assert len(list(nx.connected_components(G))) == 3

0 commit comments

Comments
 (0)