Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/ucm_config_asu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# UCM ASU fake backend example for vLLM / vLLM-Ascend software integration tests.
Comment thread
Fengli5355 marked this conversation as resolved.
#
# Use with:
# kv_connector_extra_config={"UCM_CONFIG_FILE": "/workspace/unified-cache-management/examples/ucm_config_asu.yaml"}

ucm_connectors:
- ucm_connector_name: "UcmPipelineStore"
ucm_connector_config:
store_pipeline: "ASU"
asu_mode: "client"
asu_client_id: "ucm-vllm-asu-fake"
asu_ids: [1, 2]
asu_trans_provider_backend: "fake"

# ASU client routing. Keep the default close to kv-test's fake_backend config.
asu_router_type: "RING_HASH"
asu_ring_hash_virtual_node_count: 128

# Local software backend. This uses normal AsuClient + AsuTransportImpl submit/CQE paths
# and replaces the missing device Send/backend execution with local files.
asu_fake_backend_path: "./asu-fake-backend-store"
asu_fake_backend_latency_ms: 1

# ASU wait and operation timeouts.
asu_default_wait_timeout_ms: 5000
asu_query_timeout_ms: 5000
asu_load_timeout_ms: 5000
asu_store_timeout_ms: 5000
asu_max_inflight_tasks: 1024

enable_event_sync: true
use_layerwise: false
persist_token_threshold: 0
4 changes: 3 additions & 1 deletion scripts/build_ascend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function build_wheels()

cd ${KVCACHE_PROJECT_ROOT}
export ENABLE_SPARSE=${ENABLE_SPARSE:-TRUE}
export BUILD_UCM_ASU=${BUILD_UCM_ASU:-ON}
python3 -m build --no-isolation --wheel

if [ $? -eq 0 ]; then
Expand Down Expand Up @@ -105,6 +106,7 @@ function collect_artifacts()
cp -r "${KVCACHE_PROJECT_ROOT}/docker" .
cp -r "${KVCACHE_PROJECT_ROOT}/examples/deployments" .
cp -r "${KVCACHE_PROJECT_ROOT}/examples/ucm_config_example.yaml" ucm_config.yaml
cp -r "${KVCACHE_PROJECT_ROOT}/examples/ucm_config_asu_fake_backend.yaml" ucm_config_asu_fake_backend.yaml
cp -r "${KVCACHE_PROJECT_ROOT}/examples/metrics/metrics_configs.yaml" metrics_configs.yaml
cp -r "${KVCACHE_PROJECT_ROOT}/test" .
cp "${KVCACHE_PROJECT_ROOT}/ucm/integration/vllm/patch/0.11.0/vllm-adapt.patch" .
Expand All @@ -124,4 +126,4 @@ function package_all()

check_build_install
build_wheels
package_all
package_all
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ROOT_DIR = os.path.abspath(os.path.dirname(__file__))
PLATFORM = os.getenv("PLATFORM")
ENABLE_SPARSE = os.getenv("ENABLE_SPARSE")
BUILD_UCM_ASU = os.getenv("BUILD_UCM_ASU", "0") not in ("", "0", "false", "False")
ENABLE_MINDIE = os.getenv("UCM_ENABLE_MINDIE", "0") not in ("", "0", "false", "False")


Expand Down Expand Up @@ -166,6 +167,8 @@ def build_cmake(self, ext: CMakeExtension):

if enable_sparse():
cmake_args += ["-DBUILD_UCM_SPARSE=ON"]
if BUILD_UCM_ASU:
cmake_args += ["-DBUILD_UCM_ASU=ON"]

match PLATFORM:
case "cuda":
Expand Down
3 changes: 3 additions & 0 deletions ucm/integration/vllm/ucm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def _create_store(
* (1 if self.is_mla else self.num_head * 2)
* self.blocks_per_chunk
)
if config.get("store_pipeline") == "ASU":
config["shard_size"] = config["block_size"]
config["tensor_size"] = config["block_size"]
dp_rank = self._vllm_config.parallel_config.data_parallel_rank
config["posix_gc_enable"] = (
self._role != KVConnectorRole.WORKER and dp_rank == 0
Expand Down
7 changes: 6 additions & 1 deletion ucm/store/asu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ target_include_directories(asustore
PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/cc
)
target_link_libraries(asustore PUBLIC storeintf asu_client infra_logger)
target_link_libraries(asustore PUBLIC storeintf asu_client infra_logger PRIVATE asu_ascend_deps)
set_target_properties(asustore PROPERTIES
BUILD_WITH_INSTALL_RPATH TRUE
BUILD_RPATH "$ORIGIN;$ORIGIN/../../transport/kv/asu"
INSTALL_RPATH "$ORIGIN;$ORIGIN/../../transport/kv/asu"
)

file(RELATIVE_PATH INSTALL_REL_PATH ${UCM_ROOT_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
install(TARGETS asustore LIBRARY DESTINATION ${INSTALL_REL_PATH} COMPONENT ucm)
158 changes: 141 additions & 17 deletions ucm/store/asu/cc/asu_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@
* SOFTWARE.
* */
#include "asu_store.h"
#include <acl/acl.h>
#include <algorithm>
#include <any>
#include <cctype>
#include <cstddef>
#include <functional>
#include <iomanip>
#include <memory>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>
#include "asu_client/asu_client.h"
Expand All @@ -42,13 +46,16 @@ namespace {
using AsuStatus = UC::ASU::Status;
using AsuStatusCode = UC::ASU::StatusCode;

std::string ToHex(const Detail::BlockId& block)
std::uint64_t HashAsuKey(const Detail::BlockId& block)
{
static Detail::BlockIdHasher hasher;
return static_cast<std::uint64_t>(hasher(block));
}

std::string MakeAsuKey(const Detail::BlockId& block)
{
std::ostringstream os;
os << std::hex << std::setfill('0');
for (auto b : block) {
os << std::setw(2) << static_cast<unsigned>(std::to_integer<unsigned char>(b));
}
os << std::hex << std::setfill('0') << std::setw(16) << HashAsuKey(block);
return os.str();
}

Expand Down Expand Up @@ -78,6 +85,36 @@ void LogAsuStatus(const char* operation, const AsuStatus& status)
status.message);
}

AsuStatus WaitPrerequisiteEvent(std::uintptr_t eventHandle)
{
if (eventHandle == 0) { return AsuStatus::OK(); }
auto ret = aclrtSynchronizeEvent(reinterpret_cast<aclrtEvent>(eventHandle));
if (ret == ACL_SUCCESS) { return AsuStatus::OK(); }
return AsuStatus::Error(AsuStatusCode::INTERNAL_ERROR,
"aclrtSynchronizeEvent failed: " + std::to_string(ret));
}

const char* TransProviderBackendName(UC::ASU::TransProviderType providerType)
{
switch (providerType) {
case UC::ASU::TransProviderType::FAKE: return "fake";
case UC::ASU::TransProviderType::AIV: return "aiv";
case UC::ASU::TransProviderType::AICPU: return "aicpu";
case UC::ASU::TransProviderType::UNSUPPORTED: return "unsupported";
}
return "unknown";
}

UC::ASU::TransProviderType ParseTransProviderBackend(std::string backend)
{
std::transform(backend.begin(), backend.end(), backend.begin(),
[](unsigned char ch) { return static_cast<char>(std::toupper(ch)); });
if (backend == "FAKE") { return UC::ASU::TransProviderType::FAKE; }
if (backend == "AIV") { return UC::ASU::TransProviderType::AIV; }
if (backend == "AICPU") { return UC::ASU::TransProviderType::AICPU; }
return UC::ASU::TransProviderType::UNSUPPORTED;
}

UC::ASU::MemoryType ParseMemoryType(const std::string& memoryType)
{
if (memoryType == "host") { return UC::ASU::MemoryType::HOST; }
Expand All @@ -86,6 +123,39 @@ UC::ASU::MemoryType ParseMemoryType(const std::string& memoryType)
return UC::ASU::MemoryType::ASCEND_DEVICE;
}

bool TryGetStringLike(const Detail::Dictionary& inConfig, const std::string& key,
std::string& value)
{
if (!inConfig.Contains(key)) { return false; }
try {
inConfig.Get(key, value);
return true;
} catch (const std::bad_any_cast&) {
}
try {
bool boolValue = false;
inConfig.Get(key, boolValue);
value = boolValue ? "true" : "false";
return true;
} catch (const std::bad_any_cast&) {
}
try {
ssize_t numberValue = 0;
inConfig.GetNumber(key, numberValue);
value = std::to_string(numberValue);
return true;
} catch (const std::bad_any_cast&) {
}
return false;
}

void ReadClientAttr(const Detail::Dictionary& inConfig, const std::string& yamlKey,
const std::string& attrKey, Config& config)
{
std::string value;
if (TryGetStringLike(inConfig, yamlKey, value)) { config.clientAttrs[attrKey] = value; }
}

} // namespace

UC::ASU::TransportConfig BuildTransportConfig(const Config& config, std::size_t index)
Expand All @@ -98,13 +168,36 @@ UC::ASU::TransportConfig BuildTransportConfig(const Config& config, std::size_t
transportConfig.storeTimeoutMs = config.storeTimeoutMs;
transportConfig.maxInflightTasks = static_cast<std::uint32_t>(config.maxInflightTasks);
transportConfig.maxInflightBytes = config.maxInflightBytes;
transportConfig.providerType = config.transProviderType;
if (!config.asuIps.empty()) {
UC::ASU::AsuEndpoint endpoint;
endpoint.ip = config.asuIps[index];
endpoint.port = config.asuPort;
endpoint.deviceId = config.deviceId;
transportConfig.endpoints.emplace_back(std::move(endpoint));
}
if (config.transProviderType == UC::ASU::TransProviderType::FAKE) {
const auto fakeDeviceId = config.deviceId >= 0 ? config.deviceId : 0;
transportConfig.attrs.try_emplace("kernel_count", "1");
transportConfig.attrs.try_emplace("quiet_count", "1");
transportConfig.attrs["kv_ns_id"] = std::to_string(transportConfig.asuId);
transportConfig.attrs.try_emplace("dtype", "0");
transportConfig.attrs.try_emplace("dspec", "0");
transportConfig.attrs.try_emplace("lr", "false");
transportConfig.attrs["sc"] = "true";
transportConfig.attrs["fake_backend.path"] = config.fakeBackendPath;
transportConfig.attrs["fake_backend.latency_ms"] =
std::to_string(config.fakeBackendLatencyMs);
transportConfig.attrs["fake_backend.device_id"] = std::to_string(fakeDeviceId);
if (transportConfig.endpoints.empty()) {
UC::ASU::AsuEndpoint endpoint;
endpoint.ip = "fake_backend";
endpoint.port = 19001;
endpoint.protocol = UC::ASU::Protocol::TCP;
endpoint.deviceId = fakeDeviceId;
transportConfig.endpoints.emplace_back(std::move(endpoint));
}
}
return transportConfig;
}

Expand All @@ -117,6 +210,7 @@ class ClientBackend final : public AsuBackend {
asuConfig.clientId = config.clientId;
asuConfig.viewServiceAddrs = config.viewServiceAddrs;
asuConfig.defaultWaitTimeoutMs = config.defaultWaitTimeoutMs;
asuConfig.attrs = config.clientAttrs;
asuConfig.transportConfigs.reserve(config.asuIds.size());
for (std::size_t i = 0; i < config.asuIds.size(); ++i) {
asuConfig.transportConfigs.emplace_back(BuildTransportConfig(config, i));
Expand Down Expand Up @@ -244,10 +338,10 @@ class AsuStore final : public StoreV1 {

~AsuStore() override
{
if (!backend_) { return; }

auto status = backend_->Shutdown();
if (!status.ok()) { UC_ERROR("Failed to shutdown ASU backend: {}.", status.message); }
if (backend_) {
auto status = backend_->Shutdown();
if (!status.ok()) { UC_ERROR("Failed to shutdown ASU backend: {}.", status.message); }
}
}

Status Setup(const Detail::Dictionary& inConfig) override
Expand Down Expand Up @@ -317,6 +411,11 @@ class AsuStore final : public StoreV1 {

Expected<Detail::TaskHandle> Dump(Detail::TaskDesc task) override
{
auto status = WaitPrerequisiteEvent(task.prerequisiteHandle);
if (!status.ok()) {
LogAsuStatus("wait prerequisite event", status);
return ConvertStatus(status);
}
return Submit(std::move(task), &AsuBackend::StoreAsync);
}

Expand Down Expand Up @@ -374,6 +473,26 @@ class AsuStore final : public StoreV1 {
inConfig.GetNumber("block_size", config.blockSize);
inConfig.GetNumber("device_id", config.deviceId);
inConfig.Get("asu_memory_type", config.memoryType);
std::string providerBackend;
if (TryGetStringLike(inConfig, "asu_trans_provider_backend", providerBackend)) {
config.transProviderType = ParseTransProviderBackend(providerBackend);
}
inConfig.Get("asu_fake_backend_path", config.fakeBackendPath);
inConfig.GetNumber("asu_fake_backend_latency_ms", config.fakeBackendLatencyMs);
ReadClientAttr(inConfig, "asu_router_type", "hash_table.type", config);
ReadClientAttr(inConfig, "asu_ring_hash_virtual_node_count", "ring_hash.virtual_node_count",
config);
ReadClientAttr(inConfig, "asu_maglev_table_size", "maglev.table_size", config);
ReadClientAttr(inConfig, "asu_contiguous_block_affinity_block_count",
"contiguous_block_affinity.block_count", config);
ReadClientAttr(inConfig, "asu_contiguous_block_affinity_full_spread_type",
"contiguous_block_affinity.full_spread_type", config);
ReadClientAttr(inConfig, "asu_contiguous_block_affinity_dynamic_adjust_enabled",
"contiguous_block_affinity.dynamic_adjust_enabled", config);
ReadClientAttr(inConfig, "asu_batch_topk_affinity_top_k", "batch_topk_affinity.top_k",
config);
ReadClientAttr(inConfig, "asu_batch_topk_affinity_dynamic_adjust_enabled",
"batch_topk_affinity.dynamic_adjust_enabled", config);

std::size_t tensorSize = 0;
inConfig.GetNumber("tensor_size", tensorSize);
Expand Down Expand Up @@ -406,6 +525,14 @@ class AsuStore final : public StoreV1 {
if (!config.asuIps.empty() && config.asuIps.size() != config.asuIds.size()) {
return Status::InvalidParam("asu_ips size must match asu_ids size");
}
if (config.transProviderType == UC::ASU::TransProviderType::UNSUPPORTED) {
return Status::Unsupported();
}
if (config.transProviderType == UC::ASU::TransProviderType::FAKE &&
!config.configPath.empty()) {
return Status::InvalidParam(
"asu_trans_provider_backend=fake does not support asu_config_path");
}
if (config.tensorSizes.empty()) { return Status::InvalidParam("invalid tensor size"); }
if (config.shardSize == 0) { return Status::InvalidParam("invalid shard size"); }
if (config.blockSize == 0) { return Status::InvalidParam("invalid block size"); }
Expand All @@ -417,12 +544,6 @@ class AsuStore final : public StoreV1 {
if (config.blockSize % config.shardSize != 0) {
return Status::InvalidParam("invalid block size({})", config.blockSize);
}
if (config.blockSize != config.shardSize) {
return Status::InvalidParam("asu store requires one shard per block");
}
if (config.tensorSizes.size() != 1) {
return Status::InvalidParam("asu store requires one tensor buffer per block");
}
return Status::OK();
}

Expand Down Expand Up @@ -464,7 +585,7 @@ class AsuStore final : public StoreV1 {
std::vector<UC::ASU::CacheKey> keys;
keys.reserve(num);
for (std::size_t blockIndex = 0; blockIndex < num; ++blockIndex) {
keys.emplace_back(ToHex(blocks[blockIndex]));
keys.emplace_back(MakeAsuKey(blocks[blockIndex]));
}
return keys;
}
Expand Down Expand Up @@ -501,7 +622,7 @@ class AsuStore final : public StoreV1 {
}
for (std::size_t tensorIndex = 0; tensorIndex < shard.addrs.size(); ++tensorIndex) {
UC::ASU::KVBuffer entry;
entry.key = ToHex(shard.owner);
entry.key = MakeAsuKey(shard.owner);
entry.buffer.region.memoryType = memoryType;
entry.buffer.region.addr =
reinterpret_cast<std::uint64_t>(shard.addrs[tensorIndex]);
Expand All @@ -526,6 +647,9 @@ class AsuStore final : public StoreV1 {
UC_INFO("Set AsuStore::BlockSize to {}.", config.blockSize);
UC_INFO("Set AsuStore::TensorSizes to {}.", config.tensorSizes);
UC_INFO("Set AsuStore::DeviceId to {}.", config.deviceId);
UC_INFO("Set AsuStore::TransProviderBackend to {}.",
TransProviderBackendName(config.transProviderType));
UC_INFO("Set AsuStore::FakeBackendPath to {}.", config.fakeBackendPath);
}

Config config_;
Expand Down
5 changes: 5 additions & 0 deletions ucm/store/asu/cc/asu_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstddef>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include "asu_transport/types.h"

Expand All @@ -28,6 +29,10 @@ struct Config {
std::size_t blockSize{0};
std::int32_t deviceId{-1};
std::string memoryType;
UC::ASU::TransProviderType transProviderType{UC::ASU::TransProviderType::AICPU};
std::string fakeBackendPath;
std::uint64_t fakeBackendLatencyMs{1};
std::unordered_map<std::string, std::string> clientAttrs;
};

class AsuBackend {
Expand Down
2 changes: 1 addition & 1 deletion ucm/store/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ if(BUILD_UNIT_TESTS)
gtest_main gtest gmock
)
if(BUILD_UCM_ASU)
target_link_libraries(ucmstore.test PRIVATE asu_client)
target_link_libraries(ucmstore.test PRIVATE asu_client asu_ascend_deps)
target_compile_definitions(ucmstore.test PRIVATE ASU_BUILD_TESTS)
endif()
gtest_discover_tests(ucmstore.test)
Expand Down
Loading
Loading