Skip to content

Commit 9214637

Browse files
authored
fix(engine): validate processor plugin impls (#609)
* fix(engine): validate processor plugin impls Add the processor implementation base to assert_valid_plugin so processor plugins are checked against Processor instead of only the generic config contract. Keep plugin type validation table-driven and raise explicit AssertionError messages so checks are not skipped under optimized Python. Signed-off-by: Johnny Greco <jogreco@nvidia.com> * test(engine): require plugin base map coverage --------- Signed-off-by: Johnny Greco <jogreco@nvidia.com>
1 parent f73da19 commit 9214637

2 files changed

Lines changed: 155 additions & 8 deletions

File tree

packages/data-designer-engine/src/data_designer/engine/testing/utils.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,31 @@
55

66
from data_designer.config.base import ConfigBase
77
from data_designer.engine.configurable_task import ConfigurableTask
8+
from data_designer.engine.processing.processors.base import Processor
89
from data_designer.engine.resources.seed_reader import SeedReader
910
from data_designer.plugins.plugin import Plugin, PluginType
1011

12+
_PLUGIN_IMPLEMENTATION_BASES: dict[PluginType, type[object]] = {
13+
PluginType.COLUMN_GENERATOR: ConfigurableTask,
14+
PluginType.SEED_READER: SeedReader,
15+
PluginType.PROCESSOR: Processor,
16+
}
17+
if set(_PLUGIN_IMPLEMENTATION_BASES) != set(PluginType):
18+
raise AssertionError("Plugin implementation base map must cover all plugin types")
19+
20+
21+
def _assert_subclass(cls: type[object], base_cls: type[object], message: str) -> None:
22+
if not issubclass(cls, base_cls):
23+
raise AssertionError(message)
24+
1125

1226
def assert_valid_plugin(plugin: Plugin) -> None:
13-
assert issubclass(plugin.config_cls, ConfigBase), "Plugin config class is not a subclass of ConfigBase"
14-
15-
if plugin.plugin_type == PluginType.COLUMN_GENERATOR:
16-
assert issubclass(plugin.impl_cls, ConfigurableTask), (
17-
"Column generator plugin impl class must be a subclass of ConfigurableTask"
18-
)
19-
elif plugin.plugin_type == PluginType.SEED_READER:
20-
assert issubclass(plugin.impl_cls, SeedReader), "Seed reader plugin impl class must be a subclass of SeedReader"
27+
_assert_subclass(plugin.config_cls, ConfigBase, "Plugin config class is not a subclass of ConfigBase")
28+
29+
implementation_base = _PLUGIN_IMPLEMENTATION_BASES[plugin.plugin_type]
30+
_assert_subclass(
31+
plugin.impl_cls,
32+
implementation_base,
33+
f"{plugin.plugin_type.display_name.capitalize()} plugin impl class must be a subclass of "
34+
f"{implementation_base.__name__}",
35+
)
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from typing import Any, Literal
7+
8+
import pytest
9+
10+
from data_designer.config.base import ConfigBase, ProcessorConfig, SingleColumnConfig
11+
from data_designer.engine.configurable_task import ConfigurableTask
12+
from data_designer.engine.processing.processors.base import Processor
13+
from data_designer.engine.resources.seed_reader import SeedReader
14+
from data_designer.engine.testing.utils import assert_valid_plugin
15+
from data_designer.plugins.plugin import Plugin, PluginType
16+
17+
MODULE_NAME = __name__
18+
19+
20+
class ValidColumnGeneratorConfig(SingleColumnConfig):
21+
name: str = "valid-column-generator"
22+
column_type: Literal["valid-column-generator"] = "valid-column-generator"
23+
24+
@property
25+
def required_columns(self) -> list[str]:
26+
return []
27+
28+
@property
29+
def side_effect_columns(self) -> list[str]:
30+
return []
31+
32+
33+
class ValidColumnGenerator(ConfigurableTask[ValidColumnGeneratorConfig]):
34+
pass
35+
36+
37+
class ValidSeedReaderConfig(ConfigBase):
38+
seed_type: Literal["valid-seed-reader"] = "valid-seed-reader"
39+
40+
41+
class ValidSeedReader(SeedReader):
42+
def get_dataset_uri(self) -> str:
43+
return "unused"
44+
45+
def create_duckdb_connection(self) -> Any:
46+
raise NotImplementedError
47+
48+
49+
class ValidProcessorConfig(ProcessorConfig):
50+
name: str = "valid-processor"
51+
processor_type: Literal["valid-processor"] = "valid-processor"
52+
53+
54+
class ValidProcessor(Processor[ValidProcessorConfig]):
55+
def process_before_batch(self, data: dict[str, Any]) -> dict[str, Any]:
56+
return data
57+
58+
59+
class NonProcessor:
60+
pass
61+
62+
63+
class TaskButNotProcessor(ConfigurableTask[ValidProcessorConfig]):
64+
pass
65+
66+
67+
@pytest.mark.parametrize(
68+
("plugin_type", "config_class_name", "implementation_class_name"),
69+
[
70+
(PluginType.COLUMN_GENERATOR, "ValidColumnGeneratorConfig", "ValidColumnGenerator"),
71+
(PluginType.SEED_READER, "ValidSeedReaderConfig", "ValidSeedReader"),
72+
(PluginType.PROCESSOR, "ValidProcessorConfig", "ValidProcessor"),
73+
],
74+
)
75+
def test_assert_valid_plugin_accepts_supported_plugin_types(
76+
plugin_type: PluginType,
77+
config_class_name: str,
78+
implementation_class_name: str,
79+
) -> None:
80+
plugin = Plugin(
81+
config_qualified_name=f"{MODULE_NAME}.{config_class_name}",
82+
impl_qualified_name=f"{MODULE_NAME}.{implementation_class_name}",
83+
plugin_type=plugin_type,
84+
)
85+
86+
assert_valid_plugin(plugin)
87+
88+
89+
@pytest.mark.parametrize(
90+
("plugin_type", "config_class_name", "expected_message"),
91+
[
92+
(
93+
PluginType.COLUMN_GENERATOR,
94+
"ValidColumnGeneratorConfig",
95+
"Column generator plugin impl class must be a subclass of ConfigurableTask",
96+
),
97+
(
98+
PluginType.SEED_READER,
99+
"ValidSeedReaderConfig",
100+
"Seed reader plugin impl class must be a subclass of SeedReader",
101+
),
102+
(
103+
PluginType.PROCESSOR,
104+
"ValidProcessorConfig",
105+
"Processor plugin impl class must be a subclass of Processor",
106+
),
107+
],
108+
)
109+
def test_assert_valid_plugin_rejects_invalid_impl_for_supported_plugin_types(
110+
plugin_type: PluginType,
111+
config_class_name: str,
112+
expected_message: str,
113+
) -> None:
114+
plugin = Plugin(
115+
config_qualified_name=f"{MODULE_NAME}.{config_class_name}",
116+
impl_qualified_name=f"{MODULE_NAME}.NonProcessor",
117+
plugin_type=plugin_type,
118+
)
119+
120+
with pytest.raises(AssertionError, match=expected_message):
121+
assert_valid_plugin(plugin)
122+
123+
124+
def test_assert_valid_plugin_rejects_processor_plugin_with_configurable_task_impl() -> None:
125+
plugin = Plugin(
126+
config_qualified_name=f"{MODULE_NAME}.ValidProcessorConfig",
127+
impl_qualified_name=f"{MODULE_NAME}.TaskButNotProcessor",
128+
plugin_type=PluginType.PROCESSOR,
129+
)
130+
131+
with pytest.raises(AssertionError, match="Processor plugin impl class must be a subclass of Processor"):
132+
assert_valid_plugin(plugin)

0 commit comments

Comments
 (0)