Skip to content

Commit 3d9f518

Browse files
authored
refactor: remove task metadata property (#216)
* remove metadata * docs and tests * don't need that test * use static method for generation strategy * update docs * add docstring
1 parent 8b751b9 commit 3d9f518

29 files changed

Lines changed: 129 additions & 432 deletions

File tree

docs/plugins/example.md

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,12 @@ The implementation class defines the actual business logic of the plugin. For co
6464
import logging
6565
import pandas as pd
6666

67-
from data_designer.engine.column_generators.generators.base import (
68-
ColumnGenerator,
69-
GenerationStrategy,
70-
GeneratorMetadata,
71-
)
67+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn, GenerationStrategy
7268

7369
# Data Designer uses the standard Python logging module for logging
7470
logger = logging.getLogger(__name__)
7571

76-
class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig]):
77-
@staticmethod
78-
def metadata() -> GeneratorMetadata:
79-
"""Define metadata about this generator."""
80-
return GeneratorMetadata(
81-
name="index-multiplier",
82-
description="Generates values by multiplying the row index by a user-specified multiplier",
83-
generation_strategy=GenerationStrategy.FULL_COLUMN,
84-
)
72+
class IndexMultiplierColumnGenerator(ColumnGeneratorFullColumn[IndexMultiplierColumnConfig]):
8573

8674
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
8775
"""Generate the column data.
@@ -105,20 +93,20 @@ class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig
10593

10694
**Key points:**
10795

108-
- Generic type `ColumnGenerator[IndexMultiplierColumnConfig]` connects the task to its config
109-
- `metadata()` describes your generator and its requirements
110-
- `generation_strategy` can be `FULL_COLUMN`, `CELL_BY_CELL`
96+
- Generic type `ColumnGeneratorFullColumn[IndexMultiplierColumnConfig]` connects the task to its config
11197
- You have access to the configuration parameters via `self.config`
11298

11399
!!! info "Understanding generation_strategy"
114100
The `generation_strategy` specifies how the column generator will generate data.
115101

116102
- **`FULL_COLUMN`**: Generates the full column (at the batch level) in a single call to `generate`
117-
- `generate` must take as input a `pd.DataFrame` with all previous columns and return a `pd.DataFrame` with the generated column appended
103+
- `generate` must take as input a `pd.DataFrame` with all previous columns and return a `pd.DataFrame` with the generated column appended.
104+
- Inherit from `ColumnGeneratorFullColumn` for this strategy, as we do in the example above.
118105

119106
- **`CELL_BY_CELL`**: Generates one cell at a time
120107
- `generate` must take as input a `dict` with key/value pairs for all previous columns and return a `dict` with an additional key/value for the generated cell
121108
- Supports concurrent workers via a `max_parallel_requests` parameter on the configuration
109+
- Inherit from `ColumnGeneratorCellByCell` for this strategy.
122110

123111
## Step 4: Create the plugin object
124112

@@ -147,11 +135,8 @@ from typing import Literal
147135
import pandas as pd
148136

149137
from data_designer.config.column_configs import SingleColumnConfig
150-
from data_designer.engine.column_generators.generators.base import (
151-
ColumnGenerator,
152-
GenerationStrategy,
153-
GeneratorMetadata,
154-
)
138+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn
139+
155140
from data_designer.plugins import Plugin, PluginType
156141

157142
# Data Designer uses the standard Python logging module for logging
@@ -169,15 +154,7 @@ class IndexMultiplierColumnConfig(SingleColumnConfig):
169154
column_type: Literal["index-multiplier"] = "index-multiplier"
170155

171156

172-
class IndexMultiplierColumnGenerator(ColumnGenerator[IndexMultiplierColumnConfig]):
173-
@staticmethod
174-
def metadata() -> GeneratorMetadata:
175-
"""Define metadata about this generator."""
176-
return GeneratorMetadata(
177-
name="index-multiplier",
178-
description="Generates values by multiplying the row index by a user-specified multiplier",
179-
generation_strategy=GenerationStrategy.FULL_COLUMN,
180-
)
157+
class IndexMultiplierColumnGenerator(ColumnGeneratorFullColumn[IndexMultiplierColumnConfig]):
181158

182159
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
183160
"""Generate the column data.

src/data_designer/engine/analysis/column_profilers/base.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from data_designer.config.base import ConfigBase
1414
from data_designer.config.column_configs import SingleColumnConfig
1515
from data_designer.config.column_types import DataDesignerColumnType
16-
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, TaskConfigT
16+
from data_designer.engine.configurable_task import ConfigurableTask, TaskConfigT
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -32,17 +32,14 @@ def as_tuple(self) -> tuple[SingleColumnConfig, pd.DataFrame]:
3232
return (self.column_config, self.df)
3333

3434

35-
class ColumnProfilerMetadata(ConfigurableTaskMetadata):
36-
applicable_column_types: list[DataDesignerColumnType]
37-
38-
3935
class ColumnProfiler(ConfigurableTask[TaskConfigT], ABC):
4036
@staticmethod
4137
@abstractmethod
42-
def metadata() -> ColumnProfilerMetadata: ...
38+
def get_applicable_column_types() -> list[DataDesignerColumnType]:
39+
"""Returns a list of column types that this profiler can be applied to during dataset profiling."""
4340

4441
@abstractmethod
4542
def profile(self, column_config_with_df: ColumnConfigWithDataFrame) -> BaseModel: ...
4643

4744
def _initialize(self) -> None:
48-
logger.info(f"💫 Initializing column profiler: '{self.metadata().name}'")
45+
logger.info(f"💫 Initializing column profiler: '{self.name}'")

src/data_designer/engine/analysis/column_profilers/judge_score_profiler.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@
2020
NumericalDistribution,
2121
)
2222
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType
23-
from data_designer.engine.analysis.column_profilers.base import (
24-
ColumnConfigWithDataFrame,
25-
ColumnProfiler,
26-
ColumnProfilerMetadata,
27-
)
23+
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
2824
from data_designer.engine.analysis.utils.judge_score_processing import (
2925
extract_judge_score_distributions,
3026
sample_scores_and_reasoning,
@@ -37,12 +33,8 @@
3733

3834
class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
3935
@staticmethod
40-
def metadata() -> ColumnProfilerMetadata:
41-
return ColumnProfilerMetadata(
42-
name="judge_score_profiler",
43-
description="Analyzes LLM-as-judge score distributions in a Data Designer dataset.",
44-
applicable_column_types=[DataDesignerColumnType.LLM_JUDGE],
45-
)
36+
def get_applicable_column_types() -> list[DataDesignerColumnType]:
37+
return [DataDesignerColumnType.LLM_JUDGE]
4638

4739
def get_model(self, model_alias: str) -> ModelFacade:
4840
return self.resource_provider.model_registry.get_model(model_alias=model_alias)

src/data_designer/engine/analysis/dataset_profiler.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
1414
from data_designer.config.base import ConfigBase
1515
from data_designer.config.column_configs import SingleColumnConfig
16-
from data_designer.config.column_types import (
17-
COLUMN_TYPE_EMOJI_MAP,
18-
ColumnConfigT,
19-
)
16+
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, ColumnConfigT
2017
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
2118
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
2219
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
@@ -81,14 +78,14 @@ def profile_dataset(
8178
column_profiles = []
8279
for profiler_config in self.config.column_profiler_configs or []:
8380
profiler = self._create_column_profiler(profiler_config)
84-
applicable_column_types = profiler.metadata().applicable_column_types
81+
applicable_column_types = profiler.get_applicable_column_types()
8582
for c in self.config.column_configs:
8683
if c.column_type in applicable_column_types:
8784
params = ColumnConfigWithDataFrame(column_config=c, df=dataset)
8885
column_profiles.append(profiler.profile(params))
8986
if len(column_profiles) == 0:
9087
logger.warning(
91-
f"⚠️ No applicable column types found for the '{profiler.metadata().name}' profiler. "
88+
f"⚠️ No applicable column types found for the '{profiler.name}' profiler. "
9289
f"This profiler is applicable to the following column types: {applicable_column_types}"
9390
)
9491

src/data_designer/engine/column_generators/generators/base.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pandas as pd
1313

14-
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
14+
from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT
1515

1616
if TYPE_CHECKING:
1717
from data_designer.config.models import BaseInferenceParams, ModelConfig
@@ -27,22 +27,14 @@ class GenerationStrategy(str, Enum):
2727
FULL_COLUMN = "full_column"
2828

2929

30-
class GeneratorMetadata(ConfigurableTaskMetadata):
31-
generation_strategy: GenerationStrategy
32-
33-
3430
class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
3531
@property
3632
def can_generate_from_scratch(self) -> bool:
3733
return False
3834

39-
@property
40-
def generation_strategy(self) -> GenerationStrategy:
41-
return self.metadata().generation_strategy
42-
4335
@staticmethod
4436
@abstractmethod
45-
def metadata() -> GeneratorMetadata: ...
37+
def get_generation_strategy() -> GenerationStrategy: ...
4638

4739
@overload
4840
@abstractmethod
@@ -108,3 +100,21 @@ def log_pre_generation(self) -> None:
108100
logger.info(f" |-- model alias: {self.config.model_alias!r}")
109101
logger.info(f" |-- model provider: {self.get_model_provider_name(model_alias=self.config.model_alias)!r}")
110102
logger.info(f" |-- inference parameters: {self.inference_parameters.format_for_display()}")
103+
104+
105+
class ColumnGeneratorCellByCell(ColumnGenerator[TaskConfigT], ABC):
106+
@staticmethod
107+
def get_generation_strategy() -> GenerationStrategy:
108+
return GenerationStrategy.CELL_BY_CELL
109+
110+
@abstractmethod
111+
def generate(self, data: dict) -> dict: ...
112+
113+
114+
class ColumnGeneratorFullColumn(ColumnGenerator[TaskConfigT], ABC):
115+
@staticmethod
116+
def get_generation_strategy() -> GenerationStrategy:
117+
return GenerationStrategy.FULL_COLUMN
118+
119+
@abstractmethod
120+
def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...

src/data_designer/engine/column_generators/generators/embedding.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
from pydantic import BaseModel, computed_field
66

77
from data_designer.config.column_configs import EmbeddingColumnConfig
8-
from data_designer.engine.column_generators.generators.base import (
9-
ColumnGeneratorWithModel,
10-
GenerationStrategy,
11-
GeneratorMetadata,
12-
)
8+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
139
from data_designer.engine.processing.utils import deserialize_json_values, parse_list_string
1410

1511

@@ -27,12 +23,8 @@ def dimension(self) -> int:
2723

2824
class EmbeddingCellGenerator(ColumnGeneratorWithModel[EmbeddingColumnConfig]):
2925
@staticmethod
30-
def metadata() -> GeneratorMetadata:
31-
return GeneratorMetadata(
32-
name="embedding_cell_generator",
33-
description="Generate embeddings for a text column.",
34-
generation_strategy=GenerationStrategy.CELL_BY_CELL,
35-
)
26+
def get_generation_strategy() -> GenerationStrategy:
27+
return GenerationStrategy.CELL_BY_CELL
3628

3729
def generate(self, data: dict) -> dict:
3830
deserialized_record = deserialize_json_values(data)

src/data_designer/engine/column_generators/generators/expression.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,15 @@
88
import pandas as pd
99

1010
from data_designer.config.column_configs import ExpressionColumnConfig
11-
from data_designer.engine.column_generators.generators.base import (
12-
ColumnGenerator,
13-
GenerationStrategy,
14-
GeneratorMetadata,
15-
)
11+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorFullColumn
1612
from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError
1713
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
1814
from data_designer.engine.processing.utils import deserialize_json_values
1915

2016
logger = logging.getLogger(__name__)
2117

2218

23-
class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGenerator[ExpressionColumnConfig]):
24-
@staticmethod
25-
def metadata() -> GeneratorMetadata:
26-
return GeneratorMetadata(
27-
name="expression_generator",
28-
description="Generate a column from a jinja2 expression.",
29-
generation_strategy=GenerationStrategy.FULL_COLUMN,
30-
)
31-
19+
class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGeneratorFullColumn[ExpressionColumnConfig]):
3220
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
3321
logger.info(f"🧩 Generating column `{self.config.name}` from expression")
3422

src/data_designer/engine/column_generators/generators/llm_completion.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
LLMTextColumnConfig,
1212
)
1313
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
14-
from data_designer.engine.column_generators.generators.base import (
15-
ColumnGeneratorWithModel,
16-
GenerationStrategy,
17-
GeneratorMetadata,
18-
)
14+
from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModel, GenerationStrategy
1915
from data_designer.engine.column_generators.utils.prompt_renderer import (
2016
PromptType,
2117
RecordBasedPromptRenderer,
@@ -29,6 +25,10 @@
2925

3026

3127
class ColumnGeneratorWithModelChatCompletion(ColumnGeneratorWithModel[TaskConfigT]):
28+
@staticmethod
29+
def get_generation_strategy() -> GenerationStrategy:
30+
return GenerationStrategy.CELL_BY_CELL
31+
3232
@functools.cached_property
3333
def response_recipe(self) -> ResponseRecipe:
3434
return create_response_recipe(self.config, self.model_config)
@@ -87,41 +87,13 @@ def generate(self, data: dict) -> dict:
8787
return data
8888

8989

90-
class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]):
91-
@staticmethod
92-
def metadata() -> GeneratorMetadata:
93-
return GeneratorMetadata(
94-
name="llm_text_generator",
95-
description="Generate a new dataset cell from a prompt template",
96-
generation_strategy=GenerationStrategy.CELL_BY_CELL,
97-
)
90+
class LLMTextCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMTextColumnConfig]): ...
9891

9992

100-
class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]):
101-
@staticmethod
102-
def metadata() -> GeneratorMetadata:
103-
return GeneratorMetadata(
104-
name="llm_code_generator",
105-
description="Generate a new dataset cell from a prompt template",
106-
generation_strategy=GenerationStrategy.CELL_BY_CELL,
107-
)
93+
class LLMCodeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMCodeColumnConfig]): ...
10894

10995

110-
class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]):
111-
@staticmethod
112-
def metadata() -> GeneratorMetadata:
113-
return GeneratorMetadata(
114-
name="llm_structured_generator",
115-
description="Generate a new dataset cell from a prompt template",
116-
generation_strategy=GenerationStrategy.CELL_BY_CELL,
117-
)
96+
class LLMStructuredCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMStructuredColumnConfig]): ...
11897

11998

120-
class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]):
121-
@staticmethod
122-
def metadata() -> GeneratorMetadata:
123-
return GeneratorMetadata(
124-
name="llm_judge_generator",
125-
description="Judge a new dataset cell based on a set of rubrics",
126-
generation_strategy=GenerationStrategy.CELL_BY_CELL,
127-
)
99+
class LLMJudgeCellGenerator(ColumnGeneratorWithModelChatCompletion[LLMJudgeColumnConfig]): ...

src/data_designer/engine/column_generators/generators/samplers.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
import pandas as pd
1212

1313
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
14-
from data_designer.engine.column_generators.generators.base import (
15-
FromScratchColumnGenerator,
16-
GenerationStrategy,
17-
GeneratorMetadata,
18-
)
14+
from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator, GenerationStrategy
1915
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
2016
from data_designer.engine.processing.utils import concat_datasets
2117
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
@@ -28,12 +24,8 @@
2824

2925
class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
3026
@staticmethod
31-
def metadata() -> GeneratorMetadata:
32-
return GeneratorMetadata(
33-
name="sampler_column_generator",
34-
description="Generate columns using sampling-based method.",
35-
generation_strategy=GenerationStrategy.FULL_COLUMN,
36-
)
27+
def get_generation_strategy() -> GenerationStrategy:
28+
return GenerationStrategy.FULL_COLUMN
3729

3830
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
3931
df_samplers = self.generate_from_scratch(len(data))

0 commit comments

Comments
 (0)