Skip to content

Commit 3edcd2a

Browse files
committed
update utils.cc
1 parent b334e9c commit 3edcd2a

1 file changed

Lines changed: 37 additions & 37 deletions

File tree

backends/metax_gpu/common/utils.cc

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,17 @@ C_Status AsyncMemCpyD2D(const C_Device device,
104104

105105
template <typename Context>
106106
inline void TensorCopy(const Context& dev_ctx,
107-
const phi::DenseTensor& src,
107+
const DenseTensor& src,
108108
bool blocking,
109-
phi::DenseTensor* dst,
110-
const phi::Place& dst_place = phi::CustomPlace()) {
109+
DenseTensor* dst,
110+
const Place& dst_place = CustomPlace()) {
111111
auto* src_ptr = src.data();
112112
const auto& src_place = src.place();
113113
if (src_ptr == nullptr) {
114114
return;
115115
}
116116
auto dst_place_ = dst_place;
117-
if (dst_place_.GetType() != phi::AllocationType::CPU) {
117+
if (dst_place_.GetType() != AllocationType::CPU) {
118118
dst_place_ = dev_ctx.GetPlace();
119119
}
120120

@@ -125,7 +125,7 @@ inline void TensorCopy(const Context& dev_ctx,
125125
} else {
126126
VLOG(6) << "Src and dst are the same Tensor, in-place copy data("
127127
<< src_ptr << ") from " << src_place << " to " << dst_place_;
128-
const phi::DenseTensor src_copy = src;
128+
const DenseTensor src_copy = src;
129129
TensorCopy(dev_ctx, src_copy, blocking, dst, dst_place_);
130130
}
131131
return;
@@ -134,7 +134,7 @@ inline void TensorCopy(const Context& dev_ctx,
134134
auto dst_dims = dst->dims();
135135
dst->Resize(src.dims());
136136
void* dst_ptr = nullptr;
137-
if (dst_place_.GetType() != phi::AllocationType::CPU) {
137+
if (dst_place_.GetType() != AllocationType::CPU) {
138138
dst_ptr = dev_ctx.Alloc(dst, src.dtype());
139139
} else {
140140
dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
@@ -156,7 +156,7 @@ inline void TensorCopy(const Context& dev_ctx,
156156
return;
157157
} else {
158158
// scatter memory
159-
phi::DenseTensor tmp_dst;
159+
DenseTensor tmp_dst;
160160
tmp_dst.set_meta(dst->meta());
161161
tmp_dst.Resize(dst_dims);
162162
dst_ptr = dev_ctx.Alloc(&tmp_dst, tmp_dst.dtype());
@@ -176,26 +176,26 @@ inline void TensorCopy(const Context& dev_ctx,
176176
return;
177177
}
178178

179-
if (src_place.GetType() == phi::AllocationType::CPU &&
180-
dst_place_.GetType() == phi::AllocationType::CUSTOM) {
179+
if (src_place.GetType() == AllocationType::CPU &&
180+
dst_place_.GetType() == AllocationType::CUSTOM) {
181181
VLOG(6) << "TensorCopy from cpu to cus";
182182
C_Device_st device;
183183
device.id = dst_place_.GetDeviceId();
184184
AsyncMemCpyH2D(&device, stream, dst_ptr, src_ptr, size);
185185
if (blocking) {
186186
dev_ctx.Wait();
187187
}
188-
} else if (src_place.GetType() == phi::AllocationType::CUSTOM &&
189-
dst_place_.GetType() == phi::AllocationType::CPU) {
188+
} else if (src_place.GetType() == AllocationType::CUSTOM &&
189+
dst_place_.GetType() == AllocationType::CPU) {
190190
VLOG(6) << "TensorCopy from cus to cpu";
191191
C_Device_st device;
192192
device.id = src_place.GetDeviceId();
193193
AsyncMemCpyD2H(&device, stream, dst_ptr, src_ptr, size);
194194
if (blocking) {
195195
dev_ctx.Wait();
196196
}
197-
} else if (src_place.GetType() == phi::AllocationType::CUSTOM &&
198-
dst_place_.GetType() == phi::AllocationType::CUSTOM) {
197+
} else if (src_place.GetType() == AllocationType::CUSTOM &&
198+
dst_place_.GetType() == AllocationType::CUSTOM) {
199199
VLOG(6) << "TensorCopy from cus to cus";
200200
if (src_place.GetDeviceType() == dst_place_.GetDeviceType()) {
201201
if (src_place.GetDeviceId() == dst_place_.GetDeviceId()) {
@@ -212,26 +212,26 @@ inline void TensorCopy(const Context& dev_ctx,
212212
} else {
213213
PADDLE_THROW(phi::errors::Unimplemented("TensorCopy is not supported."));
214214
}
215-
} else if (src_place.GetType() == phi::AllocationType::CPU &&
216-
dst_place_.GetType() == phi::AllocationType::CPU) {
215+
} else if (src_place.GetType() == AllocationType::CPU &&
216+
dst_place_.GetType() == AllocationType::CPU) {
217217
VLOG(6) << "TensorCopy from cpu to cpu";
218218
std::memcpy(dst_ptr, src_ptr, size);
219219
}
220220
}
221221

222222
template <typename T = float>
223-
std::ostream& PrintTensor(std::ostream& os, const phi::DenseTensor& tensor) {
224-
phi::DenseTensor cpu_tensor;
225-
if (tensor.place().GetType() != phi::AllocationType::CPU) {
226-
auto dev_ctx = static_cast<const phi::CustomContext*>(
227-
phi::DeviceContextPool::Instance().Get(tensor.place()));
228-
TensorCopy(*dev_ctx, tensor, true, &cpu_tensor, phi::CPUPlace());
223+
std::ostream& PrintTensor(std::ostream& os, const DenseTensor& tensor) {
224+
DenseTensor cpu_tensor;
225+
if (tensor.place().GetType() != AllocationType::CPU) {
226+
auto dev_ctx = static_cast<const CustomContext*>(
227+
DeviceContextPool::Instance().Get(tensor.place()));
228+
TensorCopy(*dev_ctx, tensor, true, &cpu_tensor, CPUPlace());
229229
} else {
230230
cpu_tensor = tensor;
231231
}
232232
os << "DenseTensor<";
233233
if (tensor.initialized()) {
234-
os << phi::DataTypeToString(tensor.dtype()) << ", ";
234+
os << DataTypeToString(tensor.dtype()) << ", ";
235235
os << tensor.place() << ", ";
236236
os << "Shape(" << tensor.dims() << "), ";
237237
os << "Strides(" << tensor.strides() << "), ";
@@ -266,27 +266,27 @@ std::ostream& PrintTensor(std::ostream& os, const phi::DenseTensor& tensor) {
266266
}
267267
} // namespace
268268

269-
#define FOR_EACH_DATA_TYPE_TO_PRINT(_) \
270-
_(bool, phi::DataType::BOOL) \
271-
_(int8_t, phi::DataType::INT8) \
272-
_(uint8_t, phi::DataType::UINT8) \
273-
_(int16_t, phi::DataType::INT16) \
274-
_(uint16_t, phi::DataType::UINT16) \
275-
_(int32_t, phi::DataType::INT32) \
276-
_(uint32_t, phi::DataType::UINT32) \
277-
_(int64_t, phi::DataType::INT64) \
278-
_(uint64_t, phi::DataType::UINT64) \
279-
_(phi::bfloat16, phi::DataType::BFLOAT16) \
280-
_(phi::float16, phi::DataType::FLOAT16) \
281-
_(float, phi::DataType::FLOAT32) \
282-
_(double, phi::DataType::FLOAT64)
269+
#define FOR_EACH_DATA_TYPE_TO_PRINT(_) \
270+
_(bool, DataType::BOOL) \
271+
_(int8_t, DataType::INT8) \
272+
_(uint8_t, DataType::UINT8) \
273+
_(int16_t, DataType::INT16) \
274+
_(uint16_t, DataType::UINT16) \
275+
_(int32_t, DataType::INT32) \
276+
_(uint32_t, DataType::UINT32) \
277+
_(int64_t, DataType::INT64) \
278+
_(uint64_t, DataType::UINT64) \
279+
_(bfloat16, DataType::BFLOAT16) \
280+
_(float16, DataType::FLOAT16) \
281+
_(float, DataType::FLOAT32) \
282+
_(double, DataType::FLOAT64)
283283

284284
#define CALL_PRINT_TENSOR(cpp_type, data_type) \
285285
case data_type: \
286286
PrintTensor<cpp_type>(os, t); \
287287
break;
288288

289-
std::ostream& operator<<(std::ostream& os, const phi::DenseTensor& t) {
289+
std::ostream& operator<<(std::ostream& os, const DenseTensor& t) {
290290
switch (t.dtype()) {
291291
FOR_EACH_DATA_TYPE_TO_PRINT(CALL_PRINT_TENSOR)
292292
default:

0 commit comments

Comments
 (0)