Skip to content

Commit 565b370

Browse files
psiddhCopilot
authored andcommitted
Android: consistent error types across all modules (pytorch#19099)
TrainingModule: implement Closeable, replace Log.e + silent empty returns with IllegalStateException throws. Add checkNotDestroyed() guard on all public methods. SGD: throw IllegalStateException instead of bare RuntimeException when optimizer is destroyed. AsrModule: throw ExecutorchRuntimeException instead of bare RuntimeException on transcription failure. ExecuTorchRuntime.validateFilePath: throw IllegalArgumentException instead of bare RuntimeException, with descriptive message. JNI constructors: wrap ExecuTorchJni and ExecuTorchLlmJni constructor bodies in try-catch so C++ exceptions become ExecutorchRuntimeException instead of generic RuntimeException. This commit was authored with the help of Claude. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9e5d17f commit 565b370

6 files changed

Lines changed: 152 additions & 112 deletions

File tree

extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,22 @@ public static ExecuTorchRuntime getRuntime() {
3636
/**
3737
* Validates that the given path points to a readable file.
3838
*
39-
* @throws RuntimeException if the file does not exist or is not readable.
39+
* @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not
40+
* readable.
4041
*/
4142
public static void validateFilePath(String path, String description) {
43+
if (path == null) {
44+
throw new IllegalArgumentException("Cannot load " + description + ": path is null");
45+
}
4246
File file = new File(path);
43-
if (!file.canRead() || !file.isFile()) {
44-
throw new RuntimeException("Cannot load " + description + " " + path);
47+
if (!file.exists()) {
48+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
49+
}
50+
if (!file.isFile()) {
51+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
52+
}
53+
if (!file.canRead()) {
54+
throw new IllegalArgumentException("Cannot load " + description + "!! " + path);
4555
}
4656
}
4757

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package org.pytorch.executorch.extension.asr
1111
import java.io.Closeable
1212
import java.io.File
1313
import java.util.concurrent.atomic.AtomicLong
14+
import org.pytorch.executorch.ExecutorchRuntimeException
1415
import org.pytorch.executorch.annotations.Experimental
1516

1617
/**
@@ -53,7 +54,10 @@ class AsrModule(
5354

5455
val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
5556
if (handle == 0L) {
56-
throw RuntimeException("Failed to create native AsrModule")
57+
throw ExecutorchRuntimeException(
58+
ExecutorchRuntimeException.INTERNAL,
59+
"Failed to create native AsrModule",
60+
)
5761
}
5862
nativeHandle.set(handle)
5963
}
@@ -129,7 +133,7 @@ class AsrModule(
129133
* @param callback Optional callback to receive tokens as they are generated (can be null)
130134
* @return The complete transcribed text
131135
* @throws IllegalStateException if the module has been destroyed
132-
* @throws RuntimeException if transcription fails (non-zero result code)
136+
* @throws ExecutorchRuntimeException if transcription fails (error code carried in exception)
133137
*/
134138
@JvmOverloads
135139
fun transcribe(
@@ -160,7 +164,7 @@ class AsrModule(
160164
)
161165

162166
if (status != 0) {
163-
throw RuntimeException("Transcription failed with error code: $status")
167+
throw ExecutorchRuntimeException(status, "Transcription failed")
164168
}
165169

166170
return result.toString()

extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public static SGD create(Map<String, Tensor> namedParameters, double learningRat
9393
*/
9494
public void step(Map<String, Tensor> namedGradients) {
9595
if (!mHybridData.isValid()) {
96-
throw new RuntimeException("Attempt to use a destroyed SGD optimizer");
96+
throw new IllegalStateException("SGD optimizer has been destroyed");
9797
}
9898
stepNative(namedGradients);
9999
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88

99
package org.pytorch.executorch.training;
1010

11-
import android.util.Log;
1211
import com.facebook.jni.HybridData;
1312
import com.facebook.jni.annotations.DoNotStrip;
1413
import com.facebook.soloader.nativeloader.NativeLoader;
1514
import com.facebook.soloader.nativeloader.SystemDelegate;
16-
import java.util.HashMap;
15+
import java.io.Closeable;
1716
import java.util.Map;
1817
import org.pytorch.executorch.EValue;
1918
import org.pytorch.executorch.ExecuTorchRuntime;
@@ -26,7 +25,7 @@
2625
* <p>Warning: These APIs are experimental and subject to change without notice
2726
*/
2827
@Experimental
29-
public class TrainingModule {
28+
public class TrainingModule implements Closeable {
3029

3130
static {
3231
if (!NativeLoader.isInitialized()) {
@@ -37,6 +36,7 @@ public class TrainingModule {
3736
}
3837

3938
private final HybridData mHybridData;
39+
private boolean mDestroyed = false;
4040

4141
@DoNotStrip
4242
private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath);
@@ -45,6 +45,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) {
4545
mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath);
4646
}
4747

48+
private void checkNotDestroyed() {
49+
if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed");
50+
}
51+
4852
/**
4953
* Loads a serialized ExecuTorch Training Module from the specified path on the disk.
5054
*
@@ -78,35 +82,33 @@ public static TrainingModule load(final String modelPath) {
7882
* @return return value(s) from the method.
7983
*/
8084
public EValue[] executeForwardBackward(String methodName, EValue... inputs) {
81-
if (!mHybridData.isValid()) {
82-
Log.e("ExecuTorch", "Attempt to use a destroyed module");
83-
return new EValue[0];
84-
}
85+
checkNotDestroyed();
8586
return executeForwardBackwardNative(methodName, inputs);
8687
}
8788

8889
@DoNotStrip
8990
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
9091

9192
public Map<String, Tensor> namedParameters(String methodName) {
92-
if (!mHybridData.isValid()) {
93-
Log.e("ExecuTorch", "Attempt to use a destroyed module");
94-
return new HashMap<String, Tensor>();
95-
}
93+
checkNotDestroyed();
9694
return namedParametersNative(methodName);
9795
}
9896

9997
@DoNotStrip
10098
private native Map<String, Tensor> namedParametersNative(String methodName);
10199

102100
public Map<String, Tensor> namedGradients(String methodName) {
103-
if (!mHybridData.isValid()) {
104-
Log.e("ExecuTorch", "Attempt to use a destroyed module");
105-
return new HashMap<String, Tensor>();
106-
}
101+
checkNotDestroyed();
107102
return namedGradientsNative(methodName);
108103
}
109104

110105
@DoNotStrip
111106
private native Map<String, Tensor> namedGradientsNative(String methodName);
107+
108+
@Override
109+
public void close() {
110+
if (mDestroyed) return;
111+
mDestroyed = true;
112+
mHybridData.resetNative();
113+
}
112114
}

extension/android/jni/jni_layer.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,18 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
284284
#else
285285
auto etdump_gen = nullptr;
286286
#endif
287-
module_ = std::make_unique<Module>(
288-
modelPath->toStdString(), load_mode, std::move(etdump_gen));
287+
try {
288+
module_ = std::make_unique<Module>(
289+
modelPath->toStdString(), load_mode, std::move(etdump_gen));
290+
} catch (const std::exception& e) {
291+
executorch::jni_helper::throwExecutorchException(
292+
static_cast<uint32_t>(Error::Internal),
293+
std::string("Failed to create Module: ") + e.what());
294+
} catch (...) {
295+
executorch::jni_helper::throwExecutorchException(
296+
static_cast<uint32_t>(Error::Internal),
297+
"Failed to create Module: unknown native error");
298+
}
289299

290300
#ifdef ET_USE_THREADPOOL
291301
// Default to using cores/2 threadpool threads. The long-term plan is to

extension/android/jni/jni_layer_llama.cpp

Lines changed: 102 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -149,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
149149
jint num_bos = 0,
150150
jint num_eos = 0,
151151
jint load_mode = 1) {
152-
temperature_ = temperature;
153-
num_bos_ = num_bos;
154-
num_eos_ = num_eos;
152+
try {
153+
temperature_ = temperature;
154+
num_bos_ = num_bos;
155+
num_eos_ = num_eos;
155156
#if defined(ET_USE_THREADPOOL)
156-
// Reserve 1 thread for the main thread.
157-
int32_t num_performant_cores =
158-
::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
159-
if (num_performant_cores > 0) {
160-
ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores);
161-
::executorch::extension::threadpool::get_threadpool()
162-
->_unsafe_reset_threadpool(num_performant_cores);
163-
}
157+
// Reserve 1 thread for the main thread.
158+
int32_t num_performant_cores =
159+
::executorch::extension::cpuinfo::get_num_performant_cores() - 1;
160+
if (num_performant_cores > 0) {
161+
ET_LOG(
162+
Info, "Resetting threadpool to %d threads", num_performant_cores);
163+
::executorch::extension::threadpool::get_threadpool()
164+
->_unsafe_reset_threadpool(num_performant_cores);
165+
}
164166
#endif
165167

166-
model_type_category_ = model_type_category;
167-
auto cpp_load_mode = load_mode_from_int(load_mode);
168-
std::vector<std::string> data_files_vector;
169-
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
170-
runner_ = llm::create_multimodal_runner(
171-
model_path->toStdString().c_str(),
172-
llm::load_tokenizer(tokenizer_path->toStdString()),
173-
std::nullopt,
174-
cpp_load_mode);
175-
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
176-
if (data_files != nullptr) {
177-
// Convert Java List<String> to C++ std::vector<string>
178-
auto list_class = facebook::jni::findClassStatic("java/util/List");
179-
auto size_method = list_class->getMethod<jint()>("size");
180-
auto get_method =
181-
list_class->getMethod<facebook::jni::local_ref<jobject>(jint)>(
182-
"get");
183-
184-
jint size = size_method(data_files);
185-
for (jint i = 0; i < size; ++i) {
186-
auto str_obj = get_method(data_files, i);
187-
auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
188-
data_files_vector.push_back(jstr->toStdString());
168+
model_type_category_ = model_type_category;
169+
auto cpp_load_mode = load_mode_from_int(load_mode);
170+
std::vector<std::string> data_files_vector;
171+
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
172+
runner_ = llm::create_multimodal_runner(
173+
model_path->toStdString().c_str(),
174+
llm::load_tokenizer(tokenizer_path->toStdString()),
175+
std::nullopt,
176+
cpp_load_mode);
177+
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
178+
if (data_files != nullptr) {
179+
// Convert Java List<String> to C++ std::vector<string>
180+
auto list_class = facebook::jni::findClassStatic("java/util/List");
181+
auto size_method = list_class->getMethod<jint()>("size");
182+
auto get_method =
183+
list_class->getMethod<facebook::jni::local_ref<jobject>(jint)>(
184+
"get");
185+
186+
jint size = size_method(data_files);
187+
for (jint i = 0; i < size; ++i) {
188+
auto str_obj = get_method(data_files, i);
189+
auto jstr = facebook::jni::static_ref_cast<jstring>(str_obj);
190+
data_files_vector.push_back(jstr->toStdString());
191+
}
189192
}
190-
}
191-
runner_ = executorch::extension::llm::create_text_llm_runner(
192-
model_path->toStdString(),
193-
llm::load_tokenizer(tokenizer_path->toStdString()),
194-
data_files_vector,
195-
/*temperature=*/-1.0f,
196-
/*event_tracer=*/nullptr,
197-
/*method_name=*/"forward",
198-
cpp_load_mode);
193+
runner_ = executorch::extension::llm::create_text_llm_runner(
194+
model_path->toStdString(),
195+
llm::load_tokenizer(tokenizer_path->toStdString()),
196+
data_files_vector,
197+
/*temperature=*/-1.0f,
198+
/*event_tracer=*/nullptr,
199+
/*method_name=*/"forward",
200+
cpp_load_mode);
199201
#if defined(EXECUTORCH_BUILD_QNN)
200-
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
201-
std::unique_ptr<executorch::extension::Module> module =
202-
std::make_unique<executorch::extension::Module>(
203-
model_path->toStdString().c_str(),
204-
data_files_vector,
205-
cpp_load_mode);
206-
std::string decoder_model = "llama3"; // use llama3 for now
207-
// Using 8bit as default since this meta is introduced with 16bit kv io
208-
// support and older models only have 8bit kv io.
209-
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
210-
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
211-
kv_bitwidth = static_cast<example::KvBitWidth>(
212-
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
213-
}
202+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
203+
std::unique_ptr<executorch::extension::Module> module =
204+
std::make_unique<executorch::extension::Module>(
205+
model_path->toStdString().c_str(),
206+
data_files_vector,
207+
cpp_load_mode);
208+
std::string decoder_model = "llama3"; // use llama3 for now
209+
// Using 8bit as default since this meta is introduced with 16bit kv io
210+
// support and older models only have 8bit kv io.
211+
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
212+
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
213+
kv_bitwidth = static_cast<example::KvBitWidth>(
214+
module->get("get_kv_io_bit_width")
215+
.get()
216+
.toScalar()
217+
.to<int64_t>());
218+
}
214219

215-
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
216-
runner_ = std::make_unique<example::Runner<uint8_t>>(
217-
std::move(module),
218-
decoder_model.c_str(),
219-
model_path->toStdString().c_str(),
220-
tokenizer_path->toStdString().c_str(),
221-
"",
222-
"",
223-
temperature_);
224-
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
225-
runner_ = std::make_unique<example::Runner<uint16_t>>(
226-
std::move(module),
227-
decoder_model.c_str(),
228-
model_path->toStdString().c_str(),
229-
tokenizer_path->toStdString().c_str(),
230-
"",
231-
"",
232-
temperature_);
233-
} else {
234-
ET_CHECK_MSG(
235-
false,
236-
"Unsupported kv bitwidth: %ld",
237-
static_cast<int64_t>(kv_bitwidth));
238-
}
239-
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
220+
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
221+
runner_ = std::make_unique<example::Runner<uint8_t>>(
222+
std::move(module),
223+
decoder_model.c_str(),
224+
model_path->toStdString().c_str(),
225+
tokenizer_path->toStdString().c_str(),
226+
"",
227+
"",
228+
temperature_);
229+
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
230+
runner_ = std::make_unique<example::Runner<uint16_t>>(
231+
std::move(module),
232+
decoder_model.c_str(),
233+
model_path->toStdString().c_str(),
234+
tokenizer_path->toStdString().c_str(),
235+
"",
236+
"",
237+
temperature_);
238+
} else {
239+
ET_CHECK_MSG(
240+
false,
241+
"Unsupported kv bitwidth: %ld",
242+
static_cast<int64_t>(kv_bitwidth));
243+
}
244+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
240245
#endif
241246
#if defined(EXECUTORCH_BUILD_MEDIATEK)
242-
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
243-
runner_ = std::make_unique<MTKLlamaRunner>(
244-
model_path->toStdString().c_str(),
245-
tokenizer_path->toStdString().c_str());
246-
// Interpret the model type as LLM
247-
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
247+
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
248+
runner_ = std::make_unique<MTKLlamaRunner>(
249+
model_path->toStdString().c_str(),
250+
tokenizer_path->toStdString().c_str());
251+
// Interpret the model type as LLM
252+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
248253
#endif
254+
}
255+
} catch (const std::exception& e) {
256+
executorch::jni_helper::throwExecutorchException(
257+
static_cast<uint32_t>(Error::Internal),
258+
std::string("Failed to create LlmModule: ") + e.what());
259+
} catch (...) {
260+
executorch::jni_helper::throwExecutorchException(
261+
static_cast<uint32_t>(Error::Internal),
262+
"Failed to create LlmModule: unknown native error");
249263
}
250264
}
251265

0 commit comments

Comments
 (0)