Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1135,8 +1135,6 @@ void ComputeGraph::prepack() {
int i = 0;
bool submitted = false;
const bool reduce_peak_memory = total_constant_nbytes_ > 10 * MB;
// int count = 0;

context_->set_cmd();
for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
// Do not trigger on the first or last prepack node.
Expand Down
8 changes: 6 additions & 2 deletions backends/vulkan/runtime/graph/containers/Constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ struct TensorRef final {
// This will be empty (default constructed) for the raw pointer constructor
executorch::runtime::FreeableBuffer buffer;

// Number of PrepackNodes that still need to read from this TensorRef. When
// this reaches 0, the buffer can be safely freed. This prevents
// use-after-free when multiple PrepackNodes reference the same TensorRef
// (e.g. shared/tied weights).
int32_t prepack_use_count{0};

explicit TensorRef(
const std::vector<int64_t>& t_sizes,
vkapi::ScalarType t_dtype,
Expand All @@ -44,8 +50,6 @@ struct TensorRef final {
return utils::multiply_integers(sizes) * vkapi::element_size(dtype);
}

// Manually free the buffer if needed (though it will be freed automatically
// on destruction)
void free_buffer() {
buffer.Free();
}
Expand Down
10 changes: 7 additions & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ PrepackNode::PrepackNode(
push_constants_(push_constants) {
graph.update_descriptor_counts(shader, /*execute = */ false);
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
if (!graph.val_is_none(tref)) {
graph.get_tref(tref)->prepack_use_count++;
}
}

api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
Expand Down Expand Up @@ -100,9 +103,10 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
}
}

// Once the staging buffer is copied, if the TensorRef owns a FreeableBuffer,
// it can be freed.
tref->free_buffer();
if (--tref->prepack_use_count == 0) {
tref->free_buffer();
}

return staging;
}

Expand Down
Loading