Skip to content

Commit b2f6e15

Browse files
authored
[MLAS] RVV-Optimized LLM Operators for RISC-V (#28518)
## Description Added RVV implementations for a subset of LLM inference operators. Optimization of activation functions is in #28308. All tests were conducted on a Spacemit K3 CPU (VLEN=256). | Operator | File | Speedup vs Scalar | Precision | | :--- | :--- | :--- | :--- | | FP16 GEMM | `riscv64/halfgemm_kernel_rvv.cpp` | 51–191x | max_abs ≤ 0.0005 (PASS) | | FP16↔FP32 Cast | `riscv64/cast_kernel_rvv.cpp` | 4–12x | Bit-exact (PASS) | | RotaryEmbedding | `riscv64/rotary_embedding_kernel_rvv.cpp` | 3.1x | max_abs ~6e-08 (PASS) | | SimplifiedLayerNorm | `layer_norm_impl.cc` (inline RVV) | 4.3x | Bit-exact (PASS) | ## Operator Performance ### **FP16 GEMM** | Shape (M×N×K) | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | | 1×768×768 | 7.32 ms | 0.14 ms | 51.6x | 1.22e-04 | | 32×768×768 | 235 ms | 1.25 ms | 187.9x | 3.66e-04 | | 64×768×768 | 469 ms | 2.49 ms | 188.8x | 2.44e-04 | | 128×3072×768 | 5718 ms | 30.4 ms | 187.8x | 4.88e-04 | --- ### **FP16↔FP32 Cast** | Elements | Direction | ORT Scalar | RVV | Speedup | | :--- | :--- | :--- | :--- | :--- | | 1K | F16→F32 | 0.002 ms | 0.000 ms | 9.6x | | 1K | F32→F16 | 0.002 ms | 0.000 ms | 10.5x | | 64K | F16→F32 | 0.127 ms | 0.013 ms | 9.7x | | 64K | F32→F16 | 0.150 ms | 0.013 ms | 11.3x | | 1M | F16→F32 | 2.03 ms | 0.26 ms | 7.7x | | 1M | F32→F16 | 2.39 ms | 0.50 ms | 4.8x | --- ### **RotaryEmbedding** | Dim | Mode | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | :--- | | 64 | non-interleaved | 0.32 us | 0.05 us | 7.1x | 5.96e-08 | | 64 | interleaved | 0.22 us | 0.07 us | 3.2x | 0 | | 128 | non-interleaved | 0.64 us | 0.07 us | 9.7x | 5.96e-08 | | 128 | interleaved | 0.44 us | 0.13 us | 3.4x | 0 | | 256 | non-interleaved | 1.28 us | 0.10 us | 13.0x | 1.19e-07 | | 256 | interleaved | 1.05 us | 0.25 us | 4.3x | 0 | --- ### **RMSNorm** | Hidden | ORT Scalar | RVV | Speedup | Max Abs Error | | :--- | :--- | :--- | :--- | :--- | | 512 | 2.31 us | 0.38 us | 6.0x | 2.38e-07 | | 1024 | 4.64 us | 0.71 us | 6.5x | 2.38e-07 | | 2048 | 9.24 us | 1.42 us | 6.5x | 3.58e-07 | | 4096 | 18.5 us | 2.82 us | 6.6x | 3.58e-06 | > **Note**: ORT's LayerNorm ComputeJob is in an anonymous namespace — there's no public API to call it separately. So I rewrite the benchmark using the same algorithm as ORT's ComputeJob. ## Model Performance The ONNX model comes from: https://huggingface.co/onnx-community/Qwen3-0.6B-ONNX | Metric | FP32 | FP16 | | :--- | :--- | :--- | | Prompt processing | 61.1 tok/s (255 ms p50) | 58.8 tok/s (272 ms p50) | | Token generation | 6.5 tok/s (152 ms p50) | 6.1 tok/s (162 ms p50) | | E2E (16+32 tokens) | 4987 ms p50 | 5357 ms p50 | | Peak memory | 3.1 GB | 4.1 GB | > **Note**: FP32 is slightly faster because it runs SGEMM directly without the FP16↔FP32 cast overhead. FP16 uses ~1 GB less storage on disk but more runtime memory (the cast creates FP32 copies). Both use the RVV SGEMM kernel for the actual compute.
1 parent a4f79e8 commit b2f6e15

25 files changed

Lines changed: 1758 additions & 5 deletions

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
9090
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
9191
option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF)
9292
option(onnxruntime_USE_RVV "Build with RISC-V Vector support in MLAS" OFF)
93+
option(onnxruntime_USE_RVV_ZVFH "Build with RISC-V Zvfh (FP16 vector) support in MLAS" OFF)
9394
option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF)
9495

9596
option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF)

cmake/onnxruntime_mlas.cmake

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
5757
${MLAS_SRC_DIR}/flashattn.cpp
5858
${MLAS_SRC_DIR}/qkv_quant.cpp
5959
${MLAS_SRC_DIR}/cast.cpp
60+
${MLAS_SRC_DIR}/layernorm.cpp
6061
${MLAS_SRC_DIR}/rotary_embedding.h
6162
${MLAS_SRC_DIR}/rotary_embedding.cpp
6263
${MLAS_SRC_DIR}/softmax.h
@@ -959,6 +960,8 @@ endif()
959960
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
960961
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
961962
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
963+
${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp
964+
${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp
962965
)
963966
list(REMOVE_ITEM mlas_platform_srcs
964967
"${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp")
@@ -968,8 +971,22 @@ endif()
968971
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
969972
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
970973
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
974+
${MLAS_SRC_DIR}/riscv64/rotary_embedding_kernel_rvv.cpp
975+
${MLAS_SRC_DIR}/riscv64/layernorm_kernel_rvv.cpp
971976
PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d")
972977
list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1)
978+
979+
if(onnxruntime_USE_RVV_ZVFH)
980+
list(APPEND mlas_platform_srcs
981+
${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp
982+
${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp
983+
)
984+
set_source_files_properties(
985+
${MLAS_SRC_DIR}/riscv64/halfgemm_kernel_rvv.cpp
986+
${MLAS_SRC_DIR}/riscv64/cast_kernel_rvv.cpp
987+
PROPERTIES COMPILE_FLAGS "-march=rv64gcv_zvfh -mabi=lp64d")
988+
list(APPEND mlas_private_compile_definitions MLAS_USE_RVV_ZVFH=1)
989+
endif()
973990
else()
974991
message(
975992
WARNING

cmake/onnxruntime_unittests.cmake

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,50 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
14501450
PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS})
14511451
target_compile_definitions(onnxruntime_mlas_softmax_riscv_compare PRIVATE ${mlas_private_compile_definitions})
14521452
set_target_properties(onnxruntime_mlas_softmax_riscv_compare PROPERTIES FOLDER "ONNXRuntimeTest")
1453+
1454+
onnxruntime_add_executable(
1455+
onnxruntime_mlas_halfgemm_rvv_bench
1456+
${MLAS_RISCV64_BENCH_DIR}/halfgemm_rvv_bench.cpp)
1457+
target_include_directories(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE
1458+
${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib)
1459+
target_link_libraries(
1460+
onnxruntime_mlas_halfgemm_rvv_bench
1461+
PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS})
1462+
target_compile_definitions(onnxruntime_mlas_halfgemm_rvv_bench PRIVATE ${mlas_private_compile_definitions})
1463+
set_target_properties(onnxruntime_mlas_halfgemm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest")
1464+
1465+
onnxruntime_add_executable(
1466+
onnxruntime_mlas_cast_rvv_bench
1467+
${MLAS_RISCV64_BENCH_DIR}/cast_rvv_bench.cpp)
1468+
target_include_directories(onnxruntime_mlas_cast_rvv_bench PRIVATE
1469+
${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib)
1470+
target_link_libraries(
1471+
onnxruntime_mlas_cast_rvv_bench
1472+
PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS})
1473+
target_compile_definitions(onnxruntime_mlas_cast_rvv_bench PRIVATE ${mlas_private_compile_definitions})
1474+
set_target_properties(onnxruntime_mlas_cast_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest")
1475+
1476+
onnxruntime_add_executable(
1477+
onnxruntime_mlas_rope_rvv_bench
1478+
${MLAS_RISCV64_BENCH_DIR}/rope_rvv_bench.cpp)
1479+
target_include_directories(onnxruntime_mlas_rope_rvv_bench PRIVATE
1480+
${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib)
1481+
target_link_libraries(
1482+
onnxruntime_mlas_rope_rvv_bench
1483+
PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS})
1484+
target_compile_definitions(onnxruntime_mlas_rope_rvv_bench PRIVATE ${mlas_private_compile_definitions})
1485+
set_target_properties(onnxruntime_mlas_rope_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest")
1486+
1487+
onnxruntime_add_executable(
1488+
onnxruntime_mlas_rmsnorm_rvv_bench
1489+
${MLAS_RISCV64_BENCH_DIR}/rmsnorm_rvv_bench.cpp)
1490+
target_include_directories(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE
1491+
${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}/core/mlas/lib)
1492+
target_link_libraries(
1493+
onnxruntime_mlas_rmsnorm_rvv_bench
1494+
PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common ${CMAKE_DL_LIBS})
1495+
target_compile_definitions(onnxruntime_mlas_rmsnorm_rvv_bench PRIVATE ${mlas_private_compile_definitions})
1496+
set_target_properties(onnxruntime_mlas_rmsnorm_rvv_bench PROPERTIES FOLDER "ONNXRuntimeTest")
14531497
endif()
14541498

14551499
if(WIN32)

onnxruntime/core/common/cpuid_arch_definition.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@
1212
#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__)
1313
#define CPUIDINFO_ARCH_ARM
1414
#endif // ARM or ARM64
15+
16+
#if defined(__riscv) && __riscv_xlen == 64
17+
#define CPUIDINFO_ARCH_RISCV64
18+
#endif

onnxruntime/core/common/cpuid_info.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@
4747

4848
#endif // ARM
4949

50+
#if defined(CPUIDINFO_ARCH_RISCV64)
51+
#include <asm/hwprobe.h>
52+
#ifndef RISCV_HWPROBE_EXT_ZVFH
53+
#define RISCV_HWPROBE_EXT_ZVFH (1 << 30)
54+
#endif
55+
#ifndef RISCV_HWPROBE_IMA_V
56+
#define RISCV_HWPROBE_IMA_V (1 << 2)
57+
#endif
58+
#endif // RISCV64
59+
5060
#endif // Linux
5161

5262
#if _WIN32
@@ -334,6 +344,17 @@ void CPUIDInfo::ArmAppleInit() {
334344

335345
#endif // defined(CPUIDINFO_ARCH_ARM)
336346

347+
#if defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__)
348+
void CPUIDInfo::RiscvLinuxInit() {
349+
struct riscv_hwprobe pairs[] = {
350+
{RISCV_HWPROBE_KEY_IMA_EXT_0, 0},
351+
};
352+
if (syscall(__NR_riscv_hwprobe, pairs, 1, 0, nullptr, 0) == 0) {
353+
has_fp16_ = (pairs[0].value & RISCV_HWPROBE_EXT_ZVFH) != 0;
354+
}
355+
}
356+
#endif // defined(CPUIDINFO_ARCH_RISCV64) && defined(__linux__)
357+
337358
uint32_t CPUIDInfo::GetCurrentCoreIdx() const {
338359
#ifdef _WIN32
339360
return GetCurrentProcessorNumber();
@@ -377,5 +398,11 @@ CPUIDInfo::CPUIDInfo() {
377398
ArmAppleInit();
378399
#endif
379400
#endif // defined(CPUIDINFO_ARCH_ARM)
401+
402+
#if defined(CPUIDINFO_ARCH_RISCV64)
403+
#if defined(__linux__)
404+
RiscvLinuxInit();
405+
#endif
406+
#endif // defined(CPUIDINFO_ARCH_RISCV64)
380407
}
381408
} // namespace onnxruntime

onnxruntime/core/common/cpuid_info.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,12 @@ class CPUIDInfo {
135135

136136
#endif // defined(CPUIDINFO_ARCH_ARM)
137137

138+
#if defined(CPUIDINFO_ARCH_RISCV64)
139+
#if defined(__linux__)
140+
void RiscvLinuxInit();
141+
#endif
142+
#endif // defined(CPUIDINFO_ARCH_RISCV64)
143+
138144
#if defined(CPUINFO_SUPPORTED)
139145
bool pytorch_cpuinfo_init_{false};
140146
#endif // defined(CPUINFO_SUPPORTED)

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,27 @@ MlasRotaryEmbedOneRow(
16651665
T* output
16661666
);
16671667

1668+
/**
1669+
* @brief Compute LayerNorm or RMSNorm (simplified) for one row of float data.
1670+
* Uses platform-optimized kernel if available, otherwise returns false.
1671+
* Any platform (AMD64/ARM64/RISC-V) can register a LayerNormF32Kernel.
1672+
*
1673+
* @return true if an optimized kernel was used, false if caller should fall back
1674+
*/
1675+
bool
1676+
MLASCALL
1677+
MlasLayerNormF32(
1678+
const float* Input,
1679+
const float* Scale,
1680+
const float* Bias,
1681+
float* Output,
1682+
float* MeanOut,
1683+
float* InvStdDevOut,
1684+
size_t NormSize,
1685+
float Epsilon,
1686+
bool Simplified
1687+
);
1688+
16681689
/**
16691690
* @brief Supply matrices data information to half precision gemm functions
16701691
*/

onnxruntime/core/mlas/lib/halfgemm.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ MlasFp16AccelerationSupported()
2727
{
2828
#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
2929
return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration();
30+
#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH)
31+
return MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration();
3032
#else
3133
return false;
3234
#endif

onnxruntime/core/mlas/lib/halfgemm.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,12 +503,21 @@ extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault;
503503
extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchNeon;
504504
#endif
505505

506+
#if defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH)
507+
extern const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchRvv;
508+
#endif
509+
506510
MLAS_FORCEINLINE
507511
const MLAS_HALFGEMM_DISPATCH*
508512
MlasHalfGemmGetDispatch()
509513
{
510514
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
511515
return &MlasHalfGemmDispatchNeon;
516+
#elif defined(MLAS_TARGET_RISCV64) && defined(MLAS_USE_RVV_ZVFH)
517+
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasFp16VectorAcceleration()) {
518+
return &MlasHalfGemmDispatchRvv;
519+
}
520+
return &MlasHalfGemmDispatchDefault;
512521
#else
513522
return &MlasHalfGemmDispatchDefault;
514523
#endif
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*++
2+
3+
Copyright (c) Microsoft Corporation. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
layernorm.cpp
10+
11+
Abstract:
12+
13+
This module implements the dispatch for platform-optimized
14+
LayerNorm/RMSNorm kernels.
15+
16+
--*/
17+
18+
#include "mlasi.h"
19+
20+
bool
21+
MLASCALL
22+
MlasLayerNormF32(
23+
const float* Input,
24+
const float* Scale,
25+
const float* Bias,
26+
float* Output,
27+
float* MeanOut,
28+
float* InvStdDevOut,
29+
size_t NormSize,
30+
float Epsilon,
31+
bool Simplified
32+
)
33+
{
34+
auto kernel = GetMlasPlatform().LayerNormF32Kernel;
35+
if (kernel == nullptr) {
36+
return false;
37+
}
38+
39+
kernel(Input, Scale, Bias, Output, MeanOut, InvStdDevOut, NormSize, Epsilon, Simplified);
40+
return true;
41+
}

0 commit comments

Comments
 (0)