Skip to content

Commit 7783afa

Browse files
committed
Merge branch 'develop' into o7
2 parents 7a505c7 + 7e4289d commit 7783afa

133 files changed

Lines changed: 3642 additions & 1125 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cmake/cupti.cmake

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,18 @@ endif()
55
include(${PROJECT_SOURCE_DIR}/cmake/architecture.cmake)
66

77
if(WITH_ROCM)
8+
if(EXISTS "${ROCM_PATH}/cuda/extras/CUPTI")
9+
set(ROCM_CUDA_DIR "${ROCM_PATH}/cuda")
10+
elseif(EXISTS "${ROCM_PATH}/cuda/cuda/extras/CUPTI")
11+
set(ROCM_CUDA_DIR "${ROCM_PATH}/cuda/cuda")
12+
else()
13+
message(
14+
FATAL_ERROR
15+
"CUPTI not found under ${ROCM_PATH}/cuda/extras/CUPTI or ${ROCM_PATH}/cuda/cuda/extras/CUPTI"
16+
)
17+
endif()
818
set(CUPTI_ROOT
9-
"${ROCM_PATH}/cuda/extras/CUPTI"
19+
"${ROCM_CUDA_DIR}/extras/CUPTI"
1020
CACHE PATH "CUPTI ROOT")
1121
else()
1222
set(CUPTI_ROOT
@@ -59,7 +69,7 @@ get_filename_component(CUPTI_LIBRARY_PATH ${CUPTI_LIBRARY} DIRECTORY)
5969
if(CUPTI_INCLUDE_DIR AND CUPTI_LIBRARY)
6070
set(CUPTI_FOUND ON)
6171
if(WITH_ROCM)
62-
include_directories(${ROCM_PATH}/cuda/include)
72+
include_directories(${ROCM_CUDA_DIR}/include)
6373
add_definitions(-D__CUDA_HIP_PLATFORM_AMD__)
6474
endif()
6575
else()

paddle/common/flags.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,36 @@ PHI_DEFINE_EXPORTED_bool(
230230
"operator. The autotuning algorithm may be non-deterministic. If "
231231
"true, the algorithm is deterministic.");
232232

233+
/**
234+
* GPU RNG related FLAG
235+
* Name: FLAGS_deterministic_rng
236+
* Since Version: 3.4
237+
* Value Range: bool, default=false
238+
* Example: paddle.set_flags({'FLAGS_deterministic_rng': True})
239+
* Note: Fix RNG kernel launch config so same seed gives same results
240+
* across GPU types.
241+
*/
242+
PHI_DEFINE_EXPORTED_bool(
243+
deterministic_rng,
244+
false,
245+
"Enable cross-device RNG consistency by fixing GPU kernel launch "
246+
"configuration. When true, RNG kernels use a fixed grid/block size "
247+
"so that the same seed produces identical results across GPU types.");
248+
249+
/**
250+
* GPU RNG related FLAG
251+
* Name: FLAGS_deterministic_rng_grid
252+
* Since Version: 3.4
253+
* Value Range: int32, default=1024
254+
* Example: paddle.set_flags({'FLAGS_deterministic_rng_grid': 4096})
255+
* Note: Grid size cap used when FLAGS_deterministic_rng is enabled.
256+
* Cross-device consistency requires the same value on all devices.
257+
*/
258+
PHI_DEFINE_EXPORTED_int32(
259+
deterministic_rng_grid,
260+
1024,
261+
"Grid size cap when FLAGS_deterministic_rng is enabled.");
262+
233263
/**
234264
* CUDA related FLAG
235265
* Name: FLAGS_embedding_deterministic

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,24 @@ bool AminOpInferSymbolicShape(pir::Operation *op,
307307
axis.size() == 0 /*reduce_all*/);
308308
}
309309

310+
bool AminmaxOpInferSymbolicShape(
311+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
312+
const auto &axis = details::GetVectorAttr(op, "axis");
313+
bool keepdim = GetBoolAttr(op, "keepdim");
314+
bool reduce_all = axis.size() == 0;
315+
316+
// ReduceInferDim only sets result(0). We need the same shape for both
317+
// outputs, so call it for result(0) then copy to result(1).
318+
bool ret =
319+
details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all);
320+
if (ret) {
321+
const auto &out_shape =
322+
infer_context->GetShapeOrDataForValue(op->result(0));
323+
infer_context->SetShapeOrDataForValue(op->result(1), out_shape);
324+
}
325+
return ret;
326+
}
327+
310328
bool AnyOpInferSymbolicShape(pir::Operation *op,
311329
pir::InferSymbolicShapeContext *infer_context) {
312330
const auto &axis = details::GetVectorAttr(op, "axis");

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AffineGrid)
2121
OP_DECLARE_INFER_SYMBOLIC_SHAPE(All)
2222
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Amax)
2323
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Amin)
24+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Aminmax)
2425
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Any)
2526
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax)
2627
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin)

paddle/fluid/pybind/eager.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ typedef struct {
3636
std::vector<bool> forward_input_tensor_is_duplicable;
3737
std::vector<bool> forward_output_tensor_is_duplicable;
3838
std::weak_ptr<egr::GradNodePyLayer> grad_node;
39+
// Holds strong references to DenseTensor impls saved via save_for_backward,
40+
// preventing _clear_dataptr() from freeing the underlying memory before
41+
// backward runs. Lifecycle: born with container (set_container), dies with
42+
// the PyLayerObject (PyLayerDealloc).
43+
std::vector<std::shared_ptr<phi::TensorBase>> tensor_hold_helper;
44+
// Holds strong references to DenseTensor impls captured in Python closures
45+
// of the forward function (not recorded via save_for_backward). The
46+
// top-level ``closure_obj`` keeps the owning Python Tensor objects alive
47+
// and defines the DFS order used by RestoreDenseTensors. Populated by
48+
// ctx._hold_tensors(obj); applied by ctx._restore_held_tensors() to
49+
// re-install impl_ after _clear_dataptr(). Released in PyLayerDealloc.
50+
PyObject* closure_obj;
51+
std::vector<std::shared_ptr<phi::TensorBase>> closure_tensor_hold_helper;
3952
#ifdef PADDLE_WITH_CUDA
4053
std::vector<egr::ReloadFunctor> reload_functors;
4154
#endif

paddle/fluid/pybind/eager_py_layer.cc

Lines changed: 185 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ PyObject* PyLayerNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
8989
new (&v->grad_node) std::weak_ptr<egr::GradNodePyLayer>();
9090
new (&v->forward_input_tensor_is_duplicable) std::vector<bool>();
9191
new (&v->forward_output_tensor_is_duplicable) std::vector<bool>();
92+
new (&v->tensor_hold_helper)
93+
std::vector<std::shared_ptr<phi::DenseTensor>>();
94+
v->closure_obj = nullptr;
95+
new (&v->closure_tensor_hold_helper)
96+
std::vector<std::shared_ptr<phi::TensorBase>>();
9297
#ifdef PADDLE_WITH_CUDA
9398
new (&v->reload_functors) std::vector<egr::ReloadFunctor>();
9499
#endif
@@ -110,6 +115,10 @@ static void PyLayerDealloc(PyLayerObject* self) {
110115
self->unpack_hook = nullptr;
111116
self->forward_input_tensor_is_duplicable.~vector();
112117
self->forward_output_tensor_is_duplicable.~vector();
118+
self->tensor_hold_helper.~vector();
119+
Py_XDECREF(self->closure_obj);
120+
self->closure_obj = nullptr;
121+
self->closure_tensor_hold_helper.~vector();
113122
#ifdef PADDLE_WITH_CUDA
114123
self->reload_functors.~vector();
115124
#endif
@@ -271,8 +280,9 @@ PyObject* pylayer_method_apply(PyObject* cls,
271280

272281
for (int64_t i = inputs_size - 1; i >= 0; --i) {
273282
PyObject* obj = nullptr;
274-
if (i >= args_size) {
275-
obj = PyList_GetItem(kwargs_value_list, i - args_size); // NOLINT
283+
if (i >= static_cast<int64_t>(args_size)) {
284+
obj = PyList_GetItem(kwargs_value_list,
285+
i - static_cast<int64_t>(args_size)); // NOLINT
276286
} else {
277287
obj = PyTuple_GET_ITEM(args, i);
278288
}
@@ -685,6 +695,62 @@ PyObject* pylayer_method_apply(PyObject* cls,
685695
EAGER_CATCH_AND_THROW_RETURN_NULL
686696
}
687697

698+
// Deep-traverse a PyObject to collect shared_ptr<phi::DenseTensor> for all
699+
// DenseTensors found (Tensor / Tuple / List / Dict, recursively). Used by
700+
// tensor_properties_set_container and ctx._hold_tensors to hold strong
701+
// references so _clear_dataptr() cannot free the underlying allocation
702+
// before backward. DFS-walks obj and calls fn(tensor) for every Tensor
703+
// leaf. CollectDenseTensors and RestoreDenseTensors are built on top.
704+
template <typename Fn>
705+
static void WalkDenseTensors(PyObject* obj, Fn&& fn) {
706+
if (!obj || obj == Py_None) return;
707+
if (PyCheckTensor(obj)) {
708+
fn(reinterpret_cast<TensorObject*>(obj)->tensor);
709+
return;
710+
}
711+
if (PyTuple_Check(obj)) {
712+
Py_ssize_t n = PyTuple_GET_SIZE(obj);
713+
for (Py_ssize_t i = 0; i < n; ++i)
714+
WalkDenseTensors(PyTuple_GET_ITEM(obj, i), fn);
715+
return;
716+
}
717+
if (PyList_Check(obj)) {
718+
Py_ssize_t n = PyList_GET_SIZE(obj);
719+
for (Py_ssize_t i = 0; i < n; ++i)
720+
WalkDenseTensors(PyList_GET_ITEM(obj, i), fn);
721+
return;
722+
}
723+
if (PyDict_Check(obj)) {
724+
PyObject *k = nullptr, *v = nullptr;
725+
Py_ssize_t pos = 0;
726+
while (PyDict_Next(obj, &pos, &k, &v)) {
727+
WalkDenseTensors(v, fn);
728+
}
729+
return;
730+
}
731+
}
732+
733+
static void CollectDenseTensors(
734+
PyObject* obj, std::vector<std::shared_ptr<phi::TensorBase>>* holder) {
735+
WalkDenseTensors(obj, [holder](const paddle::Tensor& tensor) {
736+
if (tensor.impl()) holder->push_back(tensor.impl());
737+
});
738+
}
739+
740+
// Re-installs impl() for tensors cleared by _clear_dataptr(), using the
741+
// shared_ptrs stored in holder (same DFS order as CollectDenseTensors).
742+
static void RestoreDenseTensors(
743+
PyObject* obj,
744+
const std::vector<std::shared_ptr<phi::TensorBase>>& holder) {
745+
size_t idx = 0;
746+
WalkDenseTensors(obj, [&holder, &idx](paddle::Tensor& tensor) {
747+
if (idx < holder.size()) {
748+
if (!tensor.impl()) tensor.set_impl(holder[idx]);
749+
++idx;
750+
}
751+
});
752+
}
753+
688754
PyObject* call_unpack_hook(PyLayerObject* self) {
689755
auto unpack_hook = self->unpack_hook;
690756
auto packed_value = self->container;
@@ -734,10 +800,16 @@ PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) {
734800
}
735801
if (self->container_be_packed) {
736802
return call_unpack_hook(self);
737-
} else {
738-
Py_INCREF(self->container);
739-
return self->container;
740803
}
804+
// Re-attach any DenseTensor impls that were freed by _clear_dataptr().
805+
// tensor_hold_helper keeps the underlying allocations alive; walk the
806+
// container in the same DFS order as CollectDenseTensors and reinstall
807+
// impls for tensors whose impl() is currently null.
808+
if (!self->tensor_hold_helper.empty()) {
809+
RestoreDenseTensors(self->container, self->tensor_hold_helper);
810+
}
811+
Py_INCREF(self->container);
812+
return self->container;
741813
EAGER_CATCH_AND_THROW_RETURN_NULL
742814
}
743815

@@ -836,11 +908,18 @@ int tensor_properties_set_container(PyLayerObject* self,
836908
void* closure) {
837909
EAGER_TRY
838910
if (egr::SavedTensorsHooks::GetInstance().IsEnable()) {
911+
// Note 1: when hooks are enabled the tensors are packed; do NOT populate
912+
// tensor_hold_helper (the hook system manages tensor lifetimes itself).
839913
call_pack_hook(self, value);
840914
} else {
841915
Py_XINCREF(value);
842916
Py_XDECREF(self->container);
843917
self->container = value;
918+
// Note 2: deep-traverse value (Tensor / Tuple / List / nested) to hold
919+
// strong references to every DenseTensor impl, preventing _clear_dataptr()
920+
// from freeing the underlying allocation before backward runs.
921+
self->tensor_hold_helper.clear();
922+
CollectDenseTensors(value, &self->tensor_hold_helper);
844923
}
845924
return 0;
846925
EAGER_CATCH_AND_THROW_RETURN_NEG
@@ -907,15 +986,107 @@ int tensor_properties_set_grad_in_dtype_consistent(PyLayerObject* self,
907986
EAGER_CATCH_AND_THROW_RETURN_NEG
908987
}
909988

910-
PyMethodDef pylayer_methods[] = {{"name", // NOLINT
911-
(PyCFunction)(void (*)())pylayer_method_name,
912-
METH_NOARGS,
913-
nullptr},
914-
{"apply",
915-
(PyCFunction)(void (*)())pylayer_method_apply,
916-
METH_CLASS | METH_VARARGS | METH_KEYWORDS,
917-
nullptr},
918-
{nullptr, nullptr, 0, nullptr}};
989+
// ctx._pop_saved_impl(tensor)
990+
// Removes the strong reference held in tensor_hold_helper for the given
991+
// tensor's underlying DenseTensor, allowing its memory to be freed early
992+
// (e.g. inside backward when the tensor is no longer needed).
993+
// The tensor must have a valid impl() — i.e. pass the recovered tensor
994+
// returned by ctx.saved_tensor(), not the already-cleared one.
995+
PyObject* pylayer_pop_saved_impl(PyObject* self_, PyObject* args) {
996+
EAGER_TRY
997+
auto* self = reinterpret_cast<PyLayerObject*>(self_);
998+
PyObject* tensor_obj = nullptr;
999+
if (!PyArg_ParseTuple(args, "O", &tensor_obj)) {
1000+
RETURN_PY_NONE;
1001+
}
1002+
if (!tensor_obj || !PyCheckTensor(tensor_obj)) {
1003+
RETURN_PY_NONE;
1004+
}
1005+
const auto& tensor = reinterpret_cast<TensorObject*>(tensor_obj)->tensor;
1006+
if (!tensor.impl() || !tensor.is_dense_tensor()) {
1007+
RETURN_PY_NONE;
1008+
}
1009+
auto* raw = static_cast<phi::DenseTensor*>(tensor.impl().get());
1010+
for (auto it = self->tensor_hold_helper.begin();
1011+
it != self->tensor_hold_helper.end();
1012+
++it) {
1013+
if (it->get() == raw) {
1014+
self->tensor_hold_helper.erase(it);
1015+
break;
1016+
}
1017+
}
1018+
RETURN_PY_NONE;
1019+
EAGER_CATCH_AND_THROW_RETURN_NULL
1020+
}
1021+
1022+
// ctx._hold_tensors(obj)
1023+
// Keep strong refs to the owning container (Py_INCREF'd) and the impl() of
1024+
// every DenseTensor leaf found in obj (Tensor / Tuple / List / Dict).
1025+
// Covers tensors captured via Python closure of the forward function that
1026+
// bypass save_for_backward / container. Skipped when saved_tensors_hooks is
1027+
// enabled (the hook system owns tensor lifetime in that case).
1028+
PyObject* pylayer_hold_tensors(PyObject* self_, PyObject* args) {
1029+
EAGER_TRY
1030+
auto* self = reinterpret_cast<PyLayerObject*>(self_);
1031+
PyObject* obj = nullptr;
1032+
if (!PyArg_ParseTuple(args, "O", &obj)) {
1033+
RETURN_PY_NONE;
1034+
}
1035+
if (obj && obj != Py_None &&
1036+
!egr::SavedTensorsHooks::GetInstance().IsEnable()) {
1037+
Py_INCREF(obj);
1038+
Py_XDECREF(self->closure_obj);
1039+
self->closure_obj = obj;
1040+
self->closure_tensor_hold_helper.clear();
1041+
CollectDenseTensors(obj, &self->closure_tensor_hold_helper);
1042+
}
1043+
RETURN_PY_NONE;
1044+
EAGER_CATCH_AND_THROW_RETURN_NULL
1045+
}
1046+
1047+
// ctx._restore_held_tensors()
1048+
// Re-install impl() on any Python Tensor previously registered via
1049+
// _hold_tensors whose impl_ has been nulled by _clear_dataptr(). Typically
1050+
// called at the start of backward before recompute re-runs forward.
1051+
PyObject* pylayer_restore_held_tensors(PyObject* self_, PyObject* /*unused*/) {
1052+
EAGER_TRY
1053+
auto* self = reinterpret_cast<PyLayerObject*>(self_);
1054+
if (self->closure_obj && !self->closure_tensor_hold_helper.empty()) {
1055+
RestoreDenseTensors(self->closure_obj, self->closure_tensor_hold_helper);
1056+
}
1057+
RETURN_PY_NONE;
1058+
EAGER_CATCH_AND_THROW_RETURN_NULL
1059+
}
1060+
1061+
PyMethodDef pylayer_methods[] = {
1062+
{"name", // NOLINT
1063+
(PyCFunction)(void (*)())pylayer_method_name,
1064+
METH_NOARGS,
1065+
nullptr},
1066+
{"apply",
1067+
(PyCFunction)(void (*)())pylayer_method_apply,
1068+
METH_CLASS | METH_VARARGS | METH_KEYWORDS,
1069+
nullptr},
1070+
{"_pop_saved_impl",
1071+
(PyCFunction)(void (*)())pylayer_pop_saved_impl,
1072+
METH_VARARGS,
1073+
"Release the strong reference held for a "
1074+
"specific DenseTensor saved via "
1075+
"save_for_backward, allowing its memory to "
1076+
"be freed early if no other holder exists."},
1077+
{"_hold_tensors",
1078+
(PyCFunction)(void (*)())pylayer_hold_tensors,
1079+
METH_VARARGS,
1080+
"Deep-traverse the given object (Tensor / tuple / list / dict) and "
1081+
"keep strong references to every DenseTensor impl found, plus the "
1082+
"owning Python Tensor object. Used to protect tensors captured in "
1083+
"Python closures against _clear_dataptr() in pipeline parallel."},
1084+
{"_restore_held_tensors",
1085+
(PyCFunction)(void (*)())pylayer_restore_held_tensors,
1086+
METH_NOARGS,
1087+
"Reinstall impl() on Python Tensor objects previously registered via "
1088+
"_hold_tensors, if their impl() has been nulled by _clear_dataptr()."},
1089+
{nullptr, nullptr, 0, nullptr}};
9191090

9201091
struct PyGetSetDef pylayer_properties[] { // NOLINT
9211092
{"container",

paddle/phi/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,14 @@ if(WITH_CUTLASS)
372372
)# for memory_efficient_attention.h
373373
endif()
374374
# PADDLE_WARP_SIZE: warp size for the target GPU platform.
375-
# Default 32 (NVIDIA). Override via -DPADDLE_WARP_SIZE=64 for iluvatar (COREX).
375+
# Default 32 (NVIDIA). ROCm (AMD/Hygon) wavefront size is 64.
376+
# Override via -DPADDLE_WARP_SIZE for other platforms.
376377
if(NOT DEFINED PADDLE_WARP_SIZE)
377-
set(PADDLE_WARP_SIZE 32)
378+
if(WITH_ROCM)
379+
set(PADDLE_WARP_SIZE 64)
380+
else()
381+
set(PADDLE_WARP_SIZE 32)
382+
endif()
378383
endif()
379384
math(EXPR PADDLE_WARP_MASK "${PADDLE_WARP_SIZE} - 1")
380385
if(PADDLE_WARP_SIZE EQUAL 64)

paddle/phi/backends/dynload/rocm_driver.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ void* rocm_dso_handle = nullptr;
2222

2323
#define DEFINE_WRAP(__name) DynLoad__##__name __name
2424

25+
ROCM_ROUTINE_EACH_VVM(DEFINE_WRAP);
26+
ROCM_ROUTINE_EACH_GPU_GRAPH(DEFINE_WRAP);
2527
ROCM_ROUTINE_EACH(DEFINE_WRAP);
2628

2729
bool HasCUDADriver() {

0 commit comments

Comments
 (0)