Skip to content

Commit 776f7c4

Browse files
Fix vendor validation matrix for presets, pipeline_id, and deprecation path
1 parent bed29b6 commit 776f7c4

8 files changed

Lines changed: 478 additions & 19 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ jobs:
1515
curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1
1616
- name: Install dependencies
1717
run: poetry install
18+
- name: Release metadata check
19+
run: poetry run python scripts/check_release_workflow.py
1820
- name: Compile
1921
run: poetry run mypy .
2022
test:

scripts/check_release_workflow.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#!/usr/bin/env python3
2+
3+
import re
4+
import sys
5+
from pathlib import Path
6+
from typing import NoReturn
7+
8+
9+
def fail(message: str) -> NoReturn:
10+
print(message, file=sys.stderr)
11+
raise SystemExit(1)
12+
13+
14+
def read_version(path: str) -> str:
15+
text = Path(path).read_text()
16+
match = re.search(r'^version\s*=\s*"v?([^"]+)"', text, re.M)
17+
if not match:
18+
fail(f"version not found in {path}")
19+
return match.group(1)
20+
21+
22+
def read_compat_dependency(path: str) -> str:
23+
text = Path(path).read_text()
24+
match = re.search(r'^agora-agents\s*=\s*"([^"]+)"', text, re.M)
25+
if not match:
26+
fail(f"agora-agents dependency not found in {path}")
27+
return match.group(1)
28+
29+
30+
root_version = read_version("pyproject.toml")
31+
compat_pyproject = "compat/agora-agent-server-sdk/pyproject.toml"
32+
compat_version = read_version(compat_pyproject)
33+
compat_dependency = read_compat_dependency(compat_pyproject)
34+
35+
if compat_version != root_version:
36+
fail(f"Compat package version ({compat_version}) must match root package version ({root_version}).")
37+
38+
expected_dependency = f">={root_version},<3.0.0"
39+
if compat_dependency != expected_dependency:
40+
fail(f"Compat package dependency on agora-agents ({compat_dependency}) must be {expected_dependency}.")
41+
42+
release_workflow = Path(".github/workflows/release.yml").read_text()
43+
required_workflow_markers = [
44+
("contents: write", "release workflow must have contents: write so it can create GitHub releases"),
45+
("gh release create", "release workflow must create a GitHub release when one does not exist"),
46+
("gh release edit", "release workflow must update an existing GitHub release"),
47+
("release_notes.md", "release workflow must generate and use a release notes file"),
48+
]
49+
50+
for marker, message in required_workflow_markers:
51+
if marker not in release_workflow:
52+
fail(message)
53+
54+
print("Release metadata and workflow checks passed.")

src/agora_agent/agentkit/agent.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import typing
55
import typing_extensions
6+
import warnings
67

78
if typing.TYPE_CHECKING:
89
from .agent_session import AgentSession, AsyncAgentSession
@@ -815,6 +816,8 @@ def to_properties(
815816
app_certificate: typing.Optional[str] = None,
816817
expires_in: typing.Optional[int] = None,
817818
skip_vendor_validation: bool = False,
819+
skip_vendor_validation_categories: typing.Optional[typing.AbstractSet[str]] = None,
820+
allow_missing_vendor_categories: typing.Optional[typing.AbstractSet[str]] = None,
818821
) -> StartAgentsRequestProperties:
819822
# Validate the MLLM + enabled-avatar combination BEFORE generating the
820823
# RTC token so callers get a clear, actionable error first (matches the
@@ -895,19 +898,49 @@ def to_properties(
895898
base_kwargs["mllm"] = mllm_config
896899
return StartAgentsRequestProperties(**base_kwargs)
897900

898-
base_kwargs["asr"] = self._resolve_asr_config()
901+
if skip_vendor_validation:
902+
warnings.warn(
903+
"skip_vendor_validation is deprecated and will be removed in a future release. "
904+
"Use skip_vendor_validation_categories and allow_missing_vendor_categories instead.",
905+
DeprecationWarning,
906+
stacklevel=2,
907+
)
908+
909+
skip_categories = set(skip_vendor_validation_categories or ())
910+
allow_missing_categories = set(allow_missing_vendor_categories or ())
911+
if skip_vendor_validation:
912+
skip_categories.update({"asr", "llm", "tts"})
913+
allow_missing_categories.update({"asr", "llm", "tts"})
914+
915+
skip_asr_validation = skip_vendor_validation or "asr" in skip_categories
916+
skip_llm_validation = skip_vendor_validation or "llm" in skip_categories
917+
skip_tts_validation = skip_vendor_validation or "tts" in skip_categories
918+
allow_missing_asr = "asr" in allow_missing_categories
919+
allow_missing_llm = "llm" in allow_missing_categories
920+
allow_missing_tts = "tts" in allow_missing_categories
921+
922+
if not skip_asr_validation and (self._stt is not None or not allow_missing_asr):
923+
base_kwargs["asr"] = self._resolve_asr_config()
899924
base_kwargs["turn_detection"] = self._resolve_turn_detection_config()
900925

901926
if skip_vendor_validation:
902927
return StartAgentsRequestProperties(**base_kwargs)
903928

904-
if self._tts is None:
929+
if self._tts is None and not (skip_tts_validation or allow_missing_tts):
905930
raise ValueError("TTS configuration is required. Use with_tts() to set it.")
906931

907-
if self._llm is None:
932+
if self._llm is None and not (skip_llm_validation or allow_missing_llm):
908933
raise ValueError("LLM configuration is required. Use with_llm() to set it.")
909934

910-
llm_config = dict(self._llm)
935+
if self._llm is not None and not skip_llm_validation:
936+
base_kwargs["llm"] = self._resolve_llm_config()
937+
if self._tts is not None and not skip_tts_validation:
938+
base_kwargs["tts"] = self._tts
939+
940+
return StartAgentsRequestProperties(**base_kwargs)
941+
942+
def _resolve_llm_config(self) -> typing.Dict[str, typing.Any]:
943+
llm_config = dict(self._llm or {})
911944
# Agent-level fields take priority over the vendor's defaults.
912945
# This matches the TS SDK where agent-level values override vendor config.
913946
if self._instructions is not None:
@@ -920,11 +953,7 @@ def to_properties(
920953
llm_config["failure_message"] = self._failure_message
921954
if self._max_history is not None:
922955
llm_config["max_history"] = self._max_history
923-
924-
base_kwargs["llm"] = llm_config
925-
base_kwargs["tts"] = self._tts
926-
927-
return StartAgentsRequestProperties(**base_kwargs)
956+
return llm_config
928957

929958
def _resolve_asr_config(self) -> typing.Dict[str, typing.Any]:
930959
asr_config = dict(self._stt or {})

src/agora_agent/agentkit/agent_session.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@
2727
validate_avatar_config,
2828
validate_tts_sample_rate,
2929
)
30-
from .presets import resolve_session_presets
30+
from .presets import (
31+
get_preset_category,
32+
infer_asr_preset,
33+
infer_llm_preset,
34+
infer_tts_preset,
35+
normalize_preset_input,
36+
resolve_session_presets,
37+
)
3138
from .token import generate_convo_ai_token, _parse_numeric_uid
3239

3340

@@ -294,15 +301,17 @@ def _is_mllm_mode(self) -> bool:
294301
def _build_start_properties(
295302
self,
296303
token_opts: typing.Dict[str, typing.Any],
297-
skip_vendor_validation: bool,
304+
skip_vendor_validation_categories: typing.AbstractSet[str],
305+
allow_missing_vendor_categories: typing.AbstractSet[str],
298306
) -> typing.Dict[str, typing.Any]:
299307
base_properties = self._agent.to_properties(
300308
channel=self._channel,
301309
agent_uid=self._agent_uid,
302310
remote_uids=self._remote_uids,
303311
idle_timeout=self._idle_timeout,
304312
enable_string_uid=self._enable_string_uid,
305-
skip_vendor_validation=skip_vendor_validation,
313+
skip_vendor_validation_categories=skip_vendor_validation_categories,
314+
allow_missing_vendor_categories=allow_missing_vendor_categories,
306315
**token_opts,
307316
)
308317
properties = self._dump_model(base_properties)
@@ -340,6 +349,29 @@ def _build_start_properties(
340349

341350
return properties
342351

352+
def _vendor_validation_categories(
353+
self,
354+
pipeline_id: typing.Optional[str],
355+
) -> typing.Tuple[typing.Set[str], typing.Set[str]]:
356+
skip_categories: typing.Set[str] = set()
357+
allow_missing_categories: typing.Set[str] = {"asr", "llm", "tts"} if pipeline_id else set()
358+
359+
preset = normalize_preset_input(self._preset)
360+
if preset:
361+
for item in preset.split(","):
362+
category = get_preset_category(item)
363+
if category is not None:
364+
skip_categories.add(category)
365+
allow_missing_categories.add(category)
366+
367+
if infer_asr_preset(self._agent.stt):
368+
skip_categories.add("asr")
369+
if infer_llm_preset(self._agent.llm):
370+
skip_categories.add("llm")
371+
if infer_tts_preset(self._agent.tts):
372+
skip_categories.add("tts")
373+
return skip_categories, allow_missing_categories
374+
343375
@staticmethod
344376
def _page_value(pagination: typing.Any, field: str) -> typing.Any:
345377
if pagination is None:
@@ -460,7 +492,12 @@ def start(self) -> str:
460492
"expires_in": self._expires_in,
461493
}
462494

463-
properties = self._build_start_properties(token_opts, skip_vendor_validation=bool(self._preset or pipeline_id))
495+
skip_categories, allow_missing_categories = self._vendor_validation_categories(pipeline_id)
496+
properties = self._build_start_properties(
497+
token_opts,
498+
skip_vendor_validation_categories=skip_categories,
499+
allow_missing_vendor_categories=allow_missing_categories,
500+
)
464501
resolved_preset, resolved_properties = resolve_session_presets(
465502
self._preset,
466503
properties,
@@ -782,7 +819,12 @@ async def start(self) -> str:
782819
"expires_in": self._expires_in,
783820
}
784821

785-
properties = self._build_start_properties(token_opts, skip_vendor_validation=bool(self._preset or pipeline_id))
822+
skip_categories, allow_missing_categories = self._vendor_validation_categories(pipeline_id)
823+
properties = self._build_start_properties(
824+
token_opts,
825+
skip_vendor_validation_categories=skip_categories,
826+
allow_missing_vendor_categories=allow_missing_categories,
827+
)
786828
resolved_preset, resolved_properties = resolve_session_presets(
787829
self._preset,
788830
properties,

src/agora_agent/agentkit/presets.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@ class _AgentPresets:
3737
DeepgramPresetModels = ("nova-2", "nova-3")
3838
OpenAIPresetModels = ("gpt-4o-mini", "gpt-4.1-mini", "gpt-5-nano", "gpt-5-mini")
3939
OpenAITtsPresetModels = ("tts-1",)
40-
MiniMaxPresetModels = ("speech-2.6-turbo", "speech_2_6_turbo", "speech-2.8-turbo", "speech_2_8_turbo")
40+
MiniMaxPresetModels = (
41+
"speech-2.6-turbo",
42+
"speech_2_6_turbo",
43+
"speech-2.8-turbo",
44+
"speech_2_8_turbo",
45+
)
4146

4247
PresetInput = typing.Union[str, typing.Sequence[str]]
4348

@@ -61,7 +66,10 @@ class _AgentPresets:
6166

6267

6368
def _normalize_model_name(value: typing.Any) -> typing.Optional[str]:
64-
return value.strip().lower() if isinstance(value, str) else None
69+
if not isinstance(value, str):
70+
return None
71+
normalized = value.strip().lower()
72+
return normalized if normalized else None
6573

6674

6775
def _parse_preset_input(preset: typing.Optional[PresetInput]) -> typing.List[str]:
@@ -87,6 +95,10 @@ def _get_preset_category(preset: str) -> typing.Optional[str]:
8795
return None
8896

8997

98+
def get_preset_category(preset: str) -> typing.Optional[str]:
99+
return _get_preset_category(preset)
100+
101+
90102
def _omit_none(value: typing.Dict[str, typing.Any]) -> typing.Optional[typing.Dict[str, typing.Any]]:
91103
next_value = {k: v for k, v in value.items() if v is not None}
92104
return next_value or None

tests/custom/test_pipeline_id.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from agora_agent import Agent
3+
from agora_agent import Agent, OpenAI, OpenAITTS
44

55

66
def dump(value):
@@ -85,6 +85,75 @@ def test_agent_pipeline_id_skips_missing_vendor_validation() -> None:
8585
call = start_agent(Agent(name="support", pipeline_id="studio-pipeline-id"))
8686

8787
assert call["pipeline_id"] == "studio-pipeline-id"
88+
properties = dump(call["properties"])
89+
assert "asr" not in properties
90+
assert "llm" not in properties
91+
assert "tts" not in properties
92+
93+
94+
def test_pipeline_id_allows_single_llm_override_without_tts_or_asr() -> None:
95+
agent = Agent(name="support", pipeline_id="studio-pipeline-id").with_llm(
96+
OpenAI(
97+
api_key="openai-key",
98+
base_url="https://api.openai.com/v1/chat/completions",
99+
model="gpt-4o",
100+
)
101+
)
102+
103+
call = start_agent(agent)
104+
105+
assert call["pipeline_id"] == "studio-pipeline-id"
106+
properties = dump(call["properties"])
107+
assert "asr" not in properties
108+
assert "tts" not in properties
109+
assert properties["llm"]["api_key"] == "openai-key"
110+
assert properties["llm"]["params"]["model"] == "gpt-4o"
111+
112+
113+
def test_pipeline_id_allows_multiple_overrides_without_asr() -> None:
114+
agent = (
115+
Agent(name="support", pipeline_id="studio-pipeline-id")
116+
.with_llm(
117+
OpenAI(
118+
api_key="openai-key",
119+
base_url="https://api.openai.com/v1/chat/completions",
120+
model="gpt-4o",
121+
)
122+
)
123+
.with_tts(
124+
OpenAITTS(
125+
api_key="tts-key",
126+
base_url="https://api.openai.com/v1/audio/speech",
127+
model="tts-1-hd",
128+
voice="alloy",
129+
)
130+
)
131+
)
132+
133+
call = start_agent(agent)
134+
135+
assert call["pipeline_id"] == "studio-pipeline-id"
136+
properties = dump(call["properties"])
137+
assert "asr" not in properties
138+
assert properties["llm"]["api_key"] == "openai-key"
139+
assert properties["tts"]["vendor"] == "openai"
140+
assert properties["tts"]["params"]["api_key"] == "tts-key"
141+
142+
143+
def test_skip_vendor_validation_boolean_is_deprecated() -> None:
144+
with pytest.warns(DeprecationWarning, match="skip_vendor_validation is deprecated"):
145+
properties = Agent(name="support").to_properties(
146+
channel="channel",
147+
token="token",
148+
agent_uid="1",
149+
remote_uids=["100"],
150+
skip_vendor_validation=True,
151+
)
152+
153+
payload = dump(properties)
154+
assert "asr" not in payload
155+
assert "llm" not in payload
156+
assert "tts" not in payload
88157

89158

90159
def test_pipeline_id_is_not_sent_inside_properties() -> None:

0 commit comments

Comments
 (0)