|
9 | 9 | #pragma once |
10 | 10 |
|
11 | 11 | #include <algorithm> |
| 12 | +#include <cstring> |
12 | 13 | #include <memory> |
13 | 14 | #include <numeric> |
14 | 15 | #include <unordered_map> |
15 | 16 | #include <vector> |
16 | 17 |
|
17 | 18 | #include <c10/util/safe_numerics.h> |
18 | 19 | #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
| 20 | +#include <executorch/runtime/core/named_data_map.h> |
19 | 21 | #include <executorch/runtime/core/span.h> |
20 | 22 | #include <executorch/runtime/executor/method.h> |
21 | 23 | #include <executorch/runtime/platform/log.h> |
@@ -459,6 +461,7 @@ class StaticAttentionIOManager { |
459 | 461 | StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK; |
460 | 462 | bool generate_full_logits = true; |
461 | 463 | std::optional<size_t> last_valid_token_pos_index = 0; |
| 464 | + std::vector<size_t> lora_input_indices; |
462 | 465 | }; |
463 | 466 |
|
464 | 467 | StaticAttentionIOManager(StaticAttentionIOConfig config) |
@@ -602,6 +605,49 @@ class StaticAttentionIOManager { |
602 | 605 | return input_pos_; |
603 | 606 | } |
604 | 607 |
|
| 608 | + /** |
| 609 | + * Load LoRA adapter weights from a NamedDataMap and bind them to the |
| 610 | + * method's inputs. |
| 611 | + * |
| 612 | + * Keys are read in data-map index order and copied into internal buffers |
| 613 | + * before binding, so the bound input memory remains valid after this call. |
| 614 | + * If the data map and config_.lora_input_indices have different counts, this |
| 615 | + * method binds only the first min(counts) entries and leaves any remaining |
| 616 | + * configured LoRA inputs unchanged. |
| 617 | + */ |
| 618 | + void load_lora_io_adapter( |
| 619 | + torch::executor::Method& method, |
| 620 | + const executorch::runtime::NamedDataMap& data_map) { |
| 621 | + if (config_.lora_input_indices.empty()) { |
| 622 | + return; |
| 623 | + } |
| 624 | + auto num_keys_result = data_map.get_num_keys(); |
| 625 | + ET_CHECK(num_keys_result.ok()); |
| 626 | + auto num_keys = num_keys_result.get(); |
| 627 | + if (num_keys != config_.lora_input_indices.size()) { |
| 628 | + num_keys = config_.lora_input_indices.size(); |
| 629 | + } |
| 630 | + if (num_keys != lora_buffers_.size()) { |
| 631 | + lora_buffers_.resize(num_keys); |
| 632 | + } |
| 633 | + ET_LOG(Info, "Loading %u LoRA adapter tensors", num_keys); |
| 634 | + for (uint32_t i = 0; i < num_keys; i++) { |
| 635 | + auto key_result = data_map.get_key(i); |
| 636 | + ET_CHECK(key_result.ok()); |
| 637 | + |
| 638 | + auto data_result = data_map.get_data(key_result.get()); |
| 639 | + ET_CHECK(data_result.ok()); |
| 640 | + |
| 641 | + auto nbytes = data_result.get().size(); |
| 642 | + lora_buffers_[i].resize(nbytes); |
| 643 | + std::memcpy(lora_buffers_[i].data(), data_result.get().data(), nbytes); |
| 644 | + |
| 645 | + set_input_raw( |
| 646 | + method, config_.lora_input_indices[i], lora_buffers_[i].data()); |
| 647 | + } |
| 648 | + ET_LOG(Info, "Loaded %u LoRA adapter tensors", num_keys); |
| 649 | + } |
| 650 | + |
605 | 651 | /** |
606 | 652 | * Prefill helper. Run multiple inferences as needed depending on the length |
607 | 653 | * of the prompt and method's input length. Returns the position in the output |
@@ -886,6 +932,24 @@ class StaticAttentionIOManager { |
886 | 932 | } |
887 | 933 |
|
888 | 934 | private: |
| 935 | + void |
| 936 | + set_input_raw(executorch::runtime::Method& method, size_t idx, void* data) { |
| 937 | + auto methodMeta = method.method_meta(); |
| 938 | + auto inputMeta = methodMeta.input_tensor_meta(idx); |
| 939 | + ET_CHECK(inputMeta.ok()); |
| 940 | + auto impl = ::executorch::runtime::etensor::TensorImpl( |
| 941 | + inputMeta->scalar_type(), |
| 942 | + inputMeta->sizes().size(), |
| 943 | + const_cast<executorch::aten::TensorImpl::SizesType*>( |
| 944 | + inputMeta->sizes().data()), |
| 945 | + data, |
| 946 | + const_cast<executorch::aten::TensorImpl::DimOrderType*>( |
| 947 | + inputMeta->dim_order().data())); |
| 948 | + executorch::runtime::etensor::Tensor t(&impl); |
| 949 | + ET_CHECK(data != nullptr); |
| 950 | + ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok); |
| 951 | + } |
| 952 | + |
889 | 953 | template <typename T> |
890 | 954 | void set_input(executorch::runtime::Method& method, size_t idx, T* data) { |
891 | 955 | auto methodMeta = method.method_meta(); |
@@ -1015,6 +1079,7 @@ class StaticAttentionIOManager { |
1015 | 1079 | std::vector<RopeT> rope_freqs_cos_override_; |
1016 | 1080 | std::vector<RopeT> rope_freqs_sin_override_; |
1017 | 1081 | int64_t last_valid_token_pos_; |
| 1082 | + std::vector<std::vector<uint8_t>> lora_buffers_; |
1018 | 1083 | }; |
1019 | 1084 |
|
1020 | 1085 | } // namespace example |
0 commit comments