Skip to content

Commit 9fad249

Browse files
author
Mark Hale
committed
added clusters_from_cover to kmapper.
1 parent e351e60 commit 9fad249

2 files changed

Lines changed: 57 additions & 0 deletions

File tree

kmapper/kmapper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,29 @@ def data_from_cluster_id(self, cluster_id, graph, data):
827827
else:
828828
return np.array([])
829829

830+
def clusters_from_cover(self, cube_ids, graph):
831+
"""Returns the clusters and their members from the subset of the cover spanned by the given cube_ids
832+
833+
Parameters
834+
----------
835+
cube_ids : list of int
836+
List of hypercube indices.
837+
graph : dict
838+
The resulting dictionary after applying map().
839+
840+
Returns
841+
-------
842+
clusters : dict
843+
cluster member data indexed by cluster ID (subset of `graph["nodes"]`).
844+
845+
"""
846+
clusters = {}
847+
cluster_id_prefixes = tuple(["cube"+str(i)+"_" for i in cube_ids])
848+
for cluster_id, cluster_members in graph["nodes"].items():
849+
if cluster_id.startswith(cluster_id_prefixes):
850+
clusters[cluster_id] = cluster_members
851+
return clusters
852+
830853
def _process_projection_tuple(self, projection):
831854
# Detect if projection is a tuple (for prediction functions)
832855
# TODO: multi-label models

test/test_mapper.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,24 @@ def test_wrong_id(self):
7575
mems = mapper.data_from_cluster_id("new node", graph, data)
7676
np.testing.assert_array_equal(mems, np.array([]))
7777

78+
def test_clusters_from_cover(self):
79+
mapper = KeplerMapper(verbose=1)
80+
data = np.random.rand(100, 2)
81+
82+
graph = mapper.map(data)
83+
cube_ids = mapper.cover.find(data[0])
84+
mems = mapper.clusters_from_cover(cube_ids, graph)
85+
assert len(mems) > 0
86+
for cluster_id, cluster_members in mems.items():
87+
np.testing.assert_array_equal(cluster_members, graph["nodes"][cluster_id])
88+
89+
def test_no_clusters_from_cover(self):
90+
mapper = KeplerMapper(verbose=1)
91+
data = np.random.rand(100, 2)
92+
93+
graph = mapper.map(data)
94+
mems = mapper.clusters_from_cover([999], graph)
95+
assert len(mems) == 0
7896

7997
class TestMap:
8098
def test_simplices(self):
@@ -95,6 +113,22 @@ def test_simplices(self):
95113
assert len(nodes) == 3
96114
assert len(edges) == 3
97115

116+
def test_nodes(self):
117+
mapper = KeplerMapper()
118+
119+
X = np.random.rand(100, 2)
120+
lens = mapper.fit_transform(X)
121+
graph = mapper.map(
122+
lens,
123+
X=X,
124+
cover=Cover(n_cubes=3, perc_overlap=0.75),
125+
clusterer=cluster.DBSCAN(metric="euclidean", min_samples=3),
126+
)
127+
assert len(graph["nodes"]) == 3
128+
for i, cluster_id in enumerate(graph["nodes"]):
129+
# verify cluster ID format
130+
assert cluster_id == "cube{}_cluster0".format(i)
131+
98132
def test_precomputed(self):
99133
mapper = KeplerMapper()
100134

0 commit comments

Comments
 (0)