Skip to content

Commit 05bc4a6

Browse files
committed
Add tests for TuningBudget
1 parent d18bfdf commit 05bc4a6

6 files changed

Lines changed: 132 additions & 49 deletions

File tree

kernel_tuner/runners/parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
286286
params.update(tuning_options.cache[key])
287287

288288
# Simulate compile, verification, and benchmark time
289-
tuning_options.budget.add_time_spent(params["compile_time"])
290-
tuning_options.budget.add_time_spent(params["verification_time"])
291-
tuning_options.budget.add_time_spent(params["benchmark_time"])
289+
tuning_options.budget.add_time(milliseconds=params["compile_time"])
290+
tuning_options.budget.add_time(milliseconds=params["verification_time"])
291+
tuning_options.budget.add_time(milliseconds=params["benchmark_time"])
292292
results.append(params)
293293
else:
294294
assert key not in key2index, "duplicate jobs submitted"

kernel_tuner/runners/sequential.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from kernel_tuner.core import DeviceInterface
77
from kernel_tuner.runners.runner import Runner
8-
from kernel_tuner.util import ErrorConfig, print_config_output, process_metrics, store_cache, stop_criterion_reached
8+
from kernel_tuner.util import ErrorConfig, print_config_output, process_metrics, store_cache
99

1010

1111
class SequentialRunner(Runner):
@@ -70,12 +70,15 @@ def run(self, parameter_space, tuning_options):
7070

7171
# iterate over parameter space
7272
for element in parameter_space:
73+
# If the time limit is exceeded, just skip this element. Add `None` to
74+
# indicate to CostFunc that no result is available for this config.
75+
if tuning_options.budget.is_done():
76+
results.append(None)
77+
continue
78+
7379
tuning_options.budget.add_evaluations(1)
7480
params = dict(zip(tuning_options.tune_params.keys(), element))
7581

76-
if stop_criterion_reached(tuning_options):
77-
return results
78-
7982
result = None
8083
warmup_time = 0
8184

@@ -85,9 +88,9 @@ def run(self, parameter_space, tuning_options):
8588
params.update(tuning_options.cache[x_int])
8689

8790
# Simulate compile, verification, and benchmark time
88-
tuning_options.budget.add_time_spent(params["compile_time"])
89-
tuning_options.budget.add_time_spent(params["verification_time"])
90-
tuning_options.budget.add_time_spent(params["benchmark_time"])
91+
tuning_options.budget.add_time(milliseconds=params["compile_time"])
92+
tuning_options.budget.add_time(milliseconds=params["verification_time"])
93+
tuning_options.budget.add_time(milliseconds=params["benchmark_time"])
9194
else:
9295
# attempt to warmup the GPU by running the first config in the parameter space and ignoring the result
9396
if not self.warmed_up:

kernel_tuner/runners/simulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def run(self, parameter_space, tuning_options):
116116

117117
# Simulate the evaluation of this configuration
118118
tuning_options.budget.add_evaluations(1)
119-
tuning_options.budget.add_time_spent(result["compile_time"])
120-
tuning_options.budget.add_time_spent(result["verification_time"])
121-
tuning_options.budget.add_time_spent(result["benchmark_time"])
119+
tuning_options.budget.add_time(milliseconds=result["compile_time"])
120+
tuning_options.budget.add_time(milliseconds=result["verification_time"])
121+
tuning_options.budget.add_time(milliseconds=result["benchmark_time"])
122122

123123
try:
124124
self.total_simulated_time += result["compile_time"] + result["verification_time"] + result["benchmark_time"]

kernel_tuner/strategies/common.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def _run_configs(self, xs, check_restrictions=True):
145145
batch_keys = [] # The keys of the configs to run
146146
pending_indices_by_key = dict() # Maps key => where to store result in `final_results`
147147
final_results = [] # List returned to the user
148-
benchmark_config = []
149148

150149
# Loop over all configurations. For each configurations there are four cases:
151150
# 1. The configuration is valid, we can skip it
@@ -201,16 +200,11 @@ def _run_configs(self, xs, check_restrictions=True):
201200
self.tuning_options.unique_results[key] = result
202201
self.results.append(result)
203202

204-
205-
# check again for stop condition
206-
# this check is necessary because some strategies cannot handle partially completed requests
207-
# for example when only half of the configs in a population have been evaluated
208-
self.budget_spent_fraction = util.check_stop_criterion(self.tuning_options)
209-
210203
# upon returning from this function control will be given back to the strategy, so reset the start time
211204
self.runner.last_strategy_start_time = perf_counter()
212205

213-
# Check the tuning budget again
206+
# this check is necessary because some strategies cannot handle partially completed requests
207+
# for example when only half of the configs in a population have been evaluated
214208
self.tuning_options.budget.raise_exception_if_done()
215209
self.budget_spent_fraction = self.tuning_options.budget.get_fraction_consumed()
216210

kernel_tuner/util.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -190,19 +190,13 @@ def check_argument_list(kernel_name, kernel_string, args):
190190

191191
class TuningBudget:
192192
def __init__(self, time_limit=None, max_fevals=None):
193-
if max_fevals is None:
194-
max_fevals = float("inf")
195-
196-
if time_limit is None:
197-
time_limit = timedelta.max
198-
199-
if not isinstance(time_limit, timedelta):
193+
if time_limit is not None and not isinstance(time_limit, timedelta):
200194
time_limit = timedelta(seconds=time_limit)
201195

202-
if max_fevals <= 0:
196+
if max_fevals is not None and max_fevals <= 0:
203197
raise ValueError("max_fevals must be greater than zero")
204198

205-
if time_limit <= timedelta(seconds=0):
199+
if time_limit is not None and time_limit <= timedelta(seconds=0):
206200
raise ValueError("time_limit must be greater than zero")
207201

208202
self.start_time_seconds = time.perf_counter()
@@ -214,50 +208,55 @@ def __init__(self, time_limit=None, max_fevals=None):
214208
def add_evaluations(self, n=1):
215209
self.num_fevals += n
216210

217-
def add_time_spent(self, delta):
218-
if not isinstance(delta, timedelta):
219-
delta = timedelta(seconds=delta)
220-
self.time_spent_extra += delta
211+
def add_time(self, seconds=0, milliseconds=0):
212+
self.time_spent_extra += timedelta(seconds=seconds, milliseconds=milliseconds)
221213

222214
def get_time_spent(self) -> timedelta:
223215
seconds_passed = time.perf_counter() - self.start_time_seconds
224216
return timedelta(seconds=seconds_passed) + self.time_spent_extra
225217

226218
def get_time_remaining(self) -> timedelta:
227-
return max(self.time_limit - self.get_time_spent(), timedelta(seconds=0))
219+
if self.time_limit is not None:
220+
return max(self.time_limit - self.get_time_spent(), timedelta(seconds=0))
221+
else:
222+
return timedelta.max
228223

229224
def get_evaluations_spent(self) -> int:
230-
return max(self.max_fevals - self.num_fevals, 0)
225+
return self.num_fevals
231226

232227
def get_evaluations_remaining(self) -> int:
233-
return max(self.max_fevals - self.num_fevals, 0)
228+
if self.max_fevals is not None:
229+
return max(self.max_fevals - self.num_fevals, 0)
230+
else:
231+
return float("inf")
234232

235233
def is_done(self) -> bool:
236-
if self.num_fevals >= self.max_fevals:
234+
if self.max_fevals is not None and self.num_fevals >= self.max_fevals:
237235
return True
238236

239-
if self.get_time_spent() > self.time_limit:
237+
if self.time_limit is not None and self.get_time_spent() > self.time_limit:
240238
return True
241239

242240
return False
243241

244242
def raise_exception_if_done(self):
245-
if self.num_fevals >= self.max_fevals:
243+
if self.max_fevals is not None and self.num_fevals >= self.max_fevals:
246244
raise StopCriterionReached(f"max_fevals ({self.max_fevals}) reached")
247245

248-
if self.get_time_spent() > self.time_limit:
246+
if self.time_limit is not None and self.get_time_spent() > self.time_limit:
249247
raise StopCriterionReached("time limit exceeded")
250248

251249
def get_fraction_consumed(self) -> float:
252-
if self.num_fevals >= self.max_fevals:
253-
return 1.0
254-
255-
time_spent = self.get_time_spent()
250+
if self.max_fevals is not None and self.time_limit is not None:
251+
time_spent = self.get_time_spent()
252+
return min(1.0, time_spent / self.time_limit, self.num_fevals / self.max_fevals)
253+
elif self.max_fevals is not None:
254+
return min(1.0, self.num_fevals / self.max_fevals)
255+
elif self.time_limit is not None:
256+
return min(1.0, self.get_time_spent() / self.time_limit)
257+
else:
258+
return 0.0
256259

257-
if time_spent > self.time_limit:
258-
return 1.0
259-
260-
return max(time_spent / self.time_limit, self.num_fevals / self.max_fevals)
261260

262261

263262

test/test_util_functions.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import warnings
6+
import datetime
67

78
import numpy as np
89
import pytest
@@ -429,6 +430,92 @@ def test_check_argument_list7():
429430
assert_user_warning(check_argument_list, [kernel_name, kernel_string, args])
430431

431432

433+
def test_tuning_budget1():
434+
budget = TuningBudget()
435+
assert budget.get_evaluations_spent() == 0
436+
assert budget.get_evaluations_remaining() == float("inf")
437+
assert not budget.is_done()
438+
budget.raise_exception_if_done() # Should not raise
439+
assert budget.get_fraction_consumed() == 0.0
440+
441+
budget.add_evaluations(9000)
442+
assert budget.get_evaluations_spent() == 9000
443+
assert budget.get_evaluations_remaining() == float("inf")
444+
assert not budget.is_done()
445+
budget.raise_exception_if_done() # Should not raise
446+
assert budget.get_fraction_consumed() == 0.0
447+
448+
budget.add_time(seconds=9000)
449+
assert budget.get_evaluations_spent() == 9000
450+
assert budget.get_evaluations_remaining() == float("inf")
451+
assert not budget.is_done()
452+
budget.raise_exception_if_done() # Should not raise
453+
assert budget.get_fraction_consumed() == 0.0
454+
455+
def test_tuning_budget2():
456+
budget = TuningBudget(max_fevals=5)
457+
assert budget.get_evaluations_spent() == 0
458+
assert budget.get_evaluations_remaining() == 5
459+
assert not budget.is_done()
460+
budget.raise_exception_if_done() # Should not raise
461+
assert budget.get_fraction_consumed() == 0.0
462+
463+
budget.add_evaluations(4)
464+
assert budget.get_evaluations_spent() == 4
465+
assert budget.get_evaluations_remaining() == 1
466+
assert not budget.is_done()
467+
budget.raise_exception_if_done() # Should not raise
468+
assert budget.get_fraction_consumed() == 4/5
469+
470+
budget.add_evaluations(1)
471+
assert budget.get_evaluations_spent() == 5
472+
assert budget.get_evaluations_remaining() == 0
473+
assert budget.is_done()
474+
assert pytest.raises(StopCriterionReached, budget.raise_exception_if_done)
475+
assert budget.get_fraction_consumed() == 1.0
476+
477+
478+
def test_tuning_budget3():
479+
# Two values are similar if they are within 0.01
480+
approx = lambda x: pytest.approx(x, abs=0.01)
481+
482+
budget = TuningBudget(time_limit=5)
483+
assert budget.get_time_spent().total_seconds() == approx(0)
484+
assert budget.get_time_remaining().total_seconds() == approx(5)
485+
assert budget.get_evaluations_spent() == 0
486+
assert budget.get_evaluations_remaining() == float("inf")
487+
assert not budget.is_done()
488+
budget.raise_exception_if_done() # Should not raise
489+
assert budget.get_fraction_consumed() == approx(0.0)
490+
491+
budget.add_evaluations(1)
492+
assert budget.get_time_spent().total_seconds() == approx(0)
493+
assert budget.get_time_remaining().total_seconds() == approx(5)
494+
assert budget.get_evaluations_spent() == 1
495+
assert budget.get_evaluations_remaining() == float("inf")
496+
assert not budget.is_done()
497+
budget.raise_exception_if_done() # Should not raise
498+
assert budget.get_fraction_consumed() == approx(0.0)
499+
500+
budget.add_time(seconds=2)
501+
assert budget.get_time_spent().total_seconds() == approx(2)
502+
assert budget.get_time_remaining().total_seconds() == approx(3)
503+
assert budget.get_evaluations_spent() == 1
504+
assert budget.get_evaluations_remaining() == float("inf")
505+
assert not budget.is_done()
506+
budget.raise_exception_if_done() # Should not raise
507+
assert budget.get_fraction_consumed() == approx(2/5)
508+
509+
budget.add_time(seconds=4)
510+
assert budget.get_time_spent().total_seconds() == approx(6)
511+
assert budget.get_time_remaining().total_seconds() == approx(0)
512+
assert budget.get_evaluations_spent() == 1
513+
assert budget.get_evaluations_remaining() == float("inf")
514+
assert budget.is_done()
515+
assert pytest.raises(StopCriterionReached, budget.raise_exception_if_done)
516+
assert budget.get_fraction_consumed() == 1.0
517+
518+
432519
def test_check_tune_params_list():
433520
tune_params = dict(
434521
zip(

0 commit comments

Comments
 (0)