Skip to content

Commit 977e3fc

Browse files
authored
Align SystemPrompt alias with config fields (#1501)
1 parent cfc60ae commit 977e3fc

8 files changed

Lines changed: 45 additions & 31 deletions

File tree

docs/reference.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ v1 task loader return types. Override `load_tasks(split=...)` on a
7777
### SystemPrompt
7878

7979
```python
80-
SystemPrompt = str | Sequence[Message | JsonData]
80+
SystemPrompt = PromptInput | SystemPromptConfig | None
8181
SystemPromptStrategy = Literal["HT", "TH", "H_OR_T", "T_OR_H", "H", "T", "REJECT"]
8282

8383
class SystemPromptConfig:
@@ -1030,7 +1030,7 @@ class EnvConfig(Config):
10301030

10311031
class TasksetConfig(Config):
10321032
taskset_id: str | None = None # `id` shorthand accepted
1033-
system_prompt: PromptInput | SystemPromptConfig | None = None
1033+
system_prompt: SystemPrompt = None
10341034
user: UserConfig | None = None
10351035
bindings: BindingsConfig = BindingsConfig()
10361036
objects: ObjectsConfig = ObjectsConfig()
@@ -1040,7 +1040,7 @@ class HarnessConfig(Config):
10401040
harness_id: str | None = None # `id` shorthand accepted
10411041
program: ProgramConfig = ProgramConfig()
10421042
model: ModelConfig = ModelConfig()
1043-
system_prompt: PromptInput | SystemPromptConfig | None = None
1043+
system_prompt: SystemPrompt = None
10441044
system_prompt_strategy: SystemPromptStrategy = "HT"
10451045
sandbox: SandboxConfig | None = None
10461046
user: UserConfig | None = None

tests/test_v1_config_extension.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,16 +2310,14 @@ class ChildTaskset(BaseTaskset):
23102310

23112311
def test_taskset_class_loader_owns_split_loading() -> None:
23122312
class LoaderTasksetConfig(TasksetConfig):
2313-
system_prompt: vf.SystemPrompt | None = "class prompt"
2313+
system_prompt: vf.SystemPrompt = "class prompt"
23142314

23152315
class LoaderTaskset(Taskset[LoaderTasksetConfig]):
23162316
def load_tasks(self, split: vf.TaskSplit = "train") -> vf.Tasks:
23172317
answer = "class eval" if split == "eval" else "class tasks"
23182318
return [{"prompt": [], "answer": answer}]
23192319

2320-
def load_system_prompt(
2321-
self, config: LoaderTasksetConfig
2322-
) -> vf.SystemPrompt | None:
2320+
def load_system_prompt(self, config: LoaderTasksetConfig) -> vf.SystemPrompt:
23232321
return config.system_prompt
23242322

23252323
defaulted = LoaderTaskset(config=LoaderTasksetConfig())
@@ -2341,6 +2339,25 @@ def load_system_prompt(
23412339
assert disabled_prompt.system_prompt == []
23422340

23432341

2342+
def test_system_prompt_alias_accepts_config_data(tmp_path) -> None:
2343+
prompt_path = tmp_path / "system_prompt.txt"
2344+
prompt_path.write_text("alias path system prompt", encoding="utf-8")
2345+
2346+
class PromptTasksetConfig(TasksetConfig):
2347+
system_prompt: vf.SystemPrompt = None
2348+
2349+
config = PromptTasksetConfig.model_validate(
2350+
{"system_prompt": {"path": str(prompt_path)}}
2351+
)
2352+
assert isinstance(config.system_prompt, vf.SystemPromptConfig)
2353+
2354+
taskset = Taskset(config=config)
2355+
2356+
assert taskset.system_prompt == [
2357+
{"role": "system", "content": "alias path system prompt"}
2358+
]
2359+
2360+
23442361
def test_taskset_load_tasks_can_return_empty_dataset() -> None:
23452362
class LocalTasksetConfig(TasksetConfig):
23462363
enabled: bool = True

verifiers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.15.dev16"
1+
__version__ = "0.1.15.dev17"
22

33
import importlib
44
import os

verifiers/v1/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,13 @@
5959
)
6060
from .utils.endpoint_utils import Endpoint
6161
from .utils.binding_utils import BindingsConfig, ObjectsConfig
62-
from .utils.prompt_utils import SystemPromptConfig, SystemPromptStrategy
62+
from .utils.prompt_utils import SystemPrompt, SystemPromptConfig, SystemPromptStrategy
6363
from .types import (
6464
ConfigData,
6565
Handler,
6666
JsonData,
6767
Objects,
6868
PromptInput,
69-
SystemPrompt,
7069
TaskSplit,
7170
Tasks,
7271
)

verifiers/v1/harness.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@
7272
run_sandbox_python_program,
7373
)
7474
from .utils.prompt_utils import (
75+
SystemPrompt,
7576
SystemPromptStrategy,
76-
SystemPromptConfig,
7777
normalize_prompt,
7878
normalize_system_prompt,
7979
resolve_system_prompt,
@@ -88,7 +88,6 @@
8888
ConfigData,
8989
JsonData,
9090
Objects,
91-
PromptInput,
9291
)
9392

9493
if TYPE_CHECKING:
@@ -106,7 +105,7 @@ class HarnessConfig(LifecycleConfig):
106105
)
107106
program: ProgramConfig = ProgramConfig()
108107
model: ModelConfig = ModelConfig()
109-
system_prompt: PromptInput | SystemPromptConfig | None = None
108+
system_prompt: SystemPrompt = None
110109
system_prompt_strategy: SystemPromptStrategy = "HT"
111110
sandbox: SandboxConfig | None = None
112111
user: UserConfig | None = None
@@ -217,9 +216,7 @@ def __init__(
217216
self.endpoint = self.load_endpoint()
218217
self.program = self.compile_program(self.program_config)
219218

220-
def load_system_prompt(
221-
self, config: ConfigT
222-
) -> PromptInput | SystemPromptConfig | None:
219+
def load_system_prompt(self, config: ConfigT) -> SystemPrompt:
223220
return config.system_prompt
224221

225222
def load_sandbox(self, config: SandboxConfig | None) -> SandboxConfig | None:

verifiers/v1/taskset.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
BindingsConfig,
1919
ObjectsConfig,
2020
)
21-
from .utils.prompt_utils import SystemPromptConfig, normalize_system_prompt
21+
from .utils.prompt_utils import SystemPrompt, normalize_system_prompt
2222
from .utils.config_utils import (
2323
coerce_config,
2424
config_ref_context,
@@ -36,7 +36,6 @@
3636
from .types import (
3737
JsonData,
3838
Objects,
39-
PromptInput,
4039
TaskSplit,
4140
Tasks,
4241
)
@@ -48,7 +47,7 @@ class TasksetConfig(LifecycleConfig):
4847
default=None,
4948
validation_alias=AliasChoices("taskset_id", "id"),
5049
)
51-
system_prompt: PromptInput | SystemPromptConfig | None = None
50+
system_prompt: SystemPrompt = None
5251
user: UserConfig | None = None
5352
bindings: BindingsConfig = BindingsConfig()
5453
objects: ObjectsConfig = ObjectsConfig()
@@ -152,7 +151,5 @@ def __iter__(self):
152151
def __len__(self) -> int:
153152
return len(self.get_dataset())
154153

155-
def load_system_prompt(
156-
self, config: ConfigT
157-
) -> PromptInput | SystemPromptConfig | None:
154+
def load_system_prompt(self, config: ConfigT) -> SystemPrompt:
158155
return config.system_prompt

verifiers/v1/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242
PromptMessage: TypeAlias = Message | JsonData
4343
PromptInput: TypeAlias = str | Sequence[PromptMessage]
44-
SystemPrompt: TypeAlias = PromptInput
4544

4645
ModelClient: TypeAlias = Client | ClientConfig
4746
RuntimeObject: TypeAlias = object

verifiers/v1/utils/prompt_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import importlib.util
22
from dataclasses import dataclass
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Literal, cast
4+
from typing import TYPE_CHECKING, Literal, TypeAlias, cast
55

66
from pydantic import model_validator
77
from typing_extensions import Self
88
from verifiers.types import Messages, SystemMessage
99
from verifiers.utils.message_utils import normalize_messages
1010

1111
from ..config import Config
12-
from ..types import JsonData, PromptInput, SystemPrompt
12+
from ..types import JsonData, PromptInput
1313
from .config_utils import current_config_ref_module
1414

1515
if TYPE_CHECKING:
@@ -64,13 +64,15 @@ class SystemPromptConfig(Config):
6464
messages: list[JsonData] = []
6565

6666
@model_validator(mode="after")
67-
def validate_one_source(self) -> Self:
68-
sources = [
67+
def validate_one_input(self) -> Self:
68+
inputs = [
6969
self.path is not None,
7070
bool(self.messages),
7171
]
72-
if sum(sources) != 1:
73-
raise ValueError("SystemPromptConfig requires exactly one source.")
72+
if sum(inputs) != 1:
73+
raise ValueError(
74+
"SystemPromptConfig requires exactly one of path or messages."
75+
)
7476
return self
7577

7678
def load(self, field_name: str) -> PromptInput | None:
@@ -81,6 +83,9 @@ def load(self, field_name: str) -> PromptInput | None:
8183
return self.messages
8284

8385

86+
SystemPrompt: TypeAlias = PromptInput | SystemPromptConfig | None
87+
88+
8489
def normalize_prompt(
8590
value: PromptInput | None, field_name: str = "prompt"
8691
) -> list[JsonData]:
@@ -95,7 +100,7 @@ def normalize_prompt(
95100

96101

97102
def normalize_system_prompt(
98-
value: SystemPrompt | SystemPromptConfig | None,
103+
value: SystemPrompt,
99104
field_name: str = "system_prompt",
100105
) -> list[JsonData]:
101106
value = resolve_system_prompt_input(value, field_name=field_name)
@@ -111,7 +116,7 @@ def normalize_system_prompt(
111116

112117

113118
def resolve_system_prompt_input(
114-
value: PromptInput | SystemPromptConfig | None,
119+
value: SystemPrompt,
115120
*,
116121
field_name: str,
117122
) -> PromptInput | None:

0 commit comments

Comments
 (0)