Skip to content

Commit a5a818d

Browse files
committed
Default LLM Android model loading to mmap-only (no mlock)
On Android, ExecuTorch LLM apps previously used mmap+mlock to load .pte model files. While mmap memory-maps the file (pages loaded on demand), mlock pins all mapped pages into physical RAM upfront — defeating mmap's lazy-loading benefit for large models (1-4GB). This causes high OOM kill risk on devices with 6-12GB RAM shared across all apps. Changes: - LlmModuleConfig.java: Add LOAD_MODE_* constants and loadMode field (default LOAD_MODE_MMAP) with builder method and getter - LlmModule.java: Thread loadMode through to JNI initHybrid; existing constructors default to LOAD_MODE_MMAP — no breaking change - jni_layer_llama.cpp: Accept loadMode from Java, map to C++ Module::LoadMode enum, pass to all runner creation paths (text, multimodal, QNN) instead of hardcoded MmapUseMlockIgnoreErrors Apps needing the old behavior can pass LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS.
1 parent f7977a6 commit a5a818d

3 files changed

Lines changed: 114 additions & 18 deletions

File tree

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public class LlmModule {
3434
private static final float DEFAULT_TEMPERATURE = -1.0f;
3535
private static final int DEFAULT_BOS = 0;
3636
private static final int DEFAULT_EOS = 0;
37+
private static final int DEFAULT_LOAD_MODE = 1; // LOAD_MODE_MMAP
3738

3839
@DoNotStrip
3940
private static native HybridData initHybrid(
@@ -43,11 +44,12 @@ private static native HybridData initHybrid(
4344
float temperature,
4445
List<String> dataFiles,
4546
int numBos,
46-
int numEos);
47+
int numEos,
48+
int loadMode);
4749

4850
/**
49-
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
50-
* dataFiles.
51+
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature,
52+
* dataFiles, and load mode.
5153
*/
5254
public LlmModule(
5355
int modelType,
@@ -56,13 +58,38 @@ public LlmModule(
5658
float temperature,
5759
List<String> dataFiles,
5860
int numBos,
59-
int numEos) {
61+
int numEos,
62+
int loadMode) {
6063
ExecuTorchRuntime.getRuntime();
6164
ExecuTorchRuntime.validateFilePath(modulePath, "model path");
6265
ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path");
6366

6467
mHybridData =
65-
initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos);
68+
initHybrid(
69+
modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos, loadMode);
70+
}
71+
72+
/**
73+
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
74+
* dataFiles.
75+
*/
76+
public LlmModule(
77+
int modelType,
78+
String modulePath,
79+
String tokenizerPath,
80+
float temperature,
81+
List<String> dataFiles,
82+
int numBos,
83+
int numEos) {
84+
this(
85+
modelType,
86+
modulePath,
87+
tokenizerPath,
88+
temperature,
89+
dataFiles,
90+
numBos,
91+
numEos,
92+
DEFAULT_LOAD_MODE);
6693
}
6794

6895
/**
@@ -75,7 +102,15 @@ public LlmModule(
75102
String tokenizerPath,
76103
float temperature,
77104
List<String> dataFiles) {
78-
this(modelType, modulePath, tokenizerPath, temperature, dataFiles, DEFAULT_BOS, DEFAULT_EOS);
105+
this(
106+
modelType,
107+
modulePath,
108+
tokenizerPath,
109+
temperature,
110+
dataFiles,
111+
DEFAULT_BOS,
112+
DEFAULT_EOS,
113+
DEFAULT_LOAD_MODE);
79114
}
80115

81116
/**
@@ -148,9 +183,10 @@ public LlmModule(LlmModuleConfig config) {
148183
config.getModulePath(),
149184
config.getTokenizerPath(),
150185
config.getTemperature(),
151-
config.getDataPath(),
186+
config.getDataPath() != null ? List.of(config.getDataPath()) : List.of(),
152187
config.getNumBos(),
153-
config.getNumEos());
188+
config.getNumEos(),
189+
config.getLoadMode());
154190
}
155191

156192
public void resetNative() {

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,19 @@ public class LlmModuleConfig {
2121
private final int modelType;
2222
private final int numBos;
2323
private final int numEos;
24+
private final int loadMode;
25+
26+
/** Load model from file descriptor (no mmap). */
27+
public static final int LOAD_MODE_FILE = 0;
28+
29+
/** Load model via mmap without mlock (default). Pages faulted in on demand. */
30+
public static final int LOAD_MODE_MMAP = 1;
31+
32+
/** Load model via mmap and pin all pages with mlock. */
33+
public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
34+
35+
/** Load model via mmap and attempt mlock, ignoring mlock failures. */
36+
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
2437

2538
private LlmModuleConfig(Builder builder) {
2639
this.modulePath = builder.modulePath;
@@ -30,6 +43,7 @@ private LlmModuleConfig(Builder builder) {
3043
this.modelType = builder.modelType;
3144
this.numBos = builder.numBos;
3245
this.numEos = builder.numEos;
46+
this.loadMode = builder.loadMode;
3347
}
3448

3549
/** Model type constant for text-only models. */
@@ -100,6 +114,13 @@ public int getNumEos() {
100114
return numEos;
101115
}
102116

117+
/**
118+
* @return Load mode for the model file (one of LOAD_MODE_* constants)
119+
*/
120+
public int getLoadMode() {
121+
return loadMode;
122+
}
123+
103124
/**
104125
* Builder class for constructing LlmModuleConfig instances with optional parameters.
105126
*
@@ -114,6 +135,7 @@ public static class Builder {
114135
private int modelType = MODEL_TYPE_TEXT;
115136
private int numBos = 0;
116137
private int numEos = 0;
138+
private int loadMode = LOAD_MODE_MMAP;
117139

118140
Builder() {}
119141

@@ -194,6 +216,19 @@ public Builder numEos(int numEos) {
194216
return this;
195217
}
196218

219+
/**
220+
* Sets the load mode for the model file. Defaults to {@link #LOAD_MODE_MMAP} (mmap without
221+
* mlock), which avoids pinning model pages in RAM.
222+
*
223+
* @param loadMode One of LOAD_MODE_FILE, LOAD_MODE_MMAP, LOAD_MODE_MMAP_USE_MLOCK,
224+
* LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS
225+
* @return This builder instance for method chaining
226+
*/
227+
public Builder loadMode(int loadMode) {
228+
this.loadMode = loadMode;
229+
return this;
230+
}
231+
197232
/**
198233
* Constructs the LlmModuleConfig instance with validated parameters.
199234
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,34 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
109109
facebook::jni::alias_ref<facebook::jni::JList<jstring>::javaobject>
110110
data_files,
111111
jint num_bos,
112-
jint num_eos) {
112+
jint num_eos,
113+
jint load_mode) {
113114
return makeCxxInstance(
114115
model_type_category,
115116
model_path,
116117
tokenizer_path,
117118
temperature,
118119
data_files,
119120
num_bos,
120-
num_eos);
121+
num_eos,
122+
load_mode);
123+
}
124+
125+
static executorch::extension::Module::LoadMode load_mode_from_int(
126+
jint load_mode) {
127+
switch (load_mode) {
128+
case 0:
129+
return executorch::extension::Module::LoadMode::File;
130+
case 1:
131+
return executorch::extension::Module::LoadMode::Mmap;
132+
case 2:
133+
return executorch::extension::Module::LoadMode::MmapUseMlock;
134+
case 3:
135+
return executorch::extension::Module::LoadMode::
136+
MmapUseMlockIgnoreErrors;
137+
default:
138+
return executorch::extension::Module::LoadMode::Mmap;
139+
}
121140
}
122141

123142
ExecuTorchLlmJni(
@@ -127,7 +146,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
127146
jfloat temperature,
128147
facebook::jni::alias_ref<jobject> data_files = nullptr,
129148
jint num_bos = 0,
130-
jint num_eos = 0) {
149+
jint num_eos = 0,
150+
jint load_mode = 1) {
131151
temperature_ = temperature;
132152
num_bos_ = num_bos;
133153
num_eos_ = num_eos;
@@ -143,13 +163,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
143163
#endif
144164

145165
model_type_category_ = model_type_category;
166+
auto cpp_load_mode = load_mode_from_int(load_mode);
146167
std::vector<std::string> data_files_vector;
147168
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
148169
runner_ = llm::create_multimodal_runner(
149170
model_path->toStdString().c_str(),
150171
llm::load_tokenizer(tokenizer_path->toStdString()),
151172
std::nullopt,
152-
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
173+
cpp_load_mode);
153174
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
154175
if (data_files != nullptr) {
155176
// Convert Java List<String> to C++ std::vector<string>
@@ -169,14 +190,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
169190
runner_ = executorch::extension::llm::create_text_llm_runner(
170191
model_path->toStdString(),
171192
llm::load_tokenizer(tokenizer_path->toStdString()),
172-
data_files_vector);
193+
data_files_vector,
194+
/*temperature=*/-1.0f,
195+
/*event_tracer=*/nullptr,
196+
/*method_name=*/"forward",
197+
cpp_load_mode);
173198
#if defined(EXECUTORCH_BUILD_QNN)
174199
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
175-
std::unique_ptr<executorch::extension::Module> module = std::make_unique<
176-
executorch::extension::Module>(
177-
model_path->toStdString().c_str(),
178-
data_files_vector,
179-
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
200+
std::unique_ptr<executorch::extension::Module> module =
201+
std::make_unique<executorch::extension::Module>(
202+
model_path->toStdString().c_str(),
203+
data_files_vector,
204+
cpp_load_mode);
180205
std::string decoder_model = "llama3"; // use llama3 for now
181206
runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
182207
std::move(module),

0 commit comments

Comments
 (0)