Skip to content

Commit bb30696

Browse files
omoYangSigureMo
andcommitted
Fix arange/range accuracy and assign op handling
- Fix the problem of size accuracy when using arange/range with different input and output types. - include paddle/phi/core/visit_type.h in xpu/arange_kernel.cc - 1.Add old infermeta to support legacy; 2.align precision in range_kernel; 3. fix a little bug in creation - Add an `assign` op after a Tensor argument to avoid confusing it with the scalar case in `arange` - refine the comment about assign op insertion - cleanup unused comments - insert assign for full op output only - 1.Add TestRangeV2LegacyInferMeta.test_range_v2_legacy to test the RangeTensorInferMetaLegacy via legacy static graph path; 2.delete the inf check for `step` - Check if CI reaches GetSizeForRange - 1.delete PADDLE_THROW in range_func.h; 2.add a case to cover the size is 0 in range kernel - Update python/paddle/jit/dy2static/convert_operators.py - 1.Remove redundant code and fix type inconsistencies in kernel. 2.Add missing includes to avoid transitive includes. 3.Wrap the main body of test_range_v2_legacy in a try/finally to guarantee static mode is disabled even on failure. - 1. Use MPType for intermediate computation in CPU kernels to support future 16-bit types; 2.Unify tensor input type conversion by converting tensor to scalar and calling scalar.to<>(). Co-authored-by: omoYang <1115418865@qq.com> Co-authored-by: SigureMo <sigure.qaq@gmail.com>
1 parent 2b9f8b6 commit bb30696

14 files changed

Lines changed: 456 additions & 111 deletions

File tree

paddle/phi/common/data_type.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,20 @@ inline DataType StringToDataType(const std::string& dtype) {
328328
}
329329
}
330330

331+
inline bool IsFloatingType(const DataType& type) {
332+
return (type == DataType::FLOAT16 || type == DataType::BFLOAT16 ||
333+
type == DataType::FLOAT32 || type == DataType::FLOAT64 ||
334+
type == DataType::FLOAT8_E4M3FN || type == DataType::FLOAT8_E5M2);
335+
}
336+
337+
inline bool IsIntegerType(const DataType& type) {
338+
return (type == DataType::INT8 || type == DataType::INT16 ||
339+
type == DataType::INT32 || type == DataType::INT64 ||
340+
type == DataType::UINT8 || type == DataType::UINT16 ||
341+
type == DataType::UINT32 || type == DataType::UINT64 ||
342+
type == DataType::BOOL);
343+
}
344+
331345
} // namespace phi
332346

333347
namespace paddle {

paddle/phi/infermeta/ternary.cc

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,61 @@ void FlashAttnV3VarlenInferMeta(const MetaTensor& q,
882882
softmax_lse->set_dtype(DataType::FLOAT32);
883883
}
884884

885+
void ArangeTensorInferMetaLegacy(const MetaTensor& start,
886+
const MetaTensor& end,
887+
const MetaTensor& step,
888+
MetaTensor* out) {
889+
PADDLE_ENFORCE_EQ(common::product(start.dims()),
890+
1,
891+
common::errors::InvalidArgument(
892+
"The numel of Input(start) should be 1, but got %d",
893+
common::product(start.dims())));
894+
895+
PADDLE_ENFORCE_EQ(common::product(end.dims()),
896+
1,
897+
common::errors::InvalidArgument(
898+
"The numel of Input(end) should be 1, but got %d",
899+
common::product(end.dims())));
900+
901+
PADDLE_ENFORCE_EQ(common::product(step.dims()),
902+
1,
903+
common::errors::InvalidArgument(
904+
"The numel of Input(step) should be 1, but got %d",
905+
common::product(step.dims())));
906+
907+
out->set_dims({-1});
908+
out->set_dtype(start.dtype());
909+
}
910+
911+
void RangeTensorInferMetaLegacy(const MetaTensor& start,
912+
const MetaTensor& end,
913+
const MetaTensor& step,
914+
MetaTensor* out) {
915+
PADDLE_ENFORCE_EQ(common::product(start.dims()),
916+
1,
917+
common::errors::InvalidArgument(
918+
"The numel of Input(start) should be 1, but got %d",
919+
common::product(start.dims())));
920+
921+
PADDLE_ENFORCE_EQ(common::product(end.dims()),
922+
1,
923+
common::errors::InvalidArgument(
924+
"The numel of Input(end) should be 1, but got %d",
925+
common::product(end.dims())));
926+
927+
PADDLE_ENFORCE_EQ(common::product(step.dims()),
928+
1,
929+
common::errors::InvalidArgument(
930+
"The numel of Input(step) should be 1, but got %d",
931+
common::product(step.dims())));
932+
933+
out->set_dims({-1});
934+
out->set_dtype(start.dtype());
935+
}
885936
void ArangeTensorInferMeta(const MetaTensor& start,
886937
const MetaTensor& end,
887938
const MetaTensor& step,
939+
DataType dtype,
888940
MetaTensor* out) {
889941
PADDLE_ENFORCE_EQ(common::product(start.dims()),
890942
1,
@@ -905,12 +957,13 @@ void ArangeTensorInferMeta(const MetaTensor& start,
905957
common::product(step.dims())));
906958

907959
out->set_dims({-1});
908-
out->set_dtype(start.dtype());
960+
out->set_dtype(dtype);
909961
}
910962

911963
void RangeTensorInferMeta(const MetaTensor& start,
912964
const MetaTensor& end,
913965
const MetaTensor& step,
966+
DataType dtype,
914967
MetaTensor* out) {
915968
PADDLE_ENFORCE_EQ(common::product(start.dims()),
916969
1,
@@ -931,7 +984,7 @@ void RangeTensorInferMeta(const MetaTensor& start,
931984
common::product(step.dims())));
932985

933986
out->set_dims({-1});
934-
out->set_dtype(start.dtype());
987+
out->set_dtype(dtype);
935988
}
936989

937990
void CollectFpnProposalsInferMeta(

paddle/phi/infermeta/ternary.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,25 @@ PADDLE_API void AffineChannelInferMeta(const MetaTensor& x,
6565
PADDLE_API void ArangeTensorInferMeta(const MetaTensor& start,
6666
const MetaTensor& end,
6767
const MetaTensor& step,
68+
DataType dtype,
6869
MetaTensor* out);
6970

7071
PADDLE_API void RangeTensorInferMeta(const MetaTensor& start,
7172
const MetaTensor& end,
7273
const MetaTensor& step,
74+
DataType dtype,
7375
MetaTensor* out);
7476

77+
PADDLE_API void ArangeTensorInferMetaLegacy(const MetaTensor& start,
78+
const MetaTensor& end,
79+
const MetaTensor& step,
80+
MetaTensor* out);
81+
82+
PADDLE_API void RangeTensorInferMetaLegacy(const MetaTensor& start,
83+
const MetaTensor& end,
84+
const MetaTensor& step,
85+
MetaTensor* out);
86+
7587
PADDLE_API void AssignPosInferMeta(const MetaTensor& x,
7688
const MetaTensor& cum_count,
7789
const MetaTensor& eff_num_len,

paddle/phi/kernels/cpu/arange_kernel.cc

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/phi/kernels/arange_kernel.h"
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/common/amp_type_traits.h"
1819
#include "paddle/phi/core/kernel_registry.h"
1920
#include "paddle/phi/kernels/funcs/range_function.h"
2021

@@ -43,10 +44,38 @@ void ArangeTensorKernel(const Context& dev_ctx,
4344
const DenseTensor& end,
4445
const DenseTensor& step,
4546
DenseTensor* out) {
46-
T start_value = start.data<T>()[0];
47-
T end_value = end.data<T>()[0];
48-
T step_value = step.data<T>()[0];
49-
ArangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
47+
int64_t size = 0;
48+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
49+
50+
bool any_float = phi::IsFloatingType(start.dtype()) ||
51+
phi::IsFloatingType(end.dtype()) ||
52+
phi::IsFloatingType(step.dtype());
53+
54+
Scalar start_scalar(start);
55+
Scalar end_scalar(end);
56+
Scalar step_scalar(step);
57+
58+
if (any_float) {
59+
double sv = start_scalar.to<double>();
60+
double ev = end_scalar.to<double>();
61+
double stv = step_scalar.to<double>();
62+
funcs::GetSize<double>(sv, ev, stv, &size);
63+
} else {
64+
int64_t sv = start_scalar.to<int64_t>();
65+
int64_t ev = end_scalar.to<int64_t>();
66+
int64_t stv = step_scalar.to<int64_t>();
67+
funcs::GetSize<int64_t>(sv, ev, stv, &size);
68+
}
69+
MPType start_value = start_scalar.to<MPType>();
70+
MPType step_value = step_scalar.to<MPType>();
71+
72+
out->Resize({size});
73+
T* out_data = dev_ctx.template Alloc<T>(out);
74+
MPType value = start_value;
75+
for (int64_t i = 0; i < size; ++i) {
76+
out_data[i] = static_cast<T>(value);
77+
value += step_value;
78+
}
5079
}
5180

5281
template <typename T, typename Context>
@@ -55,10 +84,31 @@ void ArangeKernel(const Context& dev_ctx,
5584
const Scalar& end,
5685
const Scalar& step,
5786
DenseTensor* out) {
58-
T start_value = start.to<T>();
59-
T end_value = end.to<T>();
60-
T step_value = step.to<T>();
61-
ArangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
87+
bool any_float = phi::IsFloatingType(start.dtype()) ||
88+
phi::IsFloatingType(end.dtype()) ||
89+
phi::IsFloatingType(step.dtype());
90+
int64_t size = 0;
91+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
92+
if (any_float) {
93+
double sv = start.to<double>();
94+
double ev = end.to<double>();
95+
double stv = step.to<double>();
96+
funcs::GetSize<double>(sv, ev, stv, &size);
97+
} else {
98+
int64_t sv = start.to<int64_t>();
99+
int64_t ev = end.to<int64_t>();
100+
int64_t stv = step.to<int64_t>();
101+
funcs::GetSize<int64_t>(sv, ev, stv, &size);
102+
}
103+
MPType start_value = start.to<MPType>();
104+
MPType step_value = step.to<MPType>();
105+
out->Resize({size});
106+
T* out_data = dev_ctx.template Alloc<T>(out);
107+
MPType value = start_value;
108+
for (int64_t i = 0; i < size; ++i) {
109+
out_data[i] = static_cast<T>(value);
110+
value += step_value;
111+
}
62112
}
63113

64114
} // namespace phi

paddle/phi/kernels/cpu/range_kernel.cc

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License. */
1515
#include "paddle/phi/kernels/range_kernel.h"
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
18+
#include "paddle/phi/common/amp_type_traits.h"
1819
#include "paddle/phi/core/kernel_registry.h"
1920
#include "paddle/phi/kernels/funcs/range_function.h"
2021

@@ -46,13 +47,27 @@ void RangeTensorKernel(const Context& dev_ctx,
4647
const DenseTensor& end,
4748
const DenseTensor& step,
4849
DenseTensor* out) {
49-
T start_value = start.data<T>()[0];
50-
T end_value = end.data<T>()[0];
51-
T step_value = step.data<T>()[0];
52-
if (step_value == static_cast<T>(0)) {
53-
PADDLE_THROW(errors::InvalidArgument("step must be nonzero."));
50+
int64_t size = 0;
51+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
52+
Scalar start_scalar(start);
53+
Scalar end_scalar(end);
54+
Scalar step_scalar(step);
55+
MPType start_value = start_scalar.to<MPType>();
56+
MPType end_value = end_scalar.to<MPType>();
57+
MPType step_value = step_scalar.to<MPType>();
58+
59+
funcs::GetSizeForRange(start_value, end_value, step_value, &size);
60+
61+
out->Resize({size});
62+
T* out_data = dev_ctx.template Alloc<T>(out);
63+
if (size == 0) {
64+
return;
65+
}
66+
MPType value = start_value;
67+
for (int64_t i = 0; i < size; ++i) {
68+
out_data[i] = static_cast<T>(value);
69+
value += step_value;
5470
}
55-
RangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
5671
}
5772

5873
template <typename T, typename Context>
@@ -61,16 +76,22 @@ void RangeKernel(const Context& dev_ctx,
6176
const Scalar& end,
6277
const Scalar& step,
6378
DenseTensor* out) {
64-
T start_value = start.to<T>();
65-
T end_value = end.to<T>();
66-
T step_value = step.to<T>();
67-
if constexpr (std::is_floating_point_v<T>) {
68-
if (std::isnan(end_value)) {
69-
PADDLE_THROW(common::errors::InvalidArgument(
70-
"The end value of range cannot be NaN. Please check your input."));
71-
}
79+
int64_t size = 0;
80+
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
81+
MPType start_value = start.to<MPType>();
82+
MPType end_value = end.to<MPType>();
83+
MPType step_value = step.to<MPType>();
84+
funcs::GetSizeForRange(start_value, end_value, step_value, &size);
85+
out->Resize({size});
86+
T* out_data = dev_ctx.template Alloc<T>(out);
87+
if (size == 0) {
88+
return;
89+
}
90+
MPType value = start_value;
91+
for (int64_t i = 0; i < size; ++i) {
92+
out_data[i] = static_cast<T>(value);
93+
value += step_value;
7294
}
73-
RangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
7495
}
7596

7697
} // namespace phi

paddle/phi/kernels/funcs/range_function.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#pragma once
16+
#include <cmath>
17+
#include <type_traits>
1618
#include "paddle/phi/core/enforce.h"
1719

1820
namespace phi {
@@ -61,5 +63,37 @@ void GetSize(T start, T end, T step, int64_t* size) {
6163
: std::ceil(std::abs((end - start) / step));
6264
}
6365

66+
template <typename T>
67+
void GetSizeForRange(T start, T end, T step, int64_t* size) {
68+
// For range op: closed interval [start, end]
69+
PADDLE_ENFORCE_NE(
70+
step,
71+
0,
72+
common::errors::InvalidArgument("The step of range op should not be 0."));
73+
74+
if constexpr (std::is_same_v<T, phi::bfloat16> ||
75+
std::is_same_v<T, phi::float16>) {
76+
PADDLE_ENFORCE_EQ(
77+
phi::dtype::isfinite(start) && phi::dtype::isfinite(end),
78+
true,
79+
common::errors::InvalidArgument(
80+
"The start, end and step of range op should be finite "
81+
"numbers, but received start=%f, end=%f.",
82+
static_cast<double>(start),
83+
static_cast<double>(end)));
84+
} else if constexpr (std::is_floating_point_v<T>) {
85+
PADDLE_ENFORCE_EQ(
86+
std::isfinite(start) && std::isfinite(end),
87+
true,
88+
common::errors::InvalidArgument(
89+
"The start, end and step of range op should be finite "
90+
"numbers, but received start=%f, end=%f.",
91+
static_cast<double>(start),
92+
static_cast<double>(end)));
93+
}
94+
// Closed interval [start, end], so we add 1
95+
*size = static_cast<int64_t>(((end - start) / step) + 1);
96+
}
97+
6498
} // namespace funcs
6599
} // namespace phi

0 commit comments

Comments
 (0)