Skip to content

Commit 7f0b4b5

Browse files
fix tinyformat (#78793)
* fix tinyformat * for test
1 parent 475141a commit 7f0b4b5

5 files changed

Lines changed: 46 additions & 49 deletions

File tree

paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ COMMON_DECLARE_bool(use_stride_compute_kernel);
3838

3939
namespace phi {
4040

41-
inline void PrepareStridedOut(DenseTensor* out) {
41+
inline void PrepareStridedOut_elementwise(DenseTensor* out) {
4242
if (!FLAGS_use_stride_kernel) {
4343
PADDLE_THROW(common::errors::Fatal(
4444
"FLAGS_use_stride_kernel is closed. Strided kernel "
@@ -56,7 +56,7 @@ void SumStrideKernel(const Context& dev_ctx,
5656
DataType out_dtype,
5757
bool keep_dim,
5858
DenseTensor* out) {
59-
PrepareStridedOut(out);
59+
PrepareStridedOut_elementwise(out);
6060

6161
phi::SumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
6262
}

paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ inline bool UseCanonicalizedTransposeGradPath(const Context& dev_ctx) {
4444
#endif
4545
}
4646

47-
inline void PrepareStridedOut(DenseTensor* out) {
47+
inline void PrepareStridedOut_matmul(DenseTensor* out) {
4848
if (out == nullptr) {
4949
return;
5050
}
@@ -175,8 +175,8 @@ void MatmulGradStrideKernel(const Context& dev_ctx,
175175
if (!out_grad_.meta().is_contiguous()) {
176176
out_grad_ = Tensor2Contiguous<Context>(dev_ctx, out_grad_);
177177
}
178-
PrepareStridedOut(dx);
179-
PrepareStridedOut(dy);
178+
PrepareStridedOut_matmul(dx);
179+
PrepareStridedOut_matmul(dy);
180180
phi::MatmulGradKernel<T, Context>(
181181
dev_ctx, x_, y_, out_grad_, transpose_x, transpose_y, dx, dy);
182182
return;
@@ -204,14 +204,14 @@ void MatmulGradStrideKernel(const Context& dev_ctx,
204204
dx_tmp.Resize(x_.dims());
205205
dx_out = &dx_tmp;
206206
} else {
207-
PrepareStridedOut(dx_out);
207+
PrepareStridedOut_matmul(dx_out);
208208
}
209209

210210
if (dy != nullptr && y_info.applied) {
211211
dy_tmp.Resize(y_.dims());
212212
dy_out = &dy_tmp;
213213
} else {
214-
PrepareStridedOut(dy_out);
214+
PrepareStridedOut_matmul(dy_out);
215215
}
216216

217217
phi::MatmulGradKernel<T, Context>(

paddle/phi/kernels/stride/reduce_stride_kernel.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ COMMON_DECLARE_bool(force_stride_compute_contig_out);
3434

3535
namespace phi {
3636

37-
inline void PrepareStridedOut(DenseTensor* out) {
37+
inline void PrepareStridedOut_reduce(DenseTensor* out) {
3838
if (!FLAGS_use_stride_kernel) {
3939
PADDLE_THROW(common::errors::Fatal(
4040
"FLAGS_use_stride_kernel is closed. Strided kernel "
@@ -51,7 +51,7 @@ void AMaxStrideKernel(const Context& dev_ctx,
5151
const std::vector<int64_t>& dims,
5252
bool keep_dim,
5353
DenseTensor* out) {
54-
PrepareStridedOut(out);
54+
PrepareStridedOut_reduce(out);
5555

5656
phi::AMaxKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
5757
}
@@ -62,7 +62,7 @@ void AMinStrideKernel(const Context& dev_ctx,
6262
const std::vector<int64_t>& dims,
6363
bool keep_dim,
6464
DenseTensor* out) {
65-
PrepareStridedOut(out);
65+
PrepareStridedOut_reduce(out);
6666

6767
phi::AMinKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
6868
}
@@ -73,7 +73,7 @@ void MaxStrideKernel(const Context& dev_ctx,
7373
const IntArray& dims,
7474
bool keep_dim,
7575
DenseTensor* out) {
76-
PrepareStridedOut(out);
76+
PrepareStridedOut_reduce(out);
7777

7878
phi::MaxKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
7979
}
@@ -84,7 +84,7 @@ void MinStrideKernel(const Context& dev_ctx,
8484
const IntArray& dims,
8585
bool keep_dim,
8686
DenseTensor* out) {
87-
PrepareStridedOut(out);
87+
PrepareStridedOut_reduce(out);
8888

8989
phi::MinKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
9090
}
@@ -96,7 +96,7 @@ void ProdStrideKernel(const Context& dev_ctx,
9696
bool keep_dim,
9797
bool reduce_all,
9898
DenseTensor* out) {
99-
PrepareStridedOut(out);
99+
PrepareStridedOut_reduce(out);
100100

101101
phi::ProdKernel<T, Context>(dev_ctx, x, dims, keep_dim, reduce_all, out);
102102
}
@@ -107,7 +107,7 @@ void AllStrideKernel(const Context& dev_ctx,
107107
const std::vector<int64_t>& dims,
108108
bool keep_dim,
109109
DenseTensor* out) {
110-
PrepareStridedOut(out);
110+
PrepareStridedOut_reduce(out);
111111

112112
phi::AllKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
113113
}
@@ -118,7 +118,7 @@ void AnyStrideKernel(const Context& dev_ctx,
118118
const std::vector<int64_t>& dims,
119119
bool keep_dim,
120120
DenseTensor* out) {
121-
PrepareStridedOut(out);
121+
PrepareStridedOut_reduce(out);
122122

123123
phi::AnyKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
124124
}
@@ -130,7 +130,7 @@ void SumStrideKernel(const Context& dev_ctx,
130130
DataType out_dtype,
131131
bool keep_dim,
132132
DenseTensor* out) {
133-
PrepareStridedOut(out);
133+
PrepareStridedOut_reduce(out);
134134

135135
phi::SumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
136136
}
@@ -142,7 +142,7 @@ void NansumStrideKernel(const Context& dev_ctx,
142142
DataType out_dtype,
143143
bool keep_dim,
144144
DenseTensor* out) {
145-
PrepareStridedOut(out);
145+
PrepareStridedOut_reduce(out);
146146
phi::NansumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
147147
}
148148

@@ -152,7 +152,7 @@ void MeanStrideKernel(const Context& dev_ctx,
152152
const IntArray& dims,
153153
bool keep_dim,
154154
DenseTensor* out) {
155-
PrepareStridedOut(out);
155+
PrepareStridedOut_reduce(out);
156156

157157
phi::MeanKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
158158
}

paddle/utils/string/printf.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@ namespace string {
8181

8282
template <typename... Args>
8383
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
84-
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
84+
try {
85+
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
86+
} catch (const tinyformat::detail::FormatError&) {
87+
out << fmt;
88+
}
8589
}
8690

8791
inline std::string Sprintf() { return ""; }
@@ -95,9 +99,13 @@ std::string Sprintf(const Args&... args) {
9599

96100
template <typename... Args>
97101
std::string Sprintf(const char* fmt, const Args&... args) {
98-
std::ostringstream oss;
99-
Fprintf(oss, fmt, args...);
100-
return oss.str();
102+
try {
103+
std::ostringstream oss;
104+
Fprintf(oss, fmt, args...);
105+
return oss.str();
106+
} catch (const tinyformat::detail::FormatError&) {
107+
return fmt;
108+
}
101109
}
102110

103111
template <typename... Args>

paddle/utils/string/tinyformat/tinyformat.h

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@
119119
// Additional API information
120120
// --------------------------
121121
//
122-
// Error handling: Define TINYFORMAT_ERROR to customize the error handling for
123-
// format strings which are unsupported or have the wrong number of format
124-
// specifiers (calls assert() by default).
122+
// Error handling: Format errors throw detail::FormatError, which is caught
123+
// at the public API level to fall back to the raw format string.
125124
//
126125
// User defined types: Uses operator<< for user defined types by default.
127126
// Overload formatValue() for more control.
@@ -139,13 +138,14 @@ namespace paddle {
139138
namespace string {
140139
namespace tinyformat {
141140

142-
#ifndef TINYFORMAT_ERROR
143-
#define TINYFORMAT_ERROR(reason) assert(0 && reason)
144-
#endif
145-
146141
//------------------------------------------------------------------------------
147142
namespace detail {
148143

144+
// Exception thrown on format errors instead of crashing via assert.
145+
// Caught at the public API level to fall back to returning the raw format
146+
// string, so that a wrong PADDLE_ENFORCE format never causes an abort.
147+
struct FormatError {};
148+
149149
// Test whether type T1 is convertible to type T2
150150
template <typename T1, typename T2>
151151
struct is_convertible {
@@ -192,9 +192,7 @@ struct formatValueAsType<T, fmtT, true> {
192192
template <typename T, bool convertible = is_convertible<T, int>::value>
193193
struct convertToInt {
194194
static int invoke(const T & /*value*/) {
195-
TINYFORMAT_ERROR(
196-
"tinyformat: Cannot convert from argument type to "
197-
"integer for use as variable width or precision");
195+
throw FormatError();
198196
return 0;
199197
}
200198
};
@@ -579,8 +577,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
579577
int &argIndex, // NOLINT
580578
int numFormatters) {
581579
if (*fmtStart != '%') {
582-
TINYFORMAT_ERROR(
583-
"tinyformat: Not enough conversion specifiers in format string");
580+
throw FormatError();
584581
return fmtStart;
585582
}
586583
// Reset stream state to defaults.
@@ -639,8 +636,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
639636
if (argIndex < numFormatters)
640637
width = formatters[argIndex++].toInt();
641638
else
642-
TINYFORMAT_ERROR(
643-
"tinyformat: Not enough arguments to read variable width");
639+
throw FormatError();
644640
if (width < 0) {
645641
// negative widths correspond to '-' flag set
646642
out.fill(' ');
@@ -659,8 +655,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
659655
if (argIndex < numFormatters)
660656
precision = formatters[argIndex++].toInt();
661657
else
662-
TINYFORMAT_ERROR(
663-
"tinyformat: Not enough arguments to read variable precision");
658+
throw FormatError();
664659
} else {
665660
if (*c >= '0' && *c <= '9')
666661
precision = parseIntAndAdvance(c);
@@ -724,9 +719,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
724719
break;
725720
case 'a':
726721
case 'A':
727-
TINYFORMAT_ERROR(
728-
"tinyformat: the %a and %A conversion specs "
729-
"are not supported");
722+
throw FormatError();
730723
break;
731724
case 'c':
732725
// Handled as special case inside formatValue()
@@ -738,12 +731,10 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
738731
break;
739732
case 'n':
740733
// Not supported - will cause problems!
741-
TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported");
734+
throw FormatError();
742735
break;
743736
case '\0':
744-
TINYFORMAT_ERROR(
745-
"tinyformat: Conversion spec incorrectly "
746-
"terminated by end of string");
737+
throw FormatError();
747738
return c;
748739
default:
749740
break;
@@ -785,7 +776,7 @@ inline void formatImpl(std::ostream &out,
785776
numFormatters);
786777
if (argIndex >= numFormatters) {
787778
// Check args remain after reading any variable width/precision
788-
TINYFORMAT_ERROR("tinyformat: Not enough format arguments");
779+
throw FormatError();
789780
return;
790781
}
791782
const FormatArg &arg = formatters[argIndex];
@@ -811,9 +802,7 @@ inline void formatImpl(std::ostream &out,
811802

812803
// Print remaining part of format string.
813804
fmt = printFormatStringLiteral(out, fmt);
814-
if (fmt != nullptr && *fmt != '\0' && *fmt != 0)
815-
TINYFORMAT_ERROR(
816-
"tinyformat: Too many conversion specifiers in format string");
805+
if (fmt != nullptr && *fmt != '\0' && *fmt != 0) throw FormatError();
817806

818807
// Restore stream state
819808
out.width(origWidth);

0 commit comments

Comments
 (0)