Skip to content

Commit eba14a6

Browse files
authored
test(workflow-operator): add unit test coverage for SklearnAdvanced trainer descriptors (#5768)
### What changes were proposed in this PR? Pin behavior of four previously-uncovered sklearn-trainer descriptors in `common/workflow-operator/operator/machineLearning/sklearnAdvanced/`. Each is a 30-line override of `SklearnMLOperatorDescriptor` that contributes just two values: the Python `import` statement and the operator-info label. Drift in either silently breaks generated Python code or the UI label. No production-code changes. | Spec | Source class | Tests | | --- | --- | --- | | `SklearnAdvancedKNNClassifierTrainerOpDescSpec` | `SklearnAdvancedKNNClassifierTrainerOpDesc` | 5 | | `SklearnAdvancedKNNRegressorTrainerOpDescSpec` | `SklearnAdvancedKNNRegressorTrainerOpDesc` | 6 | | `SklearnAdvancedSVCTrainerOpDescSpec` | `SklearnAdvancedSVCTrainerOpDesc` | 5 | | `SklearnAdvancedSVRTrainerOpDescSpec` | `SklearnAdvancedSVRTrainerOpDesc` | 6 | All four spec files follow the `<srcClassName>Spec.scala` one-to-one convention. **Behavior pinned (per descriptor)** | Surface | Contract | | --- | --- | | `getImportStatements` | exact canonical Python import (`KNeighborsClassifier` / `KNeighborsRegressor` / `SVC` / `SVR` from the appropriate sklearn module) | | `getOperatorInfo` | exact canonical label (`"KNN Classifier"` / `"KNN Regressor"` / `"SVM Classifier"` / `"SVM Regressor"`) | | Stability across two instances | both methods return the same string regardless of which instance is queried | | Type assignability | extends `SklearnMLOperatorDescriptor[ParamsT]` (compile-time enforced through a typed `val` binding) | | Type-pattern matching | `case _: SklearnMLOperatorDescriptor[_]` matches a concrete instance | The Regressor spec additionally cross-checks against the Classifier sibling (and SVR vs SVC) — catches copy-paste regressions where one subclass accidentally returned the other's strings. ### Any related issues, documentation, discussions? Closes #5765. ### How was this PR tested? Pure unit-test additions; verified locally with: - `sbt "WorkflowOperator/testOnly org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer.SklearnAdvancedKNNClassifierTrainerOpDescSpec org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer.SklearnAdvancedKNNRegressorTrainerOpDescSpec org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.SVCTrainer.SklearnAdvancedSVCTrainerOpDescSpec org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.SVRTrainer.SklearnAdvancedSVRTrainerOpDescSpec"` — 22 tests, all green - `sbt scalafmtCheckAll` — clean - CI to confirm ### Was this PR authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.7 [1M context])
1 parent 0eb7baa commit eba14a6

4 files changed

Lines changed: 258 additions & 0 deletions

File tree

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer
21+
22+
import org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.base.SklearnMLOperatorDescriptor
23+
import org.scalatest.flatspec.AnyFlatSpec
24+
25+
class SklearnAdvancedKNNClassifierTrainerOpDescSpec extends AnyFlatSpec {
26+
27+
"SklearnAdvancedKNNClassifierTrainerOpDesc.getImportStatements" should
28+
"return the canonical KNeighborsClassifier import" in {
29+
val d = new SklearnAdvancedKNNClassifierTrainerOpDesc
30+
assert(d.getImportStatements == "from sklearn.neighbors import KNeighborsClassifier")
31+
}
32+
33+
"SklearnAdvancedKNNClassifierTrainerOpDesc.getOperatorInfo" should
34+
"return 'KNN Classifier'" in {
35+
val d = new SklearnAdvancedKNNClassifierTrainerOpDesc
36+
assert(d.getOperatorInfo == "KNN Classifier")
37+
}
38+
39+
it should "be stable across two instances (no instance-state interaction)" in {
40+
val a = new SklearnAdvancedKNNClassifierTrainerOpDesc
41+
val b = new SklearnAdvancedKNNClassifierTrainerOpDesc
42+
assert(a.getImportStatements == b.getImportStatements)
43+
assert(a.getOperatorInfo == b.getOperatorInfo)
44+
}
45+
46+
"SklearnAdvancedKNNClassifierTrainerOpDesc" should
47+
"extend SklearnMLOperatorDescriptor (compile-time enforced)" in {
48+
val d: SklearnMLOperatorDescriptor[SklearnAdvancedKNNParameters] =
49+
new SklearnAdvancedKNNClassifierTrainerOpDesc
50+
assert(d.getImportStatements.contains("KNeighborsClassifier"))
51+
}
52+
53+
it should "be matchable via the SklearnMLOperatorDescriptor type-pattern" in {
54+
val any: AnyRef = new SklearnAdvancedKNNClassifierTrainerOpDesc
55+
val matched = any match {
56+
case _: SklearnMLOperatorDescriptor[_] => true
57+
case _ => false
58+
}
59+
assert(matched)
60+
}
61+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.KNNTrainer
21+
22+
import org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.base.SklearnMLOperatorDescriptor
23+
import org.scalatest.flatspec.AnyFlatSpec
24+
25+
class SklearnAdvancedKNNRegressorTrainerOpDescSpec extends AnyFlatSpec {
26+
27+
"SklearnAdvancedKNNRegressorTrainerOpDesc.getImportStatements" should
28+
"return the canonical KNeighborsRegressor import" in {
29+
val d = new SklearnAdvancedKNNRegressorTrainerOpDesc
30+
assert(d.getImportStatements == "from sklearn.neighbors import KNeighborsRegressor")
31+
}
32+
33+
"SklearnAdvancedKNNRegressorTrainerOpDesc.getOperatorInfo" should "return 'KNN Regressor'" in {
34+
val d = new SklearnAdvancedKNNRegressorTrainerOpDesc
35+
assert(d.getOperatorInfo == "KNN Regressor")
36+
}
37+
38+
it should "be stable across two instances (no instance-state interaction)" in {
39+
val a = new SklearnAdvancedKNNRegressorTrainerOpDesc
40+
val b = new SklearnAdvancedKNNRegressorTrainerOpDesc
41+
assert(a.getImportStatements == b.getImportStatements)
42+
assert(a.getOperatorInfo == b.getOperatorInfo)
43+
}
44+
45+
"SklearnAdvancedKNNRegressorTrainerOpDesc" should
46+
"extend SklearnMLOperatorDescriptor (compile-time enforced)" in {
47+
val d: SklearnMLOperatorDescriptor[SklearnAdvancedKNNParameters] =
48+
new SklearnAdvancedKNNRegressorTrainerOpDesc
49+
assert(d.getImportStatements.contains("KNeighborsRegressor"))
50+
}
51+
52+
it should "be matchable via the SklearnMLOperatorDescriptor type-pattern" in {
53+
val any: AnyRef = new SklearnAdvancedKNNRegressorTrainerOpDesc
54+
val matched = any match {
55+
case _: SklearnMLOperatorDescriptor[_] => true
56+
case _ => false
57+
}
58+
assert(matched)
59+
}
60+
61+
it should "differ from the Classifier sibling on both methods" in {
62+
// Catches a copy-paste regression where the Regressor accidentally
63+
// returned the Classifier's strings (or vice-versa).
64+
val regressor = new SklearnAdvancedKNNRegressorTrainerOpDesc
65+
val classifier = new SklearnAdvancedKNNClassifierTrainerOpDesc
66+
assert(regressor.getImportStatements != classifier.getImportStatements)
67+
assert(regressor.getOperatorInfo != classifier.getOperatorInfo)
68+
}
69+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.SVCTrainer
21+
22+
import org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.base.SklearnMLOperatorDescriptor
23+
import org.scalatest.flatspec.AnyFlatSpec
24+
25+
class SklearnAdvancedSVCTrainerOpDescSpec extends AnyFlatSpec {
26+
27+
"SklearnAdvancedSVCTrainerOpDesc.getImportStatements" should
28+
"return the canonical SVC import (from sklearn.svm)" in {
29+
val d = new SklearnAdvancedSVCTrainerOpDesc
30+
assert(d.getImportStatements == "from sklearn.svm import SVC")
31+
}
32+
33+
"SklearnAdvancedSVCTrainerOpDesc.getOperatorInfo" should "return 'SVM Classifier'" in {
34+
val d = new SklearnAdvancedSVCTrainerOpDesc
35+
assert(d.getOperatorInfo == "SVM Classifier")
36+
}
37+
38+
it should "be stable across two instances (no instance-state interaction)" in {
39+
val a = new SklearnAdvancedSVCTrainerOpDesc
40+
val b = new SklearnAdvancedSVCTrainerOpDesc
41+
assert(a.getImportStatements == b.getImportStatements)
42+
assert(a.getOperatorInfo == b.getOperatorInfo)
43+
}
44+
45+
"SklearnAdvancedSVCTrainerOpDesc" should
46+
"extend SklearnMLOperatorDescriptor (compile-time enforced)" in {
47+
val d: SklearnMLOperatorDescriptor[SklearnAdvancedSVCParameters] =
48+
new SklearnAdvancedSVCTrainerOpDesc
49+
assert(d.getImportStatements.contains("SVC"))
50+
}
51+
52+
it should "be matchable via the SklearnMLOperatorDescriptor type-pattern" in {
53+
val any: AnyRef = new SklearnAdvancedSVCTrainerOpDesc
54+
val matched = any match {
55+
case _: SklearnMLOperatorDescriptor[_] => true
56+
case _ => false
57+
}
58+
assert(matched)
59+
}
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.SVRTrainer
21+
22+
import org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.SVCTrainer.SklearnAdvancedSVCTrainerOpDesc
23+
import org.apache.texera.amber.operator.machineLearning.sklearnAdvanced.base.SklearnMLOperatorDescriptor
24+
import org.scalatest.flatspec.AnyFlatSpec
25+
26+
class SklearnAdvancedSVRTrainerOpDescSpec extends AnyFlatSpec {
27+
28+
"SklearnAdvancedSVRTrainerOpDesc.getImportStatements" should
29+
"return the canonical SVR import (from sklearn.svm)" in {
30+
val d = new SklearnAdvancedSVRTrainerOpDesc
31+
assert(d.getImportStatements == "from sklearn.svm import SVR")
32+
}
33+
34+
"SklearnAdvancedSVRTrainerOpDesc.getOperatorInfo" should "return 'SVM Regressor'" in {
35+
val d = new SklearnAdvancedSVRTrainerOpDesc
36+
assert(d.getOperatorInfo == "SVM Regressor")
37+
}
38+
39+
it should "be stable across two instances (no instance-state interaction)" in {
40+
val a = new SklearnAdvancedSVRTrainerOpDesc
41+
val b = new SklearnAdvancedSVRTrainerOpDesc
42+
assert(a.getImportStatements == b.getImportStatements)
43+
assert(a.getOperatorInfo == b.getOperatorInfo)
44+
}
45+
46+
"SklearnAdvancedSVRTrainerOpDesc" should
47+
"extend SklearnMLOperatorDescriptor (compile-time enforced)" in {
48+
val d: SklearnMLOperatorDescriptor[SklearnAdvancedSVRParameters] =
49+
new SklearnAdvancedSVRTrainerOpDesc
50+
assert(d.getImportStatements.contains("SVR"))
51+
}
52+
53+
it should "be matchable via the SklearnMLOperatorDescriptor type-pattern" in {
54+
val any: AnyRef = new SklearnAdvancedSVRTrainerOpDesc
55+
val matched = any match {
56+
case _: SklearnMLOperatorDescriptor[_] => true
57+
case _ => false
58+
}
59+
assert(matched)
60+
}
61+
62+
it should "differ from the SVC sibling on both methods" in {
63+
val regressor = new SklearnAdvancedSVRTrainerOpDesc
64+
val classifier = new SklearnAdvancedSVCTrainerOpDesc
65+
assert(regressor.getImportStatements != classifier.getImportStatements)
66+
assert(regressor.getOperatorInfo != classifier.getOperatorInfo)
67+
}
68+
}

0 commit comments

Comments
 (0)