Skip to content

Commit b814f84

Browse files
additional tests for multi-objective strategies
1 parent 24165ce commit b814f84

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

test/context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@
8080
except ImportError:
8181
pyatf_present = False
8282

83+
try:
84+
import pymoo
85+
pymoo_present = True
86+
except ImportError:
87+
pymoo_present = False
88+
8389
try:
8490
from autotuning_methodology.report_experiments import get_strategy_scores
8591
methodology_present = True
@@ -110,6 +116,7 @@
110116
skip_if_no_hip = pytest.mark.skipif(not hip_present, reason="No HIP Python found")
111117
skip_if_no_pyatf = pytest.mark.skipif(not pyatf_present, reason="PyATF not installed")
112118
skip_if_no_methodology = pytest.mark.skipif(not methodology_present, reason="Autotuning Methodology not found")
119+
skip_if_no_pymoo = pytest.mark.skipif(not pymoo_present, reason="No PyMOO found")
113120

114121

115122
def skip_backend(backend: str):
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
3+
import numpy as np
4+
import pytest
5+
from pathlib import Path
6+
7+
import kernel_tuner
8+
from kernel_tuner import util
9+
10+
from ..context import skip_if_no_pymoo
11+
12+
cache_filename = Path(__file__).parent / "test_cache_time_energy.json"
13+
14+
strategies = ["nsga2", "nsga3"]
15+
16+
@skip_if_no_pymoo
17+
@pytest.mark.parametrize('strategy', strategies)
18+
def test_strategies(strategy):
19+
20+
options = dict(strategy=strategy,
21+
strategy_options = dict(popsize=5, max_fevals=15),
22+
objective = ["time", "energy"],
23+
objective_higher_is_better = [False, False],
24+
verbose=True,
25+
)
26+
27+
print(f"testing {strategy}")
28+
assert cache_filename.exists()
29+
results, env = kernel_tuner.tune_cache(cache_filename, **options)
30+
31+
# assert has results
32+
assert len(results) > 0
33+
34+
# assert pareto front is stored in env["best_config"]
35+
pareto_front = util.get_pareto_results(results, options["objective"], options["objective_higher_is_better"])
36+
assert len(pareto_front) == len(env["best_config"])

0 commit comments

Comments
 (0)