Skip to content

Commit cc2d587

Browse files
committed
Fix type ignore errors
1 parent b75478f commit cc2d587

8 files changed

Lines changed: 36 additions & 33 deletions

File tree

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ jobs:
248248

249249
- run:
250250
name: typeignore-check
251-
command: ./ci/check_typeignore.py
251+
command: ./ci/check_typeignore.py ./src/ ./tests/
252252

253253
- run:
254254
name: flake8

ci/check_typeignore.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
without explicitly indicating the reason for the ignore. This is to ensure that we
77
don't accidentally ignore errors that we should be fixing.
88
"""
9-
10-
9+
import argparse
1110
import os
1211
import pathlib
1312
import re
@@ -51,29 +50,37 @@ def check_files(files: List[pathlib.Path]):
5150
check_file(file)
5251

5352

54-
def get_files_to_check(root_dir: pathlib.Path) -> List[pathlib.Path]:
53+
def get_files_to_check(root_dirs: List[pathlib.Path]) -> List[pathlib.Path]:
5554
"""Returns a list of files that should be checked for "# type: ignore" comments."""
5655
# Get the list of files that should be checked.
5756
files = []
58-
for root, _, filenames in os.walk(root_dir):
59-
for filename in filenames:
60-
if filename.endswith(".py"):
61-
files.append(pathlib.Path(root) / filename)
57+
for root_dir in root_dirs:
58+
for root, _, filenames in os.walk(root_dir):
59+
for filename in filenames:
60+
if filename.endswith(".py"):
61+
files.append(pathlib.Path(root) / filename)
6262

6363
return files
6464

6565

66-
@click.command()
67-
@click.option(
68-
"--root-dir",
69-
type=click.Path(exists=True, file_okay=False, dir_okay=True),
70-
default="src",
71-
)
72-
def main(root_dir: str):
66+
def parse_args():
67+
"""Parse command-line arguments."""
68+
parser = argparse.ArgumentParser()
69+
parser.add_argument(
70+
"files",
71+
nargs="+",
72+
type=pathlib.Path,
73+
help="List of files or paths to check for invalid '# type: ignore' comments.",
74+
)
75+
args = parser.parse_args()
76+
return parser, args
77+
78+
def main():
7379
"""Check for invalid "# type: ignore" comments."""
74-
files = get_files_to_check(pathlib.Path(root_dir))
80+
parser, args = parse_args()
81+
file_list = get_files_to_check(args.files)
7582
try:
76-
check_files(files)
83+
check_files(file_list)
7784
except InvalidTypeIgnore as e:
7885
print(e)
7986
sys.exit(1)

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _samples_to_reward_fn_input(
2323
)
2424

2525

26-
class ReplayBufferRewardWrapper(BaseBuffer):
26+
class ReplayBufferRewardWrapper(ReplayBuffer):
2727
"""Relabel the rewards in transitions sampled from a ReplayBuffer."""
2828

2929
def __init__(

tests/algorithms/test_density_baselines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,4 @@ def test_density_trainer_raises(
158158
)
159159

160160
with pytest.raises(TypeError, match="Unsupported demonstration type"):
161-
density_trainer.set_demonstrations("foo") # type: ignore
161+
density_trainer.set_demonstrations("foo")

tests/data/test_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_generate_trajectories_type_error(rng):
300300
sample_until = rollout.make_min_episodes(1)
301301
with pytest.raises(TypeError, match="Policy must be.*got <class 'str'> instead"):
302302
rollout.generate_trajectories(
303-
"strings_are_not_valid_policies", # type: ignore
303+
"strings_are_not_valid_policies",
304304
venv,
305305
rng=rng,
306306
sample_until=sample_until,

tests/policies/test_replay_buffer_wrapper.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,13 @@ def make_algo_with_wrapped_buffer(
3838
policy_kwargs=dict(),
3939
env=venv,
4040
seed=42,
41-
# TODO(juan) we ignore the type below due to
42-
# https://github.com/DLR-RM/stable-baselines3/issues/1039
43-
# PR fixing this has been merged to master,
44-
# remove the type ignore in the next sb3 release.
45-
replay_buffer_class=ReplayBufferRewardWrapper, # type: ignore
41+
replay_buffer_class=ReplayBufferRewardWrapper,
4642
replay_buffer_kwargs=dict(
4743
replay_buffer_class=replay_buffer_class,
4844
reward_fn=zero_reward_fn,
4945
),
5046
buffer_size=buffer_size,
51-
) # type: ignore
47+
) # type: ignore[call-arg]
5248
return rl_algo
5349

5450

@@ -60,7 +56,7 @@ def test_invalid_args(rng):
6056
# we ignore the type because we are intentionally
6157
# passing the wrong type for the test
6258
make_algo_with_wrapped_buffer(
63-
rl_cls=sb3.PPO, # type: ignore
59+
rl_cls=sb3.PPO,
6460
policy_cls=policies.ActorCriticPolicy,
6561
replay_buffer_class=buffers.ReplayBuffer,
6662
rng=rng,

tests/test_regularization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def test_interval_param_scaler_raises(interval_param_scaler):
8787
with pytest.raises(ValueError, match="train_loss must be a scalar"):
8888
scaler(1.0, th.Tensor([1.0, 2.0]), 1.0)
8989
with pytest.raises(ValueError, match="train_loss must be a scalar"):
90-
scaler(1.0, "random value", th.tensor(1.0)) # type: ignore
90+
scaler(1.0, "random value", th.tensor(1.0))
9191
with pytest.raises(ValueError, match="val_loss must be a scalar"):
92-
scaler(1.0, 1.0, "random value") # type: ignore
92+
scaler(1.0, 1.0, "random value")
9393
with pytest.raises(ValueError, match="lambda_ must be a float"):
94-
scaler(th.tensor(1.0), 1.0, 1.0) # type: ignore
94+
scaler(th.tensor(1.0), 1.0, 1.0)
9595
with pytest.raises(ValueError, match="lambda_ must not be zero.*"):
9696
scaler(0.0, 1.0, 1.0)
9797
with pytest.raises(ValueError, match="lambda_ must be non-negative.*"):
@@ -131,12 +131,12 @@ def test_interval_param_scaler_init_raises():
131131
ValueError,
132132
match="tolerable_interval must be a tuple of length 2",
133133
):
134-
updaters.IntervalParamScaler(0.5, (0.1, 0.9, 0.5)) # type: ignore
134+
updaters.IntervalParamScaler(0.5, (0.1, 0.9, 0.5))
135135
with pytest.raises(
136136
ValueError,
137137
match="tolerable_interval must be a tuple of length 2",
138138
):
139-
updaters.IntervalParamScaler(0.5, (0.1,)) # type: ignore
139+
updaters.IntervalParamScaler(0.5, (0.1,))
140140

141141
# the first element of the interval must be at least 0.
142142
with pytest.raises(

tests/util/test_wb_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def finish(self):
8989

9090
# we ignore the type below as one should technically not access the
9191
# __init__ method directly but only by creating an instance.
92-
@mock.patch.object(wandb, "__init__", mock_wandb.__init__) # type: ignore
92+
@mock.patch.object(wandb, "__init__", mock_wandb.__init__) # type: ignore[misc]
9393
@mock.patch.object(wandb, "init", mock_wandb.init)
9494
@mock.patch.object(wandb, "log", mock_wandb.log)
9595
@mock.patch.object(wandb, "finish", mock_wandb.finish)

0 commit comments

Comments
 (0)