Skip to content

Commit 10ea9c1

Browse files
fix: fix lint errors in test_interleaved_groups_no_false_merge
- Replace torch.ao.quantization imports with torchao.quantization.pt2e.quantize_pt2e and executorch.backends.xnnpack.quantizer.xnnpack_quantizer - Fix UFMT formatting issues in forward() signatures and torch.bmm/norm2 calls
1 parent cbc1d31 commit 10ea9c1

1 file changed

Lines changed: 6 additions & 14 deletions

File tree

exir/backend/test/test_group_partitioner.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,21 +1712,15 @@ def __init__(self, d: int = 256):
17121712
self.norm1 = torch.nn.LayerNorm(d)
17131713
self.norm2 = torch.nn.LayerNorm(d)
17141714

1715-
def forward(
1716-
self, x: torch.Tensor, mem: torch.Tensor
1717-
) -> torch.Tensor:
1715+
def forward(self, x: torch.Tensor, mem: torch.Tensor) -> torch.Tensor:
17181716
q = self.q_proj(x)
17191717
k = self.k_proj(mem)
17201718
v = self.v_proj(mem)
1721-
attn = torch.bmm(
1722-
q, k.transpose(-2, -1)
1723-
) / math.sqrt(q.size(-1))
1719+
attn = torch.bmm(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
17241720
attn = torch.softmax(attn, dim=-1)
17251721
out = self.out_proj(torch.bmm(attn, v))
17261722
x = self.norm1(x + out)
1727-
x = self.norm2(
1728-
x + self.ffn2(torch.relu(self.ffn1(x)))
1729-
)
1723+
x = self.norm2(x + self.ffn2(torch.relu(self.ffn1(x))))
17301724
return x
17311725

17321726
class TwoLayerDecoder(torch.nn.Module):
@@ -1735,18 +1729,16 @@ def __init__(self):
17351729
self.layer0 = DecoderLayer()
17361730
self.layer1 = DecoderLayer()
17371731

1738-
def forward(
1739-
self, query: torch.Tensor, memory: torch.Tensor
1740-
) -> torch.Tensor:
1732+
def forward(self, query: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
17411733
x = self.layer0(query, memory)
17421734
x = self.layer1(x, memory)
17431735
return x
17441736

1745-
from torch.ao.quantization.quantize_pt2e import (
1737+
from torchao.quantization.pt2e.quantize_pt2e import (
17461738
convert_pt2e,
17471739
prepare_pt2e,
17481740
)
1749-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
1741+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
17501742
XNNPACKQuantizer,
17511743
get_symmetric_quantization_config,
17521744
)

0 commit comments

Comments
 (0)