@@ -54,7 +54,7 @@ def check_bias(model: GraphModelWrapper, ref_bias: list):
5454 for node in nncf_graph .get_all_nodes ():
5555 if not is_node_with_fused_bias (node , nncf_graph ):
5656 continue
57- bias_value = get_fused_bias_value (node , nncf_graph , model .model )
57+ bias_value = get_fused_bias_value (node , nncf_graph , model .model ). cpu ()
5858 # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
5959 assert torch .all (torch .isclose (bias_value , ref_bias , atol = 0.02 )), f"{ bias_value } != { ref_bias } "
6060 return
@@ -76,17 +76,3 @@ def backend_specific_model(model: bool, tmp_dir: str):
7676 @staticmethod
7777 def fn_to_type (tensor ):
7878 return torch .Tensor (tensor ).cuda ()
79-
80- @staticmethod
81- def check_bias (model : GraphModelWrapper , ref_bias : list ):
82- ref_bias = torch .Tensor (ref_bias )
83- nncf_graph = model .get_graph ()
84- for node in nncf_graph .get_all_nodes ():
85- if not is_node_with_fused_bias (node , nncf_graph ):
86- continue
87- bias_value = get_fused_bias_value (node , nncf_graph , model .model ).cpu ()
88- # TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
89- assert torch .all (torch .isclose (bias_value , ref_bias , atol = 0.02 )), f"{ bias_value } != { ref_bias } "
90- return
91- msg = "Not found node with bias"
92- raise ValueError (msg )
0 commit comments