Skip to content

Commit f839021

Browse files
committed
Fix general type ignore in src
1 parent 707ba66 commit f839021

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/imitation/testing/reward_improvement.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def is_significant_reward_improvement(
4545

4646

4747
def mean_reward_improved_by(
48-
old_rewards: Iterable[float],
49-
new_rewards: Iterable[float],
48+
old_rews: Iterable[float],
49+
new_rews: Iterable[float],
5050
min_improvement: float,
5151
):
5252
"""Checks if mean rewards improved wrt. to old rewards by a certain amount.
5353
5454
Args:
55-
old_rewards: Iterable of "old" trajectory rewards (e.g. before training).
56-
new_rewards: Iterable of "new" trajectory rewards (e.g. after training).
55+
old_rews: Iterable of "old" trajectory rewards (e.g. before training).
56+
new_rews: Iterable of "new" trajectory rewards (e.g. after training).
5757
min_improvement: The minimum amount of improvement that we expect.
5858
5959
Returns:
@@ -66,5 +66,5 @@ def mean_reward_improved_by(
6666
>>> mean_reward_improved_by([5, 8, 7], [8, 9, 10], 5)
6767
False
6868
"""
69-
improvement = np.mean(new_rewards) - np.mean(old_rewards) # type: ignore
69+
improvement = np.mean(new_rews) - np.mean(old_rews) # type: ignore[call-overload]
7070
return improvement >= min_improvement

0 commit comments

Comments
 (0)