Skip to content

Commit 962087b

Browse files
authored
Adding the possibility to inject precomputed results in study (#4457)
1 parent 5d5ddbe commit 962087b

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,39 @@ def remove_cases(self, case_keys):
292292
self.remove_benchmark(key)
293293
(self.folder / "cases.pickle").write_bytes(pickle.dumps(self.cases))
294294

295+
def set_precomputed_results(self, precomputed_results, verbose=False):
296+
"""Set precomputed results for some cases. This is useful when you want to compute results outside of the benchmark and
297+
then set them in the benchmark.
298+
299+
Parameters
300+
----------
301+
precomputed_results : dict
302+
A dict with the same keys as cases and values are dict with the results to set for each case.
303+
The keys of the inner dict must be the same as the keys of the benchmark result.
304+
'run_time' is a special key that will be set to 0.0 if not present in the precomputed results.
305+
verbose : bool, default: False
306+
Whether to print the keys of the precomputed results when setting them.
307+
"""
308+
309+
for key in precomputed_results.keys():
310+
assert key in self.cases, f"Key {key} in precomputed_results is not in cases"
311+
benchmark = self.create_benchmark(key)
312+
if verbose:
313+
print("### Set benchmark", key, "###")
314+
315+
for k, v in precomputed_results[key].items():
316+
benchmark.result[k] = v
317+
if "run_time" not in benchmark.result:
318+
benchmark.result["run_time"] = 0.0
319+
if verbose:
320+
print(f"Warning: 'run_time' is not in the precomputed results for key {key}, setting it to 0.0")
321+
322+
self.benchmarks[key] = benchmark
323+
bench_folder = self.folder / "results" / self.key_to_str(key)
324+
bench_folder.mkdir(exist_ok=True)
325+
benchmark.save_run(bench_folder)
326+
benchmark.save_main(bench_folder)
327+
295328
def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
296329
if case_keys is None:
297330
case_keys = list(self.cases.keys())

0 commit comments

Comments
 (0)