diff --git a/examples/ucm_config_asu.yaml b/examples/ucm_config_asu.yaml new file mode 100644 index 000000000..a89a3cbd2 --- /dev/null +++ b/examples/ucm_config_asu.yaml @@ -0,0 +1,33 @@ +# UCM ASU fake backend example for vLLM / vLLM-Ascend software integration tests. +# +# Use with: +# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_asu.yaml"} + +ucm_connectors: + - ucm_connector_name: "UcmPipelineStore" + ucm_connector_config: + store_pipeline: "ASU" + asu_mode: "client" + asu_client_id: "ucm-vllm-asu-fake" + asu_ids: [1, 2] + asu_trans_provider_backend: "fake" + + # ASU client routing. Keep the default close to kv-test's fake_backend config. + asu_router_type: "RING_HASH" + asu_ring_hash_virtual_node_count: 128 + + # Local software backend. This uses normal AsuClient + AsuTransportImpl submit/CQE paths + # and replaces the missing device Send/backend execution with local files. + asu_fake_backend_path: "./asu-fake-backend-store" + asu_fake_backend_latency_ms: 1 + + # ASU wait and operation timeouts. + asu_default_wait_timeout_ms: 5000 + asu_query_timeout_ms: 5000 + asu_load_timeout_ms: 5000 + asu_store_timeout_ms: 5000 + asu_max_inflight_tasks: 1024 + +enable_event_sync: true +use_layerwise: false +persist_token_threshold: 0 diff --git a/scripts/build_ascend.sh b/scripts/build_ascend.sh index ce0ed58bb..3bcfee310 100644 --- a/scripts/build_ascend.sh +++ b/scripts/build_ascend.sh @@ -73,6 +73,7 @@ function build_wheels() cd ${KVCACHE_PROJECT_ROOT} export ENABLE_SPARSE=${ENABLE_SPARSE:-TRUE} + export BUILD_UCM_ASU=${BUILD_UCM_ASU:-ON} python3 -m build --no-isolation --wheel if [ $? -eq 0 ]; then @@ -105,6 +106,7 @@ function collect_artifacts() cp -r "${KVCACHE_PROJECT_ROOT}/docker" . cp -r "${KVCACHE_PROJECT_ROOT}/examples/deployments" . cp -r "${KVCACHE_PROJECT_ROOT}/examples/ucm_config_example.yaml" ucm_config.yaml + cp -r "${KVCACHE_PROJECT_ROOT}/examples/ucm_config_asu_fake_backend.yaml" ucm_config_asu_fake_backend.yaml cp -r "${KVCACHE_PROJECT_ROOT}/examples/metrics/metrics_configs.yaml" metrics_configs.yaml cp -r "${KVCACHE_PROJECT_ROOT}/test" . cp "${KVCACHE_PROJECT_ROOT}/ucm/integration/vllm/patch/0.11.0/vllm-adapt.patch" . @@ -124,4 +126,4 @@ function package_all() check_build_install build_wheels -package_all \ No newline at end of file +package_all diff --git a/setup.py b/setup.py index 9b2f35639..d4a75fc45 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) PLATFORM = os.getenv("PLATFORM") ENABLE_SPARSE = os.getenv("ENABLE_SPARSE") +BUILD_UCM_ASU = os.getenv("BUILD_UCM_ASU", "0") not in ("", "0", "false", "False") ENABLE_MINDIE = os.getenv("UCM_ENABLE_MINDIE", "0") not in ("", "0", "false", "False") @@ -166,6 +167,8 @@ def build_cmake(self, ext: CMakeExtension): if enable_sparse(): cmake_args += ["-DBUILD_UCM_SPARSE=ON"] + if BUILD_UCM_ASU: + cmake_args += ["-DBUILD_UCM_ASU=ON"] match PLATFORM: case "cuda": diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 3bd7a83e4..d24c8a560 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -355,6 +355,9 @@ def _create_store( * (1 if self.is_mla else self.num_head * 2) * self.blocks_per_chunk ) + if config.get("store_pipeline") == "ASU": + config["shard_size"] = config["block_size"] + config["tensor_size"] = config["block_size"] dp_rank = self._vllm_config.parallel_config.data_parallel_rank config["posix_gc_enable"] = ( self._role != KVConnectorRole.WORKER and dp_rank == 0 diff --git a/ucm/store/asu/CMakeLists.txt b/ucm/store/asu/CMakeLists.txt index a3e5d9446..a12f4a712 100644 --- a/ucm/store/asu/CMakeLists.txt +++ b/ucm/store/asu/CMakeLists.txt @@ -4,7 +4,12 @@ target_include_directories(asustore PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cc ) -target_link_libraries(asustore PUBLIC storeintf asu_client infra_logger) +target_link_libraries(asustore PUBLIC storeintf asu_client infra_logger PRIVATE asu_ascend_deps) +set_target_properties(asustore PROPERTIES + BUILD_WITH_INSTALL_RPATH TRUE + BUILD_RPATH "$ORIGIN;$ORIGIN/../../transport/kv/asu" + INSTALL_RPATH "$ORIGIN;$ORIGIN/../../transport/kv/asu" +) file(RELATIVE_PATH INSTALL_REL_PATH ${UCM_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) install(TARGETS asustore LIBRARY DESTINATION ${INSTALL_REL_PATH} COMPONENT ucm) diff --git a/ucm/store/asu/cc/asu_store.cc b/ucm/store/asu/cc/asu_store.cc index bec5eb83d..1688dbd6f 100644 --- a/ucm/store/asu/cc/asu_store.cc +++ b/ucm/store/asu/cc/asu_store.cc @@ -22,13 +22,17 @@ * SOFTWARE. * */ #include "asu_store.h" +#include #include +#include +#include #include #include #include #include #include #include +#include #include #include #include "asu_client/asu_client.h" @@ -42,13 +46,16 @@ namespace { using AsuStatus = UC::ASU::Status; using AsuStatusCode = UC::ASU::StatusCode; -std::string ToHex(const Detail::BlockId& block) +std::uint64_t HashAsuKey(const Detail::BlockId& block) +{ + static Detail::BlockIdHasher hasher; + return static_cast(hasher(block)); +} + +std::string MakeAsuKey(const Detail::BlockId& block) { std::ostringstream os; - os << std::hex << std::setfill('0'); - for (auto b : block) { - os << std::setw(2) << static_cast(std::to_integer(b)); - } + os << std::hex << std::setfill('0') << std::setw(16) << HashAsuKey(block); return os.str(); } @@ -78,6 +85,36 @@ void LogAsuStatus(const char* operation, const AsuStatus& status) status.message); } +AsuStatus WaitPrerequisiteEvent(std::uintptr_t eventHandle) +{ + if (eventHandle == 0) { return AsuStatus::OK(); } + auto ret = aclrtSynchronizeEvent(reinterpret_cast(eventHandle)); + if (ret == ACL_SUCCESS) { return AsuStatus::OK(); } + return AsuStatus::Error(AsuStatusCode::INTERNAL_ERROR, + "aclrtSynchronizeEvent failed: " + std::to_string(ret)); +} + +const char* TransProviderBackendName(UC::ASU::TransProviderType providerType) +{ + switch (providerType) { + case UC::ASU::TransProviderType::FAKE: return "fake"; + case UC::ASU::TransProviderType::AIV: return "aiv"; + case UC::ASU::TransProviderType::AICPU: return "aicpu"; + case UC::ASU::TransProviderType::UNSUPPORTED: return "unsupported"; + } + return "unknown"; +} + +UC::ASU::TransProviderType ParseTransProviderBackend(std::string backend) +{ + std::transform(backend.begin(), backend.end(), backend.begin(), + [](unsigned char ch) { return static_cast(std::toupper(ch)); }); + if (backend == "FAKE") { return UC::ASU::TransProviderType::FAKE; } + if (backend == "AIV") { return UC::ASU::TransProviderType::AIV; } + if (backend == "AICPU") { return UC::ASU::TransProviderType::AICPU; } + return UC::ASU::TransProviderType::UNSUPPORTED; +} + UC::ASU::MemoryType ParseMemoryType(const std::string& memoryType) { if (memoryType == "host") { return UC::ASU::MemoryType::HOST; } @@ -86,6 +123,39 @@ UC::ASU::MemoryType ParseMemoryType(const std::string& memoryType) return UC::ASU::MemoryType::ASCEND_DEVICE; } +bool TryGetStringLike(const Detail::Dictionary& inConfig, const std::string& key, + std::string& value) +{ + if (!inConfig.Contains(key)) { return false; } + try { + inConfig.Get(key, value); + return true; + } catch (const std::bad_any_cast&) { + } + try { + bool boolValue = false; + inConfig.Get(key, boolValue); + value = boolValue ? "true" : "false"; + return true; + } catch (const std::bad_any_cast&) { + } + try { + ssize_t numberValue = 0; + inConfig.GetNumber(key, numberValue); + value = std::to_string(numberValue); + return true; + } catch (const std::bad_any_cast&) { + } + return false; +} + +void ReadClientAttr(const Detail::Dictionary& inConfig, const std::string& yamlKey, + const std::string& attrKey, Config& config) +{ + std::string value; + if (TryGetStringLike(inConfig, yamlKey, value)) { config.clientAttrs[attrKey] = value; } +} + } // namespace UC::ASU::TransportConfig BuildTransportConfig(const Config& config, std::size_t index) @@ -98,6 +168,7 @@ UC::ASU::TransportConfig BuildTransportConfig(const Config& config, std::size_t transportConfig.storeTimeoutMs = config.storeTimeoutMs; transportConfig.maxInflightTasks = static_cast(config.maxInflightTasks); transportConfig.maxInflightBytes = config.maxInflightBytes; + transportConfig.providerType = config.transProviderType; if (!config.asuIps.empty()) { UC::ASU::AsuEndpoint endpoint; endpoint.ip = config.asuIps[index]; @@ -105,6 +176,28 @@ UC::ASU::TransportConfig BuildTransportConfig(const Config& config, std::size_t endpoint.deviceId = config.deviceId; transportConfig.endpoints.emplace_back(std::move(endpoint)); } + if (config.transProviderType == UC::ASU::TransProviderType::FAKE) { + const auto fakeDeviceId = config.deviceId >= 0 ? config.deviceId : 0; + transportConfig.attrs.try_emplace("kernel_count", "1"); + transportConfig.attrs.try_emplace("quiet_count", "1"); + transportConfig.attrs["kv_ns_id"] = std::to_string(transportConfig.asuId); + transportConfig.attrs.try_emplace("dtype", "0"); + transportConfig.attrs.try_emplace("dspec", "0"); + transportConfig.attrs.try_emplace("lr", "false"); + transportConfig.attrs["sc"] = "true"; + transportConfig.attrs["fake_backend.path"] = config.fakeBackendPath; + transportConfig.attrs["fake_backend.latency_ms"] = + std::to_string(config.fakeBackendLatencyMs); + transportConfig.attrs["fake_backend.device_id"] = std::to_string(fakeDeviceId); + if (transportConfig.endpoints.empty()) { + UC::ASU::AsuEndpoint endpoint; + endpoint.ip = "fake_backend"; + endpoint.port = 19001; + endpoint.protocol = UC::ASU::Protocol::TCP; + endpoint.deviceId = fakeDeviceId; + transportConfig.endpoints.emplace_back(std::move(endpoint)); + } + } return transportConfig; } @@ -117,6 +210,7 @@ class ClientBackend final : public AsuBackend { asuConfig.clientId = config.clientId; asuConfig.viewServiceAddrs = config.viewServiceAddrs; asuConfig.defaultWaitTimeoutMs = config.defaultWaitTimeoutMs; + asuConfig.attrs = config.clientAttrs; asuConfig.transportConfigs.reserve(config.asuIds.size()); for (std::size_t i = 0; i < config.asuIds.size(); ++i) { asuConfig.transportConfigs.emplace_back(BuildTransportConfig(config, i)); @@ -244,10 +338,10 @@ class AsuStore final : public StoreV1 { ~AsuStore() override { - if (!backend_) { return; } - - auto status = backend_->Shutdown(); - if (!status.ok()) { UC_ERROR("Failed to shutdown ASU backend: {}.", status.message); } + if (backend_) { + auto status = backend_->Shutdown(); + if (!status.ok()) { UC_ERROR("Failed to shutdown ASU backend: {}.", status.message); } + } } Status Setup(const Detail::Dictionary& inConfig) override @@ -317,6 +411,11 @@ class AsuStore final : public StoreV1 { Expected Dump(Detail::TaskDesc task) override { + auto status = WaitPrerequisiteEvent(task.prerequisiteHandle); + if (!status.ok()) { + LogAsuStatus("wait prerequisite event", status); + return ConvertStatus(status); + } return Submit(std::move(task), &AsuBackend::StoreAsync); } @@ -374,6 +473,26 @@ class AsuStore final : public StoreV1 { inConfig.GetNumber("block_size", config.blockSize); inConfig.GetNumber("device_id", config.deviceId); inConfig.Get("asu_memory_type", config.memoryType); + std::string providerBackend; + if (TryGetStringLike(inConfig, "asu_trans_provider_backend", providerBackend)) { + config.transProviderType = ParseTransProviderBackend(providerBackend); + } + inConfig.Get("asu_fake_backend_path", config.fakeBackendPath); + inConfig.GetNumber("asu_fake_backend_latency_ms", config.fakeBackendLatencyMs); + ReadClientAttr(inConfig, "asu_router_type", "hash_table.type", config); + ReadClientAttr(inConfig, "asu_ring_hash_virtual_node_count", "ring_hash.virtual_node_count", + config); + ReadClientAttr(inConfig, "asu_maglev_table_size", "maglev.table_size", config); + ReadClientAttr(inConfig, "asu_contiguous_block_affinity_block_count", + "contiguous_block_affinity.block_count", config); + ReadClientAttr(inConfig, "asu_contiguous_block_affinity_full_spread_type", + "contiguous_block_affinity.full_spread_type", config); + ReadClientAttr(inConfig, "asu_contiguous_block_affinity_dynamic_adjust_enabled", + "contiguous_block_affinity.dynamic_adjust_enabled", config); + ReadClientAttr(inConfig, "asu_batch_topk_affinity_top_k", "batch_topk_affinity.top_k", + config); + ReadClientAttr(inConfig, "asu_batch_topk_affinity_dynamic_adjust_enabled", + "batch_topk_affinity.dynamic_adjust_enabled", config); std::size_t tensorSize = 0; inConfig.GetNumber("tensor_size", tensorSize); @@ -406,6 +525,14 @@ class AsuStore final : public StoreV1 { if (!config.asuIps.empty() && config.asuIps.size() != config.asuIds.size()) { return Status::InvalidParam("asu_ips size must match asu_ids size"); } + if (config.transProviderType == UC::ASU::TransProviderType::UNSUPPORTED) { + return Status::Unsupported(); + } + if (config.transProviderType == UC::ASU::TransProviderType::FAKE && + !config.configPath.empty()) { + return Status::InvalidParam( + "asu_trans_provider_backend=fake does not support asu_config_path"); + } if (config.tensorSizes.empty()) { return Status::InvalidParam("invalid tensor size"); } if (config.shardSize == 0) { return Status::InvalidParam("invalid shard size"); } if (config.blockSize == 0) { return Status::InvalidParam("invalid block size"); } @@ -417,12 +544,6 @@ class AsuStore final : public StoreV1 { if (config.blockSize % config.shardSize != 0) { return Status::InvalidParam("invalid block size({})", config.blockSize); } - if (config.blockSize != config.shardSize) { - return Status::InvalidParam("asu store requires one shard per block"); - } - if (config.tensorSizes.size() != 1) { - return Status::InvalidParam("asu store requires one tensor buffer per block"); - } return Status::OK(); } @@ -464,7 +585,7 @@ class AsuStore final : public StoreV1 { std::vector keys; keys.reserve(num); for (std::size_t blockIndex = 0; blockIndex < num; ++blockIndex) { - keys.emplace_back(ToHex(blocks[blockIndex])); + keys.emplace_back(MakeAsuKey(blocks[blockIndex])); } return keys; } @@ -501,7 +622,7 @@ class AsuStore final : public StoreV1 { } for (std::size_t tensorIndex = 0; tensorIndex < shard.addrs.size(); ++tensorIndex) { UC::ASU::KVBuffer entry; - entry.key = ToHex(shard.owner); + entry.key = MakeAsuKey(shard.owner); entry.buffer.region.memoryType = memoryType; entry.buffer.region.addr = reinterpret_cast(shard.addrs[tensorIndex]); @@ -526,6 +647,9 @@ class AsuStore final : public StoreV1 { UC_INFO("Set AsuStore::BlockSize to {}.", config.blockSize); UC_INFO("Set AsuStore::TensorSizes to {}.", config.tensorSizes); UC_INFO("Set AsuStore::DeviceId to {}.", config.deviceId); + UC_INFO("Set AsuStore::TransProviderBackend to {}.", + TransProviderBackendName(config.transProviderType)); + UC_INFO("Set AsuStore::FakeBackendPath to {}.", config.fakeBackendPath); } Config config_; diff --git a/ucm/store/asu/cc/asu_store.h b/ucm/store/asu/cc/asu_store.h index e3b16826f..60b72735a 100644 --- a/ucm/store/asu/cc/asu_store.h +++ b/ucm/store/asu/cc/asu_store.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "asu_transport/types.h" @@ -28,6 +29,10 @@ struct Config { std::size_t blockSize{0}; std::int32_t deviceId{-1}; std::string memoryType; + UC::ASU::TransProviderType transProviderType{UC::ASU::TransProviderType::AICPU}; + std::string fakeBackendPath; + std::uint64_t fakeBackendLatencyMs{1}; + std::unordered_map clientAttrs; }; class AsuBackend { diff --git a/ucm/store/test/CMakeLists.txt b/ucm/store/test/CMakeLists.txt index e49871c26..658aa80a1 100644 --- a/ucm/store/test/CMakeLists.txt +++ b/ucm/store/test/CMakeLists.txt @@ -11,7 +11,7 @@ if(BUILD_UNIT_TESTS) gtest_main gtest gmock ) if(BUILD_UCM_ASU) - target_link_libraries(ucmstore.test PRIVATE asu_client) + target_link_libraries(ucmstore.test PRIVATE asu_client asu_ascend_deps) target_compile_definitions(ucmstore.test PRIVATE ASU_BUILD_TESTS) endif() gtest_discover_tests(ucmstore.test) diff --git a/ucm/store/test/case/asu/asu_store_test.cc b/ucm/store/test/case/asu/asu_store_test.cc index f5519569c..f755dfce9 100644 --- a/ucm/store/test/case/asu/asu_store_test.cc +++ b/ucm/store/test/case/asu/asu_store_test.cc @@ -352,9 +352,10 @@ TEST(UCAsuStoreTest, RejectsInvalidTensorLayout) ASSERT_TRUE(status.Failure()); } -TEST(UCAsuStoreTest, RejectsMultipleShardsPerBlock) +TEST(UCAsuStoreTest, AllowsMultipleShardsPerBlock) { UC::AsuStore::AsuStore store; + UseFakeBackend(store); auto config = MakeBaseConfig(); config.Set("asu_mode", std::string{"transport"}); config.Set("asu_ips", std::vector{"127.0.0.1"}); @@ -362,12 +363,13 @@ TEST(UCAsuStoreTest, RejectsMultipleShardsPerBlock) config.SetNumber("block_size", std::size_t{128}); auto status = store.Setup(config); - ASSERT_TRUE(status.Failure()); + ASSERT_TRUE(status.Success()) << status.ToString(); } -TEST(UCAsuStoreTest, RejectsMultipleTensorBuffersPerBlock) +TEST(UCAsuStoreTest, AllowsMultipleTensorBuffersPerShard) { UC::AsuStore::AsuStore store; + UseFakeBackend(store); auto config = MakeBaseConfig(); config.Set("asu_mode", std::string{"transport"}); config.Set("asu_ips", std::vector{"127.0.0.1"}); @@ -376,5 +378,17 @@ TEST(UCAsuStoreTest, RejectsMultipleTensorBuffersPerBlock) config.Set("tensor_size_list", std::vector{32, 32}); auto status = store.Setup(config); - ASSERT_TRUE(status.Failure()); + ASSERT_TRUE(status.Success()) << status.ToString(); + + std::array first{}; + std::array second{}; + auto block = UC::Test::Detail::TypesHelper::MakeBlockId("aab2c3d4e5f6789012345678901234ab"); + UC::Detail::TaskDesc task; + task.brief = "asu-store-test"; + task.push_back(UC::Detail::Shard{ + block, 0, {first.data(), second.data()} + }); + auto dump = store.Dump(task); + ASSERT_TRUE(dump.HasValue()) << dump.Error().ToString(); + ASSERT_TRUE(store.Wait(dump.Value()).Success()); } diff --git a/ucm/transport/kv/asu/CMakeLists.txt b/ucm/transport/kv/asu/CMakeLists.txt index df4ecbf88..09a173b1a 100644 --- a/ucm/transport/kv/asu/CMakeLists.txt +++ b/ucm/transport/kv/asu/CMakeLists.txt @@ -3,7 +3,13 @@ if(NOT RUNTIME_ENVIRONMENT STREQUAL "ascend") endif() if(NOT DEFINED ASCEND_ROOT) - set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory") + if(DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_ROOT "$ENV{ASCEND_HOME_PATH}" CACHE PATH "Path to Ascend root directory") + elseif(DEFINED ENV{ASCEND_TOOLKIT_HOME}) + set(ASCEND_ROOT "$ENV{ASCEND_TOOLKIT_HOME}" CACHE PATH "Path to Ascend root directory") + else() + set(ASCEND_ROOT "/usr/local/Ascend/ascend-toolkit/latest" CACHE PATH "Path to Ascend root directory") + endif() endif() find_path(ASU_ASCEND_INCLUDE_DIR @@ -23,6 +29,9 @@ find_library(ASU_ASCENDCL_LIB HINTS ${ASCEND_ROOT}/lib64 ${ASCEND_ROOT}/aarch64-linux/lib64 + ${ASCEND_ROOT}/aarch64-linux/devlib + ${ASCEND_ROOT}/arm64-linux/lib64 + ${ASCEND_ROOT}/arm64-linux/devlib NO_DEFAULT_PATH ) if(NOT ASU_ASCENDCL_LIB) @@ -82,6 +91,12 @@ target_include_directories(asu_client target_link_libraries(asu_client PUBLIC asu_transport kv_common pthread) target_compile_definitions(asu_client PRIVATE SPDLOG_FMT_EXTERNAL) +file(RELATIVE_PATH INSTALL_REL_PATH ${UCM_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) +set_target_properties(asu_client PROPERTIES INSTALL_RPATH "$ORIGIN") +set_target_properties(asu_transport PROPERTIES INSTALL_RPATH "$ORIGIN") +install(TARGETS asu_client asu_transport LIBRARY DESTINATION ${INSTALL_REL_PATH} COMPONENT ucm) +install(TARGETS asu_client asu_transport LIBRARY DESTINATION ucm/store/asu COMPONENT ucm) + if(BUILD_UNIT_TESTS) target_compile_definitions(asu_transport PRIVATE ASU_BUILD_TESTS) include(GoogleTest) diff --git a/ucm/transport/kv/asu/client/src/asu_client_impl.cpp b/ucm/transport/kv/asu/client/src/asu_client_impl.cpp index 11c6f9dce..c448f9cae 100644 --- a/ucm/transport/kv/asu/client/src/asu_client_impl.cpp +++ b/ucm/transport/kv/asu/client/src/asu_client_impl.cpp @@ -43,6 +43,41 @@ Status PartialFailed(const std::string& message) return Status::Error(StatusCode::PARTIAL_FAILED, message); } +const char* ClientOpTypeName(ClientOpType opType) +{ + switch (opType) { + case ClientOpType::LOAD: return "load"; + case ClientOpType::STORE: return "store"; + case ClientOpType::DELETE: return "delete"; + default: return "unknown"; + } +} + +std::size_t SubTaskItemCount(const ClientSubTask& subTask) +{ + return subTask.entries.empty() ? subTask.keys.size() : subTask.entries.size(); +} + +std::string SubTaskContext(const ClientTaskContext& ctx, const ClientSubTask& subTask) +{ + return "client_task_id=" + std::to_string(ctx.taskId) + " op=" + ClientOpTypeName(ctx.opType) + + " asuId=" + std::to_string(subTask.asuId) + + " trans_task_id=" + std::to_string(subTask.transTaskId) + + " item_count=" + std::to_string(SubTaskItemCount(subTask)); +} + +std::string FirstFailedSubTaskContext(const ClientTaskContext& ctx) +{ + for (const auto& subTask : ctx.subTasks) { + if (!subTask.failed) { continue; } + + return SubTaskContext(ctx, subTask) + + " code=" + std::to_string(static_cast(subTask.status.code)) + + " message=" + subTask.status.message; + } + return "client_task_id=" + std::to_string(ctx.taskId) + " op=" + ClientOpTypeName(ctx.opType); +} + std::vector ExtractEntryKeys(const std::vector& entries) { std::vector keys; @@ -535,6 +570,10 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx) if (transIter == snapshot->transports.end()) { subTask.completed = true; subTask.failed = true; + subTask.status = Status::Error(StatusCode::NOT_FOUND, "routed asu transport not found"); + UC_ERROR("ASU client subtask check failed: {} code={} message={}.", + SubTaskContext(*ctx, subTask), static_cast(subTask.status.code), + subTask.status.message); anyFailed = true; continue; } @@ -544,6 +583,9 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx) if (!status.ok()) { subTask.completed = true; subTask.failed = true; + subTask.status = status; + UC_ERROR("ASU client subtask check failed: {} code={} message={}.", + SubTaskContext(*ctx, subTask), static_cast(status.code), status.message); anyFailed = true; continue; } @@ -554,7 +596,13 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx) subTask.completed = true; if (!subResult.status.ok()) { subTask.failed = true; + subTask.status = subResult.status; + UC_ERROR("ASU client subtask result failed after check: {} code={} message={}.", + SubTaskContext(*ctx, subTask), static_cast(subResult.status.code), + subResult.status.message); anyFailed = true; + } else { + subTask.status = Status::OK(); } const auto& originalIndices = subTask.originalIndices; @@ -566,8 +614,10 @@ bool AsuClientImpl::PollTask(const ClientTaskContextPtr& ctx) if (allDone) { ctx->finalStatus = - anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED, "client task partially failed") - : Status::OK(); + anyFailed + ? Status::Error(StatusCode::PARTIAL_FAILED, + "client task partially failed: " + FirstFailedSubTaskContext(*ctx)) + : Status::OK(); ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release); ctx->cv.notify_all(); return true; @@ -599,7 +649,15 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint if (!snapshot || ctx->state.load(std::memory_order_acquire) != ClientTaskState::INFLIGHT) { if (std::chrono::steady_clock::now() >= deadline) { BuildResult(ctx, result); - result.status = Status::Error(StatusCode::TIMEOUT, "client task wait timeout"); + result.status = Status::Error( + StatusCode::TIMEOUT, + "client task wait timeout before inflight: client_task_id=" + + std::to_string(ctx->taskId) + " op=" + ClientOpTypeName(ctx->opType) + + " wait_ms=" + std::to_string(waitMs)); + UC_ERROR( + "ASU client task wait timeout before inflight: client_task_id={} op={} " + "wait_ms={}.", + ctx->taskId, ClientOpTypeName(ctx->opType), waitMs); return result.status; } ctx->cv.wait_until(lock, deadline); @@ -616,6 +674,11 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint if (transIter == snapshot->transports.end()) { subTask.completed = true; subTask.failed = true; + subTask.status = + Status::Error(StatusCode::NOT_FOUND, "routed asu transport not found"); + UC_ERROR("ASU client subtask wait failed: {} code={} message={}.", + SubTaskContext(*ctx, subTask), static_cast(subTask.status.code), + subTask.status.message); anyFailed = true; continue; } @@ -623,7 +686,12 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint const auto now = std::chrono::steady_clock::now(); if (now >= deadline) { BuildResult(ctx, result); - result.status = Status::Error(StatusCode::TIMEOUT, "client task wait timeout"); + result.status = Status::Error(StatusCode::TIMEOUT, + "client task wait timeout before subtask wait: " + + SubTaskContext(*ctx, subTask) + + " wait_ms=" + std::to_string(waitMs)); + UC_ERROR("ASU client task wait timeout before subtask wait: {} wait_ms={}.", + SubTaskContext(*ctx, subTask), waitMs); return result.status; } const auto remainingMs = @@ -636,7 +704,14 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint if (status.code == StatusCode::TIMEOUT) { BuildResult(ctx, result); - result.status = Status::Error(StatusCode::TIMEOUT, "client task wait timeout"); + subTask.status = status; + result.status = Status::Error( + StatusCode::TIMEOUT, + "client task transport wait timeout: " + SubTaskContext(*ctx, subTask) + + " sub_timeout_ms=" + std::to_string(subTimeoutMs) + + " message=" + status.message); + UC_ERROR("ASU client transport wait timeout: {} sub_timeout_ms={} message={}.", + SubTaskContext(*ctx, subTask), subTimeoutMs, status.message); return result.status; } if (status.code == StatusCode::IN_PROGRESS || @@ -647,7 +722,15 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint subTask.completed = true; if (!status.ok() || !subResult.status.ok()) { subTask.failed = true; + subTask.status = !status.ok() ? status : subResult.status; + UC_ERROR( + "ASU client subtask result failed after wait: {} wait_status_code={} " + "wait_message={} result_status_code={} result_message={}.", + SubTaskContext(*ctx, subTask), static_cast(status.code), status.message, + static_cast(subResult.status.code), subResult.status.message); anyFailed = true; + } else { + subTask.status = Status::OK(); } const auto& originalIndices = subTask.originalIndices; @@ -663,7 +746,8 @@ Status AsuClientImpl::WaitTaskContext(const ClientTaskContextPtr& ctx, std::uint } if (allDone) { ctx->finalStatus = anyFailed ? Status::Error(StatusCode::PARTIAL_FAILED, - "client task partially failed") + "client task partially failed: " + + FirstFailedSubTaskContext(*ctx)) : Status::OK(); ctx->state.store(ClientTaskState::COMPLETED, std::memory_order_release); ctx->cv.notify_all(); diff --git a/ucm/transport/kv/asu/client/src/client_config_parser.cpp b/ucm/transport/kv/asu/client/src/client_config_parser.cpp index 1dc8cb528..8aabbf1ce 100644 --- a/ucm/transport/kv/asu/client/src/client_config_parser.cpp +++ b/ucm/transport/kv/asu/client/src/client_config_parser.cpp @@ -77,7 +77,8 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi config.attrs["view.config_path"] = value; } else if (key == "defaultWaitTimeoutMs" || key == "default_wait_timeout_ms") { config.defaultWaitTimeoutMs = ParseConfigUint64(value); - } else if (key == "hashTable.type" || key == "hash_table.type") { + } else if (key == "router.type" || key == "routerType" || key == "hashTable.type" || + key == "hash_table.type") { auto type = value; std::transform(type.begin(), type.end(), type.begin(), [](unsigned char ch) { return static_cast(std::toupper(ch)); }); @@ -136,6 +137,9 @@ Status LoadAsuClientConfig(const std::string& configPath, AsuClientConfig& confi if (ApplyTransportIoNumConfigField(transportConfig, field.first, field.second)) { continue; } + if (ApplyTransportProviderConfigField(transportConfig, field.first, field.second)) { + continue; + } transportConfig.attrs.emplace(field); } diff --git a/ucm/transport/kv/asu/client/src/client_task_manager.h b/ucm/transport/kv/asu/client/src/client_task_manager.h index 2743872b7..22053b1f0 100644 --- a/ucm/transport/kv/asu/client/src/client_task_manager.h +++ b/ucm/transport/kv/asu/client/src/client_task_manager.h @@ -42,6 +42,7 @@ struct ClientSubTask { TaskId transTaskId{kInvalidTaskId}; bool completed{false}; bool failed{false}; + Status status{Status::OK()}; // TODO: optimize by zero-copy ? std::vector entries; diff --git a/ucm/transport/kv/asu/common/config_parser_common.cpp b/ucm/transport/kv/asu/common/config_parser_common.cpp index f1f21799e..0651ea542 100644 --- a/ucm/transport/kv/asu/common/config_parser_common.cpp +++ b/ucm/transport/kv/asu/common/config_parser_common.cpp @@ -121,6 +121,16 @@ Protocol ParseConfigProtocol(std::string value) return Protocol::TCP; } +TransProviderType ParseConfigTransProviderType(std::string value) +{ + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::toupper(ch)); }); + if (value == "FAKE") { return TransProviderType::FAKE; } + if (value == "AIV") { return TransProviderType::AIV; } + if (value == "AICPU") { return TransProviderType::AICPU; } + return TransProviderType::UNSUPPORTED; +} + bool ApplyTransportBufferConfigField(TransportConfig& config, const std::string& key, const std::string& value) { @@ -159,6 +169,18 @@ bool ApplyTransportIoNumConfigField(TransportConfig& config, const std::string& return true; } +bool ApplyTransportProviderConfigField(TransportConfig& config, const std::string& key, + const std::string& value) +{ + if (key == "providerBackend" || key == "provider_backend" || key == "transProviderBackend" || + key == "trans_provider_backend") { + config.providerType = ParseConfigTransProviderType(value); + } else { + return false; + } + return true; +} + bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId) { constexpr const char* kCamelPrefix = "asuInfo."; diff --git a/ucm/transport/kv/asu/common/config_parser_common.h b/ucm/transport/kv/asu/common/config_parser_common.h index 4e9c4ec3e..005a5946f 100644 --- a/ucm/transport/kv/asu/common/config_parser_common.h +++ b/ucm/transport/kv/asu/common/config_parser_common.h @@ -34,11 +34,14 @@ std::string TrimConfigValue(const std::string& value); std::vector SplitConfigValue(const std::string& value, char delimiter); std::uint64_t ParseConfigUint64(const std::string& value); Protocol ParseConfigProtocol(std::string value); +TransProviderType ParseConfigTransProviderType(std::string value); bool ApplyTransportBufferConfigField(TransportConfig& config, const std::string& key, const std::string& value); bool ApplyTransportIoNumConfigField(TransportConfig& config, const std::string& key, const std::string& value); +bool ApplyTransportProviderConfigField(TransportConfig& config, const std::string& key, + const std::string& value); bool TryParseAsuInfoKey(const std::string& key, AsuId& asuId); bool TryGetTransportAttrKey(const std::string& key, std::string& attrKey); diff --git a/ucm/transport/kv/asu/test/case/connection_manager_test.cc b/ucm/transport/kv/asu/test/case/connection_manager_test.cc index 9ab5e9602..0be2f5d19 100644 --- a/ucm/transport/kv/asu/test/case/connection_manager_test.cc +++ b/ucm/transport/kv/asu/test/case/connection_manager_test.cc @@ -64,9 +64,15 @@ class StubTransProvider : public TransProvider { { return {}; } - Status RegisterMemory(ConnectionHandle, const std::vector&, - std::vector&) override + Status RegisterMemory(ConnectionHandle, const std::vector& descs, + std::vector& handles) override { + handles.clear(); + handles.reserve(descs.size()); + for (std::size_t index = 0; index < descs.size(); ++index) { + handles.push_back(reinterpret_cast(static_cast(index) + + static_cast(1))); + } return Status::OK(); } std::vector UnregisterMemory(const std::vector&) override @@ -78,7 +84,11 @@ class StubTransProvider : public TransProvider { return Status::OK(); } std::vector FreeThread(const std::vector&) override { return {}; } - Status GetMemTokenId(MemHandle, uint32_t&) override { return Status::OK(); } + Status GetMemTokenId(MemHandle, uint32_t& tokenId) override + { + tokenId = 1; + return Status::OK(); + } }; AsuEndpoint MakeEndpoint(const std::string& ip = "10.0.0.1") diff --git a/ucm/transport/kv/asu/trans/include/asu_transport/types.h b/ucm/transport/kv/asu/trans/include/asu_transport/types.h index 8a512e9f7..2b6702e56 100644 --- a/ucm/transport/kv/asu/trans/include/asu_transport/types.h +++ b/ucm/transport/kv/asu/trans/include/asu_transport/types.h @@ -36,7 +36,7 @@ using MRHandle = std::uint64_t; using CacheKey = std::string; using AsuId = std::uint64_t; -enum class TransProviderType { AICPU }; +enum class TransProviderType { AICPU, FAKE, AIV, UNSUPPORTED }; constexpr TaskId kInvalidTaskId = 0; constexpr MRHandle kInvalidMRHandle = 0; diff --git a/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.cpp b/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.cpp deleted file mode 100644 index 2a98032f0..000000000 --- a/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "aicpu_trans_provider.h" -#include - -namespace UC::ASU { -namespace { - -std::atomic g_sendHook{nullptr}; - -} // namespace - -void SetAICPUTransProviderSendHook(AICPUTransProviderSendHook hook) -{ - // Temporary hook for the kv-test fake_backend phase. Production AICPU sends keep the default - // provider behavior when no hook is registered. - g_sendHook.store(hook, std::memory_order_release); -} - -AICPUTransProviderSendHook GetAICPUTransProviderSendHook() -{ - return g_sendHook.load(std::memory_order_acquire); -} - -} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.h b/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.h index 1cab1f296..2502f8c33 100644 --- a/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.h +++ b/ucm/transport/kv/asu/trans/src/aicpu_trans_provider.h @@ -4,24 +4,11 @@ namespace UC::ASU { -using AICPUTransProviderSendHook = - std::vector (*)(const std::vector& ioBatches, - uint32_t kernelCount, uint32_t quietCount); - -void SetAICPUTransProviderSendHook(AICPUTransProviderSendHook hook); -AICPUTransProviderSendHook GetAICPUTransProviderSendHook(); - class AICPUTransProvider : public TransProvider { public: - Status CreateConnection(const std::string&, const std::string&, uint32_t, uint32_t qpNum, - uint32_t, std::vector& handles) override + Status CreateConnection(const std::string&, const std::string&, uint32_t, uint32_t, uint32_t, + std::vector&) override { - handles.clear(); - handles.reserve(qpNum); - for (uint32_t index = 0; index < qpNum; ++index) { - handles.push_back(reinterpret_cast( - static_cast(index) + static_cast(1))); - } return Status::OK(); } @@ -30,11 +17,8 @@ class AICPUTransProvider : public TransProvider { return std::vector(handles.size(), Status::OK()); } - std::vector Send(const std::vector& ioBatches, - uint32_t kernelCount, uint32_t quietCount) override + std::vector Send(const std::vector& ioBatches, uint32_t, uint32_t) override { - auto hook = GetAICPUTransProviderSendHook(); - if (hook != nullptr) { return hook(ioBatches, kernelCount, quietCount); } return std::vector(ioBatches.size(), Status::OK()); } diff --git a/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp b/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp index 657e3b912..dba478430 100644 --- a/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp +++ b/ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp @@ -32,6 +32,7 @@ #include "asu_transport/asu_transport.h" #include "connection_internal.h" #include "connection_manager.h" +#include "fake_trans_provider.h" #include "logger.h" #include "transport_config_parser.h" @@ -62,6 +63,16 @@ Status AsuTransportImpl::Init(const TransportConfig& config) case TransProviderType::AICPU: transProvider_ = std::make_unique(); break; + case TransProviderType::FAKE: + transProvider_ = + std::make_unique(MakeFakeTransProviderConfig(config_)); + break; + case TransProviderType::AIV: + return Status::Error(StatusCode::UNSUPPORTED, + "AIV trans provider is not implemented"); + case TransProviderType::UNSUPPORTED: + return Status::Error(StatusCode::UNSUPPORTED, + "ASU trans provider backend is not supported"); } } if (!transProvider_) { diff --git a/ucm/transport/kv/asu/trans/src/fake_trans_provider.cpp b/ucm/transport/kv/asu/trans/src/fake_trans_provider.cpp new file mode 100644 index 000000000..87e79d5d2 --- /dev/null +++ b/ucm/transport/kv/asu/trans/src/fake_trans_provider.cpp @@ -0,0 +1,461 @@ +/** + * MIT License + * + * Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#include "fake_trans_provider.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "kv_protocol.h" +#include "logger.h" + +namespace UC::ASU { +namespace { + +constexpr std::uint16_t kCqeSuccess = 0x000; +constexpr std::uint16_t kCqeCheckResultBuffer = 0x732; +constexpr std::uint8_t kBatchEntryOk = 0x0; +constexpr std::uint8_t kBatchEntryKeyNotFound = 0x3; +constexpr std::uint8_t kDeleteEntryOk = 0x0; +constexpr std::uint8_t kDeleteEntryFailed = 0x1; +constexpr std::uint8_t kExistEntryNotExist = 0x0; +constexpr std::uint8_t kExistEntryExist = 0x1; + +std::uint64_t ReadU64(std::uint32_t low, std::uint32_t high) +{ + return static_cast(low) | (static_cast(high) << 32); +} + +std::uint32_t RequestCid(const std::uint32_t* request) { return request[0] >> 16; } + +AsuId RequestAsuId(const std::uint32_t* request) { return request[1]; } + +KvOpcode RequestOpcode(const std::uint32_t* request) +{ + return static_cast(request[0] & 0xFF); +} + +std::string ReadKey(const std::uint32_t* data) +{ + char key[17] = {}; + std::memcpy(key, data, 16); + const auto keyLen = std::find(key, key + 16, '\0') - key; + return std::string(key, static_cast(keyLen)); +} + +std::string KeyFileName(const std::string& key) +{ + std::uint64_t hash = 1469598103934665603ULL; + for (unsigned char ch : key) { + hash ^= ch; + hash *= 1099511628211ULL; + } + + std::ostringstream stream; + stream << std::hex << std::setw(16) << std::setfill('0') << hash << ".bin"; + return stream.str(); +} + +std::filesystem::path AsuRoot(const FakeTransProviderConfig& config, AsuId asuId) +{ + return std::filesystem::path(config.storePath) / ("asu-" + std::to_string(asuId)); +} + +std::filesystem::path KeyPath(const FakeTransProviderConfig& config, AsuId asuId, + const std::string& key) +{ + return AsuRoot(config, asuId) / KeyFileName(key); +} + +bool StoreBytes(const FakeTransProviderConfig& config, AsuId asuId, const std::string& key, + std::uint64_t addr, std::uint32_t length) +{ + std::vector buffer(length); + auto ret = aclrtMemcpy(buffer.data(), buffer.size(), reinterpret_cast(addr), + length, ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != ACL_SUCCESS) { + UC_ERROR( + "ASU fake backend device-to-host copy failed asuId={} key={} addr={} length={} " + "ret={}.", + asuId, key, addr, length, ret); + return false; + } + + std::filesystem::create_directories(AsuRoot(config, asuId)); + std::ofstream output(KeyPath(config, asuId, key), std::ios::binary | std::ios::trunc); + if (!output) { + UC_ERROR("ASU fake backend failed to open store file asuId={} key={} path={}.", asuId, key, + KeyPath(config, asuId, key).string()); + return false; + } + output.write(buffer.data(), static_cast(buffer.size())); + return output.good(); +} + +bool LoadBytes(const FakeTransProviderConfig& config, AsuId asuId, const std::string& key, + std::uint64_t addr, std::uint32_t length) +{ + std::ifstream input(KeyPath(config, asuId, key), std::ios::binary); + if (!input) { + UC_ERROR("ASU fake backend failed to open load file asuId={} key={} path={}.", asuId, key, + KeyPath(config, asuId, key).string()); + return false; + } + std::vector buffer(length, 0); + input.read(buffer.data(), static_cast(buffer.size())); + const auto readCount = input.gcount(); + if (readCount < static_cast(length)) { + std::fill(buffer.begin() + readCount, buffer.end(), 0); + } + auto ret = aclrtMemcpy(reinterpret_cast(addr), length, buffer.data(), buffer.size(), + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_SUCCESS) { + UC_ERROR( + "ASU fake backend host-to-device copy failed asuId={} key={} addr={} length={} " + "ret={}.", + asuId, key, addr, length, ret); + return false; + } + return true; +} + +bool DeleteKey(const FakeTransProviderConfig& config, AsuId asuId, const std::string& key) +{ + std::error_code errorCode; + std::filesystem::remove(KeyPath(config, asuId, key), errorCode); + return !errorCode; +} + +bool ExistsKey(const FakeTransProviderConfig& config, AsuId asuId, const std::string& key) +{ + std::error_code errorCode; + return std::filesystem::exists(KeyPath(config, asuId, key), errorCode); +} + +void PackCqeHeader(std::uint32_t* flagBuffer, std::uint16_t cid, std::uint16_t status) +{ + flagBuffer[0] = 0; + flagBuffer[1] = 0; + flagBuffer[2] = 0; + flagBuffer[3] = static_cast(cid) | (static_cast(status) << 17); +} + +void PackResultBuffer4Bit(std::uint32_t* resultData, const std::vector& results) +{ + const auto dwordCount = (results.size() + 7) / 8; + std::fill(resultData, resultData + dwordCount, 0); + for (std::size_t index = 0; index < results.size(); ++index) { + resultData[index / 8] |= static_cast(results[index] & 0xF) + << ((index % 8) * 4); + } +} + +void PackResultBuffer1Bit(std::uint32_t* resultData, const std::vector& results) +{ + const auto dwordCount = (results.size() + 31) / 32; + std::fill(resultData, resultData + dwordCount, 0); + for (std::size_t index = 0; index < results.size(); ++index) { + resultData[index / 32] |= static_cast(results[index] & 0x1) << (index % 32); + } +} + +struct BatchEntry { + std::string key; + std::uint64_t bufferAddr{0}; + std::uint32_t length{0}; +}; + +std::vector ReadBatchEntries(const std::uint32_t* request, std::uint16_t batchNumber) +{ + std::vector entries; + entries.reserve(batchNumber); + for (std::uint16_t index = 0; index < batchNumber; ++index) { + const auto* entry = request + kSqeDwordCount + index * kBatchEntryDwordCount; + BatchEntry parsed; + parsed.key = ReadKey(entry + 1); + parsed.bufferAddr = ReadU64(entry[5], entry[6]); + parsed.length = entry[7] & 0xFFFFFF; + entries.emplace_back(std::move(parsed)); + } + return entries; +} + +std::vector ReadKeyEntries(const std::uint32_t* request, std::uint16_t batchNumber) +{ + std::vector keys; + keys.reserve(batchNumber); + for (std::uint16_t index = 0; index < batchNumber; ++index) { + const auto* entry = request + kSqeDwordCount + index * kKeyEntryDwordCount; + keys.emplace_back(ReadKey(entry)); + } + return keys; +} + +Status CompleteBatchStore(const FakeTransProviderConfig& config, AsuId asuId, + const std::uint32_t* request) +{ + const auto cid = static_cast(RequestCid(request)); + const auto responseBufferAddr = ReadU64(request[3], request[4]); + const auto batchNumber = static_cast(request[10] & 0xFFFF); + auto* flagBuffer = reinterpret_cast(responseBufferAddr); + std::vector results(batchNumber, kBatchEntryOk); + + const auto entries = ReadBatchEntries(request, batchNumber); + for (std::size_t index = 0; index < entries.size(); ++index) { + const auto& entry = entries[index]; + if (!StoreBytes(config, asuId, entry.key, entry.bufferAddr, entry.length)) { + results[index] = kBatchEntryKeyNotFound; + } + } + + const auto allOk = std::all_of(results.begin(), results.end(), + [](std::uint8_t result) { return result == kBatchEntryOk; }); + const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; + PackCqeHeader(flagBuffer, cid, cqeStatus); + if (!allOk) { PackResultBuffer4Bit(flagBuffer + kCqeDwordCount, results); } + return Status::OK(); +} + +Status CompleteBatchRetrieve(const FakeTransProviderConfig& config, AsuId asuId, + const std::uint32_t* request) +{ + const auto cid = static_cast(RequestCid(request)); + const auto responseBufferAddr = ReadU64(request[3], request[4]); + const auto batchNumber = static_cast(request[10] & 0xFFFF); + auto* flagBuffer = reinterpret_cast(responseBufferAddr); + std::vector results(batchNumber, kBatchEntryOk); + + const auto entries = ReadBatchEntries(request, batchNumber); + for (std::size_t index = 0; index < entries.size(); ++index) { + const auto& entry = entries[index]; + if (!LoadBytes(config, asuId, entry.key, entry.bufferAddr, entry.length)) { + results[index] = kBatchEntryKeyNotFound; + } + } + + const auto allOk = std::all_of(results.begin(), results.end(), + [](std::uint8_t result) { return result == kBatchEntryOk; }); + const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; + PackCqeHeader(flagBuffer, cid, cqeStatus); + if (!allOk) { PackResultBuffer4Bit(flagBuffer + kCqeDwordCount, results); } + return Status::OK(); +} + +Status CompleteDelete(const FakeTransProviderConfig& config, AsuId asuId, + const std::uint32_t* request) +{ + const auto cid = static_cast(RequestCid(request)); + const auto responseBufferAddr = ReadU64(request[3], request[4]); + const auto batchNumber = static_cast(request[10] & 0xFFFF); + auto* flagBuffer = reinterpret_cast(responseBufferAddr); + std::vector results(batchNumber, kDeleteEntryOk); + + const auto keys = ReadKeyEntries(request, batchNumber); + for (std::size_t index = 0; index < keys.size(); ++index) { + if (!DeleteKey(config, asuId, keys[index])) { results[index] = kDeleteEntryFailed; } + } + + const auto allOk = std::all_of(results.begin(), results.end(), + [](std::uint8_t result) { return result == kDeleteEntryOk; }); + const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; + PackCqeHeader(flagBuffer, cid, cqeStatus); + if (!allOk) { PackResultBuffer1Bit(flagBuffer + kCqeDwordCount, results); } + return Status::OK(); +} + +Status CompleteExist(const FakeTransProviderConfig& config, AsuId asuId, + const std::uint32_t* request) +{ + const auto cid = static_cast(RequestCid(request)); + const auto responseBufferAddr = ReadU64(request[3], request[4]); + const auto batchNumber = static_cast(request[10] & 0xFFFF); + auto* flagBuffer = reinterpret_cast(responseBufferAddr); + std::vector results(batchNumber, kExistEntryNotExist); + std::uint16_t existingKeyNumber = 0; + + const auto keys = ReadKeyEntries(request, batchNumber); + for (std::size_t index = 0; index < keys.size(); ++index) { + if (ExistsKey(config, asuId, keys[index])) { + results[index] = kExistEntryExist; + ++existingKeyNumber; + } + } + + const auto allExist = existingKeyNumber == batchNumber; + const auto cqeStatus = allExist ? kCqeSuccess : kCqeCheckResultBuffer; + PackCqeHeader(flagBuffer, cid, cqeStatus); + flagBuffer[0] = existingKeyNumber; + if (!allExist) { PackResultBuffer1Bit(flagBuffer + kCqeDwordCount, results); } + return Status::OK(); +} + +Status CompleteFakeBackendRequest(const FakeTransProviderConfig& config, const void* sendBuffer, + std::uint64_t len) +{ + if (config.latencyMs > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(config.latencyMs)); + } + + if (sendBuffer == nullptr || len < sizeof(std::uint32_t)) { + return Status::Error(StatusCode::INVALID_ARGUMENT, "fake backend send buffer is empty"); + } + + const auto* request = reinterpret_cast(sendBuffer); + const auto asuId = RequestAsuId(request); + switch (RequestOpcode(request)) { + case KvOpcode::BatchStore: return CompleteBatchStore(config, asuId, request); + case KvOpcode::BatchRetrieve: return CompleteBatchRetrieve(config, asuId, request); + case KvOpcode::Delete: return CompleteDelete(config, asuId, request); + case KvOpcode::Exist: return CompleteExist(config, asuId, request); + case KvOpcode::KeepAlive: { + auto* flagBuffer = reinterpret_cast(ReadU64(request[3], request[4])); + PackCqeHeader(flagBuffer, static_cast(RequestCid(request)), kCqeSuccess); + return Status::OK(); + } + default: + return Status::Error(StatusCode::UNSUPPORTED, + "fake backend only supports batch ASU operations"); + } +} + +} // namespace + +FakeTransProviderConfig MakeFakeTransProviderConfig(const TransportConfig& config) +{ + FakeTransProviderConfig fakeConfig; + fakeConfig.deviceId = config.endpoints.empty() ? 0 : config.endpoints.front().deviceId; + auto pathIter = config.attrs.find("fake_backend.path"); + if (pathIter != config.attrs.end() && !pathIter->second.empty()) { + fakeConfig.storePath = pathIter->second; + } + auto latencyIter = config.attrs.find("fake_backend.latency_ms"); + if (latencyIter != config.attrs.end()) { + fakeConfig.latencyMs = static_cast(std::stoull(latencyIter->second)); + } + auto deviceIter = config.attrs.find("fake_backend.device_id"); + if (deviceIter != config.attrs.end()) { + fakeConfig.deviceId = static_cast(std::stol(deviceIter->second)); + } + return fakeConfig; +} + +FakeTransProvider::FakeTransProvider(FakeTransProviderConfig config) : config_(std::move(config)) {} + +Status FakeTransProvider::SetUpAclRuntime() +{ + if (aclReady_) { return Status::OK(); } + auto ret = aclInit(nullptr); + if (ret != ACL_SUCCESS && ret != ACL_ERROR_REPEAT_INITIALIZE) { + return Status::Error(StatusCode::INTERNAL_ERROR, + "ASU fake backend aclInit failed: " + std::to_string(ret)); + } + + const auto deviceId = config_.deviceId < 0 ? 0 : config_.deviceId; + ret = aclrtSetDevice(deviceId); + if (ret != ACL_SUCCESS) { + return Status::Error(StatusCode::INTERNAL_ERROR, + "ASU fake backend aclrtSetDevice failed: device_id=" + + std::to_string(deviceId) + " ret=" + std::to_string(ret)); + } + aclReady_ = true; + return Status::OK(); +} + +Status FakeTransProvider::CreateConnection(const std::string&, const std::string&, uint32_t, + uint32_t qpNum, uint32_t, + std::vector& handles) +{ + auto status = SetUpAclRuntime(); + if (!status.ok()) { return status; } + handles.clear(); + handles.reserve(qpNum); + for (uint32_t index = 0; index < qpNum; ++index) { + handles.push_back(reinterpret_cast(static_cast(index) + + static_cast(1))); + } + return Status::OK(); +} + +std::vector FakeTransProvider::DeleteConnections( + const std::vector& handles) +{ + return std::vector(handles.size(), Status::OK()); +} + +std::vector FakeTransProvider::Send(const std::vector& ioBatches, + uint32_t kernelCount, uint32_t quietCount) +{ + (void)kernelCount; + (void)quietCount; + std::vector statuses; + statuses.reserve(ioBatches.size()); + for (const auto& ioBatch : ioBatches) { + statuses.emplace_back(CompleteFakeBackendRequest(config_, ioBatch.sendBuffer, ioBatch.len)); + } + return statuses; +} + +Status FakeTransProvider::RegisterMemory(ConnectionHandle, + const std::vector& memoryDescs, + std::vector& memoryHandles) +{ + memoryHandles.clear(); + memoryHandles.reserve(memoryDescs.size()); + for (std::size_t index = 0; index < memoryDescs.size(); ++index) { + memoryHandles.push_back(reinterpret_cast(static_cast(index) + + static_cast(1))); + } + return Status::OK(); +} + +std::vector FakeTransProvider::UnregisterMemory( + const std::vector& handles) +{ + return std::vector(handles.size(), Status::OK()); +} + +Status FakeTransProvider::AllocThread(uint32_t, const std::vector&, + std::vector&) +{ + return Status::OK(); +} + +std::vector FakeTransProvider::FreeThread(const std::vector& threads) +{ + return std::vector(threads.size(), Status::OK()); +} + +Status FakeTransProvider::GetMemTokenId(MemHandle, uint32_t& tokenId) +{ + tokenId = 1; + return Status::OK(); +} + +} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/fake_trans_provider.h b/ucm/transport/kv/asu/trans/src/fake_trans_provider.h new file mode 100644 index 000000000..b287e50ba --- /dev/null +++ b/ucm/transport/kv/asu/trans/src/fake_trans_provider.h @@ -0,0 +1,69 @@ +/** + * MIT License + * + * Copyright (c) 2026 Huawei Technologies Co., Ltd. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * */ +#pragma once + +#include "asu_transport/asu_transport.h" +#include "trans_provider.h" + +namespace UC::ASU { + +struct FakeTransProviderConfig { + std::string storePath{"./asu-fake-backend-store"}; + std::uint64_t latencyMs{1}; + std::int32_t deviceId{0}; +}; + +class FakeTransProvider : public TransProvider { +public: + explicit FakeTransProvider(FakeTransProviderConfig config); + + Status CreateConnection(const std::string&, const std::string&, uint32_t, uint32_t qpNum, + uint32_t, std::vector& handles) override; + + std::vector DeleteConnections(const std::vector& handles) override; + + std::vector Send(const std::vector& ioBatches, uint32_t kernelCount, + uint32_t quietCount) override; + + Status RegisterMemory(ConnectionHandle, const std::vector& memoryDescs, + std::vector& memoryHandles) override; + + std::vector UnregisterMemory(const std::vector& handles) override; + + Status AllocThread(uint32_t, const std::vector&, std::vector&) override; + + std::vector FreeThread(const std::vector& threads) override; + + Status GetMemTokenId(MemHandle, uint32_t& tokenId) override; + +private: + Status SetUpAclRuntime(); + + FakeTransProviderConfig config_; + bool aclReady_{false}; +}; + +FakeTransProviderConfig MakeFakeTransProviderConfig(const TransportConfig& config); + +} // namespace UC::ASU diff --git a/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp b/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp index 339d881c2..4df5ae37d 100644 --- a/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp +++ b/ucm/transport/kv/asu/trans/src/transport_config_parser.cpp @@ -163,6 +163,8 @@ Status LoadTransportConfig(const std::string& configPath, TransportConfig& confi continue; } else if (ApplyTransportIoNumConfigField(config, key, value)) { continue; + } else if (ApplyTransportProviderConfigField(config, key, value)) { + continue; } else { config.attrs[key] = value; } diff --git a/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp b/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp index 92b3406d1..4dd4a581c 100644 --- a/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp +++ b/ucm/transport/kv/asu/trans/test/transport_task_completion_test.cpp @@ -23,8 +23,10 @@ * */ #include #include +#include #include #include +#include #define private public #include "asu_transport_impl.h" #undef private diff --git a/ucm/transport/kv/kv-test/include/kv_test/kv_test_types.h b/ucm/transport/kv/kv-test/include/kv_test/kv_test_types.h index a029833fc..faec65404 100644 --- a/ucm/transport/kv/kv-test/include/kv_test/kv_test_types.h +++ b/ucm/transport/kv/kv-test/include/kv_test/kv_test_types.h @@ -141,7 +141,7 @@ struct OutputConfig { std::uint64_t realtimeFileMaxBytes{0}; }; -struct FakeBackendConfig { +struct KvTestFakeBackendConfig { std::string storePath; std::uint64_t latencyMs{1}; }; @@ -171,7 +171,7 @@ struct KvTestConfig { HcommProtocolMapping hcommProtocolMapping; BenchConfig bench; OutputConfig output; - FakeBackendConfig fakeBackend; + KvTestFakeBackendConfig fakeBackend; std::string asuClientMode; std::string localStorePath; std::string keyPrefix; diff --git a/ucm/transport/kv/kv-test/src/fake_backend.cpp b/ucm/transport/kv/kv-test/src/fake_backend.cpp index c163320fb..15a919c91 100644 --- a/ucm/transport/kv/kv-test/src/fake_backend.cpp +++ b/ucm/transport/kv/kv-test/src/fake_backend.cpp @@ -2,47 +2,15 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "aicpu_trans_provider.h" -#include "buffer_manager.h" -#include "connection_manager.h" -#include "kv_protocol.h" -#include "trans_provider.h" - -namespace UC::ASU { - -std::vector MockSend(const std::vector& ioBatches, - std::uint32_t kernelCount, std::uint32_t quietCount); - -} // namespace UC::ASU +#include +#include namespace UC::KVTest { namespace { -constexpr std::uint16_t kCqeSuccess = 0x000; -constexpr std::uint16_t kCqeCheckResultBuffer = 0x732; -constexpr std::uint8_t kBatchEntryOk = 0x0; -constexpr std::uint8_t kBatchEntryKeyNotFound = 0x3; -constexpr std::uint8_t kDeleteEntryOk = 0x0; -constexpr std::uint8_t kDeleteEntryFailed = 0x1; -constexpr std::uint8_t kExistEntryNotExist = 0x0; -constexpr std::uint8_t kExistEntryExist = 0x1; constexpr int kExitInvalidArgument = 1; constexpr int kFakeBackendAclDeviceId = 0; -std::mutex g_fakeBackendMu; -std::mutex g_traceMu; -FakeBackendConfig g_fakeBackendConfig; -bool g_fakeBackendEnabled = false; - std::string NormalizeMode(std::string value) { std::transform(value.begin(), value.end(), value.begin(), @@ -50,369 +18,26 @@ std::string NormalizeMode(std::string value) return value; } -std::filesystem::path StoreRoot(const FakeBackendConfig& config) -{ - if (!config.storePath.empty()) { return config.storePath; } - return "./kv-test-fake-backend-store"; -} - -std::uint64_t ReadU64(std::uint32_t low, std::uint32_t high) -{ - return static_cast(low) | (static_cast(high) << 32); -} - -std::uint32_t RequestCid(const std::uint32_t* request) { return request[0] >> 16; } - -UC::ASU::AsuId RequestAsuId(const std::uint32_t* request) -{ - // kv-test fake backend temporarily reuses kv_ns_id as the ASU store namespace. - return request[1]; -} - -UC::ASU::KvOpcode RequestOpcode(const std::uint32_t* request) -{ - return static_cast(request[0] & 0xFF); -} - -const char* OpcodeName(UC::ASU::KvOpcode opcode) -{ - switch (opcode) { - case UC::ASU::KvOpcode::BatchStore: return "BatchStore"; - case UC::ASU::KvOpcode::BatchRetrieve: return "BatchRetrieve"; - case UC::ASU::KvOpcode::Delete: return "Delete"; - case UC::ASU::KvOpcode::Exist: return "Exist"; - case UC::ASU::KvOpcode::KeepAlive: return "KeepAlive"; - default: return "Unknown"; - } -} - -void TraceCompletion(UC::ASU::KvOpcode opcode, UC::ASU::AsuId asuId, std::uint16_t cid, - std::uint16_t status, bool resultBuffer, std::uint16_t batchNumber, - std::uint16_t existingKeyNumber = 0) -{ - const auto* tracePath = std::getenv("KV_TEST_FAKE_BACKEND_TRACE"); - if (tracePath == nullptr || tracePath[0] == '\0') { return; } - - std::lock_guard lock(g_traceMu); - std::ofstream trace(tracePath, std::ios::app); - if (!trace) { return; } - - trace << "opcode=" << OpcodeName(opcode) << " asu_id=" << asuId << " cid=" << cid - << " status=0x" << std::hex << std::setw(3) << std::setfill('0') << status << std::dec - << " result_buffer=" << (resultBuffer ? 1 : 0) << " batch_number=" << batchNumber - << " existing_key_number=" << existingKeyNumber << '\n'; -} - -std::string ReadKey(const std::uint32_t* data) -{ - char key[17] = {}; - std::memcpy(key, data, 16); - const auto keyLen = std::find(key, key + 16, '\0') - key; - return std::string(key, static_cast(keyLen)); -} - -std::string KeyFileName(const std::string& key) -{ - std::uint64_t hash = 1469598103934665603ULL; - for (unsigned char ch : key) { - hash ^= ch; - hash *= 1099511628211ULL; - } - - std::ostringstream stream; - stream << std::hex << std::setw(16) << std::setfill('0') << hash << ".bin"; - return stream.str(); -} - -std::filesystem::path AsuRoot(const FakeBackendConfig& config, UC::ASU::AsuId asuId) -{ - return StoreRoot(config) / ("asu-" + std::to_string(asuId)); -} - -std::filesystem::path KeyPath(const FakeBackendConfig& config, UC::ASU::AsuId asuId, - const std::string& key) -{ - return AsuRoot(config, asuId) / KeyFileName(key); -} - -bool StoreBytes(const FakeBackendConfig& config, UC::ASU::AsuId asuId, const std::string& key, - std::uint64_t addr, std::uint32_t length) -{ - std::filesystem::create_directories(AsuRoot(config, asuId)); - std::ofstream output(KeyPath(config, asuId, key), std::ios::binary | std::ios::trunc); - if (!output) { return false; } - output.write(reinterpret_cast(addr), length); - return output.good(); -} - -bool LoadBytes(const FakeBackendConfig& config, UC::ASU::AsuId asuId, const std::string& key, - std::uint64_t addr, std::uint32_t length) -{ - std::ifstream input(KeyPath(config, asuId, key), std::ios::binary); - if (!input) { return false; } - input.read(reinterpret_cast(addr), length); - const auto readCount = input.gcount(); - if (readCount < static_cast(length)) { - std::memset(reinterpret_cast(addr) + readCount, 0, - length - static_cast(readCount)); - } - return true; -} - -bool DeleteKey(const FakeBackendConfig& config, UC::ASU::AsuId asuId, const std::string& key) -{ - std::error_code errorCode; - std::filesystem::remove(KeyPath(config, asuId, key), errorCode); - // Delete result buffer uses 0 for success. A missing key is still a successful delete. - return !errorCode; -} - -bool ExistsKey(const FakeBackendConfig& config, UC::ASU::AsuId asuId, const std::string& key) -{ - std::error_code errorCode; - return std::filesystem::exists(KeyPath(config, asuId, key), errorCode); -} - -void PackCqeHeader(std::uint32_t* flagBuffer, std::uint16_t cid, std::uint16_t status) -{ - flagBuffer[0] = 0; - flagBuffer[1] = 0; - flagBuffer[2] = 0; - flagBuffer[3] = static_cast(cid) | (static_cast(status) << 17); -} - -void PackResultBuffer4Bit(std::uint32_t* resultData, const std::vector& results) -{ - const auto dwordCount = (results.size() + 7) / 8; - std::fill(resultData, resultData + dwordCount, 0); - for (std::size_t index = 0; index < results.size(); ++index) { - resultData[index / 8] |= static_cast(results[index] & 0xF) - << ((index % 8) * 4); - } -} - -void PackResultBuffer1Bit(std::uint32_t* resultData, const std::vector& results) -{ - const auto dwordCount = (results.size() + 31) / 32; - std::fill(resultData, resultData + dwordCount, 0); - for (std::size_t index = 0; index < results.size(); ++index) { - resultData[index / 32] |= static_cast(results[index] & 0x1) << (index % 32); - } -} - -struct BatchEntry { - std::string key; - std::uint64_t bufferAddr{0}; - std::uint32_t length{0}; -}; - -std::vector ReadBatchEntries(const std::uint32_t* request, std::uint16_t batchNumber) -{ - std::vector entries; - entries.reserve(batchNumber); - for (std::uint16_t index = 0; index < batchNumber; ++index) { - const auto* entry = - request + UC::ASU::kSqeDwordCount + index * UC::ASU::kBatchEntryDwordCount; - BatchEntry parsed; - parsed.key = ReadKey(entry + 1); - parsed.bufferAddr = ReadU64(entry[5], entry[6]); - parsed.length = entry[7] & 0xFFFFFF; - entries.emplace_back(std::move(parsed)); - } - return entries; -} - -std::vector ReadKeyEntries(const std::uint32_t* request, std::uint16_t batchNumber) -{ - std::vector keys; - keys.reserve(batchNumber); - for (std::uint16_t index = 0; index < batchNumber; ++index) { - const auto* entry = - request + UC::ASU::kSqeDwordCount + index * UC::ASU::kKeyEntryDwordCount; - keys.emplace_back(ReadKey(entry)); - } - return keys; -} - -UC::ASU::Status CompleteBatchStore(const FakeBackendConfig& config, UC::ASU::AsuId asuId, - const std::uint32_t* request) -{ - const auto cid = static_cast(RequestCid(request)); - const auto responseBufferAddr = ReadU64(request[3], request[4]); - const auto batchNumber = static_cast(request[10] & 0xFFFF); - auto* flagBuffer = reinterpret_cast(responseBufferAddr); - std::vector results(batchNumber, kBatchEntryOk); - - const auto entries = ReadBatchEntries(request, batchNumber); - for (std::size_t index = 0; index < entries.size(); ++index) { - const auto& entry = entries[index]; - if (!StoreBytes(config, asuId, entry.key, entry.bufferAddr, entry.length)) { - results[index] = kBatchEntryKeyNotFound; - } - } - - const auto allOk = std::all_of(results.begin(), results.end(), - [](std::uint8_t result) { return result == kBatchEntryOk; }); - const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; - PackCqeHeader(flagBuffer, cid, cqeStatus); - if (!allOk) { PackResultBuffer4Bit(flagBuffer + UC::ASU::kCqeDwordCount, results); } - TraceCompletion(UC::ASU::KvOpcode::BatchStore, asuId, cid, cqeStatus, !allOk, batchNumber); - return UC::ASU::Status::OK(); -} - -UC::ASU::Status CompleteBatchRetrieve(const FakeBackendConfig& config, UC::ASU::AsuId asuId, - const std::uint32_t* request) -{ - const auto cid = static_cast(RequestCid(request)); - const auto responseBufferAddr = ReadU64(request[3], request[4]); - const auto batchNumber = static_cast(request[10] & 0xFFFF); - auto* flagBuffer = reinterpret_cast(responseBufferAddr); - std::vector results(batchNumber, kBatchEntryOk); - - const auto entries = ReadBatchEntries(request, batchNumber); - for (std::size_t index = 0; index < entries.size(); ++index) { - const auto& entry = entries[index]; - if (!LoadBytes(config, asuId, entry.key, entry.bufferAddr, entry.length)) { - results[index] = kBatchEntryKeyNotFound; - } - } - - const auto allOk = std::all_of(results.begin(), results.end(), - [](std::uint8_t result) { return result == kBatchEntryOk; }); - const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; - PackCqeHeader(flagBuffer, cid, cqeStatus); - if (!allOk) { PackResultBuffer4Bit(flagBuffer + UC::ASU::kCqeDwordCount, results); } - TraceCompletion(UC::ASU::KvOpcode::BatchRetrieve, asuId, cid, cqeStatus, !allOk, batchNumber); - return UC::ASU::Status::OK(); -} - -UC::ASU::Status CompleteDelete(const FakeBackendConfig& config, UC::ASU::AsuId asuId, - const std::uint32_t* request) -{ - const auto cid = static_cast(RequestCid(request)); - const auto responseBufferAddr = ReadU64(request[3], request[4]); - const auto batchNumber = static_cast(request[10] & 0xFFFF); - auto* flagBuffer = reinterpret_cast(responseBufferAddr); - std::vector results(batchNumber, kDeleteEntryOk); - - const auto keys = ReadKeyEntries(request, batchNumber); - for (std::size_t index = 0; index < keys.size(); ++index) { - if (!DeleteKey(config, asuId, keys[index])) { results[index] = kDeleteEntryFailed; } - } - - const auto allOk = std::all_of(results.begin(), results.end(), - [](std::uint8_t result) { return result == kDeleteEntryOk; }); - const auto cqeStatus = allOk ? kCqeSuccess : kCqeCheckResultBuffer; - PackCqeHeader(flagBuffer, cid, cqeStatus); - if (!allOk) { PackResultBuffer1Bit(flagBuffer + UC::ASU::kCqeDwordCount, results); } - TraceCompletion(UC::ASU::KvOpcode::Delete, asuId, cid, cqeStatus, !allOk, batchNumber); - return UC::ASU::Status::OK(); -} - -UC::ASU::Status CompleteExist(const FakeBackendConfig& config, UC::ASU::AsuId asuId, - const std::uint32_t* request) -{ - const auto cid = static_cast(RequestCid(request)); - const auto responseBufferAddr = ReadU64(request[3], request[4]); - const auto batchNumber = static_cast(request[10] & 0xFFFF); - auto* flagBuffer = reinterpret_cast(responseBufferAddr); - std::vector results(batchNumber, kExistEntryNotExist); - std::uint16_t existingKeyNumber = 0; - - const auto keys = ReadKeyEntries(request, batchNumber); - for (std::size_t index = 0; index < keys.size(); ++index) { - if (!ExistsKey(config, asuId, keys[index])) { continue; } - results[index] = kExistEntryExist; - ++existingKeyNumber; - } - - const auto allExist = std::all_of(results.begin(), results.end(), [](std::uint8_t result) { - return result == kExistEntryExist; - }); - const auto cqeStatus = allExist ? kCqeSuccess : kCqeCheckResultBuffer; - PackCqeHeader(flagBuffer, cid, cqeStatus); - flagBuffer[0] = existingKeyNumber; - if (!allExist) { PackResultBuffer1Bit(flagBuffer + UC::ASU::kCqeDwordCount, results); } - TraceCompletion(UC::ASU::KvOpcode::Exist, asuId, cid, cqeStatus, !allExist, batchNumber, - existingKeyNumber); - return UC::ASU::Status::OK(); -} - -UC::ASU::Status CompleteFakeBackendRequest(FakeBackendConfig config, const void* sendBuffer, - std::uint64_t len) -{ - if (config.latencyMs > 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(config.latencyMs)); - } - - if (sendBuffer == nullptr || len < sizeof(std::uint32_t)) { - return UC::ASU::Status::Error(UC::ASU::StatusCode::INVALID_ARGUMENT, - "fake backend send buffer is empty"); - } - - const auto* request = reinterpret_cast(sendBuffer); - const auto asuId = RequestAsuId(request); - switch (RequestOpcode(request)) { - case UC::ASU::KvOpcode::BatchStore: return CompleteBatchStore(config, asuId, request); - case UC::ASU::KvOpcode::BatchRetrieve: return CompleteBatchRetrieve(config, asuId, request); - case UC::ASU::KvOpcode::Delete: return CompleteDelete(config, asuId, request); - case UC::ASU::KvOpcode::Exist: return CompleteExist(config, asuId, request); - case UC::ASU::KvOpcode::KeepAlive: { - auto* flagBuffer = reinterpret_cast(ReadU64(request[3], request[4])); - PackCqeHeader(flagBuffer, static_cast(RequestCid(request)), kCqeSuccess); - TraceCompletion(UC::ASU::KvOpcode::KeepAlive, asuId, - static_cast(RequestCid(request)), kCqeSuccess, false, 0); - return UC::ASU::Status::OK(); - } - default: - return UC::ASU::Status::Error(UC::ASU::StatusCode::UNSUPPORTED, - "fake backend only supports batch ASU operations"); - } -} - -FakeBackendConfig GetFakeBackendConfig(bool& enabled) -{ - std::lock_guard lock(g_fakeBackendMu); - enabled = g_fakeBackendEnabled; - return g_fakeBackendConfig; -} - -void SetFakeBackendConfig(FakeBackendConfig config) -{ - { - std::lock_guard lock(g_fakeBackendMu); - g_fakeBackendConfig = std::move(config); - g_fakeBackendEnabled = true; - } - UC::ASU::SetAICPUTransProviderSendHook(&UC::ASU::MockSend); -} - -void DisableFakeBackend() -{ - UC::ASU::SetAICPUTransProviderSendHook(nullptr); - { - std::lock_guard lock(g_fakeBackendMu); - g_fakeBackendConfig = FakeBackendConfig{}; - g_fakeBackendEnabled = false; - } -} - -void PatchTransportConfig(UC::ASU::TransportConfig& config) +void PatchTransportConfig(UC::ASU::TransportConfig& config, + const KvTestFakeBackendConfig& fakeConfig) { + config.providerType = UC::ASU::TransProviderType::FAKE; config.attrs.try_emplace("kernel_count", "1"); config.attrs.try_emplace("quiet_count", "1"); - // kv-test fake backend has no direct TransportConfig context in Send, so this temporary - // test-only mapping lets the mock recover the ASU store namespace from the packed SQE. config.attrs["kv_ns_id"] = std::to_string(config.asuId); config.attrs.try_emplace("dtype", "0"); config.attrs.try_emplace("dspec", "0"); config.attrs.try_emplace("lr", "false"); config.attrs["sc"] = "true"; + config.attrs["fake_backend.path"] = fakeConfig.storePath; + config.attrs["fake_backend.latency_ms"] = std::to_string(fakeConfig.latencyMs); + config.attrs["fake_backend.device_id"] = std::to_string(kFakeBackendAclDeviceId); if (config.endpoints.empty()) { UC::ASU::AsuEndpoint endpoint; endpoint.ip = "fake_backend"; endpoint.port = 19001; endpoint.protocol = UC::ASU::Protocol::TCP; + endpoint.deviceId = kFakeBackendAclDeviceId; config.endpoints.emplace_back(std::move(endpoint)); } } @@ -465,16 +90,12 @@ bool IsFakeBackendMode(const KvTestConfig& config) void MaybePrepareFakeBackend(KvTestConfig& config) { - if (!IsFakeBackendMode(config)) { - DisableFakeBackend(); - return; - } + if (!IsFakeBackendMode(config)) { return; } if (config.fakeBackend.storePath.empty()) { config.fakeBackend.storePath = config.localStorePath.empty() ? "./kv-test-fake-backend-store" : config.localStorePath; } - SetFakeBackendConfig(config.fakeBackend); config.asuClientConfig.attrs.try_emplace("hash_table.type", "RING_HASH"); config.asuClientConfig.attrs.try_emplace("ring_hash.virtual_node_count", "128"); @@ -484,50 +105,8 @@ void MaybePrepareFakeBackend(KvTestConfig& config) config.asuClientConfig.transportConfigs.emplace_back(std::move(transportConfig)); } for (auto& transportConfig : config.asuClientConfig.transportConfigs) { - PatchTransportConfig(transportConfig); + PatchTransportConfig(transportConfig, config.fakeBackend); } } } // namespace UC::KVTest - -namespace UC::ASU { - -std::vector MockSend(const std::vector& ioBatches, - std::uint32_t kernelCount, std::uint32_t quietCount) -{ - (void)kernelCount; - (void)quietCount; - - bool enabled = false; - const auto config = UC::KVTest::GetFakeBackendConfig(enabled); - if (!enabled) { - return std::vector( - ioBatches.size(), - Status::Error(StatusCode::UNSUPPORTED, "kv-test fake backend Send is not enabled")); - } - - std::vector statuses; - statuses.reserve(ioBatches.size()); - for (const auto& ioBatch : ioBatches) { - if (ioBatch.sendBuffer == nullptr || ioBatch.len == 0) { - statuses.emplace_back( - Status::Error(StatusCode::INVALID_ARGUMENT, "fake backend send buffer is empty")); - continue; - } - - // kv-test fake backend temporarily completes the CQE before Send returns. The production - // path still observes completion through Transport polling, while the mock avoids detached - // threads racing with sub-batch buffer lifetime in multi sub-batch tests. - statuses.emplace_back( - UC::KVTest::CompleteFakeBackendRequest(config, ioBatch.sendBuffer, ioBatch.len)); - } - return statuses; -} - -std::vector Send(const std::vector& ioBatches, - std::uint32_t kernelCount, std::uint32_t quietCount) -{ - return MockSend(ioBatches, kernelCount, quietCount); -} - -} // namespace UC::ASU diff --git a/ucm/transport/kv/kv-test/src/kv_test_config_loader.cpp b/ucm/transport/kv/kv-test/src/kv_test_config_loader.cpp index 581de7a60..1a3ca67ba 100644 --- a/ucm/transport/kv/kv-test/src/kv_test_config_loader.cpp +++ b/ucm/transport/kv/kv-test/src/kv_test_config_loader.cpp @@ -167,7 +167,7 @@ Status KvTestConfigLoader::Load(const std::string& configPath, KvTestConfig& con config.hcommProtocolMapping = HcommProtocolMapping{}; config.bench = BenchConfig{}; config.output = OutputConfig{}; - config.fakeBackend = FakeBackendConfig{}; + config.fakeBackend = KvTestFakeBackendConfig{}; config.asuClientMode.clear(); config.localStorePath.clear(); config.keyPrefix.clear();