@@ -5385,19 +5385,6 @@ def get_dynamic_shapes(self) -> Optional[Dict[str, any]]:
53855385 }
53865386
53875387
5388- class QuantizedLinearModel (nn .Module ):
5389- """Simple linear layer that will be quantized."""
5390-
5391- def __init__ (
5392- self , in_features : int = 64 , out_features : int = 128 , bias : bool = True
5393- ):
5394- super ().__init__ ()
5395- self .linear = nn .Linear (in_features , out_features , bias = bias )
5396-
5397- def forward (self , x : torch .Tensor ) -> torch .Tensor :
5398- return self .linear (x )
5399-
5400-
54015388@register_test
54025389class QuantizedLinearTest (OpTestCase ):
54035390 """Test case for TorchAO int4 quantized nn.Linear."""
@@ -5408,13 +5395,14 @@ class QuantizedLinearTest(OpTestCase):
54085395
54095396 def __init__ (
54105397 self ,
5411- in_features : int = 64 ,
5398+ in_features : int = 128 ,
54125399 out_features : int = 128 ,
54135400 batch_size : int = 2 ,
54145401 seq_len : int = 16 ,
54155402 bias : bool = True ,
54165403 group_size : int = 32 ,
54175404 dtype : torch .dtype = torch .bfloat16 ,
5405+ qdtype : torch .dtype = torch .int4 ,
54185406 ):
54195407 self .in_features = in_features
54205408 self .out_features = out_features
@@ -5423,8 +5411,9 @@ def __init__(
54235411 self .bias = bias
54245412 self .group_size = group_size
54255413 self .dtype = dtype
5414+ self .qdtype = qdtype
54265415
5427- parts = ["quantized_linear" , f"g{ group_size } " ]
5416+ parts = ["quantized_linear" , f"{ qdtype } " , f" g{ group_size } " ]
54285417 if not bias :
54295418 parts .append ("no_bias" )
54305419 self .name = "_" .join (parts )
@@ -5434,26 +5423,25 @@ def get_test_configs(cls) -> List["QuantizedLinearTest"]:
54345423 return [
54355424 cls (),
54365425 cls (bias = False ),
5426+ cls (group_size = 64 ),
5427+ cls (group_size = 128 ),
5428+ cls (qdtype = torch .int2 ),
5429+ cls (qdtype = torch .int8 ),
54375430 ]
54385431
54395432 def create_model (self ) -> nn .Module :
5440- model = QuantizedLinearModel (
5441- self .in_features , self .out_features , bias = self .bias
5442- )
5433+ model = LinearModel (self .in_features , self .out_features , bias = self .bias )
54435434 model = model .to (self .dtype )
54445435
5445- try :
5446- from torchao .quantization .granularity import PerGroup
5447- from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
5436+ from torchao .quantization .granularity import PerGroup
5437+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
54485438
5449- quantize_ (
5450- model ,
5451- IntxWeightOnlyConfig (
5452- weight_dtype = torch .int4 , granularity = PerGroup (self .group_size )
5453- ),
5454- )
5455- except ImportError :
5456- raise RuntimeError ("TorchAO not installed. Run: pip install torchao" )
5439+ quantize_ (
5440+ model ,
5441+ IntxWeightOnlyConfig (
5442+ weight_dtype = self .qdtype , granularity = PerGroup (self .group_size )
5443+ ),
5444+ )
54575445
54585446 return model
54595447
@@ -5464,21 +5452,6 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]:
54645452 return (x ,)
54655453
54665454
5467- class QuantizedEmbeddingModel (nn .Module ):
5468- """Simple embedding layer that will be quantized."""
5469-
5470- def __init__ (
5471- self ,
5472- num_embeddings : int = 1000 ,
5473- embedding_dim : int = 64 ,
5474- ):
5475- super ().__init__ ()
5476- self .embedding = nn .Embedding (num_embeddings , embedding_dim )
5477-
5478- def forward (self , x : torch .Tensor ) -> torch .Tensor :
5479- return self .embedding (x )
5480-
5481-
54825455@register_test
54835456class QuantizedEmbeddingTest (OpTestCase ):
54845457 """Test case for TorchAO int4 quantized nn.Embedding."""
@@ -5490,48 +5463,51 @@ class QuantizedEmbeddingTest(OpTestCase):
54905463 def __init__ (
54915464 self ,
54925465 num_embeddings : int = 1000 ,
5493- embedding_dim : int = 64 ,
5466+ embedding_dim : int = 128 ,
54945467 batch_size : int = 2 ,
54955468 seq_len : int = 16 ,
54965469 group_size : int = 32 ,
54975470 dtype : torch .dtype = torch .bfloat16 ,
5471+ qdtype : torch .dtype = torch .int4 ,
54985472 ):
54995473 self .num_embeddings = num_embeddings
55005474 self .embedding_dim = embedding_dim
55015475 self .batch_size = batch_size
55025476 self .seq_len = seq_len
55035477 self .group_size = group_size
55045478 self .dtype = dtype
5479+ self .qdtype = qdtype
55055480
5506- parts = ["quantized_embedding" , f"g{ group_size } " ]
5481+ parts = ["quantized_embedding" , f"{ qdtype } " , f" g{ group_size } " ]
55075482 self .name = "_" .join (parts )
55085483
55095484 @classmethod
55105485 def get_test_configs (cls ) -> List ["QuantizedEmbeddingTest" ]:
55115486 return [
55125487 cls (),
5488+ cls (group_size = 64 ),
5489+ cls (group_size = 128 ),
5490+ cls (qdtype = torch .int2 ),
5491+ cls (qdtype = torch .int8 ),
55135492 ]
55145493
55155494 def create_model (self ) -> nn .Module :
5516- model = QuantizedEmbeddingModel (self .num_embeddings , self .embedding_dim )
5495+ model = EmbeddingModel (self .num_embeddings , self .embedding_dim )
55175496 model = model .to (self .dtype )
55185497
5519- try :
5520- from torchao .quantization .granularity import PerGroup
5521- from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
5498+ from torchao .quantization .granularity import PerGroup
5499+ from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
55225500
5523- def embedding_filter (module : nn .Module , fqn : str ) -> bool :
5524- return isinstance (module , nn .Embedding )
5501+ def embedding_filter (module : nn .Module , fqn : str ) -> bool :
5502+ return isinstance (module , nn .Embedding )
55255503
5526- quantize_ (
5527- model ,
5528- IntxWeightOnlyConfig (
5529- weight_dtype = torch .int4 , granularity = PerGroup (self .group_size )
5530- ),
5531- embedding_filter ,
5532- )
5533- except ImportError :
5534- raise RuntimeError ("TorchAO not installed. Run: pip install torchao" )
5504+ quantize_ (
5505+ model ,
5506+ IntxWeightOnlyConfig (
5507+ weight_dtype = torch .int4 , granularity = PerGroup (self .group_size )
5508+ ),
5509+ embedding_filter ,
5510+ )
55355511
55365512 return model
55375513
0 commit comments