Skip to content

Commit 36d45ef

Browse files
Update
[ghstack-poisoned]
2 parents 9ce837a + ee865c3 commit 36d45ef

2 files changed

Lines changed: 8 additions & 8 deletions

File tree

backends/apple/metal/ops/gather_qmv.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ def gather_qmv(
4242
return y
4343

4444

45-
def _quantize_int4_affine(
46-
w: Tensor, group_size: int
47-
) -> tuple[Tensor, Tensor, Tensor]:
45+
def _quantize_int4_affine(w: Tensor, group_size: int) -> tuple[Tensor, Tensor, Tensor]:
4846
"""Quantize float weights to packed INT4 using MLX affine format.
4947
5048
Args:
@@ -67,8 +65,12 @@ def _quantize_int4_affine(
6765
scales = ((g_max - g_min) / 15.0).clamp(min=1e-8)
6866
biases = g_min
6967
w_int = (
70-
(w_groups - biases.unsqueeze(-1)) / scales.unsqueeze(-1)
71-
).round().clamp(0, 15).to(torch.uint8).reshape(*leading, K)
68+
((w_groups - biases.unsqueeze(-1)) / scales.unsqueeze(-1))
69+
.round()
70+
.clamp(0, 15)
71+
.to(torch.uint8)
72+
.reshape(*leading, K)
73+
)
7274
packed = w_int[..., 0::2] | (w_int[..., 1::2] << 4)
7375
return packed, scales, biases
7476

backends/apple/metal/tests/test_modules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,7 @@ class GatherQMV(nn.Module):
702702

703703
def __init__(self):
704704
super().__init__()
705-
from executorch.backends.apple.metal.ops.gather_qmv import (
706-
_quantize_int4_affine,
707-
)
705+
from executorch.backends.apple.metal.ops.gather_qmv import _quantize_int4_affine
708706

709707
E, N, K, gs = 4, 64, 128, 32
710708
torch.manual_seed(0)

0 commit comments

Comments
 (0)