Skip to content

Commit ece1eb1

Browse files
committed
performance improvement for flattening lists|
1 parent 379b498 commit ece1eb1

3 files changed

Lines changed: 16 additions & 5 deletions

File tree

ciw/auxiliary.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,10 @@ def truncated_normal(mean, sd):
4444
sample = random.normalvariate(mean, sd)
4545
while sample <= 0.0:
4646
sample = random.normalvariate(mean, sd)
47-
return sample
47+
return sample
48+
49+
def flatten_list(list_of_lists):
50+
flat = []
51+
for a_list in list_of_lists:
52+
flat += a_list
53+
return flat

ciw/node.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import networkx as nx
88

9-
from .auxiliary import random_choice
9+
from .auxiliary import random_choice, flatten_list
1010
from .data_record import DataRecord
1111
from .server import Server
1212

@@ -67,8 +67,7 @@ def __init__(self, id_, simulation):
6767

6868
@property
6969
def all_individuals(self):
70-
return [i for priority_class in self.individuals
71-
for i in priority_class]
70+
return flatten_list(self.individuals)
7271

7372
def __repr__(self):
7473
"""

ciw/tests/test_auxiliary.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,10 @@ def test_randomchoice(self):
7777
self.assertEqual(choice_counts, {'Exit Node': 100})
7878
self.assertEqual(r1, r2)
7979

80-
80+
def test_flatten_list(self):
81+
for seed in range(20):
82+
random.seed(seed)
83+
all_classes = [[random.random() for _ in range(random.randrange(3, 30, 1))] for _ in range(random.randrange(5, 20, 1))]
84+
A = [i for priority in all_classes for i in priority]
85+
B = ciw.flatten_list(all_classes)
86+
self.assertEqual(A, B)

0 commit comments

Comments
 (0)