Skip to content

Commit 54a9a84

Browse files
committed
Add new hyperparameter optimisation script
(and parallelise the existing one through joblib)
1 parent 51bdd73 commit 54a9a84

2 files changed

Lines changed: 370 additions & 13 deletions

File tree

qstack/regression/hyperparameters.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import scipy
66
from sklearn.model_selection import KFold
7+
from sklearn.utils.parallel import Parallel, delayed
78
from qstack.mathutils.fps import do_fps
89
from qstack.tools import correct_num_threads
910
from .kernel_utils import get_kernel, defaults, train_test_split_idx, sparse_regression_kernel
@@ -69,18 +70,21 @@ def k_fold_opt(K_all, eta):
6970

7071
def hyper_loop(sigma, eta):
7172
errors = []
72-
for s in sigma:
73-
if read_kernel is False:
74-
K_all = kernel(X_train, X_train, 1.0/s)
75-
else:
76-
K_all = X_train
77-
78-
for e in eta:
79-
mean, std = k_fold_opt(K_all, e)
80-
if printlevel>0 :
81-
sys.stderr.flush()
82-
print(s, e, mean, std, flush=True)
83-
errors.append((mean, std, e, s))
73+
with Parallel(n_jobs=-1) as parallel:
74+
for s in sigma:
75+
if read_kernel is False:
76+
K_all = kernel(X_train, X_train, 1.0/s)
77+
else:
78+
K_all = X_train
79+
80+
def inner_loop(s,e):
81+
mean, std = k_fold_opt(K_all, e)
82+
if printlevel>0 :
83+
sys.stderr.flush()
84+
print(s, e, mean, std, flush=True)
85+
return (mean, std, e, s)
86+
87+
errors += parallel(delayed(inner_loop)(s,e) for e in eta)
8488
return errors
8589
if gkernel is None:
8690
gwrap = None
@@ -139,7 +143,6 @@ def hyper_loop(sigma, eta):
139143
def _get_arg_parser():
140144
"""Parse CLI arguments."""
141145
parser = RegressionParser(description='This program finds the optimal hyperparameters.', hyperparameters_set='array')
142-
parser.remove_argument("random_state")
143146
parser.remove_argument("train_size")
144147
return parser
145148

0 commit comments

Comments
 (0)