Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mamba-ssm-armsve-kernel/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[general]
name = "mamba_selective_scan"
backends = ["cpu"]

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h",
]

[kernel.mamba_selective_scan]
backend = "cpu"
depends = ["torch"]
include = ["mamba_selective_scan_cpu"]
src = [
"mamba_selective_scan_cpu/mamba_selective_scan.cpp",
"mamba_selective_scan_cpu/mamba_selective_scan.hpp",
"mamba_selective_scan_cpu/cpu_feature.hpp",
]

[kernel.mamba_selective_scan_sve]
backend = "cpu"
cxx-flags = ["-march=armv8-a+sve", "-O3", "-ffast-math"]
depends = ["torch"]
include = ["mamba_selective_scan_cpu"]
src = [
"mamba_selective_scan_cpu/mamba_selective_scan_sve.cpp",
"mamba_selective_scan_cpu/mamba_selective_scan_sve.hpp",
]
95 changes: 95 additions & 0 deletions mamba-ssm-armsve-kernel/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions mamba-ssm-armsve-kernel/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for ReLU kernel";

inputs = {
kernel-builder.url = "github:huggingface/kernel-builder";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs {
inherit self;
path = ./.;
};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <fstream>
#include <string>

#if defined(__linux__) || defined(__ANDROID__)
#include <sys/auxv.h>
#include <asm/hwcap.h>
#endif

namespace cpuinfo {

class CPUFeaturesARM {
public:
static bool hasSVE() {
static bool supported = checkSVE();
return supported;
}

private:
static bool checkSVE() {
#if (defined(__aarch64__) || defined(_M_ARM64))
#if defined(__linux__) || defined(__ANDROID__)
// Best: auxv HWCAP bit (kernel-validated for user-space)
const unsigned long hwcap = getauxval(AT_HWCAP);
#ifdef HWCAP_SVE
if (hwcap & HWCAP_SVE) return true;
#endif
return cpuinfoHasFlag("sve");
#else
return false;
#endif
#else
return false;
#endif
}

static bool cpuinfoHasFlag(const char* flag) {
#if defined(__linux__) || defined(__ANDROID__)
std::ifstream f("/proc/cpuinfo");
if (!f) return false;

std::string line;
while (std::getline(f, line)) {
// Typical key is "Features" on ARM. We also accept "flags".
if ((line.find("Features") != std::string::npos ||
line.find("flags") != std::string::npos ||
line.find("Flags") != std::string::npos) &&
line.find(flag) != std::string::npos) {
return true;
}
}
return false;
#else
(void)flag;
return false;
#endif
}
};

} // namespace rmsnorm_cpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#include "mamba_selective_scan.hpp"
#include "mamba_selective_scan_sve.hpp"

#include "cpu_feature.hpp"
#include <cmath>
#include <cstdint>
#include <iostream>
namespace mamba {

static inline void mamba_selective_scan_kernel_cpu(
float* A,
float* B,
float* C,
float* hidden_states,
float* discrete_time_step,
float* ssm_state,
float* scan_output,
int64_t B_size,
int64_t D_size,
int64_t L_size,
int64_t N_size
) {
for (int64_t i = 0; i < B_size; i++) {
for (int64_t j = 0; j < D_size; j++) {
for (int64_t l = 0; l < L_size; l++) {
const int64_t dts_offset = i * D_size * L_size + j * L_size + l;

const float dts = discrete_time_step[dts_offset];
const float hs = hidden_states[dts_offset];

float r1 = 0.0f;

const int64_t B_base = i * L_size * N_size + l * N_size;
const int64_t ssmstate_base = i * D_size * N_size + j * N_size;
const int64_t A_base = j * N_size;

for (int64_t k = 0; k < N_size; k++) {
const float a = A[A_base + k];
const float b = B[B_base + k];
const float c = C[B_base + k];

// discrete_A = exp(dts * a)
const float discA = std::exp(dts * a);

// deltaB_u = hs * (dts * b)
const float deltaBu = hs * (dts * b);

// ssm = ssm * discA + deltaBu
float& s = ssm_state[ssmstate_base + k];
s = s * discA + deltaBu;

// r1 += s * c
r1 += s * c;
}

scan_output[dts_offset] += r1;
}
}
}
}

void mamba_selective_scan(
torch::Tensor& A,
torch::Tensor& B,
torch::Tensor& C,
torch::Tensor& hidden_states,
torch::Tensor& discrete_time_step,
torch::Tensor& ssm_state,
torch::Tensor& scan_output,
int64_t B_size,
int64_t D_size,
int64_t L_size,
int64_t N_size
) {

#if (defined(__aarch64__) || defined(_M_ARM64))
if (cpuinfo::CPUFeaturesARM::hasSVE()) {

mamba_selective_scan_sve(
A, B, C, hidden_states, discrete_time_step, ssm_state, scan_output,
B_size, D_size, L_size, N_size
);
return;
}
#endif
mamba_selective_scan_kernel_cpu(
A.data_ptr<float>(),
B.data_ptr<float>(),
C.data_ptr<float>(),
hidden_states.data_ptr<float>(),
discrete_time_step.data_ptr<float>(),
ssm_state.data_ptr<float>(),
scan_output.data_ptr<float>(),
B_size, D_size, L_size, N_size
);
}

} // namespace mamba

void mamba_selective_scan(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
torch::Tensor &hidden_states,
torch::Tensor &discrete_time_step,
torch::Tensor &ssm_state,
torch::Tensor &scan_output,
int64_t B_size,
int64_t D_size,
int64_t L_size,
int64_t N_size) {
mamba::mamba_selective_scan(
A, B, C, hidden_states, discrete_time_step, ssm_state, scan_output,
B_size, D_size, L_size, N_size
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <cstdint>
#include <torch/torch.h>

namespace mamba {

void mamba_selective_scan(
torch::Tensor& A,
torch::Tensor& B,
torch::Tensor& C,
torch::Tensor& hidden_states,
torch::Tensor& discrete_time_step,
torch::Tensor& ssm_state,
torch::Tensor& scan_output,
int64_t B_size,
int64_t D_size,
int64_t L_size,
int64_t N_size
);

}
Loading