|
| 1 | +<!-- |
| 2 | +SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 3 | +SPDX-License-Identifier: Apache-2.0 |
| 4 | +
|
| 5 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +you may not use this file except in compliance with the License. |
| 7 | +You may obtain a copy of the License at |
| 8 | +
|
| 9 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +
|
| 11 | +Unless required by applicable law or agreed to in writing, software |
| 12 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +See the License for the specific language governing permissions and |
| 15 | +limitations under the License. |
| 16 | +--> |
| 17 | + |
| 18 | +# Adding a Custom Optimizer |
| 19 | + |
| 20 | +:::{note} |
| 21 | +We recommend reading the [Optimizer](../../improve-workflows/optimizer.md) guide before proceeding with this documentation. |
| 22 | +::: |
| 23 | + |
| 24 | +NeMo Agent Toolkit provides a pluggable optimizer system for tuning workflow parameters and prompts. The built-in strategies include Optuna-based numeric optimization and a genetic algorithm (GA) for prompt optimization. You can add custom optimization strategies by implementing one of the optimizer base classes and registering it with the `@register_optimizer` decorator. |
| 25 | + |
| 26 | +## Key Interfaces |
| 27 | + |
| 28 | +* **Configuration Base Classes** |
| 29 | + - {py:class}`~nat.data_models.optimizer.OptimizerStrategyBaseConfig`: Base class that all optimizer strategy configuration models must extend. Provides an `enabled` field and integrates with the NeMo Agent Toolkit type registry. |
| 30 | + - {py:class}`~nat.data_models.optimizer.PromptOptimizationConfig`: Base for prompt optimization strategy configuration models. Adds `prompt_population_init_function` and `prompt_recombination_function` fields. |
| 31 | + - {py:class}`~nat.data_models.optimizer.OptunaParameterOptimizationConfig`: Built-in config for Optuna-based numeric parameter optimization. |
| 32 | + |
| 33 | +* **Optimizer ABCs** |
| 34 | + - {py:class}`~nat.plugins.config_optimizer.prompts.base.BasePromptOptimizer`: Abstract base class for prompt optimization strategies. Requires implementing an async `run()` method that persists optimized prompts to disk; the in-memory config is left unchanged. |
| 35 | + - {py:class}`~nat.plugins.config_optimizer.parameters.base.BaseParameterOptimizer`: Abstract base class for parameter optimization strategies. Requires implementing an async `run()` method that returns an optimized `Config`. |
| 36 | + |
| 37 | +* **Registration** |
| 38 | + - {py:deco}`~nat.cli.register_workflow.register_optimizer`: Decorator that registers an optimizer strategy with the global type registry so the optimizer runtime can resolve the strategy from the type of `cfg.optimizer.numeric` or `cfg.optimizer.prompt`. |
| 39 | + |
| 40 | +## Adding a Custom Prompt Optimizer |
| 41 | + |
| 42 | +### 1. Define a config class |
| 43 | + |
| 44 | +Create a config class extending {py:class}`~nat.data_models.optimizer.PromptOptimizationConfig` with a unique `name`: |
| 45 | + |
| 46 | +```python |
| 47 | +from pydantic import Field |
| 48 | + |
| 49 | +from nat.data_models.optimizer import PromptOptimizationConfig |
| 50 | + |
| 51 | + |
| 52 | +class IterativeRefinementPromptConfig(PromptOptimizationConfig, name="iterative"): |
| 53 | + max_iterations: int = Field(default=20, description="Maximum refinement iterations.") |
| 54 | + candidates_per_iteration: int = Field(default=5, description="Number of candidate prompts to generate per iteration.") |
| 55 | + improvement_threshold: float = Field(default=0.01, description="Minimum score improvement to continue iterating.") |
| 56 | +``` |
| 57 | + |
| 58 | +### 2. Implement the Optimizer |
| 59 | + |
| 60 | +Implement {py:class}`~nat.plugins.config_optimizer.prompts.base.BasePromptOptimizer`: |
| 61 | + |
| 62 | +```python |
| 63 | +from nat.plugins.config_optimizer.prompts.base import BasePromptOptimizer |
| 64 | +from nat.data_models.config import Config |
| 65 | +from nat.data_models.optimizable import SearchSpace |
| 66 | +from nat.data_models.optimizer import OptimizerConfig, OptimizerRunConfig |
| 67 | + |
| 68 | + |
| 69 | +class IterativeRefinementPromptOptimizer(BasePromptOptimizer): |
| 70 | + |
| 71 | + async def run( |
| 72 | + self, |
| 73 | + *, |
| 74 | + base_cfg: Config, |
| 75 | + full_space: dict[str, SearchSpace], |
| 76 | + optimizer_config: OptimizerConfig, |
| 77 | + opt_run_config: OptimizerRunConfig, |
| 78 | + ) -> None: |
| 79 | + ir_config = optimizer_config.prompt # Your IterativeRefinementPromptConfig instance |
| 80 | + |
| 81 | + # Extract prompt parameters from full_space |
| 82 | + prompt_space = {k: v for k, v in full_space.items() if v.is_prompt} |
| 83 | + if not prompt_space: |
| 84 | + return |
| 85 | + |
| 86 | + # Implement your optimization loop here |
| 87 | + # Use ir_config.max_iterations, ir_config.candidates_per_iteration, etc. |
| 88 | + ... |
| 89 | +``` |
| 90 | + |
| 91 | +The `run()` method receives: |
| 92 | +- `base_cfg`: The workflow configuration to optimize. |
| 93 | +- `full_space`: A dictionary of parameter names to {py:class}`~nat.data_models.optimizable.SearchSpace` definitions. Filter for `is_prompt=True` entries to find prompt parameters. |
| 94 | +- `optimizer_config`: The full {py:class}`~nat.data_models.optimizer.OptimizerConfig`. Access your strategy config via `optimizer_config.prompt`. |
| 95 | +- `opt_run_config`: Runtime parameters including dataset path, endpoint, and result JSON path. |
| 96 | + |
| 97 | +### 3. Register the Optimizer |
| 98 | + |
| 99 | +Use the {py:deco}`~nat.cli.register_workflow.register_optimizer` decorator to register your strategy: |
| 100 | + |
| 101 | +```python |
| 102 | +from nat.cli.register_workflow import register_optimizer |
| 103 | + |
| 104 | + |
| 105 | +@register_optimizer(config_type=IterativeRefinementPromptConfig) |
| 106 | +async def register_iterative_prompt_optimizer(config: IterativeRefinementPromptConfig): |
| 107 | + yield IterativeRefinementPromptOptimizer() |
| 108 | +``` |
| 109 | + |
| 110 | +### 4. Import for Discovery |
| 111 | + |
| 112 | +Import the registration function in your project's `register.py` to ensure it runs at startup: |
| 113 | + |
| 114 | +<!-- path-check-skip-next-line --> |
| 115 | +```python |
| 116 | +from . import iterative_prompt_optimizer # noqa: F401 — triggers @register_optimizer |
| 117 | +``` |
| 118 | + |
| 119 | +### 5. Configure Programmatically |
| 120 | + |
| 121 | +Custom strategy selection for `optimizer.prompt` is currently programmatic. After loading your workflow config, set `cfg.optimizer.prompt` to your custom config before calling `optimize_config`: |
| 122 | + |
| 123 | +```python |
| 124 | +from nat.plugins.config_optimizer.optimizer_runtime import optimize_config |
| 125 | +from nat.data_models.optimizer import OptimizerRunConfig |
| 126 | +from nat.runtime.loader import load_config |
| 127 | + |
| 128 | +cfg = load_config("workflow.yml") |
| 129 | +cfg.optimizer.prompt = IterativeRefinementPromptConfig( |
| 130 | + enabled=True, |
| 131 | + max_iterations=200, |
| 132 | + candidates_per_iteration=10, |
| 133 | + improvement_threshold=0.01, |
| 134 | + prompt_population_init_function="my_init_fn", |
| 135 | +) |
| 136 | + |
| 137 | +await optimize_config( |
| 138 | + OptimizerRunConfig( |
| 139 | + config_file=cfg, |
| 140 | + dataset="dataset.json", |
| 141 | + result_json_path="$", |
| 142 | + ) |
| 143 | +) |
| 144 | +``` |
| 145 | + |
| 146 | +## Adding a Custom Parameter Optimizer |
| 147 | + |
| 148 | +The pattern is the same, but parameter optimizers extend {py:class}`~nat.plugins.config_optimizer.parameters.base.BaseParameterOptimizer` and return an optimized {py:class}`~nat.data_models.config.Config`: |
| 149 | + |
| 150 | +### 1. Define a config class |
| 151 | + |
| 152 | +```python |
| 153 | +from pydantic import Field |
| 154 | + |
| 155 | +from nat.data_models.optimizer import OptimizerStrategyBaseConfig |
| 156 | + |
| 157 | + |
| 158 | +class RandomSearchConfig(OptimizerStrategyBaseConfig, name="random_search"): |
| 159 | + n_samples: int = Field(default=50, description="Number of random samples to evaluate.") |
| 160 | +``` |
| 161 | + |
| 162 | +### 2. Implement the Optimizer |
| 163 | + |
| 164 | +```python |
| 165 | +from nat.plugins.config_optimizer.parameters.base import BaseParameterOptimizer |
| 166 | +from nat.data_models.config import Config |
| 167 | +from nat.data_models.optimizable import SearchSpace |
| 168 | +from nat.data_models.optimizer import OptimizerConfig, OptimizerRunConfig |
| 169 | + |
| 170 | + |
| 171 | +class RandomSearchOptimizer(BaseParameterOptimizer): |
| 172 | + |
| 173 | + async def run( |
| 174 | + self, |
| 175 | + *, |
| 176 | + base_cfg: Config, |
| 177 | + full_space: dict[str, SearchSpace], |
| 178 | + optimizer_config: OptimizerConfig, |
| 179 | + opt_run_config: OptimizerRunConfig, |
| 180 | + ) -> Config: |
| 181 | + rs_config = optimizer_config.numeric # Your RandomSearchConfig instance |
| 182 | + |
| 183 | + # Filter out prompt parameters |
| 184 | + param_space = {k: v for k, v in full_space.items() if not v.is_prompt} |
| 185 | + if not param_space: |
| 186 | + return base_cfg |
| 187 | + |
| 188 | + # Implement random search logic here |
| 189 | + # Return the best config found |
| 190 | + ... |
| 191 | + return best_cfg |
| 192 | +``` |
| 193 | + |
| 194 | +### 3. Register and Configure |
| 195 | + |
| 196 | +```python |
| 197 | +from nat.cli.register_workflow import register_optimizer |
| 198 | + |
| 199 | + |
| 200 | +@register_optimizer(config_type=RandomSearchConfig) |
| 201 | +async def register_random_search(config: RandomSearchConfig): |
| 202 | + yield RandomSearchOptimizer() |
| 203 | +``` |
| 204 | + |
| 205 | +Custom strategy selection for `optimizer.numeric` is also programmatic: |
| 206 | + |
| 207 | +```python |
| 208 | +from nat.plugins.config_optimizer.optimizer_runtime import optimize_config |
| 209 | +from nat.data_models.optimizer import OptimizerRunConfig |
| 210 | +from nat.runtime.loader import load_config |
| 211 | + |
| 212 | +cfg = load_config("workflow.yml") |
| 213 | +cfg.optimizer.numeric = RandomSearchConfig(enabled=True, n_samples=100) |
| 214 | + |
| 215 | +await optimize_config( |
| 216 | + OptimizerRunConfig( |
| 217 | + config_file=cfg, |
| 218 | + dataset="dataset.json", |
| 219 | + result_json_path="$", |
| 220 | + ) |
| 221 | +) |
| 222 | +``` |
0 commit comments