Skip to content

Commit 22e049e

Browse files
committed
update l2_norm op
1 parent 44ca51d commit 22e049e

3 files changed

Lines changed: 13 additions & 2 deletions

File tree

xllm/core/kernels/ops_api.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,14 @@ void fused_indexer_k(FusedIndexerKParams& params) {
756756
#endif
757757
}
758758

759+
torch::Tensor l2_norm(torch::Tensor& x, double eps) {
760+
#if defined(USE_NPU)
761+
return npu::npu_l2norm_last_dim(x, eps);
762+
#else
763+
NOT_IMPLEMENTED();
764+
#endif
765+
}
766+
759767
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
760768
moe_init_routing_v2(MoeInitRoutingV2Params& params) {
761769
#if defined(USE_NPU)

xllm/core/kernels/ops_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void fused_indexer_q(FusedIndexerQParams& params);
9494

9595
void fused_indexer_k(FusedIndexerKParams& params);
9696

97+
// L2 normalization along the last dimension
98+
torch::Tensor l2_norm(torch::Tensor& x, double eps = 1e-6);
99+
97100
// TODO: NPU moe_init_routing_v2 is equivalent to moe_gen_idx + moe_expand_input
98101
// (and token_count/cusum outputs) on other backends.
99102
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>

xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,8 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward(
440440
{input_params.block_tables.select(1, 0)},
441441
last_recurrent_state.transpose(-1, -2).to(ssm_cache.dtype()));
442442
} else {
443-
processed_q = l2norm(processed_q, -1, 1e-6);
444-
processed_k = l2norm(processed_k, -1, 1e-6);
443+
processed_q = xllm::kernel::l2_norm(processed_q, 1e-6);
444+
processed_k = xllm::kernel::l2_norm(processed_k, 1e-6);
445445
torch::Tensor ssm_state_indices =
446446
attn_metadata.block_table.select(1, 0).contiguous();
447447

0 commit comments

Comments
 (0)