Skip to content

Commit 2a2ef8c

Browse files
authored
[WebGPU] Support continuous decoding (RewindTo) with graph capture (microsoft#2083)
This pull request introduces improvements to the handling of attention masks in both the CUDA and WebGPU backends, focusing on more efficient and correct updates of mask buffers during decoding. The main changes are the implementation of a CPU-side update for static attention masks in CUDA and the addition of a reusable staging buffer for efficient mask updates in WebGPU, with logic to avoid redundant work for single-beam cases. **CUDA backend improvements:** * Replaced the previous (commented-out and incorrect) CUDA memory set logic in `DefaultPositionInputs::RewindMask` with a CPU-side update that correctly sets attended and non-attended positions in the attention mask for each batch/beam, followed by a copy back to the device. This ensures the mask is set with 1s for attended tokens and 0s for future tokens, supporting both `int32_t` and `int64_t` types. **WebGPU backend improvements:** * Added a reusable CPU staging buffer (`mask_staging_buffer_`) to the `InterfaceImpl` struct for efficient attention mask updates, avoiding repeated allocations and redundant writes. * Implemented the `UpdateAttentionMask` method to efficiently update the mask for single-beam cases by only filling new positions with 1s and copying the relevant portion to the device, falling back to CPU for multi-beam cases. This method handles static update path and supports both `int32_t` and `int64_t` mask types.
1 parent 09e69e4 commit 2a2ef8c

3 files changed

Lines changed: 137 additions & 14 deletions

File tree

src/models/position_inputs.cpp

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -380,23 +380,50 @@ void DefaultPositionInputs::InitializeSequenceLengths(std::array<int64_t, 2> sha
380380
}
381381

382382
void DefaultPositionInputs::RewindMask(size_t index) {
383-
if (state_.params_->use_graph_capture) {
384-
throw std::runtime_error("PositionInputs::RewindMask - Static buffer is not supported for continuous decoding.");
385-
#if 0 // TODO: Fix implementation, cudaMemsetAsync of 1 is setting bytes of 1 vs int32's of 1
386-
int past_length = static_cast<int>(index);
387-
int max_length = static_cast<int>(state_.params_->search.max_length);
388-
cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(),
389-
0,
390-
(type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * max_length,
391-
model_.cuda_stream_);
392-
cudaMemsetAsync(attention_mask_->GetTensorMutableRawData(),
393-
1,
394-
(type_ == Ort::TypeToTensorType<int32_t> ? sizeof(int32_t) : sizeof(int64_t)) * past_length,
395-
model_.cuda_stream_);
396-
#endif
383+
if (ShouldUseStaticMaskHandling()) {
384+
// Static mask layout: [batch_beam_size, max_length]
385+
// Rewind to index: write 1s for [0, index), 0s for [index, max_length)
386+
size_t max_len = static_cast<size_t>(state_.params_->search.max_length);
387+
if (index > max_len) {
388+
throw std::runtime_error("RewindMask: index exceeds max_length");
389+
}
390+
size_t batch_beam_size = static_cast<size_t>(attention_mask_shape_[0]);
391+
auto byte_span = attention_mask_->GetByteSpan();
392+
auto cpu_data = byte_span.CpuSpan();
393+
if (type_ == Ort::TypeToTensorType<int32_t>) {
394+
auto* data = reinterpret_cast<int32_t*>(cpu_data.data());
395+
for (size_t i = 0; i < batch_beam_size; i++) {
396+
std::fill_n(data + i * max_len, index, static_cast<int32_t>(1));
397+
std::fill_n(data + i * max_len + index, max_len - index, static_cast<int32_t>(0));
398+
}
399+
} else {
400+
auto* data = reinterpret_cast<int64_t*>(cpu_data.data());
401+
for (size_t i = 0; i < batch_beam_size; i++) {
402+
std::fill_n(data + i * max_len, index, static_cast<int64_t>(1));
403+
std::fill_n(data + i * max_len + index, max_len - index, static_cast<int64_t>(0));
404+
}
405+
}
406+
byte_span.CopyCpuToDevice();
407+
return;
397408
}
409+
410+
// Dynamic mask: adjust shape so the next Update() creates the correct-sized tensor.
411+
// For batch_beam_size == 1 (the only case RewindTo supports), the CPU UpdateAttentionMask
412+
// fills the entire next mask with 1s, so no data fixup is needed - just the shape.
413+
attention_mask_shape_[1] = static_cast<int64_t>(index);
398414
}
399415

416+
// Returns true when the attention mask is a fixed-size [batch_beam_size, max_length] buffer
417+
// that must be updated in-place (write 1s/0s) rather than re-created per step.
418+
// Currently triggered by:
419+
// - DML (always uses graph capture, see IsGraphCaptureEnabled in config.cpp)
420+
// - WebGPU with enableGraphCapture=1 in provider options
421+
// - NvTensorRtRtx with past-present shared buffers
422+
// Not yet using this path:
423+
// - CUDA: graph capture is currently disabled in GenAI due to bugs
424+
// (IsGraphCaptureEnabled throws for CUDA). Once re-enabled, RewindMask's
425+
// static path will work for CUDA as well since it uses device-agnostic
426+
// CpuSpan/CopyCpuToDevice.
400427
bool DefaultPositionInputs::ShouldUseStaticMaskHandling() const {
401428
return state_.params_->use_graph_capture ||
402429
(state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type) &&

src/webgpu/interface.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ struct InterfaceImpl : DeviceInterface {
171171
private:
172172
Ort::Allocator* ort_allocator_{};
173173
const OrtMemoryInfo* ort_memory_info_{};
174+
// Reusable CPU staging buffers for UpdateAttentionMask, pre-filled with 1s.
175+
// Content is always all 1s so sharing across generators is safe; only upload_bytes
176+
// worth of data is copied each call, regardless of buffer capacity.
177+
std::vector<int32_t> mask_staging_buffer_i32_;
178+
std::vector<int64_t> mask_staging_buffer_i64_;
174179

175180
public:
176181
Ort::Allocator& GetAllocator() override {
@@ -190,6 +195,47 @@ struct InterfaceImpl : DeviceInterface {
190195

191196
void Synchronize() override {} // Nothing to do?
192197

198+
bool UpdateAttentionMask([[maybe_unused]] void* next_mask_data, void* mask_data, int batch_beam_size, [[maybe_unused]] int new_kv_length, int total_length, [[maybe_unused]] int max_length, bool update_only, ONNXTensorElementDataType type) override {
199+
if (batch_beam_size != 1 || !update_only) {
200+
return false; // Fall back to CPU for multi-beam or non-static mask
201+
}
202+
if (type != Ort::TypeToTensorType<int32_t> && type != Ort::TypeToTensorType<int64_t>) {
203+
return false; // Unsupported mask type; fall back to CPU handling.
204+
}
205+
// For batch_beam_size == 1 with static mask (update_only=true, no padding),
206+
// the mask is always all 1s for attended positions.
207+
size_t num_elements = static_cast<size_t>(total_length);
208+
size_t upload_bytes;
209+
void* staging_data;
210+
211+
// Use the correctly typed staging buffer. Each grows monotonically and
212+
// only newly extended positions need to be filled with 1.
213+
if (type == Ort::TypeToTensorType<int32_t>) {
214+
if (mask_staging_buffer_i32_.size() < num_elements) {
215+
mask_staging_buffer_i32_.resize(num_elements, static_cast<int32_t>(1));
216+
}
217+
staging_data = mask_staging_buffer_i32_.data();
218+
upload_bytes = num_elements * sizeof(int32_t);
219+
} else {
220+
if (mask_staging_buffer_i64_.size() < num_elements) {
221+
mask_staging_buffer_i64_.resize(num_elements, static_cast<int64_t>(1));
222+
}
223+
staging_data = mask_staging_buffer_i64_.data();
224+
upload_bytes = num_elements * sizeof(int64_t);
225+
}
226+
227+
int64_t shape_val = static_cast<int64_t>(upload_bytes);
228+
std::span<const int64_t> shape{&shape_val, 1};
229+
static const auto cpu_mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
230+
auto src_tensor = OrtValue::CreateTensor(*cpu_mem_info, staging_data, upload_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
231+
auto dst_tensor = OrtValue::CreateTensor(*ort_memory_info_, mask_data, upload_bytes, shape, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8);
232+
const std::vector<const OrtValue*> src_ptrs = {src_tensor.get()};
233+
const std::vector<OrtValue*> dst_ptrs = {dst_tensor.get()};
234+
GetOrtEnv().CopyTensors(src_ptrs, dst_ptrs, nullptr);
235+
236+
return true;
237+
}
238+
193239
bool Cast(void* input, void* output, ONNXTensorElementDataType input_type, ONNXTensorElementDataType output_type, size_t element_count) override {
194240
if (!ort_allocator_) {
195241
throw std::runtime_error("WebGPU allocator not initialized");

test/c_api_tests.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,56 @@ TEST(CAPITests, RewindGptFp32CAPI) {
13341334
}
13351335
#endif
13361336

1337+
// Test RewindTo with static mask handling via NvTensorRtRtx past-present share buffer.
1338+
// Skipped when the phi3-fp16-nvtrt model is not available (CI-only model).
1339+
TEST(CAPITests, RewindGraphCaptureNvTensorRtRtxCAPI) {
1340+
std::string nvtrt_path = MODEL_PATH "hf-internal-testing/phi3-fp16-nvtrt";
1341+
if (!std::filesystem::exists(nvtrt_path)) {
1342+
GTEST_SKIP() << "NvTensorRtRtx model not available at " << nvtrt_path;
1343+
}
1344+
1345+
auto config = OgaConfig::Create(nvtrt_path.c_str());
1346+
config->ClearProviders();
1347+
config->AppendProvider("NvTensorRtRtx");
1348+
1349+
int max_length = 20;
1350+
1351+
auto model = OgaModel::Create(*config);
1352+
auto params = OgaGeneratorParams::Create(*model);
1353+
params->SetSearchOption("max_length", max_length);
1354+
1355+
std::vector<int32_t> input_ids{1, 15043, 29892, 920};
1356+
1357+
auto generator = OgaGenerator::Create(*model, *params);
1358+
generator->AppendTokens(input_ids.data(), input_ids.size());
1359+
while (!generator->IsDone()) {
1360+
generator->GenerateNextToken();
1361+
}
1362+
1363+
auto seq_len = generator->GetSequenceCount(0);
1364+
std::vector<int32_t> first_output(seq_len);
1365+
std::memcpy(first_output.data(), generator->GetSequenceData(0), seq_len * sizeof(int32_t));
1366+
1367+
generator->RewindTo(0);
1368+
generator->AppendTokens(input_ids.data(), input_ids.size());
1369+
while (!generator->IsDone()) {
1370+
generator->GenerateNextToken();
1371+
}
1372+
1373+
auto seq_len2 = generator->GetSequenceCount(0);
1374+
ASSERT_EQ(seq_len2, seq_len);
1375+
EXPECT_TRUE(0 == std::memcmp(first_output.data(), generator->GetSequenceData(0), seq_len * sizeof(int32_t)));
1376+
1377+
generator->RewindTo(6);
1378+
while (!generator->IsDone()) {
1379+
generator->GenerateNextToken();
1380+
}
1381+
1382+
seq_len2 = generator->GetSequenceCount(0);
1383+
ASSERT_EQ(seq_len2, seq_len);
1384+
EXPECT_TRUE(0 == std::memcmp(first_output.data(), generator->GetSequenceData(0), seq_len * sizeof(int32_t)));
1385+
}
1386+
13371387
#ifndef STREAMING_ASR_PATH
13381388
#define STREAMING_ASR_PATH MODEL_PATH "nemotron-speech-streaming"
13391389
#endif

0 commit comments

Comments
 (0)