@@ -148,13 +148,13 @@ def _collect_shared_input_modules(
148148 def _input_hook (module , input , output ):
149149 """Update dictionary with list of all modules that share the same input."""
150150 if len (input ) > 0 and isinstance (input [0 ], torch .Tensor ):
151- # TODO: Handle DBRX MoE case
152- input_to_linear [input [0 ]].append (module )
151+ # TODO: Handle DBRX MoE case
152+ input_to_linear [input [0 ]].append (module )
153153
154154 def _output_hook (module , input , output ):
155155 """Update dictionary with mapping of layernorms and their outputs."""
156156 if output_to_layernorm is not None and isinstance (output , torch .Tensor ):
157- output_to_layernorm [output ] = module
157+ output_to_layernorm [output ] = module
158158
159159 handles = []
160160
@@ -323,29 +323,29 @@ def llm_dummy_forward():
323323 if is_vl_model and ("nemotron" in model_type or is_nemotron_parse ):
324324 # For Nemotron VL models (including Nemotron-Parse), run optimization on just the
325325 # language model/decoder. This avoids needing pixel_values for the vision encoder.
326- language_model_lineage = get_language_model_from_vl (model )
326+ language_model_lineage = get_language_model_from_vl (model )
327327
328- if language_model_lineage is not None :
329- language_model = language_model_lineage [- 1 ]
330- print (
331- f"Running optimization on language model with fake_input shape: { fake_input .shape } "
332- )
333- # For Nemotron-Parse decoder, force use_cache=False to avoid tuple index errors
334- if is_nemotron_parse :
335- language_model (fake_input , use_cache = False )
336- else :
337- language_model (fake_input )
328+ if language_model_lineage is not None :
329+ language_model = language_model_lineage [- 1 ]
330+ print (
331+ f"Running optimization on language model with fake_input shape: { fake_input .shape } "
332+ )
333+ # For Nemotron-Parse decoder, force use_cache=False to avoid tuple index errors
334+ if is_nemotron_parse :
335+ language_model (fake_input , use_cache = False )
338336 else :
339- raise ValueError (
340- f"Cannot extract language_model from Nemotron VL model (type: { model_type } ). "
341- "This is required for requantization/resmoothing optimization. "
342- "Please ensure the model architecture is supported or file an issue."
343- )
337+ language_model (fake_input )
338+ else :
339+ raise ValueError (
340+ f"Cannot extract language_model from Nemotron VL model (type: { model_type } ). "
341+ "This is required for requantization/resmoothing optimization. "
342+ "Please ensure the model architecture is supported or file an issue."
343+ )
344344 elif getattr (model .config , "is_encoder_decoder" , False ):
345345 # For other encoder-decoder models (non-VL), pass both encoder and decoder input ids
346346 model (fake_input , decoder_input_ids = decoder_fake_input )
347- else :
348- model (fake_input )
347+ else :
348+ model (fake_input )
349349
350350 input_to_linear , output_to_layernorm = _collect_shared_input_modules (
351351 model , llm_dummy_forward , collect_layernorms = True
@@ -440,19 +440,14 @@ def _export_quantized_weight(
440440 weight_scaling_factor ,
441441 )
442442
443- sub_module .register_buffer (
444- quantizer_attrs .weight_scale ,
445- weight_scaling_factor ,
446- )
447-
448443 if hasattr (input_quantizer , "_amax" ) or (
449444 input_quantizer is not None
450445 and hasattr (input_quantizer , "amax" )
451446 and input_quantizer .amax is not None
452447 ):
453448 assert input_quantizer is not None
454449 if hasattr (input_quantizer , "_amax" ) and input_quantizer ._amax is not None :
455- input_quantizer ._amax = input_quantizer ._amax .to (torch .float32 )
450+ input_quantizer ._amax = input_quantizer ._amax .to (torch .float32 )
456451
457452 sub_module .register_buffer (
458453 quantizer_attrs .input_scale ,
@@ -468,7 +463,7 @@ def _export_quantized_weight(
468463 ):
469464 assert output_quantizer is not None
470465 if hasattr (output_quantizer , "_amax" ) and output_quantizer ._amax is not None :
471- output_quantizer ._amax = output_quantizer ._amax .to (torch .float32 )
466+ output_quantizer ._amax = output_quantizer ._amax .to (torch .float32 )
472467 else :
473468 # Register weight_scale and input_scale
474469 if quantization_format == QUANTIZATION_FP8_PB_REAL :
@@ -485,7 +480,7 @@ def _export_quantized_weight(
485480 )
486481 sub_module .register_buffer (quantizer_attrs .weight_scale , e8m0_scale )
487482 if hasattr (weight_quantizer , "_scale" ) and weight_quantizer ._scale is not None :
488- del weight_quantizer ._scale
483+ del weight_quantizer ._scale
489484 else :
490485 sub_module .register_buffer (
491486 quantizer_attrs .weight_scale , get_weight_scaling_factor (sub_module , weight_name )
0 commit comments