Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/phi/common/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,20 @@ inline DataType StringToDataType(const std::string& dtype) {
}
}

inline bool IsFloatingType(const DataType& type) {
return (type == DataType::FLOAT16 || type == DataType::BFLOAT16 ||
type == DataType::FLOAT32 || type == DataType::FLOAT64 ||
type == DataType::FLOAT8_E4M3FN || type == DataType::FLOAT8_E5M2);
}

inline bool IsIntegerType(const DataType& type) {
return (type == DataType::INT8 || type == DataType::INT16 ||
type == DataType::INT32 || type == DataType::INT64 ||
type == DataType::UINT8 || type == DataType::UINT16 ||
type == DataType::UINT32 || type == DataType::UINT64 ||
type == DataType::BOOL);
}

} // namespace phi

namespace paddle {
Expand Down
57 changes: 55 additions & 2 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,61 @@ void FlashAttnV3VarlenInferMeta(const MetaTensor& q,
softmax_lse->set_dtype(DataType::FLOAT32);
}

void ArangeTensorInferMetaLegacy(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(common::product(start.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(start) should be 1, but got %d",
common::product(start.dims())));

PADDLE_ENFORCE_EQ(common::product(end.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(end) should be 1, but got %d",
common::product(end.dims())));

PADDLE_ENFORCE_EQ(common::product(step.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(step) should be 1, but got %d",
common::product(step.dims())));

out->set_dims({-1});
out->set_dtype(start.dtype());
}

void RangeTensorInferMetaLegacy(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(common::product(start.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(start) should be 1, but got %d",
common::product(start.dims())));

PADDLE_ENFORCE_EQ(common::product(end.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(end) should be 1, but got %d",
common::product(end.dims())));

PADDLE_ENFORCE_EQ(common::product(step.dims()),
1,
common::errors::InvalidArgument(
"The numel of Input(step) should be 1, but got %d",
common::product(step.dims())));

out->set_dims({-1});
out->set_dtype(start.dtype());
}
void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
DataType dtype,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(common::product(start.dims()),
1,
Expand All @@ -905,12 +957,13 @@ void ArangeTensorInferMeta(const MetaTensor& start,
common::product(step.dims())));

out->set_dims({-1});
out->set_dtype(start.dtype());
out->set_dtype(dtype);
}

void RangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
DataType dtype,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(common::product(start.dims()),
1,
Expand All @@ -931,7 +984,7 @@ void RangeTensorInferMeta(const MetaTensor& start,
common::product(step.dims())));

out->set_dims({-1});
out->set_dtype(start.dtype());
out->set_dtype(dtype);
}

void CollectFpnProposalsInferMeta(
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,25 @@ PADDLE_API void AffineChannelInferMeta(const MetaTensor& x,
PADDLE_API void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
DataType dtype,
MetaTensor* out);

PADDLE_API void RangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
DataType dtype,
MetaTensor* out);

PADDLE_API void ArangeTensorInferMetaLegacy(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
MetaTensor* out);

PADDLE_API void RangeTensorInferMetaLegacy(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
MetaTensor* out);

PADDLE_API void AssignPosInferMeta(const MetaTensor& x,
const MetaTensor& cum_count,
const MetaTensor& eff_num_len,
Expand Down
66 changes: 58 additions & 8 deletions paddle/phi/kernels/cpu/arange_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/arange_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/range_function.h"

Expand Down Expand Up @@ -43,10 +44,38 @@ void ArangeTensorKernel(const Context& dev_ctx,
const DenseTensor& end,
const DenseTensor& step,
DenseTensor* out) {
T start_value = start.data<T>()[0];
T end_value = end.data<T>()[0];
T step_value = step.data<T>()[0];
ArangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
int64_t size = 0;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

bool any_float = phi::IsFloatingType(start.dtype()) ||
phi::IsFloatingType(end.dtype()) ||
phi::IsFloatingType(step.dtype());

Scalar start_scalar(start);
Scalar end_scalar(end);
Scalar step_scalar(step);

if (any_float) {
double sv = start_scalar.to<double>();
double ev = end_scalar.to<double>();
double stv = step_scalar.to<double>();
funcs::GetSize<double>(sv, ev, stv, &size);
} else {
int64_t sv = start_scalar.to<int64_t>();
int64_t ev = end_scalar.to<int64_t>();
int64_t stv = step_scalar.to<int64_t>();
funcs::GetSize<int64_t>(sv, ev, stv, &size);
}
MPType start_value = start_scalar.to<MPType>();
MPType step_value = step_scalar.to<MPType>();

out->Resize({size});
T* out_data = dev_ctx.template Alloc<T>(out);
MPType value = start_value;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = static_cast<T>(value);
value += step_value;
}
}

template <typename T, typename Context>
Expand All @@ -55,10 +84,31 @@ void ArangeKernel(const Context& dev_ctx,
const Scalar& end,
const Scalar& step,
DenseTensor* out) {
T start_value = start.to<T>();
T end_value = end.to<T>();
T step_value = step.to<T>();
ArangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
bool any_float = phi::IsFloatingType(start.dtype()) ||
phi::IsFloatingType(end.dtype()) ||
phi::IsFloatingType(step.dtype());
int64_t size = 0;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
if (any_float) {
double sv = start.to<double>();
double ev = end.to<double>();
double stv = step.to<double>();
funcs::GetSize<double>(sv, ev, stv, &size);
} else {
int64_t sv = start.to<int64_t>();
int64_t ev = end.to<int64_t>();
int64_t stv = step.to<int64_t>();
funcs::GetSize<int64_t>(sv, ev, stv, &size);
}
MPType start_value = start.to<MPType>();
MPType step_value = step.to<MPType>();
out->Resize({size});
T* out_data = dev_ctx.template Alloc<T>(out);
MPType value = start_value;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = static_cast<T>(value);
value += step_value;
}
}

} // namespace phi
Expand Down
51 changes: 36 additions & 15 deletions paddle/phi/kernels/cpu/range_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/range_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/range_function.h"

Expand Down Expand Up @@ -46,13 +47,27 @@ void RangeTensorKernel(const Context& dev_ctx,
const DenseTensor& end,
const DenseTensor& step,
DenseTensor* out) {
T start_value = start.data<T>()[0];
T end_value = end.data<T>()[0];
T step_value = step.data<T>()[0];
if (step_value == static_cast<T>(0)) {
PADDLE_THROW(errors::InvalidArgument("step must be nonzero."));
int64_t size = 0;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Scalar start_scalar(start);
Scalar end_scalar(end);
Scalar step_scalar(step);
MPType start_value = start_scalar.to<MPType>();
MPType end_value = end_scalar.to<MPType>();
MPType step_value = step_scalar.to<MPType>();

funcs::GetSizeForRange(start_value, end_value, step_value, &size);

out->Resize({size});
T* out_data = dev_ctx.template Alloc<T>(out);
if (size == 0) {
return;
}
MPType value = start_value;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = static_cast<T>(value);
value += step_value;
}
RangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
}

template <typename T, typename Context>
Expand All @@ -61,16 +76,22 @@ void RangeKernel(const Context& dev_ctx,
const Scalar& end,
const Scalar& step,
DenseTensor* out) {
T start_value = start.to<T>();
T end_value = end.to<T>();
T step_value = step.to<T>();
if constexpr (std::is_floating_point_v<T>) {
if (std::isnan(end_value)) {
PADDLE_THROW(common::errors::InvalidArgument(
"The end value of range cannot be NaN. Please check your input."));
}
int64_t size = 0;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType start_value = start.to<MPType>();
MPType end_value = end.to<MPType>();
MPType step_value = step.to<MPType>();
funcs::GetSizeForRange(start_value, end_value, step_value, &size);
out->Resize({size});
T* out_data = dev_ctx.template Alloc<T>(out);
if (size == 0) {
return;
}
MPType value = start_value;
for (int64_t i = 0; i < size; ++i) {
out_data[i] = static_cast<T>(value);
value += step_value;
}
RangeFunc<T, Context>(dev_ctx, start_value, end_value, step_value, out);
}

} // namespace phi
Expand Down
34 changes: 34 additions & 0 deletions paddle/phi/kernels/funcs/range_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
// limitations under the License.

#pragma once
#include <cmath>
#include <type_traits>
#include "paddle/phi/core/enforce.h"

namespace phi {
Expand Down Expand Up @@ -61,5 +63,37 @@ void GetSize(T start, T end, T step, int64_t* size) {
: std::ceil(std::abs((end - start) / step));
}

template <typename T>
void GetSizeForRange(T start, T end, T step, int64_t* size) {
// For range op: closed interval [start, end]
PADDLE_ENFORCE_NE(
step,
0,
common::errors::InvalidArgument("The step of range op should not be 0."));

if constexpr (std::is_same_v<T, phi::bfloat16> ||
std::is_same_v<T, phi::float16>) {
PADDLE_ENFORCE_EQ(
phi::dtype::isfinite(start) && phi::dtype::isfinite(end),
true,
common::errors::InvalidArgument(
"The start, end and step of range op should be finite "
"numbers, but received start=%f, end=%f.",
static_cast<double>(start),
static_cast<double>(end)));
} else if constexpr (std::is_floating_point_v<T>) {
PADDLE_ENFORCE_EQ(
std::isfinite(start) && std::isfinite(end),
true,
common::errors::InvalidArgument(
"The start, end and step of range op should be finite "
"numbers, but received start=%f, end=%f.",
static_cast<double>(start),
static_cast<double>(end)));
}
// Closed interval [start, end], so we add 1
*size = static_cast<int64_t>(((end - start) / step) + 1);
}

} // namespace funcs
} // namespace phi
Loading
Loading