Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5465d8b
Replace chunked FLA with recurrent gated delta rule for T=1 decode
Gasoonjia Apr 2, 2026
a6ebe8a
Runtime dispatch: recurrent (T=1) vs chunked (T>1) inside triton_op
Gasoonjia Apr 3, 2026
fc5018e
Revert model.py, export.py, main.cpp to main branch
Gasoonjia Apr 3, 2026
c90a8e8
Add tests for recurrent (T=1) and multi-T dispatch
Gasoonjia Apr 3, 2026
ce3e9ca
lint fix - 2
Gasoonjia Apr 3, 2026
8d35c65
lint fix - 2
Gasoonjia Apr 3, 2026
709deb0
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 3, 2026
eff976d
lint fix - 3
Gasoonjia Apr 3, 2026
7dd4280
Optimize recurrent kernel: parallelize over V tiles
Gasoonjia Apr 3, 2026
3a1ee31
Dual-method PTE with GPU-resident state for Qwen3.5 MoE
Apr 5, 2026
63c162e
Use share_mutable_buffers to eliminate select_scatter overhead
Apr 6, 2026
47d6b98
Merge branch 'main' into recurrent-fla
Gasoonjia Apr 6, 2026
375e5c0
lint
Gasoonjia Apr 6, 2026
2b36797
remove reduntdant updates
Gasoonjia Apr 6, 2026
c06d58b
Cross-method AOTI constant sharing for KV cache
Apr 7, 2026
6945b2a
Fix cross-method AOTI constant sharing and add dual-method runner
Gasoonjia Apr 7, 2026
ea51d0d
Remove debug printf and decode_only flag
Gasoonjia Apr 7, 2026
a0a62f1
Lint formatting fixes
Gasoonjia Apr 7, 2026
ca69871
Improve CUDA backend error handling and add dual-method runner fallback
Apr 9, 2026
7c148f7
Add CUDA graph capture/replay for decode method
Apr 10, 2026
ee75c2e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 10, 2026
10e7aad
lint and reformat
Gasoonjia Apr 13, 2026
9042f36
Merge branch 'main' into cuda-graph
Gasoonjia Apr 13, 2026
84d1587
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
e00a499
solve claude
Gasoonjia Apr 15, 2026
aa7bb82
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
cef386b
Merge branch 'main' into cuda-graph
Gasoonjia Apr 15, 2026
2d32422
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
1270870
Merge branch 'main' into cuda-graph
Gasoonjia Apr 16, 2026
8fc7355
solve stride out of scope
Gasoonjia Apr 17, 2026
2c46ed2
Merge branch 'main' into cuda-graph
Gasoonjia Apr 21, 2026
855eb93
Merge branch 'main' into cuda-graph
Gasoonjia Apr 22, 2026
4237d17
remove unused env var
Gasoonjia Apr 22, 2026
9b4705e
Merge branch 'main' into cuda-graph
Gasoonjia Apr 23, 2026
0492e8d
Add GPU-side Gumbel-max sampling for CUDA graph compatibility
Apr 13, 2026
8c0bbf3
lintrunner
Gasoonjia Apr 13, 2026
5245f64
remove git info
Gasoonjia Apr 23, 2026
880391d
reintro llm headers
Gasoonjia Apr 23, 2026
6f411af
lint
Gasoonjia Apr 24, 2026
eff4294
add top-p and top-k arg
Gasoonjia Apr 24, 2026
61d47aa
move top-p and top-k suport into a individual PR
Gasoonjia Apr 24, 2026
3e185c0
Merge branch 'main' into cuda-graph-sampling
Gasoonjia Apr 27, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ jobs:
# Run CUDA backend Python tests
python -m pytest backends/cuda/tests backends/cuda/passes/tests -v -o "addopts="

# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py -v -o "addopts="
# Run Qwen 3.5 MoE tests (quantize roundtrip + TurboQuant KV cache + sampler)
python -m pytest examples/models/qwen3_5_moe/test_quantize_roundtrip.py examples/models/qwen3_5_moe/test_turboquant.py examples/models/qwen3_5_moe/test_sampler.py -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
Expand Down
9 changes: 5 additions & 4 deletions backends/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ set(_aoti_cuda_shim_sources runtime/shims/memory.cpp
runtime/shims/cuda_guard.cpp
)

# Only build int4mm shim when CUDA language/toolchain is available.
# Only build CUDA shims when CUDA language/toolchain is available.
if(CMAKE_CUDA_COMPILER)
list(APPEND _aoti_cuda_shim_sources runtime/shims/int4mm.cu
runtime/shims/sort.cu
runtime/shims/sort.cu runtime/shims/rand.cu
)
endif()

Expand Down Expand Up @@ -152,7 +152,8 @@ endif()
# retention.
if(_cuda_is_msvc_toolchain)
target_link_libraries(
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart ${CMAKE_DL_LIBS}
aoti_cuda_shims PRIVATE cuda_platform CUDA::cudart CUDA::curand
${CMAKE_DL_LIBS}
)
# Link object library directly so symbols are pulled exactly once while
# avoiding duplicate static/object inclusion and interface leakage.
Expand All @@ -162,7 +163,7 @@ else()
aoti_cuda_shims
PRIVATE cuda_platform
PUBLIC -Wl,--whole-archive aoti_common_shims_slim -Wl,--no-whole-archive
CUDA::cudart ${CMAKE_DL_LIBS}
CUDA::cudart CUDA::curand ${CMAKE_DL_LIBS}
)
endif()

Expand Down
1 change: 1 addition & 0 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"at::_ops::_weight_int4pack_mm::call": None,
"at::_ops::sort_stable::call": None,
"aoti_torch_cuda_randint_low_out": None,
}

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ runtime.cxx_library(
"shims/cuda_guard.cpp",
"shims/int4mm.cu",
"shims/memory.cpp",
"shims/rand.cu",
"shims/sort.cu",
"shims/tensor_attribute.cpp",
],
Expand All @@ -41,6 +42,7 @@ runtime.cxx_library(
"shims/int4mm.cuh",
"shims/int4mm.h",
"shims/memory.h",
"shims/rand.h",
"shims/sort.h",
"shims/tensor_attribute.h",
"utils.h",
Expand Down
273 changes: 273 additions & 0 deletions backends/cuda/runtime/shims/rand.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/shims/rand.h>

#include <executorch/backends/aoti/slim/cuda/guard.h>
#include <executorch/backends/aoti/slim/factory/empty.h>
#include <executorch/backends/aoti/slim/util/size_util.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>

#include <cuda_runtime.h>
#include <curand_kernel.h>

#include <cstdint>
#include <ctime>
#include <mutex>
#include <vector>

namespace executorch::backends::cuda {

namespace c10 = executorch::backends::aoti::slim::c10;
using c10::Device;
using c10::DeviceIndex;
using c10::DeviceType;
using c10::ScalarType;
using executorch::backends::aoti::slim::empty_strided;
using executorch::backends::aoti::slim::IntArrayRef;
using executorch::backends::aoti::slim::makeArrayRef;

namespace {

// ---- GPU-resident RNG state ----
// Seed and counter live in device memory allocated during the first call
// (warmup phase, before CUDA graph capture). The counter is atomically
// advanced by each kernel invocation on-device, so it automatically
// produces different random sequences on every CUDA graph replay.

struct RngState {
unsigned long long seed;
unsigned long long counter;
// Per-launch scratch — written by advance_counter_kernel and read by
// the main RNG kernels. Single-threaded host driver is assumed
// (typical inference / CUDA-graph replay use case).
unsigned long long base_scratch;
};

static RngState* d_rng = nullptr;
// std::call_once guarantees one-shot initialization even when shims are
// invoked from multiple host threads (e.g. concurrent models / streams).
static std::once_flag g_rng_init_flag;

// Initialize RNG state on the given stream.
// Must be called during warmup (before graph capture). Subsequent calls
// from any thread are no-ops thanks to std::call_once.
void ensure_rng_init(cudaStream_t stream) {
std::call_once(g_rng_init_flag, [&]() {
cudaMallocAsync(&d_rng, sizeof(RngState), stream);
RngState h;
h.seed = static_cast<unsigned long long>(time(nullptr));
h.counter = 0;
h.base_scratch = 0;
cudaMemcpyAsync(
d_rng, &h, sizeof(RngState), cudaMemcpyHostToDevice, stream);
// Synchronize to ensure the copy completes before we return
// (the host-side RngState `h` is on the stack).
cudaStreamSynchronize(stream);
});
}

// Philox-based randint kernel. Reads its base offset from `rng->base_scratch`
// (populated by `advance_counter_kernel` immediately before this launch).
// This replaces the previous per-element atomicAdd contention with a single
// atomic per kernel launch.
__global__ void philox_randint_graph_kernel(
int64_t* __restrict__ out,
int64_t numel,
int64_t low,
int64_t range,
RngState* __restrict__ rng) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, rng->base_scratch, &state);
double val = curand_uniform_double(&state);
int64_t ival = static_cast<int64_t>(val * range);
out[idx] = low + (ival >= range ? range - 1 : ival);
}
}

// Philox-based uniform float32 generator (graph-safe version).
__global__ void philox_rand_float_graph_kernel(
float* __restrict__ out,
int64_t numel,
RngState* __restrict__ rng) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, rng->base_scratch, &state);
out[idx] = curand_uniform(&state);
}
}

// Philox-based uniform bfloat16 generator (graph-safe version).
__global__ void philox_rand_bf16_graph_kernel(
uint16_t* __restrict__ out,
int64_t numel,
RngState* __restrict__ rng) {
int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx < numel) {
curandStatePhilox4_32_10_t state;
curand_init(rng->seed, idx, rng->base_scratch, &state);
float val = curand_uniform(&state);
uint32_t bits;
memcpy(&bits, &val, sizeof(uint32_t));
uint32_t lsb = (bits >> 16) & 1;
bits += 0x7FFFu + lsb;
out[idx] = static_cast<uint16_t>(bits >> 16);
}
}

// Single-thread helper that grabs a contiguous range of `numel` offsets
// from the on-device counter and writes the base into `rng->base_scratch`.
// Replaces `numel` per-element atomics with a single atomic per launch
// while staying graph-capturable.
__global__ void advance_counter_kernel(
RngState* __restrict__ rng,
unsigned long long numel) {
if (blockIdx.x == 0 && threadIdx.x == 0) {
rng->base_scratch = atomicAdd(&rng->counter, numel);
}
}

} // anonymous namespace

extern "C" {

AOTITorchError aoti_torch_cuda_rand(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this from PyTorch/Aten or we are rolling our own?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rolling our own. I didn;t see there's an aten version.

const int64_t* size,
int64_t size_len_,
int32_t* dtype,
int32_t* layout,
int32_t* device,
int32_t device_index_,
int32_t* pin_memory,
SlimTensor** ret0) {
(void)layout;
(void)device;
(void)pin_memory;

ET_CHECK_OR_RETURN_ERROR(
ret0 != nullptr,
InvalidArgument,
"aoti_torch_cuda_rand: ret0 is null");

// Default to float32 if dtype not specified.
ScalarType scalar_type = ScalarType::Float;
if (dtype != nullptr) {
scalar_type = static_cast<ScalarType>(*dtype);
}

// Compute contiguous strides and total elements.
std::vector<int64_t> strides(size_len_);
int64_t numel = 1;
for (int64_t i = size_len_ - 1; i >= 0; i--) {
strides[i] = numel;
numel *= size[i];
}

// Allocate output tensor.
IntArrayRef sizes_ref(size, static_cast<size_t>(size_len_));
*ret0 = new SlimTensor(empty_strided(
sizes_ref,
makeArrayRef(strides),
scalar_type,
Device(DeviceType::CUDA, static_cast<DeviceIndex>(device_index_))));

if (numel == 0) {
return Error::Ok;
}

// Get the current CUDA stream.
auto stream_result = getCurrentCUDAStream(0);
ET_CHECK_OR_RETURN_ERROR(
stream_result.ok(),
Internal,
"aoti_torch_cuda_rand: failed to get CUDA stream");
cudaStream_t stream = stream_result.get();

ensure_rng_init(stream);

constexpr int kThreads = 256;
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);

// Single atomicAdd per launch — grabs `numel` consecutive counter slots
// for the kernel below, eliminating per-element contention on the GPU
// counter.
advance_counter_kernel<<<1, 1, 0, stream>>>(
d_rng, static_cast<unsigned long long>(numel));

if (scalar_type == ScalarType::Float) {
philox_rand_float_graph_kernel<<<blocks, kThreads, 0, stream>>>(
static_cast<float*>((*ret0)->data_ptr()), numel, d_rng);
} else if (scalar_type == ScalarType::BFloat16) {
philox_rand_bf16_graph_kernel<<<blocks, kThreads, 0, stream>>>(
static_cast<uint16_t*>((*ret0)->data_ptr()), numel, d_rng);
} else {
ET_LOG(
Error,
"aoti_torch_cuda_rand: unsupported dtype %d",
static_cast<int>(scalar_type));
return Error::NotSupported;
}

return Error::Ok;
}

AOTITorchError aoti_torch_cuda_randint_low_out(
SlimTensor* out,
int64_t low,
int64_t high,
const int64_t* size,
int64_t size_len_) {
ET_CHECK_OR_RETURN_ERROR(
out != nullptr,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: out tensor is null");

ET_CHECK_OR_RETURN_ERROR(
high > low,
InvalidArgument,
"aoti_torch_cuda_randint_low_out: requires high > low");

int64_t numel = 1;
for (int64_t i = 0; i < size_len_; i++) {
numel *= size[i];
}
if (numel == 0) {
return Error::Ok;
}

// Get the current CUDA stream.
auto stream_result = getCurrentCUDAStream(0);
ET_CHECK_OR_RETURN_ERROR(
stream_result.ok(),
Internal,
"aoti_torch_cuda_randint_low_out: failed to get CUDA stream");
cudaStream_t stream = stream_result.get();

ensure_rng_init(stream);

int64_t range = high - low;
int64_t* out_data = static_cast<int64_t*>(out->data_ptr());

constexpr int kThreads = 256;
int blocks = static_cast<int>((numel + kThreads - 1) / kThreads);
// One atomicAdd per launch; subsequent kernel reads `rng->base_scratch`.
advance_counter_kernel<<<1, 1, 0, stream>>>(
d_rng, static_cast<unsigned long long>(numel));
philox_randint_graph_kernel<<<blocks, kThreads, 0, stream>>>(
out_data, numel, low, range, d_rng);

return Error::Ok;
}

} // extern "C"

} // namespace executorch::backends::cuda
Loading
Loading