Skip to content

Commit 877a01e

Browse files
committed
Make CostFunc return self.invalid_return_value instead of sys.float_info.max
1 parent 0dad036 commit 877a01e

5 files changed

Lines changed: 25 additions & 25 deletions

File tree

kernel_tuner/runners/parallel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from kernel_tuner.runners.runner import Runner
1010
from kernel_tuner.util import (
1111
Timer,
12-
disable_benchmark_timings,
12+
copy_without_benchmark_timings,
1313
ErrorConfig,
1414
TuningBudget,
1515
print_config_output,
@@ -313,7 +313,7 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
313313
if key in tuning_options.cache:
314314
# We must disable the timings as otherwise these will counted
315315
# as part of the total_compile/benchmark/verification_time
316-
result = disable_benchmark_timings(tuning_options.cache[key])
316+
result = copy_without_benchmark_timings(tuning_options.cache[key])
317317

318318
# recompute matrics for this entry
319319
result = process_metrics(result, metrics)
@@ -369,7 +369,7 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
369369
# as otherwise we would count them multiple times in the total
370370
for i, j in duplicate_entries:
371371
if results[j]:
372-
results[i] = disable_benchmark_timings(results[j])
372+
results[i] = copy_without_benchmark_timings(results[j])
373373

374374
# Count the number of valid results
375375
num_valid_results = sum(bool(r) for r in results)
@@ -388,7 +388,7 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
388388
for result in results:
389389
if result:
390390
# Time must be in ms
391-
result["strategy_time"] = strategy_time / num_valid_results
392-
result["framework_time"] = framework_time / num_valid_results
391+
result["strategy_time"] = 1000 * strategy_time / num_valid_results
392+
result["framework_time"] = 1000 * framework_time / num_valid_results
393393

394394
return results

kernel_tuner/runners/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def __init__(self):
1313
self.timer = Timer()
1414
self.accumulated_strategy_time = 0
1515

16-
def add_strategy_time(self, seconds):
16+
def add_strategy_time(self, seconds: float):
17+
""" Notify this runner of the amount of time spent by the search strategy."""
1718
self.accumulated_strategy_time += seconds
1819

1920
def shutdown(self):

kernel_tuner/runners/sequential.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from kernel_tuner.core import DeviceInterface
77
from kernel_tuner.runners.runner import Runner
8-
from kernel_tuner.util import ErrorConfig, Timer, print_config_output, process_metrics, store_cache, disable_benchmark_timings
8+
from kernel_tuner.util import ErrorConfig, Timer, print_config_output, process_metrics, store_cache, copy_without_benchmark_timings
99

1010

1111
class SequentialRunner(Runner):
@@ -63,6 +63,7 @@ def run(self, parameter_space, tuning_options):
6363

6464
results = []
6565
worker_time = 0
66+
warmup_time = 0
6667

6768
# iterate over parameter space
6869
for element in parameter_space:
@@ -71,18 +72,16 @@ def run(self, parameter_space, tuning_options):
7172
if tuning_options.budget.is_done():
7273
results.append(None)
7374
continue
74-
75+
7576
tuning_options.budget.add_evaluations(1)
7677
params = dict(zip(tuning_options.tune_params.keys(), element))
77-
7878
result = None
79-
warmup_time = 0
8079

8180
# check if configuration is in the cache
8281
x_int = ",".join([str(i) for i in element])
8382
if tuning_options.cache and x_int in tuning_options.cache:
8483
cache_entry = tuning_options.cache[x_int]
85-
params.update(disable_benchmark_timings(cache_entry))
84+
params.update(copy_without_benchmark_timings(cache_entry))
8685
else:
8786
# attempt to warmup the GPU by running the first config in the parameter space and ignoring the result
8887
if not self.warmed_up:
@@ -136,6 +135,7 @@ def run(self, parameter_space, tuning_options):
136135
# Amortize the time over all the results
137136
for result in results:
138137
if result:
138+
# Time must be in ms
139139
result["strategy_time"] = 1000 * strategy_time / num_valid_results
140140
result["framework_time"] = 1000 * framework_time / num_valid_results
141141

kernel_tuner/runners/simulation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob
6060

6161
def get_device_info(self):
6262
return self.dev
63-
63+
6464
def get_environment(self, tuning_options):
6565
env = self.dev.get_environment()
6666
env["simulation"] = True
@@ -90,7 +90,7 @@ def run(self, parameter_space, tuning_options):
9090
if tuning_options.budget.is_done():
9191
results.append(None)
9292
continue
93-
93+
9494
# check if element is in the cache
9595
key = ",".join([str(i) for i in element])
9696

@@ -109,33 +109,33 @@ def run(self, parameter_space, tuning_options):
109109
# is served from the cache beyond the first timel. That is, when the
110110
# configuration is already counted towards the unique_results.
111111
if key in self.visited_results:
112-
result = util.disable_benchmark_timings(result)
112+
result = util.copy_without_benchmark_timings(result)
113113
else:
114114
# configuration is evaluated for the first time, print to the console
115115
util.print_config_output(tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units)
116116
self.visited_results.add(key)
117117

118-
# Simulate the evaluation of this configuration
119-
tuning_options.budget.add_evaluations(1)
120-
tuning_options.budget.add_time(milliseconds=result["compile_time"])
121-
tuning_options.budget.add_time(milliseconds=result["verification_time"])
122-
tuning_options.budget.add_time(milliseconds=result["benchmark_time"])
123-
124118
try:
125119
self.total_simulated_time += result["compile_time"] + result["verification_time"] + result["benchmark_time"]
126120
except KeyError:
127121
raise RuntimeError(
128122
"Cannot use simulation mode with a time limit on a cache file that does not have full compile, verification, and benchmark timings on all configurations"
129123
)
130124

125+
# Simulate the evaluation of this configuration
126+
tuning_options.budget.add_evaluations(1)
127+
tuning_options.budget.add_time(milliseconds=result["compile_time"])
128+
tuning_options.budget.add_time(milliseconds=result["verification_time"])
129+
tuning_options.budget.add_time(milliseconds=result["benchmark_time"])
130+
131131
results.append(result)
132132
continue
133133

134134
# if the configuration is not in the cache and not within restrictions, simulate an InvalidConfig with warning
135135
params_dict = dict(zip(tuning_options['tune_params'].keys(), element))
136136
check = util.check_restrictions(tuning_options.restrictions, params_dict, True)
137137
if not check:
138-
result = util.disable_benchmark_timings(params_dict) # Set timings to zero
138+
result = util.copy_without_benchmark_timings(params_dict) # Set timings to zero
139139
result[tuning_options.objective] = util.InvalidConfig()
140140
results.append(result)
141141
warn(f"Configuration {element} not in cache, does not pass restrictions. Will be treated as an InvalidConfig, but make sure you are evaluating the correct cache file.")
@@ -159,8 +159,8 @@ def run(self, parameter_space, tuning_options):
159159
for result in results:
160160
if result:
161161
# Time must be in ms
162-
result["strategy_time"] = strategy_time / num_valid_results
163-
result["framework_time"] = framework_time / num_valid_results
162+
result["strategy_time"] = 1000 * strategy_time / num_valid_results
163+
result["framework_time"] = 1000 * framework_time / num_valid_results
164164

165165

166166
return results

kernel_tuner/strategies/common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
scaling=False,
7272
snap=True,
7373
return_invalid=False,
74-
return_raw=None,
7574
invalid_value=sys.float_info.max,
7675
):
7776
"""An abstract method to handle evaluation of configurations.
@@ -207,7 +206,7 @@ def eval_all(self, xs, check_restrictions=True):
207206
else:
208207
# this is not a valid configuration, replace with float max if needed
209208
if not self.return_invalid:
210-
return_value = sys.float_info.max
209+
return_value = self.invalid_return_value
211210

212211
# include raw data in return if requested
213212
return_values.append(return_value)

0 commit comments

Comments
 (0)