Skip to content

Commit 658dcd4

Browse files
authored
OSS add LoRA adapter as inputs loading on runtime to support the ANE LoRA-IO model (pytorch#19952)
Differential Revision: D107257631 Pull Request resolved: pytorch#19952
1 parent b5f8155 commit 658dcd4

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
#pragma once
1010

1111
#include <algorithm>
12+
#include <cstring>
1213
#include <memory>
1314
#include <numeric>
1415
#include <unordered_map>
1516
#include <vector>
1617

1718
#include <c10/util/safe_numerics.h>
1819
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
20+
#include <executorch/runtime/core/named_data_map.h>
1921
#include <executorch/runtime/core/span.h>
2022
#include <executorch/runtime/executor/method.h>
2123
#include <executorch/runtime/platform/log.h>
@@ -459,6 +461,7 @@ class StaticAttentionIOManager {
459461
StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK;
460462
bool generate_full_logits = true;
461463
std::optional<size_t> last_valid_token_pos_index = 0;
464+
std::vector<size_t> lora_input_indices;
462465
};
463466

464467
StaticAttentionIOManager(StaticAttentionIOConfig config)
@@ -602,6 +605,49 @@ class StaticAttentionIOManager {
602605
return input_pos_;
603606
}
604607

608+
/**
609+
* Load LoRA adapter weights from a NamedDataMap and bind them to the
610+
* method's inputs.
611+
*
612+
* Keys are read in data-map index order and copied into internal buffers
613+
* before binding, so the bound input memory remains valid after this call.
614+
* If the data map and config_.lora_input_indices have different counts, this
615+
* method binds only the first min(counts) entries and leaves any remaining
616+
* configured LoRA inputs unchanged.
617+
*/
618+
void load_lora_io_adapter(
619+
torch::executor::Method& method,
620+
const executorch::runtime::NamedDataMap& data_map) {
621+
if (config_.lora_input_indices.empty()) {
622+
return;
623+
}
624+
auto num_keys_result = data_map.get_num_keys();
625+
ET_CHECK(num_keys_result.ok());
626+
auto num_keys = num_keys_result.get();
627+
if (num_keys != config_.lora_input_indices.size()) {
628+
num_keys = config_.lora_input_indices.size();
629+
}
630+
if (num_keys != lora_buffers_.size()) {
631+
lora_buffers_.resize(num_keys);
632+
}
633+
ET_LOG(Info, "Loading %u LoRA adapter tensors", num_keys);
634+
for (uint32_t i = 0; i < num_keys; i++) {
635+
auto key_result = data_map.get_key(i);
636+
ET_CHECK(key_result.ok());
637+
638+
auto data_result = data_map.get_data(key_result.get());
639+
ET_CHECK(data_result.ok());
640+
641+
auto nbytes = data_result.get().size();
642+
lora_buffers_[i].resize(nbytes);
643+
std::memcpy(lora_buffers_[i].data(), data_result.get().data(), nbytes);
644+
645+
set_input_raw(
646+
method, config_.lora_input_indices[i], lora_buffers_[i].data());
647+
}
648+
ET_LOG(Info, "Loaded %u LoRA adapter tensors", num_keys);
649+
}
650+
605651
/**
606652
* Prefill helper. Run multiple inferences as needed depending on the length
607653
* of the prompt and method's input length. Returns the position in the output
@@ -886,6 +932,24 @@ class StaticAttentionIOManager {
886932
}
887933

888934
private:
935+
void
936+
set_input_raw(executorch::runtime::Method& method, size_t idx, void* data) {
937+
auto methodMeta = method.method_meta();
938+
auto inputMeta = methodMeta.input_tensor_meta(idx);
939+
ET_CHECK(inputMeta.ok());
940+
auto impl = ::executorch::runtime::etensor::TensorImpl(
941+
inputMeta->scalar_type(),
942+
inputMeta->sizes().size(),
943+
const_cast<executorch::aten::TensorImpl::SizesType*>(
944+
inputMeta->sizes().data()),
945+
data,
946+
const_cast<executorch::aten::TensorImpl::DimOrderType*>(
947+
inputMeta->dim_order().data()));
948+
executorch::runtime::etensor::Tensor t(&impl);
949+
ET_CHECK(data != nullptr);
950+
ET_CHECK(method.set_input(t, idx) == executorch::runtime::Error::Ok);
951+
}
952+
889953
template <typename T>
890954
void set_input(executorch::runtime::Method& method, size_t idx, T* data) {
891955
auto methodMeta = method.method_meta();
@@ -1015,6 +1079,7 @@ class StaticAttentionIOManager {
10151079
std::vector<RopeT> rope_freqs_cos_override_;
10161080
std::vector<RopeT> rope_freqs_sin_override_;
10171081
int64_t last_valid_token_pos_;
1082+
std::vector<std::vector<uint8_t>> lora_buffers_;
10181083
};
10191084

10201085
} // namespace example

0 commit comments

Comments
 (0)