Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions backends/metax_gpu/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ C_Status AsyncMemCpyD2D(const C_Device device,

template <typename Context>
inline void TensorCopy(const Context& dev_ctx,
const phi::DenseTensor& src,
const DenseTensor& src,
bool blocking,
phi::DenseTensor* dst,
const phi::Place& dst_place = phi::CustomPlace()) {
DenseTensor* dst,
const Place& dst_place = CustomPlace()) {
auto* src_ptr = src.data();
const auto& src_place = src.place();
if (src_ptr == nullptr) {
return;
}
auto dst_place_ = dst_place;
if (dst_place_.GetType() != phi::AllocationType::CPU) {
if (dst_place_.GetType() != AllocationType::CPU) {
dst_place_ = dev_ctx.GetPlace();
}

Expand All @@ -125,7 +125,7 @@ inline void TensorCopy(const Context& dev_ctx,
} else {
VLOG(6) << "Src and dst are the same Tensor, in-place copy data("
<< src_ptr << ") from " << src_place << " to " << dst_place_;
const phi::DenseTensor src_copy = src;
const DenseTensor src_copy = src;
TensorCopy(dev_ctx, src_copy, blocking, dst, dst_place_);
}
return;
Expand All @@ -134,7 +134,7 @@ inline void TensorCopy(const Context& dev_ctx,
auto dst_dims = dst->dims();
dst->Resize(src.dims());
void* dst_ptr = nullptr;
if (dst_place_.GetType() != phi::AllocationType::CPU) {
if (dst_place_.GetType() != AllocationType::CPU) {
dst_ptr = dev_ctx.Alloc(dst, src.dtype());
} else {
dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
Expand All @@ -156,7 +156,7 @@ inline void TensorCopy(const Context& dev_ctx,
return;
} else {
// scatter memory
phi::DenseTensor tmp_dst;
DenseTensor tmp_dst;
tmp_dst.set_meta(dst->meta());
tmp_dst.Resize(dst_dims);
dst_ptr = dev_ctx.Alloc(&tmp_dst, tmp_dst.dtype());
Expand All @@ -176,26 +176,26 @@ inline void TensorCopy(const Context& dev_ctx,
return;
}

if (src_place.GetType() == phi::AllocationType::CPU &&
dst_place_.GetType() == phi::AllocationType::CUSTOM) {
if (src_place.GetType() == AllocationType::CPU &&
dst_place_.GetType() == AllocationType::CUSTOM) {
VLOG(6) << "TensorCopy from cpu to cus";
C_Device_st device;
device.id = dst_place_.GetDeviceId();
AsyncMemCpyH2D(&device, stream, dst_ptr, src_ptr, size);
if (blocking) {
dev_ctx.Wait();
}
} else if (src_place.GetType() == phi::AllocationType::CUSTOM &&
dst_place_.GetType() == phi::AllocationType::CPU) {
} else if (src_place.GetType() == AllocationType::CUSTOM &&
dst_place_.GetType() == AllocationType::CPU) {
VLOG(6) << "TensorCopy from cus to cpu";
C_Device_st device;
device.id = src_place.GetDeviceId();
AsyncMemCpyD2H(&device, stream, dst_ptr, src_ptr, size);
if (blocking) {
dev_ctx.Wait();
}
} else if (src_place.GetType() == phi::AllocationType::CUSTOM &&
dst_place_.GetType() == phi::AllocationType::CUSTOM) {
} else if (src_place.GetType() == AllocationType::CUSTOM &&
dst_place_.GetType() == AllocationType::CUSTOM) {
VLOG(6) << "TensorCopy from cus to cus";
if (src_place.GetDeviceType() == dst_place_.GetDeviceType()) {
if (src_place.GetDeviceId() == dst_place_.GetDeviceId()) {
Expand All @@ -212,26 +212,26 @@ inline void TensorCopy(const Context& dev_ctx,
} else {
PADDLE_THROW(phi::errors::Unimplemented("TensorCopy is not supported."));
}
} else if (src_place.GetType() == phi::AllocationType::CPU &&
dst_place_.GetType() == phi::AllocationType::CPU) {
} else if (src_place.GetType() == AllocationType::CPU &&
dst_place_.GetType() == AllocationType::CPU) {
VLOG(6) << "TensorCopy from cpu to cpu";
std::memcpy(dst_ptr, src_ptr, size);
}
}

template <typename T = float>
std::ostream& PrintTensor(std::ostream& os, const phi::DenseTensor& tensor) {
phi::DenseTensor cpu_tensor;
if (tensor.place().GetType() != phi::AllocationType::CPU) {
auto dev_ctx = static_cast<const phi::CustomContext*>(
phi::DeviceContextPool::Instance().Get(tensor.place()));
TensorCopy(*dev_ctx, tensor, true, &cpu_tensor, phi::CPUPlace());
std::ostream& PrintTensor(std::ostream& os, const DenseTensor& tensor) {
DenseTensor cpu_tensor;
if (tensor.place().GetType() != AllocationType::CPU) {
auto dev_ctx = static_cast<const CustomContext*>(
DeviceContextPool::Instance().Get(tensor.place()));
TensorCopy(*dev_ctx, tensor, true, &cpu_tensor, CPUPlace());
} else {
cpu_tensor = tensor;
}
os << "DenseTensor<";
if (tensor.initialized()) {
os << phi::DataTypeToString(tensor.dtype()) << ", ";
os << DataTypeToString(tensor.dtype()) << ", ";
os << tensor.place() << ", ";
os << "Shape(" << tensor.dims() << "), ";
os << "Strides(" << tensor.strides() << "), ";
Expand Down Expand Up @@ -266,27 +266,27 @@ std::ostream& PrintTensor(std::ostream& os, const phi::DenseTensor& tensor) {
}
} // namespace

#define FOR_EACH_DATA_TYPE_TO_PRINT(_) \
_(bool, phi::DataType::BOOL) \
_(int8_t, phi::DataType::INT8) \
_(uint8_t, phi::DataType::UINT8) \
_(int16_t, phi::DataType::INT16) \
_(uint16_t, phi::DataType::UINT16) \
_(int32_t, phi::DataType::INT32) \
_(uint32_t, phi::DataType::UINT32) \
_(int64_t, phi::DataType::INT64) \
_(uint64_t, phi::DataType::UINT64) \
_(phi::bfloat16, phi::DataType::BFLOAT16) \
_(phi::float16, phi::DataType::FLOAT16) \
_(float, phi::DataType::FLOAT32) \
_(double, phi::DataType::FLOAT64)
#define FOR_EACH_DATA_TYPE_TO_PRINT(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(uint16_t, DataType::UINT16) \
_(int32_t, DataType::INT32) \
_(uint32_t, DataType::UINT32) \
_(int64_t, DataType::INT64) \
_(uint64_t, DataType::UINT64) \
_(bfloat16, DataType::BFLOAT16) \
_(float16, DataType::FLOAT16) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64)

#define CALL_PRINT_TENSOR(cpp_type, data_type) \
case data_type: \
PrintTensor<cpp_type>(os, t); \
break;

std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) {
std::ostream& operator<<(std::ostream& os, const DenseTensor& t) {
switch (t.dtype()) {
FOR_EACH_DATA_TYPE_TO_PRINT(CALL_PRINT_TENSOR)
default:
Expand Down
Loading