Skip to content

Commit 475141a

Browse files
authored
Add the check of API names in check_nan_or_inf, so that the checks … (#78816)
* Add the check of API names in `check_nan_or_inf`, so that the checks for nan or inf can be skipped in certain APIs * fix Windows LNK2019 for FLAGS_check_nan_inf_blacklist in nan_inf_utils_test
1 parent 101ea9c commit 475141a

3 files changed

Lines changed: 98 additions & 0 deletions

File tree

paddle/common/flags.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,19 @@ PHI_DEFINE_EXPORTED_int32(
102102
0,
103103
"Setting the check and print level when FLAGS_check_nan_inf is set.");
104104

105+
/**
106+
* Operator related FLAG
107+
* Name: FLAGS_check_nan_inf_blacklist
108+
* Since Version:
109+
* Value Range: string, default=""
110+
* Example: FLAGS_check_nan_inf_blacklist="op1,op2,op3"
111+
* Note: Blacklist of ops to skip when checking NAN/INF
112+
*/
113+
PHI_DEFINE_EXPORTED_string(
114+
check_nan_inf_blacklist,
115+
"",
116+
"Blacklist of ops to skip when checking NAN/INF, split by ','");
117+
105118
/**
106119
* Operator related FLAG
107120
* Name: FLAGS_check_nan_inf

paddle/fluid/eager/nan_inf_utils.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
2424
#include "paddle/phi/core/selected_rows.h"
2525

26+
COMMON_DECLARE_string(check_nan_inf_blacklist);
2627
COMMON_DECLARE_int32(check_nan_inf_level);
2728
namespace egr {
2829

@@ -82,6 +83,27 @@ bool CheckOp(const std::string& api_name) {
8283
}
8384

8485
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) {
86+
if (api_name == "empty") {
87+
VLOG(4) << "Current op is \"empty\", skip nan inf check.";
88+
return;
89+
}
90+
91+
if (api_name == "empty_like") {
92+
VLOG(4) << "Current op is \"empty_like\", skip nan inf check.";
93+
return;
94+
}
95+
96+
if (!FLAGS_check_nan_inf_blacklist.empty()) {
97+
std::stringstream blacklist_ss(FLAGS_check_nan_inf_blacklist);
98+
std::string blacklisted_op;
99+
while (std::getline(blacklist_ss, blacklisted_op, ',')) {
100+
if (api_name == blacklisted_op) {
101+
VLOG(4) << "Current op is in blacklist, skip nan inf check: "
102+
<< api_name;
103+
return;
104+
}
105+
}
106+
}
85107
auto op_name = phi::TransToFluidOpName(api_name);
86108
if (tensor.initialized() && CheckOp(op_name)) {
87109
auto& tensor_name = tensor.name();

test/cpp/eager/task_tests/nan_inf_utils_test.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616

1717
#include <iostream>
1818
#include <limits>
19+
#include <ostream>
20+
#include <string>
1921
#include <tuple>
2022

2123
#include "gtest/gtest.h"
24+
#include "paddle/common/flags.h"
2225
#include "paddle/fluid/framework/tensor_util.h"
2326
#include "paddle/fluid/platform/enforce.h"
2427
#include "paddle/phi/api/include/api.h"
@@ -28,8 +31,11 @@
2831
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
2932
PD_DECLARE_KERNEL(strings_empty, CPU, ALL_LAYOUT);
3033

34+
COMMON_DECLARE_string(check_nan_inf_blacklist);
35+
3136
namespace egr {
3237

38+
using paddle_flags::FLAGS_check_nan_inf_blacklist;
3339
#define CHECK_NAN_INF(tensors) \
3440
{ \
3541
bool caught_exception = false; \
@@ -56,6 +62,63 @@ namespace egr {
5662
EXPECT_FALSE(caught_exception); \
5763
}
5864

65+
#define CHECK_APINAME_SKIP(api_name, tensor) \
66+
{ \
67+
bool caught_exception = false; \
68+
try { \
69+
CheckTensorHasNanOrInf(api_name, tensor); \
70+
} catch (paddle::platform::EnforceNotMet & error) { \
71+
caught_exception = true; \
72+
} \
73+
EXPECT_FALSE(caught_exception); \
74+
}
75+
76+
#define CHECK_APINAME_NO_SKIP(api_name, tensor) \
77+
{ \
78+
bool caught_exception = false; \
79+
try { \
80+
CheckTensorHasNanOrInf(api_name, tensor); \
81+
} catch (paddle::platform::EnforceNotMet & error) { \
82+
caught_exception = true; \
83+
} \
84+
EXPECT_TRUE(caught_exception); \
85+
}
86+
87+
TEST(NanInfUtils, BlacklistSkipCheck) {
88+
auto nan_tensor = paddle::experimental::full(
89+
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
90+
91+
FLAGS_check_nan_inf_blacklist = "";
92+
CHECK_APINAME_SKIP("empty", nan_tensor);
93+
94+
// Test that "empty_like" always skips regardless of blacklist
95+
FLAGS_check_nan_inf_blacklist = "";
96+
CHECK_APINAME_SKIP("empty_like", nan_tensor);
97+
98+
// Test with empty blacklist (default behavior)
99+
FLAGS_check_nan_inf_blacklist = "";
100+
CHECK_APINAME_NO_SKIP("some_op", nan_tensor);
101+
102+
// Test with single op in blacklist
103+
FLAGS_check_nan_inf_blacklist = "single_op";
104+
CHECK_APINAME_SKIP("single_op", nan_tensor);
105+
CHECK_APINAME_NO_SKIP("other_op", nan_tensor);
106+
107+
// Even when blacklist is set, these should still skip
108+
CHECK_APINAME_SKIP("empty", nan_tensor);
109+
CHECK_APINAME_SKIP("empty_like", nan_tensor);
110+
111+
// blacklist="op1,op2,op3" and op is in blacklist
112+
FLAGS_check_nan_inf_blacklist = "op1,op2,op3";
113+
CHECK_APINAME_SKIP("op1", nan_tensor);
114+
CHECK_APINAME_SKIP("op2", nan_tensor);
115+
CHECK_APINAME_SKIP("op3", nan_tensor);
116+
// not in blacklist, should perform nan_or_inf check
117+
CHECK_APINAME_NO_SKIP("op4", nan_tensor);
118+
119+
FLAGS_check_nan_inf_blacklist = "";
120+
}
121+
59122
TEST(NanInfUtils, Functions) {
60123
// test all methods
61124
auto tensor = paddle::experimental::full(

0 commit comments

Comments
 (0)