Skip to content

Commit b7da263

Browse files
committed
up
1 parent d15ee3c commit b7da263

4 files changed

Lines changed: 42 additions & 95 deletions

File tree

backends/mlx/builder/op_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,11 @@ def parse_dequant_node(
321321
quantized_dim, group_size = non_one[0]
322322
if group_size not in [32, 64, 128]:
323323
return None
324-
if qmin == -8 and qmax == 7:
325-
bits = 4
326-
elif qmin == -128 and qmax == 127:
327-
bits = 8
328-
else:
324+
325+
# TODO: MLX supports 3, 5, and 7, but we need to figure out the
326+
# packing story in to_mlx_qparams to use them
327+
bits = (qmax - qmin + 1).bit_length() - 1
328+
if bits not in [2, 4, 8]:
329329
return None
330330
return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim
331331

backends/mlx/runtime/MLXInterpreter.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,6 @@ exec_addmm(const AddmmNode& n, ExecutionState& st, StreamOrDevice s) {
149149
st.set_tensor(n.out, std::move(Y));
150150
}
151151

152-
inline void
153-
exec_linear(const LinearNode& n, ExecutionState& st, StreamOrDevice s) {
154-
const auto& X = st.const_tensor_ref(n.x);
155-
auto W = st.const_tensor_ref(n.weight);
156-
W = transpose(W, {1, 0}, s);
157-
158-
array Y = n.bias ? addmm(
159-
st.const_tensor_ref(*n.bias),
160-
X,
161-
W,
162-
/*alpha=*/1.0f,
163-
/*beta=*/1.0f,
164-
s)
165-
: matmul(X, W, s);
166-
167-
st.set_tensor(n.out, std::move(Y));
168-
}
169-
170152
inline void
171153
exec_item_int(const ItemIntNode& n, ExecutionState& st, StreamOrDevice) {
172154
// Intentional sync: item() requires a concrete scalar value for SymInt
@@ -1601,9 +1583,6 @@ class Interpreter {
16011583
case OpCode::ADDMM:
16021584
ops::exec_addmm(std::get<AddmmNode>(instr.node), st, s);
16031585
break;
1604-
case OpCode::LINEAR:
1605-
ops::exec_linear(std::get<LinearNode>(instr.node), st, s);
1606-
break;
16071586
case OpCode::ITEM_INT:
16081587
ops::exec_item_int(std::get<ItemIntNode>(instr.node), st, s);
16091588
break;

backends/mlx/serialization/schema.fbs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,6 @@ table AddmmNode {
8686
beta: float = 1.0; // Scalar multiplier for bias
8787
}
8888

89-
table LinearNode {
90-
x: Tid (required);
91-
weight: Tid (required);
92-
out: Tid (required);
93-
bias: Tid; // optional
94-
}
95-
9689
table ItemIntNode {
9790
x: Tid (required);
9891
out: Vid (required);
@@ -916,7 +909,6 @@ union OpNode {
916909
NoopNode,
917910
IdCopyNode,
918911
AddmmNode,
919-
LinearNode,
920912
ItemIntNode,
921913
ExpandDimsNode,
922914
TileNode,

backends/mlx/test/test_ops.py

Lines changed: 37 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -5385,19 +5385,6 @@ def get_dynamic_shapes(self) -> Optional[Dict[str, any]]:
53855385
}
53865386

53875387

5388-
class QuantizedLinearModel(nn.Module):
5389-
"""Simple linear layer that will be quantized."""
5390-
5391-
def __init__(
5392-
self, in_features: int = 64, out_features: int = 128, bias: bool = True
5393-
):
5394-
super().__init__()
5395-
self.linear = nn.Linear(in_features, out_features, bias=bias)
5396-
5397-
def forward(self, x: torch.Tensor) -> torch.Tensor:
5398-
return self.linear(x)
5399-
5400-
54015388
@register_test
54025389
class QuantizedLinearTest(OpTestCase):
54035390
"""Test case for TorchAO int4 quantized nn.Linear."""
@@ -5408,13 +5395,14 @@ class QuantizedLinearTest(OpTestCase):
54085395

54095396
def __init__(
54105397
self,
5411-
in_features: int = 64,
5398+
in_features: int = 128,
54125399
out_features: int = 128,
54135400
batch_size: int = 2,
54145401
seq_len: int = 16,
54155402
bias: bool = True,
54165403
group_size: int = 32,
54175404
dtype: torch.dtype = torch.bfloat16,
5405+
qdtype: torch.dtype = torch.int4,
54185406
):
54195407
self.in_features = in_features
54205408
self.out_features = out_features
@@ -5423,8 +5411,9 @@ def __init__(
54235411
self.bias = bias
54245412
self.group_size = group_size
54255413
self.dtype = dtype
5414+
self.qdtype = qdtype
54265415

5427-
parts = ["quantized_linear", f"g{group_size}"]
5416+
parts = ["quantized_linear", f"{qdtype}", f"g{group_size}"]
54285417
if not bias:
54295418
parts.append("no_bias")
54305419
self.name = "_".join(parts)
@@ -5434,26 +5423,25 @@ def get_test_configs(cls) -> List["QuantizedLinearTest"]:
54345423
return [
54355424
cls(),
54365425
cls(bias=False),
5426+
cls(group_size=64),
5427+
cls(group_size=128),
5428+
cls(qdtype=torch.int2),
5429+
cls(qdtype=torch.int8),
54375430
]
54385431

54395432
def create_model(self) -> nn.Module:
5440-
model = QuantizedLinearModel(
5441-
self.in_features, self.out_features, bias=self.bias
5442-
)
5433+
model = LinearModel(self.in_features, self.out_features, bias=self.bias)
54435434
model = model.to(self.dtype)
54445435

5445-
try:
5446-
from torchao.quantization.granularity import PerGroup
5447-
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
5436+
from torchao.quantization.granularity import PerGroup
5437+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
54485438

5449-
quantize_(
5450-
model,
5451-
IntxWeightOnlyConfig(
5452-
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
5453-
),
5454-
)
5455-
except ImportError:
5456-
raise RuntimeError("TorchAO not installed. Run: pip install torchao")
5439+
quantize_(
5440+
model,
5441+
IntxWeightOnlyConfig(
5442+
weight_dtype=self.qdtype, granularity=PerGroup(self.group_size)
5443+
),
5444+
)
54575445

54585446
return model
54595447

@@ -5464,21 +5452,6 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
54645452
return (x,)
54655453

54665454

5467-
class QuantizedEmbeddingModel(nn.Module):
5468-
"""Simple embedding layer that will be quantized."""
5469-
5470-
def __init__(
5471-
self,
5472-
num_embeddings: int = 1000,
5473-
embedding_dim: int = 64,
5474-
):
5475-
super().__init__()
5476-
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
5477-
5478-
def forward(self, x: torch.Tensor) -> torch.Tensor:
5479-
return self.embedding(x)
5480-
5481-
54825455
@register_test
54835456
class QuantizedEmbeddingTest(OpTestCase):
54845457
"""Test case for TorchAO int4 quantized nn.Embedding."""
@@ -5490,48 +5463,51 @@ class QuantizedEmbeddingTest(OpTestCase):
54905463
def __init__(
54915464
self,
54925465
num_embeddings: int = 1000,
5493-
embedding_dim: int = 64,
5466+
embedding_dim: int = 128,
54945467
batch_size: int = 2,
54955468
seq_len: int = 16,
54965469
group_size: int = 32,
54975470
dtype: torch.dtype = torch.bfloat16,
5471+
qdtype: torch.dtype = torch.int4,
54985472
):
54995473
self.num_embeddings = num_embeddings
55005474
self.embedding_dim = embedding_dim
55015475
self.batch_size = batch_size
55025476
self.seq_len = seq_len
55035477
self.group_size = group_size
55045478
self.dtype = dtype
5479+
self.qdtype = qdtype
55055480

5506-
parts = ["quantized_embedding", f"g{group_size}"]
5481+
parts = ["quantized_embedding", f"{qdtype}", f"g{group_size}"]
55075482
self.name = "_".join(parts)
55085483

55095484
@classmethod
55105485
def get_test_configs(cls) -> List["QuantizedEmbeddingTest"]:
55115486
return [
55125487
cls(),
5488+
cls(group_size=64),
5489+
cls(group_size=128),
5490+
cls(qdtype=torch.int2),
5491+
cls(qdtype=torch.int8),
55135492
]
55145493

55155494
def create_model(self) -> nn.Module:
5516-
model = QuantizedEmbeddingModel(self.num_embeddings, self.embedding_dim)
5495+
model = EmbeddingModel(self.num_embeddings, self.embedding_dim)
55175496
model = model.to(self.dtype)
55185497

5519-
try:
5520-
from torchao.quantization.granularity import PerGroup
5521-
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
5498+
from torchao.quantization.granularity import PerGroup
5499+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
55225500

5523-
def embedding_filter(module: nn.Module, fqn: str) -> bool:
5524-
return isinstance(module, nn.Embedding)
5501+
def embedding_filter(module: nn.Module, fqn: str) -> bool:
5502+
return isinstance(module, nn.Embedding)
55255503

5526-
quantize_(
5527-
model,
5528-
IntxWeightOnlyConfig(
5529-
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
5530-
),
5531-
embedding_filter,
5532-
)
5533-
except ImportError:
5534-
raise RuntimeError("TorchAO not installed. Run: pip install torchao")
5504+
quantize_(
5505+
model,
5506+
IntxWeightOnlyConfig(
5507+
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
5508+
),
5509+
embedding_filter,
5510+
)
55355511

55365512
return model
55375513

0 commit comments

Comments
 (0)