Skip to content

Commit 2d14157

Browse files
committed
fix code style and formatting
1 parent 136d471 commit 2d14157

4 files changed

Lines changed: 22 additions & 19 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
from pathlib import Path
1515

16-
sys.path.insert(0, str(Path('..', 'src').resolve()))
16+
sys.path.insert(0, str(Path("..", "src").resolve()))
1717

1818
# -- Project information -----------------------------------------------------
1919

examples/visualization_examples/sankey_diagram/abalone_sankey_diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@
2929
print("\nMetrics:", metrics)
3030

3131
# Visualize sankey diagram
32-
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
32+
fig = sankey_diagram(rules=rules, interestingness_measure="support", max_rules=4)
3333
fig.show()

examples/visualization_examples/sankey_diagram/weather_data_sankey_diagram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@
3636
print("\nMetrics:", metrics)
3737

3838
# Visualize sankey diagram
39-
fig = sankey_diagram(rules=rules, interestingness_measure="support", M=4)
39+
fig = sankey_diagram(rules=rules, interestingness_measure="support", max_rules=4)
4040
fig.show()

src/niaarm/visualize.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -593,15 +593,16 @@ def prepare_data(rules, metrics):
593593
return plt
594594

595595

596-
def sankey_diagram(rules, interestingness_measure, M=4):
596+
def sankey_diagram(rules, interestingness_measure, max_rules=4):
597597
"""
598598
Visualize rules as a sankey diagram.
599599
600600
Args:
601601
rules (Rule): Association rule or rules to visualize.
602-
interestingness_measures (str): Interestingness measure Z = {supp, cons,
602+
interestingness_measure (str): Interestingness measure Z = {supp, cons,
603603
lift},reflecting the quality of a particular connection.
604-
m (int): Maximum number of rules to be selected for visualization. Default: 4
604+
max_rules (int): Maximum number of rules to be selected for visualization.
605+
Default: 4
605606
606607
Returns:
607608
Figure or plot.
@@ -626,22 +627,22 @@ def build_adjacency_matrix(rules):
626627

627628
return adjacency_matrix
628629

629-
def knapsack_selection(adj_matrix, rules, M):
630+
def knapsack_selection(adj_matrix, rules, max_rules):
630631
fitness_scores = np.array([rule.fitness for rule in rules])
631-
N = len(rules) # number of rules
632-
weights = np.ones(N) # all rules have the same weight
632+
n_rules = len(rules) # number of rules
633+
weights = np.ones(n_rules) # all rules have the same weight
633634
similarity_weight = 1.0
634635
fitness_weight = 0.5
635636
combined_profits = (
636637
similarity_weight * np.sum(adj_matrix) + fitness_weight * fitness_scores
637638
) # combined similarities with fitness for values
638639

639-
selected = np.zeros(N, dtype=int)
640+
selected = np.zeros(n_rules, dtype=int)
640641

641642
# Initialize DP table
642-
dp = np.zeros((N + 1, M + 1))
643-
for i in range(1, N + 1):
644-
for w in range(1, M + 1):
643+
dp = np.zeros((n_rules + 1, max_rules + 1))
644+
for i in range(1, n_rules + 1):
645+
for w in range(1, max_rules + 1):
645646
if weights[i - 1] <= w:
646647
dp[i, w] = max(
647648
dp[i - 1, w], dp[i - 1, w - 1] + combined_profits[i - 1]
@@ -650,22 +651,22 @@ def knapsack_selection(adj_matrix, rules, M):
650651
dp[i, w] = dp[i - 1, w]
651652

652653
# Backtrack to find selected rules
653-
w = M
654-
for i in range(N, 0, -1):
654+
w = max_rules
655+
for i in range(n_rules, 0, -1):
655656
if dp[i, w] != dp[i - 1, w]:
656657
selected[i - 1] = 1
657658
w -= 1
658659

659-
selected_rules = [rules[i] for i in range(N) if selected[i]]
660+
selected_rules = [rules[i] for i in range(n_rules) if selected[i]]
660661

661662
return selected_rules
662663

663-
def prepare_data(rules, M, interestingness_measure):
664+
def prepare_data(rules, max_rules, interestingness_measure):
664665
if not rules:
665666
return [], [], [], []
666667

667668
adj_matrix = build_adjacency_matrix(rules)
668-
selected_rules = knapsack_selection(adj_matrix, rules, M)
669+
selected_rules = knapsack_selection(adj_matrix, rules, max_rules)
669670

670671
sources = []
671672
targets = []
@@ -696,7 +697,9 @@ def prepare_data(rules, M, interestingness_measure):
696697

697698
return labels, sources, targets, values
698699

699-
labels, sources, targets, values = prepare_data(rules, M, interestingness_measure)
700+
labels, sources, targets, values = prepare_data(
701+
rules, max_rules, interestingness_measure
702+
)
700703

701704
# Visualization using Plotly
702705
fig = go.Figure(

0 commit comments

Comments
 (0)