Skip to content

Commit fc5018e

Browse files
committed
Revert model.py, export.py, main.cpp to main branch
Only chunk_gated_delta_rule.py needs modification — dispatch logic is internal to the triton_op, no model/export/runner changes needed.
1 parent a6ebe8a commit fc5018e

File tree

3 files changed

+32
-44
lines changed

3 files changed

+32
-44
lines changed

examples/models/qwen3_5_moe/export.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -374,19 +374,17 @@ def export_and_lower(model, config, args):
374374
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
375375
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
376376

377-
# --- Single method: dynamic T ---
378-
# Runtime dispatch between recurrent (T=1) and chunked (T>1) happens
379-
# inside the chunk_gated_delta_rule triton_op, not at model level.
380-
tokens = torch.tensor([[0, 1]], dtype=torch.long)
381-
input_pos = torch.tensor([0, 1], dtype=torch.long)
377+
# Dynamic shapes
378+
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
379+
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
382380
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
383381
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
384382

385-
print("Exporting model (single method, dynamic T)...")
383+
print("Exporting with torch.export...")
386384
with torch.no_grad():
387-
prog = export(
385+
exported = export(
388386
model,
389-
(tokens, input_pos),
387+
(example_tokens, example_input_pos),
390388
dynamic_shapes=dynamic_shapes,
391389
strict=True,
392390
)
@@ -404,7 +402,7 @@ def export_and_lower(model, config, args):
404402
"enable_dynamic_shape": True,
405403
}
406404
et_prog = to_edge_transform_and_lower(
407-
prog,
405+
exported,
408406
partitioner=[CudaPartitioner(compile_specs)],
409407
compile_config=EdgeCompileConfig(
410408
_check_ir_validity=False,

examples/models/qwen3_5_moe/main.cpp

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
#include <gflags/gflags.h>
1010

1111
#include <executorch/extension/llm/runner/text_llm_runner.h>
12+
#include <executorch/extension/module/module.h>
1213
#include <executorch/runtime/platform/log.h>
1314
#include <pytorch/tokenizers/hf_tokenizer.h>
1415

15-
#include <optional>
1616
#include <string>
17+
#include <vector>
1718

1819
DEFINE_string(model_path, "", "Model .pte file path.");
1920
DEFINE_string(data_path, "", "Data file (.ptd) for CUDA backend.");
@@ -23,7 +24,6 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
2324
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
2425

2526
namespace llm = ::executorch::extension::llm;
26-
using ::executorch::runtime::Error;
2727

2828
int main(int argc, char** argv) {
2929
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -37,6 +37,11 @@ int main(int argc, char** argv) {
3737
return 1;
3838
}
3939

40+
std::vector<std::string> data_files;
41+
if (!FLAGS_data_path.empty()) {
42+
data_files.push_back(FLAGS_data_path);
43+
}
44+
4045
// Load tokenizer
4146
auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
4247
auto tok_status = tokenizer->load(FLAGS_tokenizer_path);
@@ -48,37 +53,23 @@ int main(int argc, char** argv) {
4853
return 1;
4954
}
5055

51-
// Single-method runner: "forward" handles both prefill (T>1) and decode (T=1)
52-
// via torch.cond dispatch inside the model.
53-
fprintf(stderr, "Loading model from %s...\n", FLAGS_model_path.c_str());
54-
std::optional<const std::string> data_path =
55-
FLAGS_data_path.empty() ? std::nullopt
56-
: std::optional<const std::string>(FLAGS_data_path);
56+
// Create LLM runner
5757
auto runner = llm::create_text_llm_runner(
58-
FLAGS_model_path,
59-
std::move(tokenizer),
60-
data_path,
61-
FLAGS_temperature);
62-
fprintf(stderr, "Runner created successfully\n");
58+
FLAGS_model_path, std::move(tokenizer), data_files, FLAGS_temperature);
59+
60+
if (runner == nullptr) {
61+
ET_LOG(Error, "Failed to create runner");
62+
return 1;
63+
}
6364

6465
// Generate
6566
llm::GenerationConfig config;
6667
config.temperature = FLAGS_temperature;
6768
config.max_new_tokens = FLAGS_max_new_tokens;
6869

69-
fprintf(stderr, "Starting generation with prompt: %s\n", FLAGS_prompt.c_str());
70-
try {
71-
auto error = runner->generate(FLAGS_prompt.c_str(), config);
72-
if (error != Error::Ok) {
73-
fprintf(stderr, "Generation failed with error code: %d\n", static_cast<int>(error));
74-
return 1;
75-
}
76-
fprintf(stderr, "Generation completed successfully\n");
77-
} catch (const std::exception& e) {
78-
fprintf(stderr, "Exception during generation: %s\n", e.what());
79-
return 1;
80-
} catch (...) {
81-
fprintf(stderr, "Unknown exception during generation\n");
70+
auto error = runner->generate(FLAGS_prompt.c_str(), config);
71+
if (error != executorch::runtime::Error::Ok) {
72+
ET_LOG(Error, "Generation failed");
8273
return 1;
8374
}
8475

examples/models/qwen3_5_moe/model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def __init__(self, dim, eps=1e-6):
114114
self.eps = eps
115115

116116
def forward(self, x):
117-
x_fp32 = x.float()
118-
rms = torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
119-
return (x_fp32 * rms * (1.0 + self.weight.float())).to(x.dtype)
117+
x_float = x.float()
118+
normed = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
119+
return (normed * (1.0 + self.weight.float())).type_as(x)
120120

121121

122122
class RMSNormGated(nn.Module):
@@ -128,10 +128,10 @@ def __init__(self, dim, eps=1e-6):
128128
self.eps = eps
129129

130130
def forward(self, x, z):
131-
x_fp32 = x.float()
132-
rms = torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
133-
normed = x_fp32 * rms
134-
return (self.weight.float() * normed * torch.nn.functional.silu(z.float())).to(x.dtype)
131+
x_float = x.float()
132+
normed = x_float * torch.rsqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
133+
normed = self.weight * normed.type_as(x)
134+
return (normed * F.silu(z.float())).type_as(x)
135135

136136

137137
# ---------------------------------------------------------------------------
@@ -390,8 +390,7 @@ def forward(self, x, input_pos):
390390
beta = b.sigmoid()
391391
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
392392

393-
# Gated delta rule: dispatch happens inside the triton_op
394-
# (recurrent kernel for T=1 decode, chunked FLA for T>1 prefill).
393+
# FLA Triton kernel (returns final_state separately, does not mutate initial_state)
395394
output, state = torch.ops.triton.chunk_gated_delta_rule(
396395
q, k, v, g, beta, self.recurrent_state[:B]
397396
)

0 commit comments

Comments
 (0)