Skip to content

Commit e390646

Browse files
authored
[tests] accept recompile_limit from the user in tests (#13150)
accept recompile_limit from the user in tests
1 parent 59e7a46 commit e390646

File tree

2 files changed

+2
-18
lines changed

2 files changed

+2
-18
lines changed

tests/models/testing_utils/compile.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_torch_compile_recompilation_and_graph_break(self):
8181
_ = model(**inputs_dict)
8282

8383
@torch.no_grad()
84-
def test_torch_compile_repeated_blocks(self):
84+
def test_torch_compile_repeated_blocks(self, recompile_limit=1):
8585
if self.model_class._repeated_blocks is None:
8686
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
8787

@@ -92,7 +92,6 @@ def test_torch_compile_repeated_blocks(self):
9292
model.eval()
9393
model.compile_repeated_blocks(fullgraph=True)
9494

95-
recompile_limit = 1
9695
if self.model_class.__name__ == "UNet2DConditionModel":
9796
recompile_limit = 2
9897

tests/models/transformers/test_models_transformer_wan_vace.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,22 +147,7 @@ class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCom
147147
def test_torch_compile_repeated_blocks(self):
148148
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
149149
# so we need recompile_limit=2 instead of the default 1.
150-
import torch._dynamo
151-
import torch._inductor.utils
152-
153-
init_dict = self.get_init_dict()
154-
inputs_dict = self.get_dummy_inputs()
155-
156-
model = self.model_class(**init_dict).to(torch_device)
157-
model.eval()
158-
model.compile_repeated_blocks(fullgraph=True)
159-
160-
with (
161-
torch._inductor.utils.fresh_inductor_cache(),
162-
torch._dynamo.config.patch(recompile_limit=2),
163-
):
164-
_ = model(**inputs_dict)
165-
_ = model(**inputs_dict)
150+
super().test_torch_compile_repeated_blocks(recompile_limit=2)
166151

167152

168153
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):

0 commit comments

Comments
 (0)