File tree Expand file tree Collapse file tree
modelopt/torch/quantization Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -583,7 +583,7 @@ def get_model(
583583 model_kwargs = config_kwargs .copy ()
584584 # Don't set torch_dtype for VILA models as they handle it explicitly in their builder
585585 if "vila" not in ckpt_path .lower ():
586- model_kwargs .setdefault ("torch_dtype " , "auto" )
586+ model_kwargs .setdefault ("dtype " , "auto" )
587587
588588 if "vila" in ckpt_path .lower ():
589589 hf_vila = AutoModel .from_pretrained (
@@ -666,7 +666,7 @@ def has_pack_quantized_config(config):
666666 model_kwargs2 = model_kwargs .copy ()
667667 if auto_model_module not in [AutoModelForCausalLM , AutoModel ]:
668668 model_kwargs2 .pop ("trust_remote_code" , None )
669- model_kwargs2 ["torch_dtype " ] = torch_dtype
669+ model_kwargs2 ["dtype " ] = torch_dtype
670670 model_kwargs2 .pop ("max_memory" , None )
671671 model = from_config (hf_config , ** model_kwargs2 )
672672
Original file line number Diff line number Diff line change @@ -420,15 +420,10 @@ def load_model(args: argparse.Namespace):
420420 attn_implementation = args .attn_implementation ,
421421 )
422422 else :
423- if args .qformat in QUANT_CFG_CHOICES :
424- quant_cfg = QUANT_CFG_CHOICES [args .qformat ]
425- elif hasattr (mtq , args .qformat ):
426- quant_cfg = getattr (mtq , args .qformat )
427- else :
428- raise AssertionError (
429- f"Quantization format is not supported for low memory mode. "
430- f"Supported formats: { QUANT_CFG_CHOICES .keys ()} "
431- )
423+ assert args .qformat in QUANT_CFG_CHOICES , (
424+ f"Quantization format is not supported for low memory mode. Supported formats: { QUANT_CFG_CHOICES .keys ()} "
425+ )
426+ quant_cfg = QUANT_CFG_CHOICES [args .qformat ]
432427 if args .kv_cache_qformat != "none" :
433428 quant_cfg = mtq .utils .update_quant_cfg_with_kv_cache_quant (
434429 quant_cfg ,
Original file line number Diff line number Diff line change @@ -1665,7 +1665,10 @@ def gptq(
16651665
16661666 def _make_gptq_handle (name , m ):
16671667 backend = getattr (m .weight_quantizer , "backend" , None )
1668- cls = _GPTQ_HELPER_REGISTRY .get (backend , GPTQHelper )
1668+ if backend is None :
1669+ cls = GPTQHelper
1670+ else :
1671+ cls = _GPTQ_HELPER_REGISTRY .get (backend , GPTQHelper )
16691672 return cls (m , name , offload_to_cpu = True )
16701673
16711674 gptq_handles = {name : _make_gptq_handle (name , m ) for name , m in quantized_layers }
@@ -1685,10 +1688,6 @@ def _make_gptq_handle(name, m):
16851688 print_rank_0 ("Updating weights using GPTQ algorithm..." )
16861689 for handle in gptq_handles .values ():
16871690 handle .update_weights (block_size , perc_damp )
1688-
1689- # Disable weight quantizer after running GPTQ update since weights are already QDQ'ed
1690- if hasattr (handle .module , "weight_quantizer" ):
1691- handle .module .weight_quantizer .disable ()
16921691 handle .free ()
16931692 del gptq_handles
16941693
You can’t perform that action at this time.
0 commit comments