Skip to content

Commit abbe7a3

Browse files
authored
test(workflow-operator): add unit test coverage for Sklearn tree-based classifier descriptors (#5939)
### What changes were proposed in this PR? Pin behavior of four previously-untested Sklearn tree-based classifier descriptors in `common/workflow-operator`. No production-code changes. | Spec | Source class | Tests | | --- | --- | --- | | `SklearnDecisionTreeOpDescSpec` | `SklearnDecisionTreeOpDesc` | 5 | | `SklearnExtraTreeOpDescSpec` | `SklearnExtraTreeOpDesc` | 5 | | `SklearnExtraTreesOpDescSpec` | `SklearnExtraTreesOpDesc` | 5 | | `SklearnRandomForestOpDescSpec` | `SklearnRandomForestOpDesc` | 5 | **Behavior pinned** | Surface | Contract | | --- | --- | | `operatorInfo` | exact model name + `Sklearn <name> Operator` description; Sklearn group; training/testing input ports + one blocking output | | field defaults | `countVectorizer`/`tfidfTransformer` `false`; `target`/`text` `null` | | `getOutputSchemas` | `model_name` (STRING) + `model` (BINARY) keyed by the declared output port | | `generatePythonCode` | imports the matching sklearn estimator and builds the `make_pipeline` model | | Round-trip | config fields preserved through the polymorphic `LogicalOp` base, with the correct `operatorType` discriminator | ### Any related issues, documentation, discussions? Part of the ongoing `workflow-operator` unit-test coverage effort (follow-up to the Sklearn Naive Bayes coverage in #5925). ### How was this PR tested? - `sbt "WorkflowOperator/testOnly *SklearnDecisionTreeOpDescSpec *SklearnExtraTreeOpDescSpec *SklearnExtraTreesOpDescSpec *SklearnRandomForestOpDescSpec"` — 20 tests, all green - `sbt "WorkflowOperator/Test/scalafmtCheck"` and `sbt "WorkflowOperator/scalafixAll --check"` — clean - CI to confirm ### Was this PR authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.8 [1M context])
1 parent 2779df9 commit abbe7a3

4 files changed

Lines changed: 316 additions & 0 deletions

File tree

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnDecisionTreeOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnDecisionTreeOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnDecisionTreeOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Decision Tree"
35+
info.operatorDescription shouldBe "Sklearn Decision Tree Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnDecisionTreeOpDesc" should "default its config fields" in {
43+
val d = new SklearnDecisionTreeOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnDecisionTreeOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnDecisionTreeOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnDecisionTreeOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnDecisionTreeOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.tree import DecisionTreeClassifier")
63+
code should include("make_pipeline")
64+
code should include("Decision Tree")
65+
}
66+
67+
"SklearnDecisionTreeOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnDecisionTreeOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnDecisionTree\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnDecisionTreeOpDesc]
75+
val r = restored.asInstanceOf[SklearnDecisionTreeOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnExtraTreeOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnExtraTreeOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnExtraTreeOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Extra Tree"
35+
info.operatorDescription shouldBe "Sklearn Extra Tree Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnExtraTreeOpDesc" should "default its config fields" in {
43+
val d = new SklearnExtraTreeOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnExtraTreeOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnExtraTreeOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnExtraTreeOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnExtraTreeOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.tree import ExtraTreeClassifier")
63+
code should include("make_pipeline")
64+
code should include("Extra Tree")
65+
}
66+
67+
"SklearnExtraTreeOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnExtraTreeOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnExtraTree\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnExtraTreeOpDesc]
75+
val r = restored.asInstanceOf[SklearnExtraTreeOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnExtraTreesOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnExtraTreesOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnExtraTreesOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Extra Trees"
35+
info.operatorDescription shouldBe "Sklearn Extra Trees Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnExtraTreesOpDesc" should "default its config fields" in {
43+
val d = new SklearnExtraTreesOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnExtraTreesOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnExtraTreesOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnExtraTreesOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnExtraTreesOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.ensemble import ExtraTreesClassifier")
63+
code should include("make_pipeline")
64+
code should include("Extra Trees")
65+
}
66+
67+
"SklearnExtraTreesOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnExtraTreesOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnExtraTrees\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnExtraTreesOpDesc]
75+
val r = restored.asInstanceOf[SklearnExtraTreesOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.operator.sklearn
21+
22+
import org.apache.texera.amber.core.tuple.AttributeType
23+
import org.apache.texera.amber.operator.LogicalOp
24+
import org.apache.texera.amber.operator.metadata.OperatorGroupConstants
25+
import org.apache.texera.amber.util.JSONUtils.objectMapper
26+
import org.scalatest.flatspec.AnyFlatSpec
27+
import org.scalatest.matchers.should.Matchers
28+
29+
class SklearnRandomForestOpDescSpec extends AnyFlatSpec with Matchers {
30+
31+
"SklearnRandomForestOpDesc.operatorInfo" should
32+
"advertise the model name, Sklearn group, and the training/testing port shape" in {
33+
val info = (new SklearnRandomForestOpDesc).operatorInfo
34+
info.userFriendlyName shouldBe "Random Forest"
35+
info.operatorDescription shouldBe "Sklearn Random Forest Operator"
36+
info.operatorGroupName shouldBe OperatorGroupConstants.SKLEARN_GROUP
37+
info.inputPorts.map(_.displayName) shouldBe List("training", "testing")
38+
info.outputPorts should have length 1
39+
info.outputPorts.head.blocking shouldBe true
40+
}
41+
42+
"SklearnRandomForestOpDesc" should "default its config fields" in {
43+
val d = new SklearnRandomForestOpDesc
44+
d.countVectorizer shouldBe false
45+
d.tfidfTransformer shouldBe false
46+
d.target shouldBe null
47+
d.text shouldBe null
48+
}
49+
50+
"SklearnRandomForestOpDesc.getOutputSchemas" should
51+
"emit the model_name/model schema keyed by the declared output port" in {
52+
val d = new SklearnRandomForestOpDesc
53+
val schema = d.getOutputSchemas(Map.empty)(d.operatorInfo.outputPorts.head.id)
54+
schema.getAttribute("model_name").getType shouldBe AttributeType.STRING
55+
schema.getAttribute("model").getType shouldBe AttributeType.BINARY
56+
}
57+
58+
"SklearnRandomForestOpDesc.generatePythonCode" should "import the configured sklearn estimator" in {
59+
val d = new SklearnRandomForestOpDesc
60+
d.target = "y"
61+
val code = d.generatePythonCode()
62+
code should include("from sklearn.ensemble import RandomForestClassifier")
63+
code should include("make_pipeline")
64+
code should include("Random Forest")
65+
}
66+
67+
"SklearnRandomForestOpDesc" should "round-trip its config fields through the polymorphic base" in {
68+
val d = new SklearnRandomForestOpDesc
69+
d.target = "label"
70+
d.countVectorizer = true
71+
val json = objectMapper.writeValueAsString(d)
72+
json should include("\"operatorType\":\"SklearnRandomForest\"")
73+
val restored = objectMapper.readValue(json, classOf[LogicalOp])
74+
restored shouldBe a[SklearnRandomForestOpDesc]
75+
val r = restored.asInstanceOf[SklearnRandomForestOpDesc]
76+
r.target shouldBe "label"
77+
r.countVectorizer shouldBe true
78+
}
79+
}

0 commit comments

Comments
 (0)