Skip to content

Commit 0865a52

Browse files
authored
Fix #1351: Merge small improvements from PR #1281 (PR #1352)
1 parent 0752578 commit 0865a52

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

pdr_backend/aimodel/aimodel_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def _build_wrapped_regr(
8080
ss = self.ss
8181
assert ss.do_regr
8282
assert ycont is not None
83+
assert X.shape[0] == ycont.shape[0], (X.shape[0], ycont.shape[0])
8384
do_constant = min(ycont) == max(ycont) or ss.approach == "RegrConstant"
8485

8586
# weight newest sample 10x, and 2nd-newest sample 5x
@@ -145,6 +146,7 @@ def _build_direct_classif(
145146
) -> Aimodel:
146147
ss = self.ss
147148
assert not ss.do_regr
149+
assert X.shape[0] == len(ytrue), (X.shape[0], len(ytrue))
148150
n_True, n_False = sum(ytrue), sum(np.invert(ytrue))
149151
smallest_n = min(n_True, n_False)
150152
do_constant = (smallest_n == 0) or ss.approach == "ClassifConstant"

pdr_backend/aimodel/aimodel_plotdata.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def __init__(
2020
model: Aimodel,
2121
X_train: np.ndarray,
2222
ytrue_train: np.ndarray,
23-
ycont_train: np.ndarray,
24-
y_thr: float,
23+
ycont_train: Optional[np.ndarray],
24+
y_thr: Optional[float],
2525
colnames: List[str],
2626
slicing_x: np.ndarray,
2727
sweep_vars: Optional[List[int]] = None,
@@ -45,7 +45,8 @@ def __init__(
4545
assert len(colnames) == n, (len(colnames), n)
4646
assert slicing_x.shape[0] == n, (slicing_x.shape[0], n)
4747
assert ytrue_train.shape[0] == N, (ytrue_train.shape[0], N)
48-
assert ycont_train.shape[0] == N, (ycont_train.shape[0], N)
48+
if ycont_train is not None:
49+
assert ycont_train.shape[0] == N, (ycont_train.shape[0], N)
4950
assert sweep_vars is None or len(sweep_vars) in [1, 2]
5051

5152
# set values

pdr_backend/aimodel/aimodel_plotter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def _plot_lineplot_1var(aimodel_plotdata: AimodelPlotdata):
138138
# line plot: regressor response, training data
139139
if d.model.do_regr:
140140
assert mesh_ycont_hat is not None
141+
assert y_thr is not None
142+
assert ycont is not None
141143
fig.add_trace(
142144
go.Scatter(
143145
x=mesh_chosen_x,

0 commit comments

Comments
 (0)