Skip to content

Commit 0c179ce

Browse files
authored
test(workflow-operator): add unit test coverage for Sklearn SVM and neighbor classifier descriptors (#5945)
### What changes were proposed in this PR? Pin behavior of four previously-untested Sklearn support-vector and neighbor classifier descriptors in `common/workflow-operator`. No production-code changes. | Spec | Source class | Tests | | --- | --- | --- | | `SklearnSVMOpDescSpec` | `SklearnSVMOpDesc` | 5 | | `SklearnLinearSVMOpDescSpec` | `SklearnLinearSVMOpDesc` | 5 | | `SklearnKNNOpDescSpec` | `SklearnKNNOpDesc` | 5 | | `SklearnNearestCentroidOpDescSpec` | `SklearnNearestCentroidOpDesc` | 5 | **Behavior pinned** | Surface | Contract | | --- | --- | | `operatorInfo` | exact model name + `Sklearn <name> Operator` description; Sklearn group; training/testing input ports + one blocking output | | field defaults | `countVectorizer`/`tfidfTransformer` `false`; `target`/`text` `null` | | `getOutputSchemas` | `model_name` (STRING) + `model` (BINARY) keyed by the declared output port | | `generatePythonCode` | imports the matching sklearn estimator and builds the `make_pipeline` model | | Round-trip | config fields preserved through the polymorphic `LogicalOp` base, with the correct `operatorType` discriminator | ### Any related issues, documentation, discussions? Part of the ongoing `workflow-operator` unit-test coverage effort (follow-up to the Sklearn classifier coverage in #5925, #5939, #5940, #5941). ### How was this PR tested? - `sbt "WorkflowOperator/testOnly *SklearnSVMOpDescSpec *SklearnLinearSVMOpDescSpec *SklearnKNNOpDescSpec *SklearnNearestCentroidOpDescSpec"` — 20 tests, all green - `sbt "WorkflowOperator/Test/scalafmtCheck"` and `sbt "WorkflowOperator/scalafixAll --check"` — clean - CI to confirm ### Was this PR authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.8 [1M context])
1 parent abbe7a3 commit 0c179ce

4 files changed

Lines changed: 316 additions & 0 deletions

File tree

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnKNNOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnKNNOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnKNNOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "K-nearest Neighbors"
35+
info.operatorDescription shouldBe "Sklearn K-nearest Neighbors Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnKNNOpDesc" should "default its config fields" in {
43+
val d = new SklearnKNNOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnKNNOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnKNNOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnKNNOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnKNNOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.neighbors import KNeighborsClassifier")
63+
code should include("make_pipeline")
64+
code should include("K-nearest Neighbors")
65+
}
66+
67+
"SklearnKNNOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnKNNOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnKNN\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnKNNOpDesc]
75+
val r = restored.asInstanceOf[SklearnKNNOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnLinearSVMOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnLinearSVMOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnLinearSVMOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Linear Support Vector Machine"
35+
info.operatorDescription shouldBe "Sklearn Linear Support Vector Machine Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnLinearSVMOpDesc" should "default its config fields" in {
43+
val d = new SklearnLinearSVMOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnLinearSVMOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnLinearSVMOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnLinearSVMOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnLinearSVMOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.svm import LinearSVC")
63+
code should include("make_pipeline")
64+
code should include("Linear Support Vector Machine")
65+
}
66+
67+
"SklearnLinearSVMOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnLinearSVMOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnLinearSVM\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnLinearSVMOpDesc]
75+
val r = restored.asInstanceOf[SklearnLinearSVMOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnNearestCentroidOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnNearestCentroidOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnNearestCentroidOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Nearest Centroid"
35+
info.operatorDescription shouldBe "Sklearn Nearest Centroid Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnNearestCentroidOpDesc" should "default its config fields" in {
43+
val d = new SklearnNearestCentroidOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnNearestCentroidOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnNearestCentroidOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnNearestCentroidOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnNearestCentroidOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.neighbors import NearestCentroid")
63+
code should include("make_pipeline")
64+
code should include("Nearest Centroid")
65+
}
66+
67+
"SklearnNearestCentroidOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnNearestCentroidOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnNearestCentroid\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnNearestCentroidOpDesc]
75+
val r = restored.asInstanceOf[SklearnNearestCentroidOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnSVMOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnSVMOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnSVMOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Support Vector Machine"
35+
info.operatorDescription shouldBe "Sklearn Support Vector Machine Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnSVMOpDesc" should "default its config fields" in {
43+
val d = new SklearnSVMOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnSVMOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnSVMOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnSVMOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnSVMOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.svm import SVC")
63+
code should include("make_pipeline")
64+
code should include("Support Vector Machine")
65+
}
66+
67+
"SklearnSVMOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnSVMOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnSVM\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnSVMOpDesc]
75+
val r = restored.asInstanceOf[SklearnSVMOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}

0 commit comments

Comments
 (0)