diff --git a/backends/xnnpack/_passes/__init__.py b/backends/xnnpack/_passes/__init__.py index 45560124f57..6633f03253e 100644 --- a/backends/xnnpack/_passes/__init__.py +++ b/backends/xnnpack/_passes/__init__.py @@ -46,6 +46,8 @@ from torch.export import ExportedProgram +import torch + class XNNPACKRemoveCloneOpsTransform(RemoveCloneOpsTransform): def __init__(self): @@ -98,14 +100,18 @@ def transform(self) -> ExportedProgram: Returns a transformed ExportedProgram """ ep = self.exported_program - for pass_ in self.passes: - if issubclass(pass_, XNNPACKPass): - transform_pass = pass_(ep) - elif issubclass(pass_, ExportPass): - transform_pass = pass_() - else: - raise RuntimeError( - f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" - ) - ep = _transform(ep, transform_pass) + + with torch.fx.experimental._config.patch( + backed_size_oblivious=True + ): + for pass_ in self.passes: + if issubclass(pass_, XNNPACKPass): + transform_pass = pass_(ep) + elif issubclass(pass_, ExportPass): + transform_pass = pass_() + else: + raise RuntimeError( + f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}" + ) + ep = _transform(ep, transform_pass) return ep diff --git a/backends/xnnpack/test/ops/test_batch_norm.py b/backends/xnnpack/test/ops/test_batch_norm.py index 2391d781d9a..c7ed41529b7 100644 --- a/backends/xnnpack/test/ops/test_batch_norm.py +++ b/backends/xnnpack/test/ops/test_batch_norm.py @@ -147,6 +147,34 @@ def test_fp32_batch_norm_nc(self): """Test BatchNorm1d with NC input is lowered to XNNPACK.""" self._test_batch_norm(self.BatchNorm1dNC(num_features=3)) + def test_fp32_batch_norm_nc_dynamic_batch(self): + """Test BatchNorm1d NC with dynamic batch, inference at batch=20.""" + model = self.BatchNorm1dNC(num_features=3) + model.eval() + with torch.no_grad(): + for _ in range(5): + model(*model.get_inputs()) + + batch = torch.export.Dim("batch", min=1, max=32) + ( + Tester( + model, + model.get_inputs(), + dynamic_shapes=({0: batch},), + ) + .export() + .to_edge_transform_and_lower() + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + .run_method_and_compare_outputs(inputs=(torch.randn(20, 3),)) + ) + def test_fp32_batch_norm_ncl(self): """Test BatchNorm1d with NCL input is lowered to XNNPACK.""" self._test_batch_norm(self.BatchNorm1dNCL(num_features=3))