Skip to content

Commit f84a17c

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Fix use-after-free in PrepackNode when TensorRefs are shared
Pull Request resolved: #18906 When a model has shared/tied weights (e.g. tied embeddings in transformers), the serialization deduplicates them into a single TensorRef that multiple PrepackNodes reference. Previously, `PrepackNode::create_staging_buffer()` called `tref->free_buffer()` unconditionally after copying weight data to a GPU staging buffer. This meant the first PrepackNode to execute would free the underlying host memory, and subsequent PrepackNodes sharing the same TensorRef would read from a dangling pointer — producing garbage/NaN values in prepacked weight and bias tensors on the GPU. The fix adds a `prepack_use_count` field to `TensorRef` that tracks how many PrepackNodes still need to read from it. Each PrepackNode increments the count in its constructor and decrements it after copying data. The buffer is only freed when the count reaches zero. This preserves the original eager-free behavior for non-shared weights (freeing immediately after the single consumer copies) while correctly deferring the free for shared weights until the last consumer is done — avoiding both the use-after-free and unnecessary peak memory increase. ghstack-source-id: 367726483 @exported-using-ghexport Differential Revision: [D101009402](https://our.internmc.facebook.com/intern/diff/D101009402/)
1 parent 28c56fe commit f84a17c

3 files changed

Lines changed: 13 additions & 7 deletions

File tree

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,8 +1135,6 @@ void ComputeGraph::prepack() {
11351135
int i = 0;
11361136
bool submitted = false;
11371137
const bool reduce_peak_memory = total_constant_nbytes_ > 10 * MB;
1138-
// int count = 0;
1139-
11401138
context_->set_cmd();
11411139
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
11421140
// Do not trigger on the first or last prepack node.

backends/vulkan/runtime/graph/containers/Constant.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ struct TensorRef final {
2929
// This will be empty (default constructed) for the raw pointer constructor
3030
executorch::runtime::FreeableBuffer buffer;
3131

32+
// Number of PrepackNodes that still need to read from this TensorRef. When
33+
// this reaches 0, the buffer can be safely freed. This prevents
34+
// use-after-free when multiple PrepackNodes reference the same TensorRef
35+
// (e.g. shared/tied weights).
36+
int32_t prepack_use_count{0};
37+
3238
explicit TensorRef(
3339
const std::vector<int64_t>& t_sizes,
3440
vkapi::ScalarType t_dtype,
@@ -44,8 +50,6 @@ struct TensorRef final {
4450
return utils::multiply_integers(sizes) * vkapi::element_size(dtype);
4551
}
4652

47-
// Manually free the buffer if needed (though it will be freed automatically
48-
// on destruction)
4953
void free_buffer() {
5054
buffer.Free();
5155
}

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ PrepackNode::PrepackNode(
4444
push_constants_(push_constants) {
4545
graph.update_descriptor_counts(shader, /*execute = */ false);
4646
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
47+
if (!graph.val_is_none(tref)) {
48+
graph.get_tref(tref)->prepack_use_count++;
49+
}
4750
}
4851

4952
api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
@@ -100,9 +103,10 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
100103
}
101104
}
102105

103-
// Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer,
104-
// it can be freed.
105-
tref->free_buffer();
106+
if (--tref->prepack_use_count == 0) {
107+
tref->free_buffer();
108+
}
109+
106110
return staging;
107111
}
108112

0 commit comments

Comments
 (0)