Skip to content

Commit 7433512

Browse files
committed
fix
1 parent 7783afa commit 7433512

1 file changed

Lines changed: 18 additions & 32 deletions

File tree

paddle/phi/core/tensor_utils.cc

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -123,48 +123,41 @@ void Copy(const Context& dev_ctx,
123123
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
124124
} else if (src_place.GetType() == AllocationType::GPU && // NOLINT
125125
dst_place.GetType() == AllocationType::CPU) {
126-
auto src_gpu_place = src_place;
127-
auto dst_cpu_place = dst_place;
128126
auto ctx_place = dev_ctx.GetPlace();
129127
PADDLE_ENFORCE_EQ(
130128
ctx_place.GetType() == AllocationType::GPU,
131129
true,
132130
errors::PreconditionNotMet(
133131
"Context place error, excepted GPUPlace, but actually %s.",
134132
ctx_place));
135-
auto ctx_gpu_place = ctx_place;
136-
PADDLE_ENFORCE_EQ(src_gpu_place,
137-
ctx_gpu_place,
133+
PADDLE_ENFORCE_EQ(src_place,
134+
ctx_place,
138135
errors::Unavailable(
139136
"Source place and context place do not match, source "
140137
"place is %s, context place is %s.",
141-
src_gpu_place,
142-
ctx_gpu_place));
138+
src_place,
139+
ctx_place));
143140
auto stream =
144141
blocking ? nullptr
145142
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
146-
memory_utils::Copy(
147-
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
143+
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
148144
} else if ((src_place.GetType() == AllocationType::CPU ||
149145
src_place.GetType() == AllocationType::GPUPINNED) && // NOLINT
150146
dst_place.GetType() == AllocationType::GPU) {
151-
auto src_cpu_place = src_place;
152-
auto dst_gpu_place = dst_place;
153147
auto ctx_place = dev_ctx.GetPlace();
154148
PADDLE_ENFORCE_EQ(
155149
ctx_place.GetType() == AllocationType::GPU,
156150
true,
157151
errors::PreconditionNotMet(
158152
"Context place error, excepted GPUPlace, but actually %s.",
159153
ctx_place));
160-
auto ctx_gpu_place = ctx_place;
161154
PADDLE_ENFORCE_EQ(
162-
dst_gpu_place,
163-
ctx_gpu_place,
155+
dst_place,
156+
ctx_place,
164157
errors::Unavailable("Destination place and context place do not match, "
165158
"destination place is %s, context place is %s.",
166-
dst_gpu_place,
167-
ctx_gpu_place));
159+
dst_place,
160+
ctx_place));
168161
auto stream =
169162
blocking ? nullptr
170163
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
@@ -176,11 +169,9 @@ void Copy(const Context& dev_ctx,
176169
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(src_ptr)),
177170
size);
178171
memory_utils::Copy(
179-
dst_gpu_place, dst_ptr, src_cpu_place, stable_src_ptr, size, stream);
172+
dst_place, dst_ptr, src_place, stable_src_ptr, size, stream);
180173
} else if (src_place.GetType() == AllocationType::GPU && // NOLINT
181174
dst_place.GetType() == AllocationType::GPU) {
182-
auto src_gpu_place = src_place;
183-
auto dst_gpu_place = dst_place;
184175
auto ctx_place = dev_ctx.GetPlace();
185176

186177
PADDLE_ENFORCE_EQ(
@@ -193,17 +184,16 @@ void Copy(const Context& dev_ctx,
193184
blocking ? nullptr
194185
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
195186
if (src_place.GetDeviceId() == dst_place.GetDeviceId()) {
196-
memory_utils::Copy(
197-
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
187+
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
198188
} else {
199189
if (ctx_place.GetDeviceId() == src_place.GetDeviceId()) {
200190
memory_utils::Copy(
201-
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
191+
dst_place, dst_ptr, src_place, src_ptr, size, stream);
202192
phi::DeviceContextPool::Instance().Get(src.place())->Wait();
203193
} else if (ctx_place.GetDeviceId() == dst_place.GetDeviceId()) {
204194
phi::DeviceContextPool::Instance().Get(src.place())->Wait();
205195
memory_utils::Copy(
206-
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
196+
dst_place, dst_ptr, src_place, src_ptr, size, stream);
207197
phi::DeviceContextPool::Instance().Get(dst_place)->Wait();
208198
} else {
209199
PADDLE_THROW(errors::Unavailable(
@@ -212,28 +202,24 @@ void Copy(const Context& dev_ctx,
212202
}
213203
} else if (src_place.GetType() == AllocationType::GPU && // NOLINT
214204
dst_place.GetType() == AllocationType::GPUPINNED) {
215-
auto src_gpu_place = src_place;
216-
auto dst_cuda_pinned_place = dst_place;
217205
auto ctx_place = dev_ctx.GetPlace();
218206
PADDLE_ENFORCE_EQ(
219207
ctx_place.GetType() == AllocationType::GPU,
220208
true,
221209
errors::PreconditionNotMet(
222210
"Context place error, excepted GPUPlace, but actually %s.",
223211
ctx_place));
224-
auto ctx_gpu_place = ctx_place;
225-
PADDLE_ENFORCE_EQ(src_gpu_place,
226-
ctx_gpu_place,
212+
PADDLE_ENFORCE_EQ(src_place,
213+
ctx_place,
227214
errors::Unavailable(
228215
"Source place and context place do not match, source "
229216
"place is %s, context place is %s.",
230-
src_gpu_place,
231-
ctx_gpu_place));
217+
src_place,
218+
ctx_place));
232219
auto stream =
233220
blocking ? nullptr
234221
: reinterpret_cast<const phi::GPUContext&>(dev_ctx).stream();
235-
memory_utils::Copy(
236-
dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
222+
memory_utils::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
237223
#endif
238224
#ifdef PADDLE_WITH_XPU
239225
} else if ((src_place.GetType() == AllocationType::CPU ||

0 commit comments

Comments
 (0)