Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions src/infinicore/context/context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,16 @@ Runtime *ContextImpl::getCurrentRuntime() {
return current_runtime_;
}

Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}

void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set.
return;
}

if (getCurrentRuntime()->isGraphRecording()) {
thread_local bool warn_switch_runtime = false;
if (getCurrentRuntime()->isGraphRecording() && !warn_switch_runtime) {
spdlog::warn("Switching device runtime during graph recording may break the graph!");
warn_switch_runtime = true;
}

if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
Expand Down Expand Up @@ -104,11 +102,8 @@ infinirtStream_t getStream() {
}

infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
setDevice(device);
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
}
Expand All @@ -127,7 +122,7 @@ std::shared_ptr<Memory> allocateMemory(size_t size) {

std::shared_ptr<Memory> allocateHostMemory(size_t size) {
setDevice(Device::cpu());
return ContextImpl::singleton().getCpuRuntime()->allocateMemory(size);
return allocateMemory(size);
}

std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) {
Expand All @@ -147,7 +142,8 @@ void memcpyD2D(void *dst, const void *src, size_t size, bool async) {
}

void memcpyH2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size);
setDevice(Device::cpu());
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size);
}

// Timing API implementations
Expand Down
2 changes: 0 additions & 2 deletions src/infinicore/context/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class ContextImpl {
public:
Runtime *getCurrentRuntime();

Runtime *getCpuRuntime();

void setDevice(Device);

size_t getDeviceCount(Device::Type type);
Expand Down
6 changes: 4 additions & 2 deletions src/infinicore/tensor/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ Tensor TensorImpl::to(Device device) const {

void TensorImpl::copy_from(Tensor src) {
if (src->shape() != this->shape()) {
throw std::runtime_error("Cannot copy from tensor with different shape");
throw std::runtime_error(
"Cannot copy from tensor with different shape. Src: " + src->info() + " Dst: " + this->info());
}
if (this->device() == src->device()) {
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), src);
Expand All @@ -31,11 +32,12 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size
size_t copy_size = std::min(this->nbytes(), src->nbytes());
if (this->device().getType() == Device::Type::CPU) {
context::setDevice(src->device());
if (this->is_contiguous()) {
context::setDevice(src->device());
context::memcpyD2H(this->data(), src->data(), copy_size);
} else {
auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device());
context::setDevice(src->device());
context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size());
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
}
Expand Down
19 changes: 11 additions & 8 deletions src/infinicore/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ inline struct SpdlogInitializer {
#define STRINGIZE_(x) #x
#define STRINGIZE(x) STRINGIZE_(x)

#define INFINICORE_CHECK_ERROR(call) \
do { \
SPDLOG_DEBUG("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
infiniStatus_t ret = (call); \
SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
} \
#define INFINICORE_CHECK_ERROR(call) \
do { \
SPDLOG_DEBUG("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
infiniStatus_t ret = (call); \
SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error("`" #call "` failed with error: " + std::string(infini_status_string(ret)) \
+ " from " + std::string(__func__) \
+ " at " + std::string(__FILE__) \
+ ":" + std::to_string(__LINE__) + "."); \
} \
} while (false)

#define INFINICORE_ASSERT_TENSORS_SAME_DEVICE(FIRST___, ...) \
Expand Down
18 changes: 13 additions & 5 deletions src/infinirt/cuda/infinirt_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

#define CHECK_CUDART(RT_API) CHECK_INTERNAL(RT_API, cudaSuccess)

#define RUN_CUDART(RT_API) \
do { \
auto api_result_ = (RT_API); \
if (api_result_ != (cudaSuccess)) { \
{ return INFINI_STATUS_INTERNAL_ERROR; } \
} \
} while (0)

// 根据宏定义选择命名空间并实现
#if defined(ENABLE_NVIDIA_API)
namespace infinirt::cuda {
Expand Down Expand Up @@ -40,7 +48,7 @@ infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
}

infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_CUDART(cudaStreamDestroy((cudaStream_t)stream));
RUN_CUDART(cudaStreamDestroy((cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -105,7 +113,7 @@ infiniStatus_t eventSynchronize(infinirtEvent_t event) {
}

infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_CUDART(cudaEventDestroy((cudaEvent_t)event));
RUN_CUDART(cudaEventDestroy((cudaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}

Expand All @@ -125,12 +133,12 @@ infiniStatus_t mallocHost(void **p_ptr, size_t size) {
}

infiniStatus_t freeDevice(void *ptr) {
CHECK_CUDART(cudaFree(ptr));
RUN_CUDART(cudaFree(ptr));
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t freeHost(void *ptr) {
CHECK_CUDART(cudaFreeHost(ptr));
RUN_CUDART(cudaFreeHost(ptr));
return INFINI_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -165,7 +173,7 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
}

infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
}