@@ -1729,46 +1729,38 @@ def __init__(self):
17291729 self .layer0 = DecoderLayer ()
17301730 self .layer1 = DecoderLayer ()
17311731
1732- def forward (self , query : torch .Tensor , memory : torch .Tensor ) -> torch .Tensor :
1732+ def forward (
1733+ self , query : torch .Tensor , memory : torch .Tensor
1734+ ) -> torch .Tensor :
17331735 x = self .layer0 (query , memory )
17341736 x = self .layer1 (x , memory )
17351737 return x
17361738
1737- from torchao .quantization .pt2e .quantize_pt2e import (
1738- convert_pt2e ,
1739- prepare_pt2e ,
1739+ from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
1740+ XnnpackDynamicallyQuantizedPartitioner ,
17401741 )
17411742 from executorch .backends .xnnpack .quantizer .xnnpack_quantizer import (
1742- XNNPACKQuantizer ,
17431743 get_symmetric_quantization_config ,
1744- )
1745-
1746- from executorch .backends .xnnpack .partition .xnnpack_partitioner import (
1747- XnnpackDynamicallyQuantizedPartitioner ,
1744+ XNNPACKQuantizer ,
17481745 )
17491746 from executorch .exir import to_edge_transform_and_lower
1747+ from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
17501748
17511749 model = TwoLayerDecoder ().eval ()
17521750 query = torch .randn (1 , 10 , 256 )
17531751 memory = torch .randn (1 , 20 , 256 )
17541752
1755- exported = torch .export .export (
1756- model , (query , memory ), strict = False
1757- )
1753+ exported = torch .export .export (model , (query , memory ), strict = False )
17581754
17591755 quantizer = XNNPACKQuantizer ().set_global (
1760- get_symmetric_quantization_config (
1761- is_per_channel = True , is_dynamic = True
1762- )
1756+ get_symmetric_quantization_config (is_per_channel = True , is_dynamic = True )
17631757 )
17641758 prepared = prepare_pt2e (exported .module (), quantizer )
17651759 with torch .no_grad ():
17661760 prepared (query , memory )
17671761 converted = convert_pt2e (prepared )
17681762
1769- re_exported = torch .export .export (
1770- converted , (query , memory ), strict = False
1771- )
1763+ re_exported = torch .export .export (converted , (query , memory ), strict = False )
17721764
17731765 # Before the fix this raised:
17741766 # AssertionError: Invalid partition, found dependency cycles
0 commit comments