Skip to content

Commit 491d5b6

Browse files
Refactor code for improved readability in prune and apriori functions
2 parents 3340eee + afdd40b commit 491d5b6

1 file changed

Lines changed: 21 additions & 9 deletions

File tree

machine_learning/apriori_algorithm.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ def load_data() -> list[list[str]]:
3030
]
3131

3232

33-
def prune(frequent_itemsets: list[list[str]], candidates: list[list[str]]) -> list[list[str]]:
33+
def prune(
34+
frequent_itemsets: list[list[str]], candidates: list[list[str]]
35+
) -> list[list[str]]:
3436
"""
3537
Prunes candidate itemsets by ensuring all (k-1)-subsets exist in
3638
previous frequent itemsets.
@@ -74,15 +76,21 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
7476
for item in transaction:
7577
item_counts[item] += 1
7678

77-
current_frequents = [[item] for item, count in item_counts.items() if count >= min_support]
78-
frequent_itemsets = [([item], count) for item, count in item_counts.items() if count >= min_support]
79+
current_frequents = [
80+
[item] for item, count in item_counts.items() if count >= min_support
81+
]
82+
frequent_itemsets = [
83+
([item], count) for item, count in item_counts.items() if count >= min_support
84+
]
7985

8086
k = 2
8187
while current_frequents:
82-
candidates = [sorted(set(i) | set(j))
83-
for i in current_frequents
84-
for j in current_frequents
85-
if len(set(i).union(j)) == k]
88+
candidates = [
89+
sorted(list(set(i) | set(j)))
90+
for i in current_frequents
91+
for j in current_frequents
92+
if len(set(i).union(j)) == k
93+
]
8694

8795
candidates = [list(c) for c in {frozenset(c) for c in candidates}]
8896

@@ -95,10 +103,14 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
95103
if set(candidate).issubset(t_set):
96104
candidate_counts[tuple(sorted(candidate))] += 1
97105

98-
current_frequents = [list(key) for key, count in candidate_counts.items() if count >= min_support]
106+
current_frequents = [
107+
list(key) for key, count in candidate_counts.items() if count >= min_support
108+
]
99109
frequent_itemsets.extend(
100110
[
101-
(list(key), count) for key, count in candidate_counts.items() if count >= min_support
111+
(list(key), count)
112+
for key, count in candidate_counts.items()
113+
if count >= min_support
102114
]
103115
)
104116

0 commit comments

Comments
 (0)