Skip to content

Commit afdd40b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0c6a251 commit afdd40b

1 file changed

Lines changed: 21 additions & 9 deletions

File tree

machine_learning/apriori_algorithm.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def load_data() -> list[list[str]]:
2525
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]
2626

2727

28-
def prune(frequent_itemsets: list[list[str]], candidates: list[list[str]]) -> list[list[str]]:
28+
def prune(
29+
frequent_itemsets: list[list[str]], candidates: list[list[str]]
30+
) -> list[list[str]]:
2931
"""
3032
Prunes candidate itemsets by ensuring all (k-1)-subsets exist in previous frequent itemsets.
3133
@@ -67,15 +69,21 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
6769
for item in transaction:
6870
item_counts[item] += 1
6971

70-
current_frequents = [[item] for item, count in item_counts.items() if count >= min_support]
71-
frequent_itemsets = [([item], count) for item, count in item_counts.items() if count >= min_support]
72+
current_frequents = [
73+
[item] for item, count in item_counts.items() if count >= min_support
74+
]
75+
frequent_itemsets = [
76+
([item], count) for item, count in item_counts.items() if count >= min_support
77+
]
7278

7379
k = 2
7480
while current_frequents:
75-
candidates = [sorted(list(set(i) | set(j)))
76-
for i in current_frequents
77-
for j in current_frequents
78-
if len(set(i).union(j)) == k]
81+
candidates = [
82+
sorted(list(set(i) | set(j)))
83+
for i in current_frequents
84+
for j in current_frequents
85+
if len(set(i).union(j)) == k
86+
]
7987

8088
candidates = [list(c) for c in {frozenset(c) for c in candidates}]
8189

@@ -88,10 +96,14 @@ def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], in
8896
if set(candidate).issubset(t_set):
8997
candidate_counts[tuple(sorted(candidate))] += 1
9098

91-
current_frequents = [list(key) for key, count in candidate_counts.items() if count >= min_support]
99+
current_frequents = [
100+
list(key) for key, count in candidate_counts.items() if count >= min_support
101+
]
92102
frequent_itemsets.extend(
93103
[
94-
(list(key), count) for key, count in candidate_counts.items() if count >= min_support
104+
(list(key), count)
105+
for key, count in candidate_counts.items()
106+
if count >= min_support
95107
]
96108
)
97109

0 commit comments

Comments
 (0)