@@ -89,6 +89,11 @@ PyObject* PyLayerNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
8989 new (&v->grad_node ) std::weak_ptr<egr::GradNodePyLayer>();
9090 new (&v->forward_input_tensor_is_duplicable ) std::vector<bool >();
9191 new (&v->forward_output_tensor_is_duplicable ) std::vector<bool >();
92+ new (&v->tensor_hold_helper )
93+ std::vector<std::shared_ptr<phi::DenseTensor>>();
94+ v->closure_obj = nullptr ;
95+ new (&v->closure_tensor_hold_helper )
96+ std::vector<std::shared_ptr<phi::TensorBase>>();
9297#ifdef PADDLE_WITH_CUDA
9398 new (&v->reload_functors ) std::vector<egr::ReloadFunctor>();
9499#endif
@@ -110,6 +115,10 @@ static void PyLayerDealloc(PyLayerObject* self) {
110115 self->unpack_hook = nullptr ;
111116 self->forward_input_tensor_is_duplicable .~vector ();
112117 self->forward_output_tensor_is_duplicable .~vector ();
118+ self->tensor_hold_helper .~vector ();
119+ Py_XDECREF (self->closure_obj );
120+ self->closure_obj = nullptr ;
121+ self->closure_tensor_hold_helper .~vector ();
113122#ifdef PADDLE_WITH_CUDA
114123 self->reload_functors .~vector ();
115124#endif
@@ -271,8 +280,9 @@ PyObject* pylayer_method_apply(PyObject* cls,
271280
272281 for (int64_t i = inputs_size - 1 ; i >= 0 ; --i) {
273282 PyObject* obj = nullptr ;
274- if (i >= args_size) {
275- obj = PyList_GetItem (kwargs_value_list, i - args_size); // NOLINT
283+ if (i >= static_cast <int64_t >(args_size)) {
284+ obj = PyList_GetItem (kwargs_value_list,
285+ i - static_cast <int64_t >(args_size)); // NOLINT
276286 } else {
277287 obj = PyTuple_GET_ITEM (args, i);
278288 }
@@ -685,6 +695,62 @@ PyObject* pylayer_method_apply(PyObject* cls,
685695 EAGER_CATCH_AND_THROW_RETURN_NULL
686696}
687697
698+ // Deep-traverse a PyObject to collect shared_ptr<phi::DenseTensor> for all
699+ // DenseTensors found (Tensor / Tuple / List / Dict, recursively). Used by
700+ // tensor_properties_set_container and ctx._hold_tensors to hold strong
701+ // references so _clear_dataptr() cannot free the underlying allocation
702+ // before backward. DFS-walks obj and calls fn(tensor) for every Tensor
703+ // leaf. CollectDenseTensors and RestoreDenseTensors are built on top.
704+ template <typename Fn>
705+ static void WalkDenseTensors (PyObject* obj, Fn&& fn) {
706+ if (!obj || obj == Py_None) return ;
707+ if (PyCheckTensor (obj)) {
708+ fn (reinterpret_cast <TensorObject*>(obj)->tensor );
709+ return ;
710+ }
711+ if (PyTuple_Check (obj)) {
712+ Py_ssize_t n = PyTuple_GET_SIZE (obj);
713+ for (Py_ssize_t i = 0 ; i < n; ++i)
714+ WalkDenseTensors (PyTuple_GET_ITEM (obj, i), fn);
715+ return ;
716+ }
717+ if (PyList_Check (obj)) {
718+ Py_ssize_t n = PyList_GET_SIZE (obj);
719+ for (Py_ssize_t i = 0 ; i < n; ++i)
720+ WalkDenseTensors (PyList_GET_ITEM (obj, i), fn);
721+ return ;
722+ }
723+ if (PyDict_Check (obj)) {
724+ PyObject *k = nullptr , *v = nullptr ;
725+ Py_ssize_t pos = 0 ;
726+ while (PyDict_Next (obj, &pos, &k, &v)) {
727+ WalkDenseTensors (v, fn);
728+ }
729+ return ;
730+ }
731+ }
732+
733+ static void CollectDenseTensors (
734+ PyObject* obj, std::vector<std::shared_ptr<phi::TensorBase>>* holder) {
735+ WalkDenseTensors (obj, [holder](const paddle::Tensor& tensor) {
736+ if (tensor.impl ()) holder->push_back (tensor.impl ());
737+ });
738+ }
739+
740+ // Re-installs impl() for tensors cleared by _clear_dataptr(), using the
741+ // shared_ptrs stored in holder (same DFS order as CollectDenseTensors).
742+ static void RestoreDenseTensors (
743+ PyObject* obj,
744+ const std::vector<std::shared_ptr<phi::TensorBase>>& holder) {
745+ size_t idx = 0 ;
746+ WalkDenseTensors (obj, [&holder, &idx](paddle::Tensor& tensor) {
747+ if (idx < holder.size ()) {
748+ if (!tensor.impl ()) tensor.set_impl (holder[idx]);
749+ ++idx;
750+ }
751+ });
752+ }
753+
688754PyObject* call_unpack_hook (PyLayerObject* self) {
689755 auto unpack_hook = self->unpack_hook ;
690756 auto packed_value = self->container ;
@@ -734,10 +800,16 @@ PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) {
734800 }
735801 if (self->container_be_packed ) {
736802 return call_unpack_hook (self);
737- } else {
738- Py_INCREF (self->container );
739- return self->container ;
740803 }
804+ // Re-attach any DenseTensor impls that were freed by _clear_dataptr().
805+ // tensor_hold_helper keeps the underlying allocations alive; walk the
806+ // container in the same DFS order as CollectDenseTensors and reinstall
807+ // impls for tensors whose impl() is currently null.
808+ if (!self->tensor_hold_helper .empty ()) {
809+ RestoreDenseTensors (self->container , self->tensor_hold_helper );
810+ }
811+ Py_INCREF (self->container );
812+ return self->container ;
741813 EAGER_CATCH_AND_THROW_RETURN_NULL
742814}
743815
@@ -836,11 +908,18 @@ int tensor_properties_set_container(PyLayerObject* self,
836908 void * closure) {
837909 EAGER_TRY
838910 if (egr::SavedTensorsHooks::GetInstance ().IsEnable ()) {
911+ // Note 1: when hooks are enabled the tensors are packed; do NOT populate
912+ // tensor_hold_helper (the hook system manages tensor lifetimes itself).
839913 call_pack_hook (self, value);
840914 } else {
841915 Py_XINCREF (value);
842916 Py_XDECREF (self->container );
843917 self->container = value;
918+ // Note 2: deep-traverse value (Tensor / Tuple / List / nested) to hold
919+ // strong references to every DenseTensor impl, preventing _clear_dataptr()
920+ // from freeing the underlying allocation before backward runs.
921+ self->tensor_hold_helper .clear ();
922+ CollectDenseTensors (value, &self->tensor_hold_helper );
844923 }
845924 return 0 ;
846925 EAGER_CATCH_AND_THROW_RETURN_NEG
@@ -907,15 +986,107 @@ int tensor_properties_set_grad_in_dtype_consistent(PyLayerObject* self,
907986 EAGER_CATCH_AND_THROW_RETURN_NEG
908987}
909988
910- PyMethodDef pylayer_methods[] = {{" name" , // NOLINT
911- (PyCFunction)(void (*)())pylayer_method_name,
912- METH_NOARGS,
913- nullptr },
914- {" apply" ,
915- (PyCFunction)(void (*)())pylayer_method_apply,
916- METH_CLASS | METH_VARARGS | METH_KEYWORDS,
917- nullptr },
918- {nullptr , nullptr , 0 , nullptr }};
989+ // ctx._pop_saved_impl(tensor)
990+ // Removes the strong reference held in tensor_hold_helper for the given
991+ // tensor's underlying DenseTensor, allowing its memory to be freed early
992+ // (e.g. inside backward when the tensor is no longer needed).
993+ // The tensor must have a valid impl() — i.e. pass the recovered tensor
994+ // returned by ctx.saved_tensor(), not the already-cleared one.
995+ PyObject* pylayer_pop_saved_impl (PyObject* self_, PyObject* args) {
996+ EAGER_TRY
997+ auto * self = reinterpret_cast <PyLayerObject*>(self_);
998+ PyObject* tensor_obj = nullptr ;
999+ if (!PyArg_ParseTuple (args, " O" , &tensor_obj)) {
1000+ RETURN_PY_NONE;
1001+ }
1002+ if (!tensor_obj || !PyCheckTensor (tensor_obj)) {
1003+ RETURN_PY_NONE;
1004+ }
1005+ const auto & tensor = reinterpret_cast <TensorObject*>(tensor_obj)->tensor ;
1006+ if (!tensor.impl () || !tensor.is_dense_tensor ()) {
1007+ RETURN_PY_NONE;
1008+ }
1009+ auto * raw = static_cast <phi::DenseTensor*>(tensor.impl ().get ());
1010+ for (auto it = self->tensor_hold_helper .begin ();
1011+ it != self->tensor_hold_helper .end ();
1012+ ++it) {
1013+ if (it->get () == raw) {
1014+ self->tensor_hold_helper .erase (it);
1015+ break ;
1016+ }
1017+ }
1018+ RETURN_PY_NONE;
1019+ EAGER_CATCH_AND_THROW_RETURN_NULL
1020+ }
1021+
1022+ // ctx._hold_tensors(obj)
1023+ // Keep strong refs to the owning container (Py_INCREF'd) and the impl() of
1024+ // every DenseTensor leaf found in obj (Tensor / Tuple / List / Dict).
1025+ // Covers tensors captured via Python closure of the forward function that
1026+ // bypass save_for_backward / container. Skipped when saved_tensors_hooks is
1027+ // enabled (the hook system owns tensor lifetime in that case).
1028+ PyObject* pylayer_hold_tensors (PyObject* self_, PyObject* args) {
1029+ EAGER_TRY
1030+ auto * self = reinterpret_cast <PyLayerObject*>(self_);
1031+ PyObject* obj = nullptr ;
1032+ if (!PyArg_ParseTuple (args, " O" , &obj)) {
1033+ RETURN_PY_NONE;
1034+ }
1035+ if (obj && obj != Py_None &&
1036+ !egr::SavedTensorsHooks::GetInstance ().IsEnable ()) {
1037+ Py_INCREF (obj);
1038+ Py_XDECREF (self->closure_obj );
1039+ self->closure_obj = obj;
1040+ self->closure_tensor_hold_helper .clear ();
1041+ CollectDenseTensors (obj, &self->closure_tensor_hold_helper );
1042+ }
1043+ RETURN_PY_NONE;
1044+ EAGER_CATCH_AND_THROW_RETURN_NULL
1045+ }
1046+
1047+ // ctx._restore_held_tensors()
1048+ // Re-install impl() on any Python Tensor previously registered via
1049+ // _hold_tensors whose impl_ has been nulled by _clear_dataptr(). Typically
1050+ // called at the start of backward before recompute re-runs forward.
1051+ PyObject* pylayer_restore_held_tensors (PyObject* self_, PyObject* /* unused*/ ) {
1052+ EAGER_TRY
1053+ auto * self = reinterpret_cast <PyLayerObject*>(self_);
1054+ if (self->closure_obj && !self->closure_tensor_hold_helper .empty ()) {
1055+ RestoreDenseTensors (self->closure_obj , self->closure_tensor_hold_helper );
1056+ }
1057+ RETURN_PY_NONE;
1058+ EAGER_CATCH_AND_THROW_RETURN_NULL
1059+ }
1060+
1061+ PyMethodDef pylayer_methods[] = {
1062+ {" name" , // NOLINT
1063+ (PyCFunction)(void (*)())pylayer_method_name,
1064+ METH_NOARGS,
1065+ nullptr },
1066+ {" apply" ,
1067+ (PyCFunction)(void (*)())pylayer_method_apply,
1068+ METH_CLASS | METH_VARARGS | METH_KEYWORDS,
1069+ nullptr },
1070+ {" _pop_saved_impl" ,
1071+ (PyCFunction)(void (*)())pylayer_pop_saved_impl,
1072+ METH_VARARGS,
1073+ " Release the strong reference held for a "
1074+ " specific DenseTensor saved via "
1075+ " save_for_backward, allowing its memory to "
1076+ " be freed early if no other holder exists." },
1077+ {" _hold_tensors" ,
1078+ (PyCFunction)(void (*)())pylayer_hold_tensors,
1079+ METH_VARARGS,
1080+ " Deep-traverse the given object (Tensor / tuple / list / dict) and "
1081+ " keep strong references to every DenseTensor impl found, plus the "
1082+ " owning Python Tensor object. Used to protect tensors captured in "
1083+ " Python closures against _clear_dataptr() in pipeline parallel." },
1084+ {" _restore_held_tensors" ,
1085+ (PyCFunction)(void (*)())pylayer_restore_held_tensors,
1086+ METH_NOARGS,
1087+ " Reinstall impl() on Python Tensor objects previously registered via "
1088+ " _hold_tensors, if their impl() has been nulled by _clear_dataptr()." },
1089+ {nullptr , nullptr , 0 , nullptr }};
9191090
9201091struct PyGetSetDef pylayer_properties[] { // NOLINT
9211092 {" container" ,
0 commit comments