Skip to content

Commit 01de5c6

Browse files
committed
Add CI workflow and validate model providers
1 parent 43f3804 commit 01de5c6

2 files changed

Lines changed: 86 additions & 2 deletions

File tree

.github/workflows/ci.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
test:
9+
runs-on: ubuntu-latest
10+
strategy:
11+
matrix:
12+
python-version: [3.11]
13+
steps:
14+
- uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: ${{ matrix.python-version }}
20+
21+
- name: Cache pip
22+
uses: actions/cache@v4
23+
with:
24+
path: ~/.cache/pip
25+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
26+
restore-keys: |
27+
${{ runner.os }}-pip-
28+
29+
- name: Install dependencies
30+
run: |
31+
python -m pip install --upgrade pip
32+
pip install -r requirements.txt
33+
34+
- name: Run tests
35+
run: pytest -q
36+
37+
- name: Upload tests directory as artifact
38+
if: always()
39+
uses: actions/upload-artifact@v4
40+
with:
41+
name: tests-directory
42+
path: tests

src/agent.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4906,6 +4906,8 @@ def main() -> None:
49064906
help='Quick override for improvements model. Format: "provider:model" or just "model"')
49074907
parser.add_argument('--n8n', type=str, metavar='URL',
49084908
help='Register an n8n webhook URL to receive agent notifications')
4909+
parser.add_argument('--model', action='append', metavar='AGENT=SPEC',
4910+
help='Repeatable. Specify per-agent model overrides. Format: agent=provider:model or agent=model')
49094911
parser.add_argument('--rate-limit', type=float, metavar='RPS',
49104912
help='Rate limit API calls to RPS requests per second')
49114913
parser.add_argument('--enable-file-locking', action='store_true',
@@ -4941,13 +4943,41 @@ def _parse_quick_flag(val: str) -> dict:
49414943
return {}
49424944
if ':' in val:
49434945
provider, model = val.split(':', 1)
4944-
return {'provider': provider, 'model': model}
4945-
return {'model': val}
4946+
return {'provider': provider.strip(), 'model': model.strip()}
4947+
return {'model': val.strip()}
4948+
4949+
def parse_model_overrides(raw_list: list[str] | None) -> dict[str, dict]:
4950+
"""Parse repeatable `--model` entries of form `agent=provider:model` or `agent=model`.
4951+
4952+
Returns mapping agent -> spec dict.
4953+
"""
4954+
out: dict[str, dict] = {}
4955+
if not raw_list:
4956+
return out
4957+
for entry in raw_list:
4958+
try:
4959+
if '=' not in entry:
4960+
logging.warning(f"Ignoring malformed --model entry: {entry}")
4961+
continue
4962+
agent_key, spec = entry.split('=', 1)
4963+
agent_key = agent_key.strip()
4964+
if ':' in spec:
4965+
provider, model = spec.split(':', 1)
4966+
out[agent_key] = {'provider': provider.strip(), 'model': model.strip()}
4967+
else:
4968+
out[agent_key] = {'model': spec.strip()}
4969+
except Exception:
4970+
logging.warning(f"Failed to parse --model entry: {entry}")
4971+
return out
49464972

49474973
if getattr(args, 'model_coder', None):
49484974
cli_models.setdefault('coder', {}).update(_parse_quick_flag(args.model_coder))
49494975
if getattr(args, 'model_improvements', None):
49504976
cli_models.setdefault('improvements', {}).update(_parse_quick_flag(args.model_improvements))
4977+
# Parse repeatable --model flags
4978+
if getattr(args, 'model', None):
4979+
parsed = parse_model_overrides(args.model)
4980+
cli_models.update(parsed)
49514981

49524982
# Health check mode
49534983
if args.health_check:
@@ -4987,6 +5017,18 @@ def _parse_quick_flag(val: str) -> dict:
49875017
try:
49885018
existing = getattr(agent, 'models', {}) or {}
49895019
# CLI overrides win
5020+
# Validate providers for simple known list
5021+
allowed_providers = {'openai', 'google', 'anthropic'}
5022+
def _validate_spec(spec: dict) -> bool:
5023+
prov = spec.get('provider')
5024+
if prov and prov not in allowed_providers:
5025+
logging.warning(f"Unknown provider '{prov}' in model spec; continuing")
5026+
return True
5027+
5028+
for k, v in cli_models.items():
5029+
if isinstance(v, dict):
5030+
_validate_spec(v)
5031+
49905032
merged = {**existing, **cli_models}
49915033
agent.models = merged
49925034
except Exception:

0 commit comments

Comments
 (0)