Skip to content

Commit 033127d

Browse files
committed
document said script's functions
1 parent 54a9a84 commit 033127d

1 file changed

Lines changed: 59 additions & 4 deletions

File tree

qstack/regression/hyperparameters2.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020

2121
def fit_quadratic(x1,x2,x3, y1,y2,y3):
22-
"""Compute the three coefficients of a quadratic polynomial going through three given points."""
22+
"""
23+
Compute the three coefficients of a quadratic polynomial going through three given points.
24+
Could probably be replaced by `np.polyfit` now that I know about it. Fluff it, we ball.
25+
"""
2326
# we need to change coordinates around for this
2427

2528
# first, slopes at 0.5(x1+x2) and 0.5(x2+x3)
@@ -38,7 +41,21 @@ def fit_quadratic(x1,x2,x3, y1,y2,y3):
3841
return curv, slope, intercept
3942

4043
def parabolic_search(x_left, x_right, get_err, n_iter=10, x_thres=0.1, y_thres=0.01):
41-
"""A 1D optimisation function, assuming the loss function in question is convex, It first checks to see the bounds are correct, then refines them by fitting quadratic polynomials"""
44+
"""
45+
Gradient-less line search of the minimum of `get_err`, supposedly between `x_left` and `x_right`.
46+
Fits quadratic polynomials to perform this search, meaning `get_err` is assumed to be convex.
47+
48+
Args:
49+
x_left (float): supposed left bound of the minimum of `get_err`
50+
x_right (float): supposed right bound of the minimum of `get_err`
51+
get_err (callable float->float): the function to minimise.
52+
n_iter (int): the number of function calls allowed
53+
x_thres (float): the acceptable error threshold for the the argmin to find
54+
y_thres (float): the acceptable error threshold for the min to find
55+
56+
Returns:
57+
the (argmin, min) tuple characterising the minimum of the function (2x float)
58+
"""
4259

4360
y_left = get_err(x_left)
4461
y_right = get_err(x_right)
@@ -122,6 +139,21 @@ def parabolic_search(x_left, x_right, get_err, n_iter=10, x_thres=0.1, y_thres=0
122139

123140

124141
def kfold_alpha_eval(K_all, y, n_splits, alpha_grid, parallel=None, on_compute=(lambda eta,err,stderr:None)):
142+
"""Module-internal function: optimise alpha (regularisation parameter) of a KRR learning model, using a K-fold validation.
143+
144+
Args:
145+
K_all: matrix of kernel values (can be n_total*n_total for naive KRR or n_total*n_references for sparse KRR)
146+
y: learnable properties for all inputs (n_total-length vector)
147+
n_splits: number of folds for k-fold validation
148+
alpha_grid: all the values of alpha to try (array-like)
149+
parallel: optional joblib.Parallel instance to use to parallelise this function (by default one is constructed)
150+
on_compute: function to call for the error summaries of each value of alpha
151+
(callable: alpha, error_mean, error_stddev -> None)
152+
153+
Returns:
154+
- optimal value of alpha
155+
- validation error list for all k-fold evaluations for this value of alpha
156+
"""
125157
if parallel is None:
126158
parallel = Parallel(n_jobs=-1, return_as="generator_unordered")
127159
kfold = KFold(n_splits=n_splits, shuffle=False)
@@ -181,9 +213,32 @@ def search_sigma(
181213
sigma_bounds, alpha_grid,
182214
n_iter, n_splits,
183215
stddev_portion=+1.0, sparse_idx=None,
184-
parallel=None, on_compute=(lambda sigma,eta,err,stderr:None)
216+
parallel=None, on_compute=(lambda sigma,alpha,err,stderr:None)
185217
):
186-
"""Search"""
218+
"""Search the optimal values of sigma and alpha for a KRR model with known representations.
219+
Sigma is the width parameter of the kernel function used,
220+
and alpha is the regularisation parameter of the resulting matrix equation.
221+
222+
Internally, calls the line-search rountine for gamma,
223+
where the function to minimise performs its own grid-based optimisation of alpha.
224+
225+
Args:
226+
X (np.ndarray[n_total,n_features]: feature vectors for the combined train-validation dataset
227+
y (np.ndarray[n_total]): learnable properties for all inputs
228+
sigma_bounds (tuple(float,float)): presumed bounds of the optimal value of sigma
229+
alpha_grid (array-like of floats): values of alpha to try
230+
n_iter (int): number of iterations for the sigma line-search
231+
n_splits (int): number of folds for k-fold validation
232+
stddev_portion (float): contribution of the error's standard deviation to compare error distributions
233+
sparse_idx (optional np.ndarray[int, n_references]): selection of reference inputs for sparse KRR.
234+
parallel (optional joblib.Parallel): tool to make the optimisation more parallel. by default, one will be (re)created as often as necessary.
235+
on_compute (callable sigma,alpha,err_mean,err_stddev -> None)
236+
237+
Returns:
238+
sigma (float): optimal value of sigma
239+
alpha (float): optimal value of alpha
240+
costs (np.ndarray[n_splits]): validation error distribution for these values of sigma,alpha
241+
"""
187242

188243
sigma_left, sigma_right = sigma_bounds
189244

0 commit comments

Comments
 (0)