Skip to content

Commit fef81a5

Browse files
fix: apply ufmt formatting to test_interleaved_groups_no_false_merge
1 parent 10ea9c1 commit fef81a5

1 file changed

Lines changed: 10 additions & 18 deletions

File tree

exir/backend/test/test_group_partitioner.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,46 +1729,38 @@ def __init__(self):
17291729
self.layer0 = DecoderLayer()
17301730
self.layer1 = DecoderLayer()
17311731

1732-
def forward(self, query: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
1732+
def forward(
1733+
self, query: torch.Tensor, memory: torch.Tensor
1734+
) -> torch.Tensor:
17331735
x = self.layer0(query, memory)
17341736
x = self.layer1(x, memory)
17351737
return x
17361738

1737-
from torchao.quantization.pt2e.quantize_pt2e import (
1738-
convert_pt2e,
1739-
prepare_pt2e,
1739+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1740+
XnnpackDynamicallyQuantizedPartitioner,
17401741
)
17411742
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
1742-
XNNPACKQuantizer,
17431743
get_symmetric_quantization_config,
1744-
)
1745-
1746-
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
1747-
XnnpackDynamicallyQuantizedPartitioner,
1744+
XNNPACKQuantizer,
17481745
)
17491746
from executorch.exir import to_edge_transform_and_lower
1747+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
17501748

17511749
model = TwoLayerDecoder().eval()
17521750
query = torch.randn(1, 10, 256)
17531751
memory = torch.randn(1, 20, 256)
17541752

1755-
exported = torch.export.export(
1756-
model, (query, memory), strict=False
1757-
)
1753+
exported = torch.export.export(model, (query, memory), strict=False)
17581754

17591755
quantizer = XNNPACKQuantizer().set_global(
1760-
get_symmetric_quantization_config(
1761-
is_per_channel=True, is_dynamic=True
1762-
)
1756+
get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
17631757
)
17641758
prepared = prepare_pt2e(exported.module(), quantizer)
17651759
with torch.no_grad():
17661760
prepared(query, memory)
17671761
converted = convert_pt2e(prepared)
17681762

1769-
re_exported = torch.export.export(
1770-
converted, (query, memory), strict=False
1771-
)
1763+
re_exported = torch.export.export(converted, (query, memory), strict=False)
17721764

17731765
# Before the fix this raised:
17741766
# AssertionError: Invalid partition, found dependency cycles

0 commit comments

Comments
 (0)