Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion libs/core/kiln_ai/adapters/fine_tune/dataset_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from kiln_ai.datamodel import DatasetSplit, TaskRun
from kiln_ai.datamodel.datamodel_enums import THINKING_DATA_STRATEGIES, ChatStrategy
from kiln_ai.datamodel.run_config import KilnAgentRunConfigProperties
from kiln_ai.datamodel.tool_id import SKILL_TOOL_ID_PREFIX
from kiln_ai.datamodel.tool_id import (
SKILL_SEARCH_TOOL_ID_PREFIX,
SKILL_TOOL_ID_PREFIX,
)
from kiln_ai.tools.base_tool import ToolCallDefinition
from kiln_ai.tools.tool_registry import tool_from_id
from kiln_ai.utils.exhaustive_error import raise_exhaustive_enum_error
Expand Down Expand Up @@ -433,10 +436,16 @@ async def _get_tool_definitions_from_config(
skill_tool_ids = [
tid for tid in tools_config.tools if tid.startswith(SKILL_TOOL_ID_PREFIX)
]
skill_search_tool_ids = [
tid
for tid in tools_config.tools
if tid.startswith(SKILL_SEARCH_TOOL_ID_PREFIX)
]
non_skill_tool_ids = [
tid
for tid in tools_config.tools
if not tid.startswith(SKILL_TOOL_ID_PREFIX)
and not tid.startswith(SKILL_SEARCH_TOOL_ID_PREFIX)
]

for tool_id in non_skill_tool_ids:
Expand Down Expand Up @@ -475,4 +484,24 @@ async def _get_tool_definitions_from_config(
self._tool_cache[cache_key] = skill_def
tool_definitions.append(skill_def)

if skill_search_tool_ids:
cache_key = "search::" + "::".join(sorted(skill_search_tool_ids))
if cache_key in self._tool_cache:
tool_definitions.append(self._tool_cache[cache_key])
else:
from kiln_ai.adapters.adapter_registry import (
load_skills_from_tool_ids,
)
from kiln_ai.tools.skill_search_tool import SkillSearchTool

skills_dict = load_skills_from_tool_ids(task, skill_search_tool_ids)
if skills_dict:
search_tool = SkillSearchTool(
f"{SKILL_SEARCH_TOOL_ID_PREFIX}_combined",
list(skills_dict.values()),
)
search_def = await search_tool.toolcall_definition()
self._tool_cache[cache_key] = search_def
tool_definitions.append(search_def)

return tool_definitions if tool_definitions else None
102 changes: 99 additions & 3 deletions libs/core/kiln_ai/adapters/model_adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@
)
from kiln_ai.datamodel.skill import Skill
from kiln_ai.datamodel.task import RunConfigProperties
from kiln_ai.datamodel.tool_id import SKILL_TOOL_ID_PREFIX, skill_id_from_tool_id
from kiln_ai.datamodel.tool_id import (
SKILL_SEARCH_TOOL_ID_PREFIX,
SKILL_TOOL_ID_PREFIX,
skill_id_from_skill_search_tool_id,
skill_id_from_tool_id,
)

# Import agent run context for run lifecycle management
from kiln_ai.run_context import (
Expand All @@ -62,6 +67,7 @@
)
from kiln_ai.tools import KilnToolInterface
from kiln_ai.tools.mcp_session_manager import MCPSessionManager
from kiln_ai.tools.skill_search_tool import SkillSearchTool
from kiln_ai.tools.skill_tool import SkillTool
from kiln_ai.tools.tool_registry import tool_from_id
from kiln_ai.utils.config import Config
Expand Down Expand Up @@ -158,6 +164,7 @@ def __init__(
self.prompt_builder = None
self._model_provider: KilnModelProvider | None = None
self._resolved_skills: list[Skill] | None = None
self._resolved_skill_search_skills: list[Skill] | None = None

self.output_schema = task.output_json_schema
self.input_schema = task.input_json_schema
Expand Down Expand Up @@ -509,9 +516,13 @@ def build_prompt(self) -> str:
== StructuredOutputMode.json_instruction_and_object
)

skill_tool_skills = self._resolve_skills()
skill_search_skills = self._resolve_skill_search_tool_skills()
combined_skills = self._merge_skills(skill_tool_skills, skill_search_skills)
return self.prompt_builder.build_prompt(
include_json_instructions=add_json_instructions,
skills=self._resolve_skills(),
skills=combined_skills,
skill_search_enabled=bool(skill_search_skills),
)

def _resolve_skills(self) -> list[Skill]:
Expand Down Expand Up @@ -566,6 +577,73 @@ def _resolve_skills(self) -> list[Skill]:
self._resolved_skills = skills
return self._resolved_skills

def _resolve_skill_search_tool_skills(self) -> list[Skill]:
"""Resolve skills referenced by ``skill_search`` tool IDs.

Mirrors :meth:`_resolve_skills` but filters by
``SKILL_SEARCH_TOOL_ID_PREFIX``. Cached separately so a run config that
enables both tools resolves each once.
"""
if self._resolved_skill_search_skills is not None:
return self._resolved_skill_search_skills

if self.run_config.type != "kiln_agent":
self._resolved_skill_search_skills = []
return self._resolved_skill_search_skills

tool_config = as_kiln_agent_run_config(self.run_config).tools_config
if tool_config is None or tool_config.tools is None:
self._resolved_skill_search_skills = []
return self._resolved_skill_search_skills

search_tool_ids = [
tid
for tid in tool_config.tools
if tid.startswith(SKILL_SEARCH_TOOL_ID_PREFIX)
]
if not search_tool_ids:
self._resolved_skill_search_skills = []
return self._resolved_skill_search_skills

injected = self.base_adapter_config.skills
if injected is None:
raise ValueError(
"Run config references skills but no skills dict was provided via "
"AdapterConfig(skills=...). Use load_skills_for_task() to pre-load "
"skills and pass them to the adapter."
)

skills: list[Skill] = []
seen: set[str] = set()
for tool_id in search_tool_ids:
sid = skill_id_from_skill_search_tool_id(tool_id)
if sid not in injected:
raise ValueError(
f"Skill {sid} referenced in run config but not found in the "
"injected skills dict."
)
if sid in seen:
continue
seen.add(sid)
skills.append(injected[sid])

self._resolved_skill_search_skills = skills
return self._resolved_skill_search_skills

@staticmethod
def _merge_skills(
skill_tool_skills: list[Skill], skill_search_skills: list[Skill]
) -> list[Skill]:
"""Union two skill lists, preserving order and de-duplicating by id."""
merged: list[Skill] = []
seen: set[str | None] = set()
for s in list(skill_tool_skills) + list(skill_search_skills):
if s.id in seen:
continue
seen.add(s.id)
merged.append(s)
return merged

def build_chat_formatter(
self,
input: InputType,
Expand Down Expand Up @@ -734,7 +812,10 @@ async def available_tools(self) -> list[KilnToolInterface]:
return []

non_skill_tool_ids = [
tid for tid in tool_config.tools if not tid.startswith(SKILL_TOOL_ID_PREFIX)
tid
for tid in tool_config.tools
if not tid.startswith(SKILL_TOOL_ID_PREFIX)
and not tid.startswith(SKILL_SEARCH_TOOL_ID_PREFIX)
]

tools: list[KilnToolInterface] = [
Expand All @@ -752,6 +833,21 @@ async def available_tools(self) -> list[KilnToolInterface]:
seen_names.add(skill.name)
tools.append(SkillTool(f"{SKILL_TOOL_ID_PREFIX}_combined", skills))

search_skills = self._resolve_skill_search_tool_skills()
if search_skills:
seen_names = set()
for skill in search_skills:
if skill.name in seen_names:
raise ValueError(
f"Duplicate skill name '{skill.name}'. Each skill must have a unique name."
)
seen_names.add(skill.name)
tools.append(
SkillSearchTool(
f"{SKILL_SEARCH_TOOL_ID_PREFIX}_combined", search_skills
)
)

tool_names = [await tool.name() for tool in tools]
if len(tool_names) != len(set(tool_names)):
raise ValueError(
Expand Down
153 changes: 153 additions & 0 deletions libs/core/kiln_ai/adapters/model_adapters/test_base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ async def test_prompt_builder_json_instructions(
mock_prompt_builder.build_prompt.assert_called_with(
include_json_instructions=expected_json_instructions,
skills=[],
skill_search_enabled=False,
)


Expand Down Expand Up @@ -1770,3 +1771,155 @@ def test_build_prompt_includes_skills(self, base_task, _run_config_with_tools):
def test_build_prompt_no_skills_section_without_skills(self, adapter):
prompt = adapter.build_prompt()
assert "## Skills" not in prompt


class TestResolveSkillSearchSkills:
@pytest.fixture
def _run_config_with_tools(self):
def _make(tools: list[str]) -> KilnAgentRunConfigProperties:
return KilnAgentRunConfigProperties(
model_name="test_model",
model_provider_name="openai",
prompt_id="simple_prompt_builder",
structured_output_mode="json_schema",
tools_config=ToolsRunConfig(tools=tools),
)

return _make

def test_returns_empty_for_non_kiln_agent(self, base_task):
from kiln_ai.datamodel.run_config import (
McpRunConfigProperties,
MCPToolReference,
)

adapter = MockAdapter(
task=base_task,
run_config=McpRunConfigProperties(
tool_reference=MCPToolReference(tool_id="mcp::local::s::t"),
),
)
assert adapter._resolve_skill_search_tool_skills() == []

def test_returns_empty_when_no_tools_config(self, adapter):
assert adapter._resolve_skill_search_tool_skills() == []

def test_returns_empty_when_no_search_ids(self, base_task, _run_config_with_tools):
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools(["kiln_tool::skill::skill_123"]),
config=AdapterConfig(skills={}),
)
assert adapter._resolve_skill_search_tool_skills() == []

def test_raises_when_skills_dict_not_provided(
self, base_task, _run_config_with_tools
):
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools(["kiln_tool::skill_search::skill_123"]),
)
with pytest.raises(ValueError, match="no skills dict was provided"):
adapter._resolve_skill_search_tool_skills()

def test_raises_when_skill_missing_from_dict(
self, base_task, _run_config_with_tools
):
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools(["kiln_tool::skill_search::skill_123"]),
config=AdapterConfig(skills={}),
)
with pytest.raises(ValueError, match="not found in the injected skills dict"):
adapter._resolve_skill_search_tool_skills()

def test_resolves_search_only_tool_config(self, base_task, _run_config_with_tools):
skill = Skill(name="my-skill", description="A skill")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools([f"kiln_tool::skill_search::{skill.id}"]),
config=AdapterConfig(skills={skill.id: skill}),
)
result = adapter._resolve_skill_search_tool_skills()
assert len(result) == 1
assert result[0].name == "my-skill"

def test_resolves_both_prefixes_independently(
self, base_task, _run_config_with_tools
):
skill = Skill(name="my-skill", description="A skill")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools(
[
f"kiln_tool::skill::{skill.id}",
f"kiln_tool::skill_search::{skill.id}",
]
),
config=AdapterConfig(skills={skill.id: skill}),
)
skill_result = adapter._resolve_skills()
search_result = adapter._resolve_skill_search_tool_skills()
assert len(skill_result) == 1
assert len(search_result) == 1
assert skill_result[0].id == skill.id
assert search_result[0].id == skill.id

def test_deduplicates_search_ids(self, base_task, _run_config_with_tools):
skill = Skill(name="my-skill", description="A skill", body="do things")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools(
[
f"kiln_tool::skill_search::{skill.id}",
f"kiln_tool::skill_search::{skill.id}",
]
),
config=AdapterConfig(skills={skill.id: skill}),
)
result = adapter._resolve_skill_search_tool_skills()
assert len(result) == 1
assert result[0].name == "my-skill"

def test_caches_result(self, base_task, _run_config_with_tools):
skill = Skill(name="my-skill", description="A skill")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools([f"kiln_tool::skill_search::{skill.id}"]),
config=AdapterConfig(skills={skill.id: skill}),
)
r1 = adapter._resolve_skill_search_tool_skills()
r2 = adapter._resolve_skill_search_tool_skills()
assert r1 is r2

def test_build_prompt_mentions_skill_search_when_enabled(
self, base_task, _run_config_with_tools
):
skill = Skill(name="my-skill", description="A test skill")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools([f"kiln_tool::skill_search::{skill.id}"]),
config=AdapterConfig(skills={skill.id: skill}),
)
prompt = adapter.build_prompt()
assert "skill_search(name, pattern)" in prompt
assert "my-skill" in prompt

def test_build_prompt_omits_skill_search_when_only_skill_tool(
self, base_task, _run_config_with_tools
):
skill = Skill(name="my-skill", description="A test skill")
adapter = MockAdapter(
task=base_task,
run_config=_run_config_with_tools([f"kiln_tool::skill::{skill.id}"]),
config=AdapterConfig(skills={skill.id: skill}),
)
prompt = adapter.build_prompt()
assert "skill_search(name, pattern)" not in prompt
assert "my-skill" in prompt

def test_merge_skills_dedups_overlapping_sets(self):
s1 = Skill(name="s1", description="d1")
s2 = Skill(name="s2", description="d2")
merged = BaseAdapter._merge_skills([s1, s2], [s2])
assert [s.id for s in merged] == [s1.id, s2.id]
Loading
Loading