Skip to content

Commit 507b661

Browse files
cursoragentbenjibc
andcommitted
Refactor CI, pre-commit, and type checking with minor code improvements
Co-authored-by: bchen <bchen@fireworks.ai>
1 parent 7c10e74 commit 507b661

5 files changed

Lines changed: 15 additions & 13 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,9 @@ jobs:
4848
- name: Ruff lint
4949
run: uv run ruff check .
5050

51-
- name: Type check with pyright
51+
- name: Run pre-commit (format, lint, type check)
5252
run: |
53-
# 'set +e' disables immediate exit on error so we can capture and report errors but exit 0
54-
# Note: We currently suppress pyright failures to allow CI to pass while we iteratively fix all type issues.
55-
# Once all type errors are resolved, we will remove this suppression and enforce strict type checking.
56-
set +e
57-
uv run basedpyright || true
53+
uv run pre-commit run --all-files
5854
5955
test-core:
6056
name: Core Tests (Python ${{ matrix.python-version }})

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ repos:
3131
NODE_OPTIONS: "--max-old-space-size=4096"
3232
# Only check Python files in the main package to reduce memory usage
3333
files: ^eval_protocol/.*\.py$
34+
additional_dependencies: ["pre-commit>=3.7.0"]

eval_protocol/mcp_servers/tau2/tests/test_tau2_e2e.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,10 @@ def tau2_airline_eval(
747747
elif role == "user":
748748
trajectory_objects.append(UserMessage(role=role, content=content))
749749
elif role == "tool":
750-
tool_id = msg.tool_call_id or ""
751-
trajectory_objects.append(ToolMessage(id=tool_id, role=role, content=content, requestor="assistant"))
750+
tool_id = msg.tool_call_id if isinstance(msg.tool_call_id, str) else ""
751+
trajectory_objects.append(
752+
ToolMessage(id=tool_id, role=role, content=content, requestor="assistant")
753+
)
752754

753755
reward = 1.0
754756

eval_protocol/rewards/function_calling.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import re
44
import warnings
5-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
5+
from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable, cast
66

77
# Import OpenAI at module level for mocking in tests
88
try:
@@ -451,7 +451,8 @@ def schema_jaccard_reward(
451451
DeprecationWarning,
452452
stacklevel=2,
453453
)
454-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
454+
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
455+
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)
455456

456457

457458
@reward_function
@@ -493,7 +494,8 @@ def llm_judge_reward(
493494
DeprecationWarning,
494495
stacklevel=2,
495496
)
496-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
497+
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
498+
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)
497499

498500

499501
@reward_function
@@ -537,7 +539,8 @@ def composite_function_call_reward(
537539
DeprecationWarning,
538540
stacklevel=2,
539541
)
540-
return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs)
542+
_exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward)
543+
return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs)
541544

542545

543546
# JSON schema reward functions have been moved to json_schema.py module

eval_protocol/rewards/lean_prover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict, List, Optional
44

55
from eval_protocol.models import EvaluateResult, Message, MetricResult
6-
from eval_protocol.reward_function import reward_function
6+
from eval_protocol.typed_interface import reward_function
77

88

99
@reward_function

0 commit comments

Comments
 (0)