Skip to content

Commit 148b475

Browse files
Merge pull request #945 from InfiniTensor/issue/810
issue/810 feat: allow graph tensor to resume to allocator's tracking
2 parents 7c97894 + c1535ae commit 148b475

File tree

10 files changed

+53
-0
lines changed

10 files changed

+53
-0
lines changed

include/infinicore/graph/graph.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class GraphManager;
1212
class GraphTensor : public Tensor {
1313
public:
1414
GraphTensor(const Tensor &);
15+
void resume() const;
1516
};
1617

1718
class GraphOperator {

include/infinicore/tensor.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ class Tensor {
9090
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
9191
std::shared_ptr<TensorImpl> impl_;
9292
friend class TensorImpl;
93+
94+
void resume_from_blob_() const;
9395
};
9496

9597
class TensorImpl : public std::enable_shared_from_this<TensorImpl> {

src/infinicore/context/allocators/pinnable_block_allocator.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,16 @@ void PinnableBlockAllocator::deallocate(std::byte *ptr) {
125125
}
126126
}
127127

128+
size_t PinnableBlockAllocator::mark_in_use_(void *ptr, bool in_use) {
129+
auto it = all_blocks_.find(reinterpret_cast<void *>(ptr));
130+
if (it == all_blocks_.end()) {
131+
throw std::runtime_error("Pointer not allocated by this allocator");
132+
}
133+
std::lock_guard<std::mutex> lock(mutex_);
134+
it->second->in_use = in_use;
135+
return it->second->size;
136+
}
137+
128138
// ------------------- trim -------------------
129139
void PinnableBlockAllocator::trim() {
130140
std::lock_guard<std::mutex> lock(mutex_);

src/infinicore/context/allocators/pinnable_block_allocator.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class PinnableBlockAllocator : public MemoryAllocator {
3232
// Switch pinned/graph mode
3333
void set_pin_mode(bool pinned) { pinned_mode_ = pinned; }
3434

35+
// internal use only, force set in_use flag for a mem block
36+
// return the size of the block
37+
size_t mark_in_use_(void *ptr, bool in_use);
38+
3539
// trim cached blocks back to GPU (not pinned)
3640
void trim();
3741

src/infinicore/context/context_impl.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "context_impl.hpp"
2+
#include "internal.hpp"
23

34
#include "../utils.hpp"
45

@@ -194,6 +195,12 @@ void addGraphOperator(std::shared_ptr<graph::GraphOperator> op) {
194195
std::shared_ptr<graph::Graph> stopGraphRecording() {
195196
return ContextImpl::singleton().getCurrentRuntime()->stopGraphRecording();
196197
}
198+
199+
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob) {
200+
setDevice(blob->device());
201+
return ContextImpl::singleton().getCurrentRuntime()->reinstantiateBlob(blob);
202+
}
203+
197204
} // namespace context
198205

199206
} // namespace infinicore
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
#include "infinicore/device.hpp"
4+
#include "infinicore/memory.hpp"
5+
6+
#include "infinicore/graph/graph.hpp"
7+
8+
namespace infinicore::context {
9+
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
10+
};

src/infinicore/context/runtime/runtime.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ std::shared_ptr<Memory> Runtime::allocatePinnedHostMemory(size_t size) {
7777
true);
7878
}
7979

80+
std::shared_ptr<Memory> Runtime::reinstantiateBlob(std::shared_ptr<Memory> blob) {
81+
device_memory_allocator_.get()->mark_in_use_(blob->data(), true);
82+
return std::make_shared<Memory>(
83+
blob->data(), blob->size(), device_,
84+
[alloc = device_memory_allocator_.get()](std::byte *p) {
85+
alloc->deallocate(p);
86+
});
87+
}
88+
8089
void Runtime::memcpyH2D(void *dst, const void *src, size_t size, bool async) {
8190
if (async) {
8291
INFINICORE_CHECK_ERROR(infinirtMemcpyAsync(dst, src, size, INFINIRT_MEMCPY_H2D, stream_));

src/infinicore/context/runtime/runtime.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Runtime {
3737

3838
std::shared_ptr<Memory> allocateMemory(size_t size);
3939
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size);
40+
std::shared_ptr<Memory> reinstantiateBlob(std::shared_ptr<Memory> blob);
4041

4142
void memcpyH2D(void *dst, const void *src, size_t size, bool async = true);
4243
void memcpyD2H(void *dst, const void *src, size_t size);

src/infinicore/graph/graph.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ namespace infinicore::graph {
1111
GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
1212
}
1313

14+
void GraphTensor::resume() const {
15+
resume_from_blob_();
16+
}
17+
1418
/* =========================
1519
* GraphOperator
1620
* ========================= */

src/infinicore/tensor/tensor.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "infinicore/tensor.hpp"
2+
#include "../context/internal.hpp"
23
#include "../utils.hpp"
34
#include "infinicore/context/context.hpp"
45
#include "infinicore/dtype.hpp"
@@ -64,6 +65,10 @@ Tensor::operator bool() const {
6465
return impl_ != nullptr;
6566
}
6667

68+
void Tensor::resume_from_blob_() const {
69+
context::reinstantiateBlob(impl_->data_.memory);
70+
}
71+
6772
TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
6873
: shape(_shape), strides(_strides), dtype(_dtype) {
6974
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));

0 commit comments

Comments
 (0)