|
2 | 2 | import os |
3 | 3 | import re |
4 | 4 | import warnings |
5 | | -from typing import Any, Dict, List, Optional, Set, Tuple, Union |
| 5 | +from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable, cast |
6 | 6 |
|
7 | 7 | # Import OpenAI at module level for mocking in tests |
8 | 8 | try: |
@@ -451,7 +451,8 @@ def schema_jaccard_reward( |
451 | 451 | DeprecationWarning, |
452 | 452 | stacklevel=2, |
453 | 453 | ) |
454 | | - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) |
| 454 | + _exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward) |
| 455 | + return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs) |
455 | 456 |
|
456 | 457 |
|
457 | 458 | @reward_function |
@@ -493,7 +494,8 @@ def llm_judge_reward( |
493 | 494 | DeprecationWarning, |
494 | 495 | stacklevel=2, |
495 | 496 | ) |
496 | | - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) |
| 497 | + _exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward) |
| 498 | + return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs) |
497 | 499 |
|
498 | 500 |
|
499 | 501 | @reward_function |
@@ -537,7 +539,8 @@ def composite_function_call_reward( |
537 | 539 | DeprecationWarning, |
538 | 540 | stacklevel=2, |
539 | 541 | ) |
540 | | - return exact_tool_match_reward(messages=messages, ground_truth=ground_truth, **kwargs) |
| 542 | + _exact_tool_match: Callable[..., EvaluateResult] = cast(Callable[..., EvaluateResult], exact_tool_match_reward) |
| 543 | + return _exact_tool_match(messages=messages, ground_truth=ground_truth, **kwargs) |
541 | 544 |
|
542 | 545 |
|
543 | 546 | # JSON schema reward functions have been moved to json_schema.py module |
0 commit comments