@@ -669,7 +669,10 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred)
669669 self .target_cf_class = np .array (
670670 [[self .infer_target_cfs_class (desired_class , test_pred , self .num_output_nodes )]],
671671 dtype = np .float32 )
672- desired_class = int (self .target_cf_class [0 ][0 ])
672+ target_cf_class_scalar = self .target_cf_class [0 ][0 ]
673+ if hasattr (target_cf_class_scalar , "item" ):
674+ target_cf_class_scalar = target_cf_class_scalar .item ()
675+ desired_class = int (target_cf_class_scalar )
673676 if self .target_cf_class == 0 and self .stopping_threshold > 0.5 :
674677 self .stopping_threshold = 0.25
675678 elif self .target_cf_class == 1 and self .stopping_threshold < 0.5 :
@@ -692,12 +695,16 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
692695 # already contains the class.
693696 if hasattr (original_pred , "__len__" ) and len (original_pred ) > 1 :
694697 original_pred_1 = np .argmax (original_pred )
698+ if hasattr (original_pred_1 , "item" ):
699+ original_pred_1 = original_pred_1 .item ()
695700 else :
696701 original_pred_1 = original_pred
697702 target_class = int (1 - original_pred_1 )
698703 return target_class
699704 elif num_output_nodes == 1 : # only for pytorch DL model
700705 original_pred_1 = np .round (original_pred )
706+ if hasattr (original_pred_1 , "item" ):
707+ original_pred_1 = original_pred_1 .item ()
701708 target_class = int (1 - original_pred_1 )
702709 return target_class
703710 elif num_output_nodes > 2 :
@@ -768,6 +775,8 @@ def is_cf_valid(self, model_score):
768775 target_cf_class = self .target_cf_class [0 ]
769776 elif len (self .target_cf_class .shape ) == 2 :
770777 target_cf_class = self .target_cf_class [0 ][0 ]
778+ if hasattr (target_cf_class , "item" ):
779+ target_cf_class = target_cf_class .item ()
771780 target_cf_class = int (target_cf_class )
772781
773782 if len (model_score ) == 1 : # for tensorflow/pytorch models
0 commit comments