@@ -585,7 +585,23 @@ def __init__(
585585 self .opset_version = _target (opset_version ) if opset_version is not None else None
586586 self ._prog = mil .Program ()
587587
588+ self .src_model_has_all_fp16_weights = False
589+
588590 if isinstance (loaded_model , torch .jit .ScriptModule ):
591+ # src_model_has_all_fp16_weights will be True
592+ # if there are more than one trainable layers in the model
593+ # and if all those trainable layers have the fp16 dtype
594+ # eg: if pytorch_model.half() has been explicitly used.
595+ num_trainable_layers = 0
596+ num_trainable_fp16_layers = 0
597+ for param in loaded_model .parameters ():
598+ if param .requires_grad :
599+ num_trainable_layers += 1
600+ if param .dtype == torch .float16 :
601+ num_trainable_fp16_layers += 1
602+ if num_trainable_layers > 0 :
603+ self .src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers
604+
589605 self .context = TranscriptionContext (frontend = TorchFrontend .TORCHSCRIPT )
590606 self .graph = InternalTorchIRGraph .from_torchscript (
591607 torchscript = loaded_model , inputs = self .inputs , cut_at_symbols = cut_at_symbols
@@ -1261,6 +1277,11 @@ def convert(self) -> Program:
12611277 user_names = list (ssa_func_inputs .keys ())
12621278 internal_names = list (self .graph .inputs .keys ())
12631279 internal_names .extend (user_names [len (internal_names ) :])
1280+ input_dtypes = []
1281+ for torch_name , ssa_name in zip (internal_names , user_names ):
1282+ input_var = ssa_func .inputs [ssa_name ]
1283+ input_dtypes .append (input_var .dtype )
1284+ all_fp16_inputs = all (x == types .fp16 for x in input_dtypes )
12641285 for torch_name , ssa_name in zip (internal_names , user_names ):
12651286 input_var = ssa_func .inputs [ssa_name ]
12661287 if self .context .frontend == TorchFrontend .TORCHSCRIPT :
@@ -1272,7 +1293,7 @@ def convert(self) -> Program:
12721293 # So here we perform the "cast input to fp32" step
12731294 if (
12741295 types .is_tensor (input_var .sym_type ) or types .is_scalar (input_var .sym_type )
1275- ) and input_var .dtype == types .fp16 :
1296+ ) and input_var .dtype == types .fp16 and not ( all_fp16_inputs and self . src_model_has_all_fp16_weights ) :
12761297 # This cast should have placeholder scope
12771298 with mb .scope (
12781299 ScopeInfo (
0 commit comments