diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 4b6c3caed94..0974a04af44 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -54,9 +54,7 @@ class LlmModuleInstrumentationTest : LlmCallback { @Test @Throws(IOException::class, URISyntaxException::class) fun testGenerate() { - val loadResult = llmModule.load() - // Check that the model can be load successfully - assertEquals(OK.toLong(), loadResult.toLong()) + llmModule.load() llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) assertEquals(results.size.toLong(), SEQ_LEN.toLong()) @@ -273,11 +271,26 @@ class LlmModuleInstrumentationTest : LlmCallback { } } + // --- Lifecycle tests --- + + @Test + fun testUseAfterCloseThrows() { + llmModule.close() + assertThrows(IllegalStateException::class.java) { + llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) + } + } + + @Test + fun testCloseIsIdempotent() { + llmModule.close() + llmModule.close() + } + companion object { private const val TEST_FILE_NAME = "/stories.pte" private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 private const val SEQ_LEN = 32 } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 99d53b6dba3..e9e9ff4637c 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -17,7 +17,6 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.pytorch.executorch.TestFileUtils.getTestFilePath @@ -40,49 +39,42 @@ class ModuleInstrumentationTest { inputStream.close() } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class, URISyntaxException::class) fun testModuleLoadAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - } - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMethodMetadata() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + module.loadMethod(FORWARD_METHOD) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(IOException::class) fun testModuleLoadForwardExplicit() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) + try { + val results = module.execute(FORWARD_METHOD, EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -95,9 +87,18 @@ class ModuleInstrumentationTest { @Throws(IOException::class) fun testModuleLoadMethodNonExistantMethod() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + try { + val exception = + Assert.assertThrows(ExecutorchRuntimeException::class.java) { + module.loadMethod(NONE_METHOD) + } + Assert.assertEquals( + ExecutorchRuntimeException.INVALID_ARGUMENT, + exception.getErrorCode(), + ) + } finally { + module.destroy() + } } @Test(expected = RuntimeException::class) @@ -105,8 +106,7 @@ class ModuleInstrumentationTest { fun testNonPteFile() { val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + module.loadMethod(FORWARD_METHOD) } @Test @@ -116,8 +116,7 @@ class ModuleInstrumentationTest { module.destroy() - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + Assert.assertThrows(IllegalStateException::class.java) { module.loadMethod(FORWARD_METHOD) } } @Test @@ -125,18 +124,13 @@ class ModuleInstrumentationTest { fun testForwardOnDestroyedModule() { val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + module.loadMethod(FORWARD_METHOD) module.destroy() - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) + Assert.assertThrows(IllegalStateException::class.java) { module.forward() } } - @Ignore( - "The forward has failure that needs to be fixed before enabling this test: [Executorch Error 0x12] Invalid argument: Execution failed for method: forward " - ) @Test @Throws(InterruptedException::class, IOException::class) fun testForwardFromMultipleThreads() { @@ -150,7 +144,7 @@ class ModuleInstrumentationTest { try { latch.countDown() latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() + val results = module.forward(EValue.from(dummyInput())) Assert.assertTrue(results[0].isTensor) completed.incrementAndGet() } catch (_: InterruptedException) {} @@ -167,6 +161,139 @@ class ModuleInstrumentationTest { } Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + module.destroy() + } + + // --- Load mode tests --- + + @Test + @Throws(IOException::class) + fun testLoadWithMmapMode() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_MMAP) + try { + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } + } + + @Test + @Throws(IOException::class) + fun testLoadWithFileMode() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME), Module.LOAD_MODE_FILE) + try { + val results = module.forward(EValue.from(dummyInput())) + Assert.assertTrue(results[0].isTensor) + } finally { + module.destroy() + } + } + + // --- getMethods / getMethodMetadata tests --- + + @Test + @Throws(IOException::class) + fun testGetMethods() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + val methods = module.getMethods() + Assert.assertNotNull(methods) + Assert.assertTrue(methods.contains(FORWARD_METHOD)) + } finally { + module.destroy() + } + } + + @Test + @Throws(IOException::class) + fun testGetMethodMetadata() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + val metadata = module.getMethodMetadata(FORWARD_METHOD) + Assert.assertNotNull(metadata) + Assert.assertEquals(FORWARD_METHOD, metadata.name) + Assert.assertNotNull(metadata.backends) + } finally { + module.destroy() + } + } + + // --- Log buffer tests --- + + @Test + @Throws(IOException::class) + fun testReadLogBuffer() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + val logs = module.readLogBuffer() + Assert.assertNotNull(logs) + } finally { + module.destroy() + } + } + + @Test + fun testReadLogBufferStatic() { + val logs = Module.readLogBufferStatic() + Assert.assertNotNull(logs) + } + + // --- etdump test --- + + @Test + @Throws(IOException::class) + fun testEtdump() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + try { + module.etdump() + } finally { + module.destroy() + } + } + + // --- Destroyed-state tests for remaining methods --- + + @Test + @Throws(IOException::class) + fun testGetMethodsOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() + Assert.assertThrows(IllegalStateException::class.java) { module.getMethods() } + } + + @Test + @Throws(IOException::class) + fun testGetMethodMetadataOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() + Assert.assertThrows(IllegalStateException::class.java) { + module.getMethodMetadata(FORWARD_METHOD) + } + } + + @Test + @Throws(IOException::class) + fun testReadLogBufferOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() + Assert.assertThrows(IllegalStateException::class.java) { module.readLogBuffer() } + } + + @Test + @Throws(IOException::class) + fun testEtdumpOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() + Assert.assertThrows(IllegalStateException::class.java) { module.etdump() } + } + + @Test + @Throws(IOException::class) + fun testDoubleDestroyIsSafe() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + module.destroy() + module.destroy() } companion object { @@ -175,9 +302,8 @@ class ModuleInstrumentationTest { private const val NON_PTE_FILE_NAME = "/test.txt" private const val FORWARD_METHOD = "forward" private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 + private val inputShape = longArrayOf(1, 3, 224, 224) + + private fun dummyInput(): Tensor = Tensor.ones(inputShape, DType.FLOAT) } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 30ebf1a2c1d..6372da9a397 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -36,12 +36,25 @@ public static ExecuTorchRuntime getRuntime() { /** * Validates that the given path points to a readable file. * - * @throws RuntimeException if the file does not exist or is not readable. + * @throws IllegalArgumentException if the path is null, does not exist, is not a file, or is not + * readable. */ public static void validateFilePath(String path, String description) { + if (path == null) { + throw new IllegalArgumentException("Cannot load " + description + ": path is null"); + } File file = new File(path); - if (!file.canRead() || !file.isFile()) { - throw new RuntimeException("Cannot load " + description + " " + path); + if (!file.exists()) { + throw new IllegalArgumentException( + "Cannot load " + description + ": path does not exist: " + path); + } + if (!file.isFile()) { + throw new IllegalArgumentException( + "Cannot load " + description + ": path is not a file: " + path); + } + if (!file.canRead()) { + throw new IllegalArgumentException( + "Cannot load " + description + ": path is not readable: " + path); } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java index d69763c8fd2..6f9d654be66 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java @@ -13,34 +13,83 @@ import java.util.HashMap; import java.util.Map; +/** + * Base exception for all ExecuTorch runtime errors. Each instance carries an integer error code + * corresponding to the native {@code runtime/core/error.h} values, accessible via {@link + * #getErrorCode()}. + */ public class ExecutorchRuntimeException extends RuntimeException { // Error code constants - keep in sync with runtime/core/error.h + // System errors + + /** Operation completed successfully. */ public static final int OK = 0x00; + + /** An unexpected internal error occurred in the runtime. */ public static final int INTERNAL = 0x01; + + /** The runtime or method is in an invalid state for the requested operation. */ public static final int INVALID_STATE = 0x02; + + /** The method has finished execution and has no more work to do. */ public static final int END_OF_METHOD = 0x03; + /** A required resource has already been loaded. */ + public static final int ALREADY_LOADED = 0x04; + // Logical errors + + /** The requested operation is not supported by this build or backend. */ public static final int NOT_SUPPORTED = 0x10; + + /** The requested operation has not been implemented. */ public static final int NOT_IMPLEMENTED = 0x11; + + /** One or more arguments passed to the operation are invalid. */ public static final int INVALID_ARGUMENT = 0x12; + + /** A value or tensor has an unexpected type. */ public static final int INVALID_TYPE = 0x13; + + /** A required operator kernel is not registered. */ public static final int OPERATOR_MISSING = 0x14; + + /** The maximum number of registered kernels has been exceeded. */ public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15; + + /** A kernel with the same name is already registered. */ public static final int REGISTRATION_ALREADY_REGISTERED = 0x16; // Resource errors + + /** A required resource (file, tensor, program) was not found. */ public static final int NOT_FOUND = 0x20; + + /** A memory allocation failed. */ public static final int MEMORY_ALLOCATION_FAILED = 0x21; + + /** Access to a resource was denied or failed. */ public static final int ACCESS_FAILED = 0x22; + + /** The loaded program is malformed or incompatible. */ public static final int INVALID_PROGRAM = 0x23; + + /** External data referenced by the program is invalid or missing. */ public static final int INVALID_EXTERNAL_DATA = 0x24; + + /** The system has run out of a required resource. */ public static final int OUT_OF_RESOURCES = 0x25; // Delegate errors + + /** A delegate reported an incompatible model or configuration. */ public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30; + + /** A delegate failed to allocate required memory. */ public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31; + + /** A delegate received an invalid or stale handle. */ public static final int DELEGATE_INVALID_HANDLE = 0x32; private static final Map ERROR_CODE_MESSAGES; @@ -53,6 +102,7 @@ public class ExecutorchRuntimeException extends RuntimeException { map.put(INTERNAL, "Internal error"); map.put(INVALID_STATE, "Invalid state"); map.put(END_OF_METHOD, "End of method reached"); + map.put(ALREADY_LOADED, "Already loaded"); // Logical errors map.put(NOT_SUPPORTED, "Operation not supported"); map.put(NOT_IMPLEMENTED, "Operation not implemented"); @@ -84,7 +134,7 @@ static String formatMessage(int errorCode, String details) { String safeDetails = details != null ? details : "No details provided"; return String.format( - "[Executorch Error 0x%s] %s: %s", + "[ExecuTorch Error 0x%s] %s: %s", Integer.toHexString(errorCode), baseMessage, safeDetails); } @@ -113,10 +163,17 @@ public ExecutorchRuntimeException(int errorCode, String details) { this.errorCode = errorCode; } + public ExecutorchRuntimeException(int errorCode, String details, Throwable cause) { + super(ErrorHelper.formatMessage(errorCode, details), cause); + this.errorCode = errorCode; + } + + /** Returns the numeric error code from {@code runtime/core/error.h}. */ public int getErrorCode() { return errorCode; } + /** Returns detailed log output captured from the native runtime, if available. */ public String getDetailedError() { return ErrorHelper.getDetailedErrorLogs(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index f7e2e37dcec..6cf99966e6a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -8,11 +8,11 @@ package org.pytorch.executorch; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.io.Closeable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.locks.Lock; @@ -25,7 +25,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class Module { +public class Module implements Closeable { static { if (!NativeLoader.isInitialized()) { @@ -130,11 +130,10 @@ public EValue[] forward(EValue... inputs) { * @return return value from the method. */ public EValue[] execute(String methodName, EValue... inputs) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; + throw new IllegalStateException("Module has been destroyed"); } return executeNative(methodName, inputs); } finally { @@ -151,17 +150,17 @@ public EValue[] execute(String methodName, EValue... inputs) { * synchronous, and will block until the method is loaded. Therefore, it is recommended to call * this on a background thread. However, users need to make sure that they don't execute before * this function returns. - * - * @return the Error code if there was an error loading the method */ - public int loadMethod(String methodName) { + public void loadMethod(String methodName) { + mLock.lock(); try { - mLock.lock(); if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return 0x2; // InvalidState + throw new IllegalStateException("Module has been destroyed"); + } + int errorCode = loadMethodNative(methodName); + if (errorCode != 0) { + throw new ExecutorchRuntimeException(errorCode, "Failed to load method: " + methodName); } - return loadMethodNative(methodName); } finally { mLock.unlock(); } @@ -184,8 +183,20 @@ public int loadMethod(String methodName) { * * @return name of methods in this Module */ + public String[] getMethods() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return getMethodsNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native String[] getMethods(); + private native String[] getMethodsNative(); /** * Get the corresponding @MethodMetadata for a method @@ -194,11 +205,19 @@ public int loadMethod(String methodName) { * @return @MethodMetadata for this method */ public MethodMetadata getMethodMetadata(String name) { - MethodMetadata methodMetadata = mMethodMetadata.get(name); - if (methodMetadata == null) { - throw new IllegalArgumentException("method " + name + " does not exist for this module"); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + MethodMetadata methodMetadata = mMethodMetadata.get(name); + if (methodMetadata == null) { + throw new IllegalArgumentException("method " + name + " does not exist for this module"); + } + return methodMetadata; + } finally { + mLock.unlock(); } - return methodMetadata; } @DoNotStrip @@ -210,7 +229,15 @@ public static String[] readLogBufferStatic() { /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - return readLogBufferNative(); + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return readLogBufferNative(); + } finally { + mLock.unlock(); + } } @DoNotStrip @@ -224,8 +251,20 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental + public boolean etdump() { + mLock.lock(); + try { + if (!mHybridData.isValid()) { + throw new IllegalStateException("Module has been destroyed"); + } + return etdumpNative(); + } finally { + mLock.unlock(); + } + } + @DoNotStrip - public native boolean etdump(); + private native boolean etdumpNative(); /** * Explicitly destroys the native Module object. Calling this method is not required, as the @@ -236,15 +275,19 @@ public String[] readLogBuffer() { public void destroy() { if (mLock.tryLock()) { try { - mHybridData.resetNative(); + if (mHybridData.isValid()) { + mHybridData.resetNative(); + } } finally { mLock.unlock(); } } else { - Log.w( - "ExecuTorch", - "Destroy was called while the module was in use. Resources will not be immediately" - + " released."); + throw new IllegalStateException("Cannot destroy module while method is executing"); } } + + @Override + public void close() { + destroy(); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt index 987cb3ec3be..ab9099ba405 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/asr/AsrModule.kt @@ -11,6 +11,7 @@ package org.pytorch.executorch.extension.asr import java.io.Closeable import java.io.File import java.util.concurrent.atomic.AtomicLong +import org.pytorch.executorch.ExecutorchRuntimeException import org.pytorch.executorch.annotations.Experimental /** @@ -53,7 +54,10 @@ class AsrModule( val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath) if (handle == 0L) { - throw RuntimeException("Failed to create native AsrModule") + throw ExecutorchRuntimeException( + ExecutorchRuntimeException.INTERNAL, + "Failed to create native AsrModule", + ) } nativeHandle.set(handle) } @@ -129,7 +133,7 @@ class AsrModule( * @param callback Optional callback to receive tokens as they are generated (can be null) * @return The complete transcribed text * @throws IllegalStateException if the module has been destroyed - * @throws RuntimeException if transcription fails (non-zero result code) + * @throws ExecutorchRuntimeException if transcription fails (error code carried in exception) */ @JvmOverloads fun transcribe( @@ -160,7 +164,7 @@ class AsrModule( ) if (status != 0) { - throw RuntimeException("Transcription failed with error code: $status") + throw ExecutorchRuntimeException(status, "Transcription failed") } return result.toString() diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index a563dc6bcc7..ce72eb42c46 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -10,9 +10,12 @@ import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; +import java.io.Closeable; import java.nio.ByteBuffer; import java.util.List; +import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.ExecuTorchRuntime; +import org.pytorch.executorch.ExecutorchRuntimeException; import org.pytorch.executorch.annotations.Experimental; /** @@ -22,13 +25,15 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class LlmModule { +public class LlmModule implements Closeable { public static final int MODEL_TYPE_TEXT = 1; public static final int MODEL_TYPE_TEXT_VISION = 2; public static final int MODEL_TYPE_MULTIMODAL = 2; private final HybridData mHybridData; + private final ReentrantLock mLock = new ReentrantLock(); + private boolean mDestroyed = false; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; private static final float DEFAULT_TEMPERATURE = -1.0f; @@ -185,8 +190,41 @@ public LlmModule(LlmModuleConfig config) { config.getLoadMode()); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("LlmModule has been destroyed"); + } + + /** + * Releases native resources. Callers must ensure no other methods are in-flight. Call {@link + * #stop()} and wait for {@link #generate(String, LlmCallback)} to return before calling this + * method. + */ + @Override + public void close() { + if (mLock.tryLock()) { + try { + if (mLock.getHoldCount() > 1) { + throw new IllegalStateException( + "Cannot close module from within a callback during execution"); + } + if (!mDestroyed) { + mDestroyed = true; + mHybridData.resetNative(); + } + } finally { + mLock.unlock(); + } + } else { + throw new IllegalStateException("Cannot close module while method is executing"); + } + } + + /** + * @deprecated Use {@link #close()} instead. + */ + @Deprecated public void resetNative() { - mHybridData.resetNative(); + close(); } /** @@ -195,8 +233,8 @@ public void resetNative() { * @param prompt Input prompt * @param llmCallback callback object to receive results. */ - public int generate(String prompt, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, LlmCallback llmCallback) { + generate( prompt, DEFAULT_SEQ_LEN, llmCallback, @@ -213,8 +251,8 @@ public int generate(String prompt, LlmCallback llmCallback) { * @param seqLen sequence length * @param llmCallback callback object to receive results. */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback) { - return generate( + public void generate(String prompt, int seqLen, LlmCallback llmCallback) { + generate( null, 0, 0, @@ -235,8 +273,8 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, LlmCallback llmCallback, boolean echo) { - return generate( + public void generate(String prompt, LlmCallback llmCallback, boolean echo) { + generate( null, 0, 0, @@ -258,9 +296,8 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( - prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); + public void generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); } /** @@ -274,7 +311,32 @@ public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public native int generate( + public void generate( + String prompt, + int seqLen, + LlmCallback llmCallback, + boolean echo, + float temperature, + int numBos, + int numEos) { + mLock.lock(); + try { + if (mLock.getHoldCount() > 1) { + throw new IllegalStateException( + "Cannot call generate() re-entrantly from a callback during execution"); + } + checkNotDestroyed(); + int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); + } + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native int generateNative( String prompt, int seqLen, LlmCallback llmCallback, @@ -290,13 +352,13 @@ public native int generate( * @param config the config for generation * @param llmCallback callback object to receive results */ - public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { + public void generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { int seqLen = config.getSeqLen(); boolean echo = config.isEcho(); float temperature = config.getTemperature(); int numBos = config.getNumBos(); int numEos = config.getNumEos(); - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -311,7 +373,7 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -320,7 +382,7 @@ public int generate( int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( + generate( image, width, height, @@ -347,7 +409,7 @@ public int generate( * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param temperature temperature for sampling (use negative value to use module default) */ - public int generate( + public void generate( int[] image, int width, int height, @@ -357,7 +419,7 @@ public int generate( LlmCallback llmCallback, boolean echo, float temperature) { - return generate( + generate( image, width, height, @@ -386,7 +448,7 @@ public int generate( * @param numBos number of BOS tokens to prepend * @param numEos number of EOS tokens to append */ - public int generate( + public void generate( int[] image, int width, int height, @@ -398,10 +460,26 @@ public int generate( float temperature, int numBos, int numEos) { - if (image != null) { - prefillImages(image, width, height, channels); + mLock.lock(); + try { + if (mLock.getHoldCount() > 1) { + throw new IllegalStateException( + "Cannot call generate() re-entrantly from a callback during execution"); + } + checkNotDestroyed(); + if (image != null) { + int nativeResult = prefillImagesInput(image, width, height, channels); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } + int err = generateNative(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to generate"); + } + } finally { + mLock.unlock(); } - return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); } /** @@ -411,16 +489,20 @@ public int generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(int[] image, int width, int height, int channels) { - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillImages(int[] image, int width, int height, int channels) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillImagesInput(image, width, height, channels); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } /** @@ -436,38 +518,44 @@ public long prefillImages(int[] image, int width, int height, int channels) { * @param channels Input image number of channels * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining * bytes - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental public void prefillImages(ByteBuffer image, int width, int height, int channels) { - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - long expectedBytes; + mLock.lock(); try { - long pixels = Math.multiplyExact((long) width, (long) height); - expectedBytes = Math.multiplyExact(pixels, (long) channels); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - "width*height*channels is too large and overflows the allowed range.", ex); - } - if (width <= 0 - || height <= 0 - || channels <= 0 - || expectedBytes > Integer.MAX_VALUE - || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels (" - + expectedBytes - + ")."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + checkNotDestroyed(); + if (!image.isDirect()) { + throw new IllegalArgumentException("Input ByteBuffer must be direct."); + } + long expectedBytes; + try { + long pixels = Math.multiplyExact((long) width, (long) height); + expectedBytes = Math.multiplyExact(pixels, (long) channels); + } catch (ArithmeticException ex) { + throw new IllegalArgumentException( + "width*height*channels is too large and overflows the allowed range.", ex); + } + if (width <= 0 + || height <= 0 + || channels <= 0 + || expectedBytes > Integer.MAX_VALUE + || image.remaining() < expectedBytes) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + + image.remaining() + + ") must be at least width*height*channels (" + + expectedBytes + + ")."); + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } } @@ -487,53 +575,61 @@ public void prefillImages(ByteBuffer image, int width, int height, int channels) * @param channels Input image number of channels * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining * bytes, is not float-aligned, or does not use native byte order - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - if (image.order() != java.nio.ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); - } - if (image.position() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); - } - final long expectedBytes; + mLock.lock(); try { - int wh = Math.multiplyExact(width, height); - long whc = Math.multiplyExact((long) wh, (long) channels); - long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); - if (totalBytes > Integer.MAX_VALUE) { + checkNotDestroyed(); + if (!image.isDirect()) { + throw new IllegalArgumentException("Input ByteBuffer must be direct."); + } + if (image.order() != java.nio.ByteOrder.nativeOrder()) { throw new IllegalArgumentException( - "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " - + totalBytes); + "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); } - expectedBytes = totalBytes; - } catch (ArithmeticException e) { - throw new IllegalArgumentException( - "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); - } - if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels*4 (" - + expectedBytes - + ")."); - } - if (image.remaining() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" + image.remaining() + ") must be a multiple of 4 (float size)."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + if (image.position() % Float.BYTES != 0) { + throw new IllegalArgumentException( + "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); + } + final long expectedBytes; + try { + int wh = Math.multiplyExact(width, height); + long whc = Math.multiplyExact((long) wh, (long) channels); + long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); + if (totalBytes > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " + + totalBytes); + } + expectedBytes = totalBytes; + } catch (ArithmeticException e) { + throw new IllegalArgumentException( + "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); + } + if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + + image.remaining() + + ") must be at least width*height*channels*4 (" + + expectedBytes + + ")."); + } + if (image.remaining() % Float.BYTES != 0) { + throw new IllegalArgumentException( + "ByteBuffer remaining (" + + image.remaining() + + ") must be a multiple of 4 (float size)."); + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } } @@ -552,16 +648,20 @@ private native int prefillNormalizedImagesInputBuffer( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillImages(float[] image, int width, int height, int channels) { - int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillImages(float[] image, int width, int height, int channels) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } private native int prefillNormalizedImagesInput( @@ -574,16 +674,20 @@ private native int prefillNormalizedImagesInput( * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); @@ -595,16 +699,20 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) * @param batch_size Input batch size * @param n_bins Input number of bins * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } private native int prefillAudioInputFloat( @@ -617,16 +725,20 @@ private native int prefillAudioInputFloat( * @param batch_size Input batch size * @param n_channels Input number of channels * @param n_samples Input number of samples - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } private native int prefillRawAudioInput( @@ -636,16 +748,20 @@ private native int prefillRawAudioInput( * Prefill the KV cache with the given text prompt. * * @param prompt The text prompt to prefill. - * @return 0 on success - * @throws RuntimeException if the prefill failed + * @throws ExecutorchRuntimeException if the prefill failed */ @Experimental - public long prefillPrompt(String prompt) { - int nativeResult = prefillTextInput(prompt); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); + public void prefillPrompt(String prompt) { + mLock.lock(); + try { + checkNotDestroyed(); + int nativeResult = prefillTextInput(prompt); + if (nativeResult != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(nativeResult, "Prefill failed"); + } + } finally { + mLock.unlock(); } - return 0; } // returns status @@ -656,7 +772,18 @@ public long prefillPrompt(String prompt) { * *

The startPos will be reset to 0. */ - public native void resetContext(); + public void resetContext() { + mLock.lock(); + try { + checkNotDestroyed(); + resetContextNative(); + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native void resetContextNative(); /** Stop current generate() before it finishes. */ @DoNotStrip @@ -664,5 +791,19 @@ public long prefillPrompt(String prompt) { /** Force loading the module. Otherwise the model is loaded during first generate(). */ @DoNotStrip - public native int load(); + public void load() { + mLock.lock(); + try { + checkNotDestroyed(); + int err = loadNative(); + if (err != 0) { + throw ExecutorchRuntimeException.makeExecutorchException(err, "Failed to load model"); + } + } finally { + mLock.unlock(); + } + } + + @DoNotStrip + private native int loadNative(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java index 8f4292c1bc8..58c7704b83e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/SGD.java @@ -93,7 +93,7 @@ public static SGD create(Map namedParameters, double learningRat */ public void step(Map namedGradients) { if (!mHybridData.isValid()) { - throw new RuntimeException("Attempt to use a destroyed SGD optimizer"); + throw new IllegalStateException("SGD optimizer has been destroyed"); } stepNative(namedGradients); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java index 4a6653cb7a1..ca4bac9aa54 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/training/TrainingModule.java @@ -8,12 +8,11 @@ package org.pytorch.executorch.training; -import android.util.Log; import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.util.HashMap; +import java.io.Closeable; import java.util.Map; import org.pytorch.executorch.EValue; import org.pytorch.executorch.ExecuTorchRuntime; @@ -26,7 +25,7 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -public class TrainingModule { +public class TrainingModule implements Closeable { static { if (!NativeLoader.isInitialized()) { @@ -37,6 +36,7 @@ public class TrainingModule { } private final HybridData mHybridData; + private boolean mDestroyed = false; @DoNotStrip private static native HybridData initHybrid(String moduleAbsolutePath, String dataAbsolutePath); @@ -45,6 +45,10 @@ private TrainingModule(String moduleAbsolutePath, String dataAbsolutePath) { mHybridData = initHybrid(moduleAbsolutePath, dataAbsolutePath); } + private void checkNotDestroyed() { + if (mDestroyed) throw new IllegalStateException("TrainingModule has been destroyed"); + } + /** * Loads a serialized ExecuTorch Training Module from the specified path on the disk. * @@ -78,10 +82,7 @@ public static TrainingModule load(final String modelPath) { * @return return value(s) from the method. */ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new EValue[0]; - } + checkNotDestroyed(); return executeForwardBackwardNative(methodName, inputs); } @@ -89,10 +90,7 @@ public EValue[] executeForwardBackward(String methodName, EValue... inputs) { private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); public Map namedParameters(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedParametersNative(methodName); } @@ -100,13 +98,17 @@ public Map namedParameters(String methodName) { private native Map namedParametersNative(String methodName); public Map namedGradients(String methodName) { - if (!mHybridData.isValid()) { - Log.e("ExecuTorch", "Attempt to use a destroyed module"); - return new HashMap(); - } + checkNotDestroyed(); return namedGradientsNative(methodName); } @DoNotStrip private native Map namedGradientsNative(String methodName); + + @Override + public void close() { + if (mDestroyed) return; + mDestroyed = true; + mHybridData.resetNative(); + } } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index beff72119b8..0cf08e41983 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -284,8 +284,18 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); + try { + module_ = std::make_unique( + modelPath->toStdString(), load_mode, std::move(etdump_gen)); + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create Module: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create Module: unknown native error"); + } #ifdef ET_USE_THREADPOOL // Default to using cores/2 threadpool threads. The long-term plan is to @@ -385,6 +395,12 @@ class ExecuTorchJni : public facebook::jni::HybridClass { static const auto toBoolMethod = JEValue::javaClassStatic()->getMethod("toBool"); evalues.emplace_back(static_cast(toBoolMethod(jevalue))); + } else { + std::stringstream ss; + ss << "Unsupported input EValue type code: " << typeCode; + jni_helper::throwExecutorchException( + static_cast(Error::InvalidArgument), ss.str()); + return {}; } } @@ -564,8 +580,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod( "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdump", ExecuTorchJni::etdump), - makeNativeMethod("getMethods", ExecuTorchJni::getMethods), + makeNativeMethod("etdumpNative", ExecuTorchJni::etdump), + makeNativeMethod("getMethodsNative", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ed144acb14b..0c1ff5c67b9 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -148,103 +149,117 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint num_bos = 0, jint num_eos = 0, jint load_mode = 1) { - temperature_ = temperature; - num_bos_ = num_bos; - num_eos_ = num_eos; + try { + temperature_ = temperature; + num_bos_ = num_bos; + num_eos_ = num_eos; #if defined(ET_USE_THREADPOOL) - // Reserve 1 thread for the main thread. - int32_t num_performant_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; - if (num_performant_cores > 0) { - ET_LOG(Info, "Resetting threadpool to %d threads", num_performant_cores); - ::executorch::extension::threadpool::get_threadpool() - ->_unsafe_reset_threadpool(num_performant_cores); - } + // Reserve 1 thread for the main thread. + int32_t num_performant_cores = + ::executorch::extension::cpuinfo::get_num_performant_cores() - 1; + if (num_performant_cores > 0) { + ET_LOG( + Info, "Resetting threadpool to %d threads", num_performant_cores); + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } #endif - model_type_category_ = model_type_category; - auto cpp_load_mode = load_mode_from_int(load_mode); - std::vector data_files_vector; - if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString()), - std::nullopt, - cpp_load_mode); - } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); + model_type_category_ = model_type_category; + auto cpp_load_mode = load_mode_from_int(load_mode); + std::vector data_files_vector; + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { + runner_ = llm::create_multimodal_runner( + model_path->toStdString().c_str(), + llm::load_tokenizer(tokenizer_path->toStdString()), + std::nullopt, + cpp_load_mode); + } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { + if (data_files != nullptr) { + // Convert Java List to C++ std::vector + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_vector.push_back(jstr->toStdString()); + } } - } - runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector, - /*temperature=*/-1.0f, - /*event_tracer=*/nullptr, - /*method_name=*/"forward", - cpp_load_mode); + runner_ = executorch::extension::llm::create_text_llm_runner( + model_path->toStdString(), + llm::load_tokenizer(tokenizer_path->toStdString()), + data_files_vector, + /*temperature=*/-1.0f, + /*event_tracer=*/nullptr, + /*method_name=*/"forward", + cpp_load_mode); #if defined(EXECUTORCH_BUILD_QNN) - } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { - std::unique_ptr module = - std::make_unique( - model_path->toStdString().c_str(), - data_files_vector, - cpp_load_mode); - std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = + std::make_unique( + model_path->toStdString().c_str(), + data_files_vector, + cpp_load_mode); + std::string decoder_model = "llama3"; // use llama3 for now + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width") + .get() + .toScalar() + .to()); + } - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) - } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - // Interpret the model type as LLM - model_type_category_ = MODEL_TYPE_CATEGORY_LLM; + } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { + runner_ = std::make_unique( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str()); + // Interpret the model type as LLM + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif + } + } catch (const std::exception& e) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + std::string("Failed to create LlmModule: ") + e.what()); + } catch (...) { + executorch::jni_helper::throwExecutorchException( + static_cast(Error::Internal), + "Failed to create LlmModule: unknown native error"); } } @@ -595,29 +610,28 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint load() { if (!runner_) { std::stringstream ss; - ss << "Invalid model type category: " << model_type_category_ - << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " - << MODEL_TYPE_CATEGORY_MULTIMODAL; + ss << "Model runner was not created. model_type_category=" + << model_type_category_ + << ". Valid values: " << MODEL_TYPE_CATEGORY_LLM << " (LLM), " + << MODEL_TYPE_CATEGORY_MULTIMODAL << " (Multimodal)"; executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return -1; + static_cast(Error::InvalidState), ss.str().c_str()); + return static_cast(Error::InvalidState); } - int result = static_cast(runner_->load()); - if (result != 0) { - std::stringstream ss; - ss << "Failed to load runner: [" << result << "]"; + const auto load_result = static_cast(runner_->load()); + if (load_result != static_cast(Error::Ok)) { executorch::jni_helper::throwExecutorchException( - static_cast(result), ss.str().c_str()); + static_cast(load_result), "Failed to load model runner"); } - return result; + return load_result; } static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlmJni::generate), + makeNativeMethod("generateNative", ExecuTorchLlmJni::generate), makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("load", ExecuTorchLlmJni::load), + makeNativeMethod("loadNative", ExecuTorchLlmJni::load), makeNativeMethod( "prefillImagesInput", ExecuTorchLlmJni::prefill_images_input), makeNativeMethod( @@ -638,7 +652,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillRawAudioInput", ExecuTorchLlmJni::prefill_raw_audio_input), makeNativeMethod( "prefillTextInput", ExecuTorchLlmJni::prefill_text_input), - makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), + makeNativeMethod("resetContextNative", ExecuTorchLlmJni::reset_context), }); } }; diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java index a1b434a37bf..79927454cb3 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -87,7 +87,15 @@ public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { @Override public void handleMessage(android.os.Message msg) { if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mLlmModelRunner.mModule.load(); + int status = 0; + try { + mLlmModelRunner.mModule.load(); + } catch (Exception e) { + status = + (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) + ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() + : -1; + } mLlmModelRunner.mCallback.onModelLoaded(status); } else if (msg.what == MESSAGE_GENERATE) { mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 28f4e3728f0..915496a25af 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -27,53 +27,73 @@ public void runBenchmark( long loadStart = System.nanoTime(); Module module = Module.load(model.getPath()); - int errorCode = module.loadMethod("forward"); + int errorCode = 0; + try { + module.loadMethod("forward"); + } catch (Exception e) { + errorCode = + (e instanceof org.pytorch.executorch.ExecutorchRuntimeException) + ? ((org.pytorch.executorch.ExecutorchRuntimeException) e).getErrorCode() + : -1; + } long loadEnd = System.nanoTime(); - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); + if (errorCode != 0) { + results.add( + new BenchmarkMetric( + benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + module.destroy(); + return; } - module.etdump(); + try { + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + latency.add(forwardMs); + } + + module.etdump(); - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(latency); + int resultSize = latency.size(); + List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); + + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + } finally { + module.destroy(); + } } }