@@ -122,26 +122,30 @@ Tensor& update_cache_out_no_context(
122122 const Tensor& value,
123123 Tensor& cache,
124124 const int64_t start_pos,
125+ const bool is_seq_dim_2,
125126 Tensor& output);
126127
127128at::Tensor update_cache_aten (
128129 const at::Tensor& value,
129130 at::Tensor& cache,
130- const int64_t start_pos);
131+ const int64_t start_pos,
132+ const bool is_seq_dim_2);
131133
132134// New functions for update_cache_with_indices
133135Tensor& update_cache_with_indices_out_no_context (
134136 const Tensor& value,
135137 Tensor& cache,
136138 const int64_t start_pos,
137139 const Tensor& indices,
140+ const bool is_seq_dim_2,
138141 Tensor& output);
139142
140143at::Tensor update_cache_with_indices_aten (
141144 const at::Tensor& value,
142145 at::Tensor& cache,
143146 const int64_t start_pos,
144- const at::Tensor& indices);
147+ const at::Tensor& indices,
148+ const bool is_seq_dim_2);
145149
146150Tensor& sdpa_with_kv_cache_out_no_context (
147151 const Tensor& q_projected,
@@ -338,19 +342,21 @@ Tensor& update_cache_out_no_context(
338342 const Tensor& value,
339343 Tensor& cache,
340344 const int64_t start_pos,
345+ const bool is_seq_dim_2,
341346 Tensor& output) {
342347 executorch::aten::RuntimeContext context{};
343348 return torch::executor::native::update_cache_out (
344- context, value, cache, start_pos, output);
349+ context, value, cache, start_pos, is_seq_dim_2, output);
345350}
346351
347352at::Tensor update_cache_aten (
348353 const at::Tensor& value,
349354 at::Tensor& cache,
350- const int64_t start_pos) {
355+ const int64_t start_pos,
356+ const bool is_seq_dim_2) {
351357 auto output = at::empty ({1 });
352- WRAP_TO_ATEN (update_cache_out_no_context, 3 )
353- (value, cache, start_pos, output);
358+ WRAP_TO_ATEN (update_cache_out_no_context, 4 )
359+ (value, cache, start_pos, is_seq_dim_2, output);
354360 return output;
355361}
356362
@@ -360,20 +366,22 @@ Tensor& update_cache_with_indices_out_no_context(
360366 Tensor& cache,
361367 const int64_t start_pos,
362368 const Tensor& indices,
369+ const bool is_seq_dim_2,
363370 Tensor& output) {
364371 executorch::aten::RuntimeContext context{};
365372 return torch::executor::native::update_cache_with_indices_out (
366- context, value, cache, start_pos, indices, output);
373+ context, value, cache, start_pos, indices, is_seq_dim_2, output);
367374}
368375
369376at::Tensor update_cache_with_indices_aten (
370377 const at::Tensor& value,
371378 at::Tensor& cache,
372379 const int64_t start_pos,
373- const at::Tensor& indices) {
380+ const at::Tensor& indices,
381+ const bool is_seq_dim_2) {
374382 auto output = at::empty ({1 });
375- WRAP_TO_ATEN (update_cache_with_indices_out_no_context, 4 )
376- (value, cache, start_pos, indices, output);
383+ WRAP_TO_ATEN (update_cache_with_indices_out_no_context, 5 )
384+ (value, cache, start_pos, indices, is_seq_dim_2, output);
377385 return output;
378386}
379387
@@ -400,16 +408,16 @@ TORCH_LIBRARY_FRAGMENT(llama, m) {
400408 " float? scale=None, *, Tensor(a!) out) -> Tensor(a!)" );
401409 m.def (
402410 " update_cache(Tensor value, Tensor(a!) cache, "
403- " SymInt start_pos) -> Tensor" );
411+ " SymInt start_pos, bool is_seq_dim_2=False ) -> Tensor" );
404412 m.def (
405413 " update_cache.out(Tensor value, Tensor(a!) cache, "
406- " SymInt start_pos, *, Tensor(b!) out) -> Tensor(b!)" );
414+ " SymInt start_pos, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)" );
407415 m.def (
408416 " update_cache_with_indices(Tensor value, Tensor(a!) cache, "
409- " SymInt start_pos, Tensor indices) -> Tensor" );
417+ " SymInt start_pos, Tensor indices, bool is_seq_dim_2=False ) -> Tensor" );
410418 m.def (
411419 " update_cache_with_indices.out(Tensor value, Tensor(a!) cache, "
412- " SymInt start_pos, Tensor indices, *, Tensor(b!) out) -> Tensor(b!)" );
420+ " SymInt start_pos, Tensor indices, bool is_seq_dim_2=False, *, Tensor(b!) out) -> Tensor(b!)" );
413421 m.def (
414422 " custom_quantized_sdpa(Tensor query, Tensor key, Tensor value, SymInt start_pos, "
415423 " Tensor? attn_mask=None, float drpout_p=0.0, bool is_causal=False, "
@@ -439,15 +447,15 @@ TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
439447 m.impl (" update_cache" , torch::executor::native::update_cache_aten);
440448 m.impl (
441449 " update_cache.out" ,
442- WRAP_TO_ATEN (torch::executor::native::update_cache_out_no_context, 3 ));
450+ WRAP_TO_ATEN (torch::executor::native::update_cache_out_no_context, 4 ));
443451 m.impl (
444452 " update_cache_with_indices" ,
445453 torch::executor::native::update_cache_with_indices_aten);
446454 m.impl (
447455 " update_cache_with_indices.out" ,
448456 WRAP_TO_ATEN (
449457 torch::executor::native::update_cache_with_indices_out_no_context,
450- 4 ));
458+ 5 ));
451459 m.impl (
452460 " custom_quantized_sdpa" ,
453461 torch::executor::native::custom_quantized_sdpa_aten);
0 commit comments