diff --git a/extension/android/BUCK b/extension/android/BUCK
index c7e275805e2..a0f3a6411d1 100644
--- a/extension/android/BUCK
+++ b/extension/android/BUCK
@@ -47,13 +47,13 @@ non_fbcode_target(_kind = fb_android_library,
name = "executorch_llama",
warnings_as_errors = False,
srcs = [
- "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java",
- "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java",
- "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java",
- "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java",
+ "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt",
+ "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt",
+ "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt",
+ "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt",
],
autoglob = False,
- language = "JAVA",
+ language = "KOTLIN",
deps = [
":executorch",
"//fbandroid/java/com/facebook/jni:jni",
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt
similarity index 53%
rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java
rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt
index 4e834d06721..3b56986bf14 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt
@@ -6,45 +6,42 @@
* LICENSE file in the root directory of this source tree.
*/
-package org.pytorch.executorch.extension.llm;
+package org.pytorch.executorch.extension.llm
-import com.facebook.jni.annotations.DoNotStrip;
-import org.pytorch.executorch.annotations.Experimental;
+import com.facebook.jni.annotations.DoNotStrip
+import org.pytorch.executorch.annotations.Experimental
/**
- * Callback interface for Llama model. Users can implement this interface to receive the generated
+ * Callback interface for Llm model. Users can implement this interface to receive the generated
* tokens and statistics.
*
- *
Warning: These APIs are experimental and subject to change without notice
+ * Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
-public interface LlmCallback {
+interface LlmCallback {
/**
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
* until generate() finishes.
*
* @param result Last generated token
*/
- @DoNotStrip
- public void onResult(String result);
+ @DoNotStrip fun onResult(result: String)
/**
* Called when the statistics for the generate() is available.
*
- *
The result will be a JSON string. See extension/llm/stats.h for the field definitions.
+ * The result will be a JSON string. See extension/llm/stats.h for the field definitions.
*
* @param stats JSON string containing the statistics for the generate()
*/
- @DoNotStrip
- default void onStats(String stats) {}
+ @DoNotStrip fun onStats(stats: String) {}
/**
* Called when an error occurs during generate().
*
- * @param errorCode Error code from the ExecuTorch runtime (see {@link
- * org.pytorch.executorch.ExecutorchRuntimeException})
+ * @param errorCode Error code from the ExecuTorch runtime (see
+ * [org.pytorch.executorch.ExecutorchRuntimeException])
* @param message Human-readable error description
*/
- @DoNotStrip
- default void onError(int errorCode, String message) {}
+ @DoNotStrip fun onError(errorCode: Int, message: String) {}
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java
deleted file mode 100644
index db7941aadad..00000000000
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java
+++ /dev/null
@@ -1,198 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-package org.pytorch.executorch.extension.llm;
-
-/**
- * Configuration class for controlling text generation parameters in LLM operations.
- *
- *
This class provides settings for text generation behavior including output formatting,
- * generation limits, and sampling parameters. Instances should be created using the {@link
- * #create()} method and the fluent builder pattern.
- */
-public class LlmGenerationConfig {
- private final boolean echo;
- private final int maxNewTokens;
- private final boolean warming;
- private final int seqLen;
- private final float temperature;
- private final int numBos;
- private final int numEos;
-
- private LlmGenerationConfig(Builder builder) {
- this.echo = builder.echo;
- this.maxNewTokens = builder.maxNewTokens;
- this.warming = builder.warming;
- this.seqLen = builder.seqLen;
- this.temperature = builder.temperature;
- this.numBos = builder.numBos;
- this.numEos = builder.numEos;
- }
-
- /**
- * Creates a new Builder instance for constructing generation configurations.
- *
- * @return a new Builder with default configuration values
- */
- public static Builder create() {
- return new Builder();
- }
-
- /**
- * @return true if input prompt should be included in the output
- */
- public boolean isEcho() {
- return echo;
- }
-
- /**
- * @return maximum number of tokens to generate (-1 for unlimited)
- */
- public int getMaxNewTokens() {
- return maxNewTokens;
- }
-
- /**
- * @return true if model warming is enabled
- */
- public boolean isWarming() {
- return warming;
- }
-
- /**
- * @return maximum sequence length for generation (-1 for default)
- */
- public int getSeqLen() {
- return seqLen;
- }
-
- /**
- * @return temperature value for sampling (higher = more random)
- */
- public float getTemperature() {
- return temperature;
- }
-
- /**
- * @return number of BOS tokens to prepend
- */
- public int getNumBos() {
- return numBos;
- }
-
- /**
- * @return number of EOS tokens to append
- */
- public int getNumEos() {
- return numEos;
- }
-
- /**
- * Builder class for constructing LlmGenerationConfig instances.
- *
- *
Provides a fluent interface for configuring generation parameters with sensible defaults.
- * All methods return the builder instance to enable method chaining.
- */
- public static class Builder {
- private boolean echo = true;
- private int maxNewTokens = -1;
- private boolean warming = false;
- private int seqLen = -1;
- private float temperature = 0.8f;
- private int numBos = 0;
- private int numEos = 0;
-
- Builder() {}
-
- /**
- * Sets whether to include the input prompt in the generated output.
- *
- * @param echo true to include input prompt, false to return only new tokens
- * @return this builder instance
- */
- public Builder echo(boolean echo) {
- this.echo = echo;
- return this;
- }
-
- /**
- * Sets the maximum number of new tokens to generate.
- *
- * @param maxNewTokens the token limit (-1 for unlimited generation)
- * @return this builder instance
- */
- public Builder maxNewTokens(int maxNewTokens) {
- this.maxNewTokens = maxNewTokens;
- return this;
- }
-
- /**
- * Enables or disables model warming.
- *
- * @param warming true to generate initial tokens for model warmup
- * @return this builder instance
- */
- public Builder warming(boolean warming) {
- this.warming = warming;
- return this;
- }
-
- /**
- * Sets the maximum sequence length for generation.
- *
- * @param seqLen maximum sequence length (-1 for default behavior)
- * @return this builder instance
- */
- public Builder seqLen(int seqLen) {
- this.seqLen = seqLen;
- return this;
- }
-
- /**
- * Sets the temperature for random sampling.
- *
- * @param temperature sampling temperature (typical range 0.0-1.0)
- * @return this builder instance
- */
- public Builder temperature(float temperature) {
- this.temperature = temperature;
- return this;
- }
-
- /**
- * Sets the number of BOS tokens to prepend.
- *
- * @param numBos number of BOS tokens
- * @return this builder instance
- */
- public Builder numBos(int numBos) {
- this.numBos = numBos;
- return this;
- }
-
- /**
- * Sets the number of EOS tokens to append.
- *
- * @param numEos number of EOS tokens
- * @return this builder instance
- */
- public Builder numEos(int numEos) {
- this.numEos = numEos;
- return this;
- }
-
- /**
- * Constructs the LlmGenerationConfig instance with the configured parameters.
- *
- * @return new LlmGenerationConfig instance with current builder settings
- */
- public LlmGenerationConfig build() {
- return new LlmGenerationConfig(this);
- }
- }
-}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt
new file mode 100644
index 00000000000..c0f8956fb7f
--- /dev/null
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt
@@ -0,0 +1,78 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.executorch.extension.llm
+
+/**
+ * Configuration class for controlling text generation parameters in LLM operations.
+ *
+ * This class provides settings for text generation behavior including output formatting, generation
+ * limits, and sampling parameters. Instances should be created using the [create] method and the
+ * fluent builder pattern.
+ */
+class LlmGenerationConfig
+private constructor(
+ @get:JvmName("isEcho") val echo: Boolean,
+ val maxNewTokens: Int,
+ @get:JvmName("isWarming") val warming: Boolean,
+ val seqLen: Int,
+ val temperature: Float,
+ val numBos: Int,
+ val numEos: Int,
+) {
+
+ companion object {
+ /**
+ * Creates a new Builder instance for constructing generation configurations.
+ *
+ * @return a new Builder with default configuration values
+ */
+ @JvmStatic fun create(): Builder = Builder()
+ }
+
+ /**
+ * Builder class for constructing LlmGenerationConfig instances.
+ *
+ * Provides a fluent interface for configuring generation parameters with sensible defaults. All
+ * methods return the builder instance to enable method chaining.
+ */
+ class Builder internal constructor() {
+ private var echo: Boolean = true
+ private var maxNewTokens: Int = -1
+ private var warming: Boolean = false
+ private var seqLen: Int = -1
+ private var temperature: Float = 0.8f
+ private var numBos: Int = 0
+ private var numEos: Int = 0
+
+ /** Sets whether to include the input prompt in the generated output. */
+ fun echo(echo: Boolean): Builder = apply { this.echo = echo }
+
+ /** Sets the maximum number of new tokens to generate. */
+ fun maxNewTokens(maxNewTokens: Int): Builder = apply { this.maxNewTokens = maxNewTokens }
+
+ /** Enables or disables model warming. */
+ fun warming(warming: Boolean): Builder = apply { this.warming = warming }
+
+ /** Sets the maximum sequence length for generation. */
+ fun seqLen(seqLen: Int): Builder = apply { this.seqLen = seqLen }
+
+ /** Sets the temperature for random sampling. */
+ fun temperature(temperature: Float): Builder = apply { this.temperature = temperature }
+
+ /** Sets the number of BOS tokens to prepend. */
+ fun numBos(numBos: Int): Builder = apply { this.numBos = numBos }
+
+ /** Sets the number of EOS tokens to append. */
+ fun numEos(numEos: Int): Builder = apply { this.numEos = numEos }
+
+ /** Constructs the LlmGenerationConfig instance with the configured parameters. */
+ fun build(): LlmGenerationConfig =
+ LlmGenerationConfig(echo, maxNewTokens, warming, seqLen, temperature, numBos, numEos)
+ }
+}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
deleted file mode 100644
index a563dc6bcc7..00000000000
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
+++ /dev/null
@@ -1,668 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-package org.pytorch.executorch.extension.llm;
-
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
-import java.nio.ByteBuffer;
-import java.util.List;
-import org.pytorch.executorch.ExecuTorchRuntime;
-import org.pytorch.executorch.annotations.Experimental;
-
-/**
- * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
- * from the model.
- *
- *
Warning: These APIs are experimental and subject to change without notice
- */
-@Experimental
-public class LlmModule {
-
- public static final int MODEL_TYPE_TEXT = 1;
- public static final int MODEL_TYPE_TEXT_VISION = 2;
- public static final int MODEL_TYPE_MULTIMODAL = 2;
-
- private final HybridData mHybridData;
- private static final int DEFAULT_SEQ_LEN = 128;
- private static final boolean DEFAULT_ECHO = true;
- private static final float DEFAULT_TEMPERATURE = -1.0f;
- private static final int DEFAULT_BOS = 0;
- private static final int DEFAULT_EOS = 0;
- private static final int DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP;
-
- @DoNotStrip
- private static native HybridData initHybrid(
- int modelType,
- String modulePath,
- String tokenizerPath,
- float temperature,
- List dataFiles,
- int numBos,
- int numEos,
- int loadMode);
-
- private LlmModule(
- int modelType,
- String modulePath,
- String tokenizerPath,
- float temperature,
- List dataFiles,
- int numBos,
- int numEos,
- int loadMode) {
- ExecuTorchRuntime.getRuntime();
- ExecuTorchRuntime.validateFilePath(modulePath, "model path");
- ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path");
-
- mHybridData =
- initHybrid(
- modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos, loadMode);
- }
-
- /**
- * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
- * dataFiles.
- */
- public LlmModule(
- int modelType,
- String modulePath,
- String tokenizerPath,
- float temperature,
- List dataFiles,
- int numBos,
- int numEos) {
- this(
- modelType,
- modulePath,
- tokenizerPath,
- temperature,
- dataFiles,
- numBos,
- numEos,
- DEFAULT_LOAD_MODE);
- }
-
- /**
- * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
- * dataFiles.
- */
- public LlmModule(
- int modelType,
- String modulePath,
- String tokenizerPath,
- float temperature,
- List dataFiles) {
- this(
- modelType,
- modulePath,
- tokenizerPath,
- temperature,
- dataFiles,
- DEFAULT_BOS,
- DEFAULT_EOS,
- DEFAULT_LOAD_MODE);
- }
-
- /**
- * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
- * data path.
- */
- public LlmModule(
- int modelType,
- String modulePath,
- String tokenizerPath,
- float temperature,
- String dataPath,
- int numBos,
- int numEos) {
- this(
- modelType,
- modulePath,
- tokenizerPath,
- temperature,
- dataPath != null ? List.of(dataPath) : List.of(),
- numBos,
- numEos);
- }
-
- /**
- * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
- * data path.
- */
- public LlmModule(
- int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) {
- this(modelType, modulePath, tokenizerPath, temperature, dataPath, DEFAULT_BOS, DEFAULT_EOS);
- }
-
- /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */
- public LlmModule(String modulePath, String tokenizerPath, float temperature) {
- this(
- MODEL_TYPE_TEXT,
- modulePath,
- tokenizerPath,
- temperature,
- List.of(),
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data
- * path.
- */
- public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
- this(
- MODEL_TYPE_TEXT,
- modulePath,
- tokenizerPath,
- temperature,
- List.of(dataPath),
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
- public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
- this(modelType, modulePath, tokenizerPath, temperature, List.of(), DEFAULT_BOS, DEFAULT_EOS);
- }
-
- /** Constructs a LLM Module for a model with the given LlmModuleConfig */
- public LlmModule(LlmModuleConfig config) {
- this(
- config.getModelType(),
- config.getModulePath(),
- config.getTokenizerPath(),
- config.getTemperature(),
- config.getDataPath() != null ? List.of(config.getDataPath()) : List.of(),
- config.getNumBos(),
- config.getNumEos(),
- config.getLoadMode());
- }
-
- public void resetNative() {
- mHybridData.resetNative();
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param llmCallback callback object to receive results.
- */
- public int generate(String prompt, LlmCallback llmCallback) {
- return generate(
- prompt,
- DEFAULT_SEQ_LEN,
- llmCallback,
- DEFAULT_ECHO,
- DEFAULT_TEMPERATURE,
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results.
- */
- public int generate(String prompt, int seqLen, LlmCallback llmCallback) {
- return generate(
- null,
- 0,
- 0,
- 0,
- prompt,
- seqLen,
- llmCallback,
- DEFAULT_ECHO,
- DEFAULT_TEMPERATURE,
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param llmCallback callback object to receive results
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- */
- public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
- return generate(
- null,
- 0,
- 0,
- 0,
- prompt,
- DEFAULT_SEQ_LEN,
- llmCallback,
- echo,
- DEFAULT_TEMPERATURE,
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- */
- public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
- return generate(
- prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- * @param temperature temperature for sampling (use negative value to use module default)
- * @param numBos number of BOS tokens to prepend
- * @param numEos number of EOS tokens to append
- */
- public native int generate(
- String prompt,
- int seqLen,
- LlmCallback llmCallback,
- boolean echo,
- float temperature,
- int numBos,
- int numEos);
-
- /**
- * Start generating tokens from the module.
- *
- * @param prompt Input prompt
- * @param config the config for generation
- * @param llmCallback callback object to receive results
- */
- public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) {
- int seqLen = config.getSeqLen();
- boolean echo = config.isEcho();
- float temperature = config.getTemperature();
- int numBos = config.getNumBos();
- int numEos = config.getNumEos();
- return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param image Input image as a byte array
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results.
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- */
- public int generate(
- int[] image,
- int width,
- int height,
- int channels,
- String prompt,
- int seqLen,
- LlmCallback llmCallback,
- boolean echo) {
- return generate(
- image,
- width,
- height,
- channels,
- prompt,
- seqLen,
- llmCallback,
- echo,
- DEFAULT_TEMPERATURE,
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param image Input image as a byte array
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results.
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- * @param temperature temperature for sampling (use negative value to use module default)
- */
- public int generate(
- int[] image,
- int width,
- int height,
- int channels,
- String prompt,
- int seqLen,
- LlmCallback llmCallback,
- boolean echo,
- float temperature) {
- return generate(
- image,
- width,
- height,
- channels,
- prompt,
- seqLen,
- llmCallback,
- echo,
- temperature,
- DEFAULT_BOS,
- DEFAULT_EOS);
- }
-
- /**
- * Start generating tokens from the module.
- *
- * @param image Input image as a byte array
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @param prompt Input prompt
- * @param seqLen sequence length
- * @param llmCallback callback object to receive results.
- * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
- * @param temperature temperature for sampling (use negative value to use module default)
- * @param numBos number of BOS tokens to prepend
- * @param numEos number of EOS tokens to append
- */
- public int generate(
- int[] image,
- int width,
- int height,
- int channels,
- String prompt,
- int seqLen,
- LlmCallback llmCallback,
- boolean echo,
- float temperature,
- int numBos,
- int numEos) {
- if (image != null) {
- prefillImages(image, width, height, channels);
- }
- return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos);
- }
-
- /**
- * Prefill the KV cache with the given image input.
- *
- * @param image Input image as a byte array
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillImages(int[] image, int width, int height, int channels) {
- int nativeResult = prefillImagesInput(image, width, height, channels);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- /**
- * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data
- * is accessed directly without JNI array copies, unlike {@link #prefillImages(int[], int, int,
- * int)}. The ByteBuffer must contain raw uint8 pixel data in CHW format with at least channels *
- * height * width bytes remaining. Only the first channels * height * width bytes from the
- * buffer's current position are read; the position of the original ByteBuffer is not modified.
- *
- * @param image Input image as a direct ByteBuffer containing uint8 pixel data
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining
- * bytes
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public void prefillImages(ByteBuffer image, int width, int height, int channels) {
- if (!image.isDirect()) {
- throw new IllegalArgumentException("Input ByteBuffer must be direct.");
- }
- long expectedBytes;
- try {
- long pixels = Math.multiplyExact((long) width, (long) height);
- expectedBytes = Math.multiplyExact(pixels, (long) channels);
- } catch (ArithmeticException ex) {
- throw new IllegalArgumentException(
- "width*height*channels is too large and overflows the allowed range.", ex);
- }
- if (width <= 0
- || height <= 0
- || channels <= 0
- || expectedBytes > Integer.MAX_VALUE
- || image.remaining() < expectedBytes) {
- throw new IllegalArgumentException(
- "ByteBuffer remaining ("
- + image.remaining()
- + ") must be at least width*height*channels ("
- + expectedBytes
- + ").");
- }
- // slice() so that getDirectBufferAddress on the native side returns a pointer
- // starting at the current position, not the base address.
- int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- }
-
- /**
- * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The
- * buffer data is accessed directly without JNI array copies, unlike {@link
- * #prefillImages(float[], int, int, int)}. The ByteBuffer must contain normalized float pixel
- * data in CHW format with at least channels * height * width * 4 bytes remaining. Only the first
- * channels * height * width floats from the buffer's current position are consumed. The buffer
- * must use the platform's native byte order (set via {@code
- * buffer.order(ByteOrder.nativeOrder())}).
- *
- * @param image Input normalized image as a direct ByteBuffer containing float pixel data in
- * native byte order
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining
- * bytes, is not float-aligned, or does not use native byte order
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) {
- if (!image.isDirect()) {
- throw new IllegalArgumentException("Input ByteBuffer must be direct.");
- }
- if (image.order() != java.nio.ByteOrder.nativeOrder()) {
- throw new IllegalArgumentException(
- "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder()).");
- }
- if (image.position() % Float.BYTES != 0) {
- throw new IllegalArgumentException(
- "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned.");
- }
- final long expectedBytes;
- try {
- int wh = Math.multiplyExact(width, height);
- long whc = Math.multiplyExact((long) wh, (long) channels);
- long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES);
- if (totalBytes > Integer.MAX_VALUE) {
- throw new IllegalArgumentException(
- "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: "
- + totalBytes);
- }
- expectedBytes = totalBytes;
- } catch (ArithmeticException e) {
- throw new IllegalArgumentException(
- "Overflow while computing width*height*channels*4 for ByteBuffer size.", e);
- }
- if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) {
- throw new IllegalArgumentException(
- "ByteBuffer remaining ("
- + image.remaining()
- + ") must be at least width*height*channels*4 ("
- + expectedBytes
- + ").");
- }
- if (image.remaining() % Float.BYTES != 0) {
- throw new IllegalArgumentException(
- "ByteBuffer remaining (" + image.remaining() + ") must be a multiple of 4 (float size).");
- }
- // slice() so that getDirectBufferAddress on the native side returns a pointer
- // starting at the current position, not the base address.
- int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- }
-
- private native int prefillImagesInput(int[] image, int width, int height, int channels);
-
- private native int prefillImagesInputBuffer(
- ByteBuffer image, int width, int height, int channels);
-
- private native int prefillNormalizedImagesInputBuffer(
- ByteBuffer image, int width, int height, int channels);
-
- /**
- * Prefill the KV cache with the given normalized image input.
- *
- * @param image Input normalized image as a float array
- * @param width Input image width
- * @param height Input image height
- * @param channels Input image number of channels
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillImages(float[] image, int width, int height, int channels) {
- int nativeResult = prefillNormalizedImagesInput(image, width, height, channels);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- private native int prefillNormalizedImagesInput(
- float[] image, int width, int height, int channels);
-
- /**
- * Prefill the KV cache with the given preprocessed audio input.
- *
- * @param audio Input preprocessed audio as a byte array
- * @param batch_size Input batch size
- * @param n_bins Input number of bins
- * @param n_frames Input number of frames
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
- int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
-
- /**
- * Prefill the KV cache with the given preprocessed audio input.
- *
- * @param audio Input preprocessed audio as a float array
- * @param batch_size Input batch size
- * @param n_bins Input number of bins
- * @param n_frames Input number of frames
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
- int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- private native int prefillAudioInputFloat(
- float[] audio, int batch_size, int n_bins, int n_frames);
-
- /**
- * Prefill the KV cache with the given raw audio input.
- *
- * @param audio Input raw audio as a byte array
- * @param batch_size Input batch size
- * @param n_channels Input number of channels
- * @param n_samples Input number of samples
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
- int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- private native int prefillRawAudioInput(
- byte[] audio, int batch_size, int n_channels, int n_samples);
-
- /**
- * Prefill the KV cache with the given text prompt.
- *
- * @param prompt The text prompt to prefill.
- * @return 0 on success
- * @throws RuntimeException if the prefill failed
- */
- @Experimental
- public long prefillPrompt(String prompt) {
- int nativeResult = prefillTextInput(prompt);
- if (nativeResult != 0) {
- throw new RuntimeException("Prefill failed with error code: " + nativeResult);
- }
- return 0;
- }
-
- // returns status
- private native int prefillTextInput(String prompt);
-
- /**
- * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
- *
- * The startPos will be reset to 0.
- */
- public native void resetContext();
-
- /** Stop current generate() before it finishes. */
- @DoNotStrip
- public native void stop();
-
- /** Force loading the module. Otherwise the model is loaded during first generate(). */
- @DoNotStrip
- public native int load();
-}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt
new file mode 100644
index 00000000000..c1e9e62e5e1
--- /dev/null
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt
@@ -0,0 +1,733 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.executorch.extension.llm
+
+import com.facebook.jni.HybridData
+import com.facebook.jni.annotations.DoNotStrip
+import java.nio.ByteBuffer
+import java.nio.ByteOrder
+import org.pytorch.executorch.ExecuTorchRuntime
+import org.pytorch.executorch.annotations.Experimental
+
+/**
+ * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text
+ * from the model.
+ *
+ * Warning: These APIs are experimental and subject to change without notice
+ */
+@Experimental
+class LlmModule
+private constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataFiles: List,
+ numBos: Int,
+ numEos: Int,
+ loadMode: Int,
+) {
+
+ private val mHybridData: HybridData
+
+ init {
+ ExecuTorchRuntime.getRuntime()
+ requireNotNull(modulePath) { "model path must not be null" }
+ requireNotNull(tokenizerPath) { "tokenizer path must not be null" }
+ ExecuTorchRuntime.validateFilePath(modulePath, "model path")
+ ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path")
+ mHybridData =
+ initHybrid(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ dataFiles,
+ numBos,
+ numEos,
+ loadMode,
+ )
+ }
+
+ /**
+ * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
+ * dataFiles.
+ */
+ constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataFiles: List,
+ numBos: Int,
+ numEos: Int,
+ ) : this(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ dataFiles,
+ numBos,
+ numEos,
+ DEFAULT_LOAD_MODE,
+ )
+
+ /**
+ * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
+ * dataFiles.
+ */
+ constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataFiles: List,
+ ) : this(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ dataFiles,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ DEFAULT_LOAD_MODE,
+ )
+
+ /**
+ * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
+ * data path.
+ */
+ constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataPath: String?,
+ numBos: Int,
+ numEos: Int,
+ ) : this(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ listOfNotNull(dataPath),
+ numBos,
+ numEos,
+ )
+
+ /**
+ * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
+ * data path.
+ */
+ constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataPath: String?,
+ ) : this(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ dataPath,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */
+ constructor(
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ ) : this(
+ MODEL_TYPE_TEXT,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ emptyList(),
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data
+ * path.
+ */
+ constructor(
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ dataPath: String,
+ ) : this(
+ MODEL_TYPE_TEXT,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ listOf(dataPath),
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
+ constructor(
+ modelType: Int,
+ modulePath: String?,
+ tokenizerPath: String?,
+ temperature: Float,
+ ) : this(
+ modelType,
+ modulePath,
+ tokenizerPath,
+ temperature,
+ emptyList(),
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /** Constructs a LLM Module for a model with the given LlmModuleConfig */
+ constructor(config: LlmModuleConfig) : this(
+ config.modelType,
+ config.modulePath,
+ config.tokenizerPath,
+ config.temperature,
+ listOfNotNull(config.dataPath),
+ config.numBos,
+ config.numEos,
+ config.loadMode,
+ )
+
+ fun resetNative() {
+ mHybridData.resetNative()
+ }
+
+ // --- generate overloads ---
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param llmCallback callback object to receive results.
+ */
+ fun generate(prompt: String, llmCallback: LlmCallback): Int =
+ generate(
+ prompt,
+ DEFAULT_SEQ_LEN,
+ llmCallback,
+ DEFAULT_ECHO,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results.
+ */
+ fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback): Int =
+ generate(
+ null,
+ 0,
+ 0,
+ 0,
+ prompt,
+ seqLen,
+ llmCallback,
+ DEFAULT_ECHO,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param llmCallback callback object to receive results
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ */
+ fun generate(prompt: String, llmCallback: LlmCallback, echo: Boolean): Int =
+ generate(
+ null,
+ 0,
+ 0,
+ 0,
+ prompt,
+ DEFAULT_SEQ_LEN,
+ llmCallback,
+ echo,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ */
+ fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback, echo: Boolean): Int =
+ generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS)
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ * @param temperature temperature for sampling (use negative value to use module default)
+ * @param numBos number of BOS tokens to prepend
+ * @param numEos number of EOS tokens to append
+ */
+ external fun generate(
+ prompt: String,
+ seqLen: Int,
+ llmCallback: LlmCallback,
+ echo: Boolean,
+ temperature: Float,
+ numBos: Int,
+ numEos: Int,
+ ): Int
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param prompt Input prompt
+ * @param config the config for generation
+ * @param llmCallback callback object to receive results
+ */
+ fun generate(prompt: String, config: LlmGenerationConfig, llmCallback: LlmCallback): Int =
+ generate(
+ null,
+ 0,
+ 0,
+ 0,
+ prompt,
+ config.seqLen,
+ llmCallback,
+ config.echo,
+ config.temperature,
+ config.numBos,
+ config.numEos,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param image Input image as a byte array
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results.
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ */
+ fun generate(
+ image: IntArray?,
+ width: Int,
+ height: Int,
+ channels: Int,
+ prompt: String,
+ seqLen: Int,
+ llmCallback: LlmCallback,
+ echo: Boolean,
+ ): Int =
+ generate(
+ image,
+ width,
+ height,
+ channels,
+ prompt,
+ seqLen,
+ llmCallback,
+ echo,
+ DEFAULT_TEMPERATURE,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param image Input image as a byte array
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results.
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ * @param temperature temperature for sampling (use negative value to use module default)
+ */
+ fun generate(
+ image: IntArray?,
+ width: Int,
+ height: Int,
+ channels: Int,
+ prompt: String,
+ seqLen: Int,
+ llmCallback: LlmCallback,
+ echo: Boolean,
+ temperature: Float,
+ ): Int =
+ generate(
+ image,
+ width,
+ height,
+ channels,
+ prompt,
+ seqLen,
+ llmCallback,
+ echo,
+ temperature,
+ DEFAULT_BOS,
+ DEFAULT_EOS,
+ )
+
+ /**
+ * Start generating tokens from the module.
+ *
+ * @param image Input image as a byte array
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @param prompt Input prompt
+ * @param seqLen sequence length
+ * @param llmCallback callback object to receive results.
+ * @param echo indicate whether to echo the input prompt or not (text completion vs chat)
+ * @param temperature temperature for sampling (use negative value to use module default)
+ * @param numBos number of BOS tokens to prepend
+ * @param numEos number of EOS tokens to append
+ */
+ fun generate(
+ image: IntArray?,
+ width: Int,
+ height: Int,
+ channels: Int,
+ prompt: String,
+ seqLen: Int,
+ llmCallback: LlmCallback,
+ echo: Boolean,
+ temperature: Float,
+ numBos: Int,
+ numEos: Int,
+ ): Int {
+ if (image != null) {
+ prefillImages(image, width, height, channels)
+ }
+ return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos)
+ }
+
+ // --- prefill methods ---
+
+ /**
+ * Prefill the KV cache with the given image input.
+ *
+ * @param image Input image as a byte array
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillImages(image: IntArray, width: Int, height: Int, channels: Int): Long {
+ val nativeResult = prefillImagesInput(image, width, height, channels)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ /**
+ * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data
+ * is accessed directly without JNI array copies, unlike [prefillImages]. The ByteBuffer must
+ * contain raw uint8 pixel data in CHW format with at least channels * height * width bytes
+ * remaining. Only the first channels * height * width bytes from the buffer's current position
+ * are read; the position of the original ByteBuffer is not modified.
+ *
+ * @param image Input image as a direct ByteBuffer containing uint8 pixel data
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining
+ * bytes
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillImages(image: ByteBuffer, width: Int, height: Int, channels: Int) {
+ require(image.isDirect) { "Input ByteBuffer must be direct." }
+ val expectedBytes: Long
+ try {
+ val pixels = Math.multiplyExact(width.toLong(), height.toLong())
+ expectedBytes = Math.multiplyExact(pixels, channels.toLong())
+ } catch (ex: ArithmeticException) {
+ throw IllegalArgumentException(
+ "width*height*channels is too large and overflows the allowed range.", ex)
+ }
+ require(
+ width > 0 &&
+ height > 0 &&
+ channels > 0 &&
+ expectedBytes <= Int.MAX_VALUE &&
+ image.remaining() >= expectedBytes) {
+ "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels ($expectedBytes)."
+ }
+ // slice() so that getDirectBufferAddress on the native side returns a pointer
+ // starting at the current position, not the base address.
+ val nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ }
+
+ /**
+ * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The
+ * buffer data is accessed directly without JNI array copies, unlike [prefillImages]. The
+ * ByteBuffer must contain normalized float pixel data in CHW format with at least channels *
+ * height * width * 4 bytes remaining. Only the first channels * height * width floats from the
+ * buffer's current position are consumed. The buffer must use the platform's native byte order
+ * (set via `buffer.order(ByteOrder.nativeOrder())`).
+ *
+ * @param image Input normalized image as a direct ByteBuffer containing float pixel data in
+ * native byte order
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining
+ * bytes, is not float-aligned, or does not use native byte order
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillNormalizedImage(image: ByteBuffer, width: Int, height: Int, channels: Int) {
+ require(image.isDirect) { "Input ByteBuffer must be direct." }
+ require(image.order() == ByteOrder.nativeOrder()) {
+ "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."
+ }
+ require(image.position() % Float.SIZE_BYTES == 0) {
+ "Input ByteBuffer position (${image.position()}) must be 4-byte aligned."
+ }
+ val expectedBytes: Long
+ try {
+ val wh = Math.multiplyExact(width, height)
+ val whc = Math.multiplyExact(wh.toLong(), channels.toLong())
+ val totalBytes = Math.multiplyExact(whc, Float.SIZE_BYTES.toLong())
+ if (totalBytes > Int.MAX_VALUE) {
+ throw IllegalArgumentException(
+ "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: $totalBytes")
+ }
+ expectedBytes = totalBytes
+ } catch (e: ArithmeticException) {
+ throw IllegalArgumentException(
+ "Overflow while computing width*height*channels*4 for ByteBuffer size.", e)
+ }
+ require(width > 0 && height > 0 && channels > 0 && image.remaining() >= expectedBytes) {
+ "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels*4 ($expectedBytes)."
+ }
+ require(image.remaining() % Float.SIZE_BYTES == 0) {
+ "ByteBuffer remaining (${image.remaining()}) must be a multiple of 4 (float size)."
+ }
+ // slice() so that getDirectBufferAddress on the native side returns a pointer
+ // starting at the current position, not the base address.
+ val nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ }
+
+ private external fun prefillImagesInput(
+ image: IntArray,
+ width: Int,
+ height: Int,
+ channels: Int,
+ ): Int
+
+ private external fun prefillImagesInputBuffer(
+ image: ByteBuffer,
+ width: Int,
+ height: Int,
+ channels: Int,
+ ): Int
+
+ private external fun prefillNormalizedImagesInputBuffer(
+ image: ByteBuffer,
+ width: Int,
+ height: Int,
+ channels: Int,
+ ): Int
+
+ /**
+ * Prefill the KV cache with the given normalized image input.
+ *
+ * @param image Input normalized image as a float array
+ * @param width Input image width
+ * @param height Input image height
+ * @param channels Input image number of channels
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillImages(image: FloatArray, width: Int, height: Int, channels: Int): Long {
+ val nativeResult = prefillNormalizedImagesInput(image, width, height, channels)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ private external fun prefillNormalizedImagesInput(
+ image: FloatArray,
+ width: Int,
+ height: Int,
+ channels: Int,
+ ): Int
+
+ /**
+ * Prefill the KV cache with the given preprocessed audio input.
+ *
+ * @param audio Input preprocessed audio as a byte array
+ * @param batchSize Input batch size
+ * @param nBins Input number of bins
+ * @param nFrames Input number of frames
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int): Long {
+ val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ private external fun prefillAudioInput(
+ audio: ByteArray,
+ batchSize: Int,
+ nBins: Int,
+ nFrames: Int,
+ ): Int
+
+ /**
+ * Prefill the KV cache with the given preprocessed audio input.
+ *
+ * @param audio Input preprocessed audio as a float array
+ * @param batchSize Input batch size
+ * @param nBins Input number of bins
+ * @param nFrames Input number of frames
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int): Long {
+ val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ private external fun prefillAudioInputFloat(
+ audio: FloatArray,
+ batchSize: Int,
+ nBins: Int,
+ nFrames: Int,
+ ): Int
+
+ /**
+ * Prefill the KV cache with the given raw audio input.
+ *
+ * @param audio Input raw audio as a byte array
+ * @param batchSize Input batch size
+ * @param nChannels Input number of channels
+ * @param nSamples Input number of samples
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int): Long {
+ val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ private external fun prefillRawAudioInput(
+ audio: ByteArray,
+ batchSize: Int,
+ nChannels: Int,
+ nSamples: Int,
+ ): Int
+
+ /**
+ * Prefill the KV cache with the given text prompt.
+ *
+ * @param prompt The text prompt to prefill
+ * @return 0 on success
+ * @throws RuntimeException if the prefill failed
+ */
+ @Experimental
+ fun prefillPrompt(prompt: String): Long {
+ val nativeResult = prefillTextInput(prompt)
+ if (nativeResult != 0) {
+ throw RuntimeException("Prefill failed with error code: $nativeResult")
+ }
+ return 0
+ }
+
+ // returns status
+ private external fun prefillTextInput(prompt: String): Int
+
+ /**
+ * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. The
+ * startPos will be reset to 0.
+ */
+ external fun resetContext()
+
+ /** Stop current generate() before it finishes. */
+ @DoNotStrip external fun stop()
+
+ /** Force loading the module. Otherwise the model is loaded during first generate(). */
+ @DoNotStrip external fun load(): Int
+
+ companion object {
+ const val MODEL_TYPE_TEXT = 1
+ const val MODEL_TYPE_TEXT_VISION = 2
+ const val MODEL_TYPE_MULTIMODAL = 2
+
+ private const val DEFAULT_SEQ_LEN = 128
+ private const val DEFAULT_ECHO = true
+ private const val DEFAULT_TEMPERATURE = -1.0f
+ private const val DEFAULT_BOS = 0
+ private const val DEFAULT_EOS = 0
+ private const val DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP
+
+ @DoNotStrip
+ @JvmStatic
+ private external fun initHybrid(
+ modelType: Int,
+ modulePath: String,
+ tokenizerPath: String,
+ temperature: Float,
+ dataFiles: List,
+ numBos: Int,
+ numEos: Int,
+ loadMode: Int,
+ ): HybridData
+ }
+}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java
deleted file mode 100644
index feb52a2b34b..00000000000
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java
+++ /dev/null
@@ -1,252 +0,0 @@
-/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-
-package org.pytorch.executorch.extension.llm;
-
-/**
- * Configuration class for initializing a LlmModule.
- *
- * {@link #create()} method and the fluent builder pattern.
- */
-public class LlmModuleConfig {
- private final String modulePath;
- private final String tokenizerPath;
- private final float temperature;
- private final String dataPath;
- private final int modelType;
- private final int numBos;
- private final int numEos;
- private final int loadMode;
-
- /** Load entire model file into a buffer (no mmap). */
- public static final int LOAD_MODE_FILE = 0;
-
- /** Load model via mmap without mlock (default). Pages faulted in on demand. */
- public static final int LOAD_MODE_MMAP = 1;
-
- /** Load model via mmap and pin all pages with mlock. */
- public static final int LOAD_MODE_MMAP_USE_MLOCK = 2;
-
- /** Load model via mmap and attempt mlock, ignoring mlock failures. */
- public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
-
- private LlmModuleConfig(Builder builder) {
- this.modulePath = builder.modulePath;
- this.tokenizerPath = builder.tokenizerPath;
- this.temperature = builder.temperature;
- this.dataPath = builder.dataPath;
- this.modelType = builder.modelType;
- this.numBos = builder.numBos;
- this.numEos = builder.numEos;
- this.loadMode = builder.loadMode;
- }
-
- /** Model type constant for text-only models. */
- public static final int MODEL_TYPE_TEXT = 1;
-
- /** Model type constant for text-and-vision multimodal models. */
- public static final int MODEL_TYPE_TEXT_VISION = 2;
-
- /** Model type constant for generic multimodal models. */
- public static final int MODEL_TYPE_MULTIMODAL = 2;
-
- /**
- * Creates a new Builder instance for constructing LlmModuleConfig objects.
- *
- * @return a new Builder instance with default configuration values
- */
- public static Builder create() {
- return new Builder();
- }
-
- // Getters with documentation
- /**
- * @return Path to the compiled model module (.pte file)
- */
- public String getModulePath() {
- return modulePath;
- }
-
- /**
- * @return Path to the tokenizer file or directory
- */
- public String getTokenizerPath() {
- return tokenizerPath;
- }
-
- /**
- * @return Temperature value for sampling (higher = more random)
- */
- public float getTemperature() {
- return temperature;
- }
-
- /**
- * @return Optional path to additional data files
- */
- public String getDataPath() {
- return dataPath;
- }
-
- /**
- * @return Type of model (text-only or text-vision)
- */
- public int getModelType() {
- return modelType;
- }
-
- /**
- * @return Number of BOS tokens to prepend
- */
- public int getNumBos() {
- return numBos;
- }
-
- /**
- * @return Number of EOS tokens to append
- */
- public int getNumEos() {
- return numEos;
- }
-
- /**
- * @return Load mode for the model file (one of LOAD_MODE_* constants)
- */
- public int getLoadMode() {
- return loadMode;
- }
-
- /**
- * Builder class for constructing LlmModuleConfig instances with optional parameters.
- *
- *
The builder provides a fluent interface for configuring model parameters and validates
- * required fields before construction.
- */
- public static class Builder {
- private String modulePath;
- private String tokenizerPath;
- private float temperature = 0.8f;
- private String dataPath = "";
- private int modelType = MODEL_TYPE_TEXT;
- private int numBos = 0;
- private int numEos = 0;
- private int loadMode = LOAD_MODE_MMAP;
-
- Builder() {}
-
- /**
- * Sets the path to the module.
- *
- * @param modulePath Path to module
- * @return This builder instance for method chaining
- */
- public Builder modulePath(String modulePath) {
- this.modulePath = modulePath;
- return this;
- }
-
- /**
- * Sets the path to the tokenizer.
- *
- * @param tokenizerPath Path to tokenizer
- * @return This builder instance for method chaining
- */
- public Builder tokenizerPath(String tokenizerPath) {
- this.tokenizerPath = tokenizerPath;
- return this;
- }
-
- /**
- * Sets the temperature for sampling generation.
- *
- * @param temperature Temperature value (typical range 0.0-1.0)
- * @return This builder instance for method chaining
- */
- public Builder temperature(float temperature) {
- this.temperature = temperature;
- return this;
- }
-
- /**
- * Sets the path to optional additional data files.
- *
- * @param dataPath Path to supplementary data resources
- * @return This builder instance for method chaining
- */
- public Builder dataPath(String dataPath) {
- this.dataPath = dataPath;
- return this;
- }
-
- /**
- * Sets the model type (text-only or multimodal).
- *
- * @param modelType One of MODEL_TYPE_TEXT, MODEL_TYPE_TEXT_VISION, MODEL_TYPE_MULTIMODAL
- * @return This builder instance for method chaining
- */
- public Builder modelType(int modelType) {
- this.modelType = modelType;
- return this;
- }
-
- /**
- * Sets the number of BOS tokens to prepend.
- *
- * @param numBos number of BOS tokens
- * @return This builder instance for method chaining
- */
- public Builder numBos(int numBos) {
- this.numBos = numBos;
- return this;
- }
-
- /**
- * Sets the number of EOS tokens to append.
- *
- * @param numEos number of EOS tokens
- * @return This builder instance for method chaining
- */
- public Builder numEos(int numEos) {
- this.numEos = numEos;
- return this;
- }
-
- /**
- * Sets the load mode for the model file. Defaults to {@link #LOAD_MODE_MMAP} (mmap without
- * mlock), which avoids pinning model pages in RAM.
- *
- * @param loadMode One of LOAD_MODE_FILE, LOAD_MODE_MMAP, LOAD_MODE_MMAP_USE_MLOCK,
- * LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS
- * @return This builder instance for method chaining
- * @throws IllegalArgumentException if {@code loadMode} is not one of the supported constants
- */
- public Builder loadMode(int loadMode) {
- if (loadMode != LOAD_MODE_FILE
- && loadMode != LOAD_MODE_MMAP
- && loadMode != LOAD_MODE_MMAP_USE_MLOCK
- && loadMode != LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) {
- throw new IllegalArgumentException("Unknown load mode: " + loadMode);
- }
- this.loadMode = loadMode;
- return this;
- }
-
- /**
- * Constructs the LlmModuleConfig instance with validated parameters.
- *
- * @return New LlmModuleConfig instance with configured values
- * @throws IllegalArgumentException if required fields are missing
- */
- public LlmModuleConfig build() {
- if (modulePath == null || tokenizerPath == null) {
- throw new IllegalArgumentException("Module path and tokenizer path are required");
- }
- return new LlmModuleConfig(this);
- }
- }
-}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt
new file mode 100644
index 00000000000..a8a9d6065a8
--- /dev/null
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.executorch.extension.llm
+
+/**
+ * Configuration class for initializing a LlmModule.
+ *
+ * Use [create] method and the fluent builder pattern.
+ */
+class LlmModuleConfig
+private constructor(
+ val modulePath: String,
+ val tokenizerPath: String,
+ val temperature: Float,
+ val dataPath: String?,
+ val modelType: Int,
+ val numBos: Int,
+ val numEos: Int,
+ val loadMode: Int,
+) {
+
+ companion object {
+ /** Load entire model file into a buffer (no mmap). */
+ const val LOAD_MODE_FILE = 0
+
+ /** Load model via mmap without mlock (default). Pages faulted in on demand. */
+ const val LOAD_MODE_MMAP = 1
+
+ /** Load model via mmap and pin all pages with mlock. */
+ const val LOAD_MODE_MMAP_USE_MLOCK = 2
+
+ /** Load model via mmap and attempt mlock, ignoring mlock failures. */
+ const val LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3
+
+ /** Model type constant for text-only models. */
+ const val MODEL_TYPE_TEXT = 1
+
+ /** Model type constant for text-and-vision multimodal models. */
+ const val MODEL_TYPE_TEXT_VISION = 2
+
+ /** Model type constant for generic multimodal models. */
+ const val MODEL_TYPE_MULTIMODAL = 2
+
+ /**
+ * Creates a new Builder instance for constructing LlmModuleConfig objects.
+ *
+ * @return a new Builder instance with default configuration values
+ */
+ @JvmStatic fun create(): Builder = Builder()
+ }
+
+ /**
+ * Builder class for constructing LlmModuleConfig instances with optional parameters.
+ *
+ * The builder provides a fluent interface for configuring model parameters and validates required
+ * fields before construction.
+ */
+ class Builder internal constructor() {
+ internal var modulePath: String? = null
+ internal var tokenizerPath: String? = null
+ internal var temperature: Float = 0.8f
+ internal var dataPath: String? = ""
+ internal var modelType: Int = MODEL_TYPE_TEXT
+ internal var numBos: Int = 0
+ internal var numEos: Int = 0
+ internal var loadMode: Int = LOAD_MODE_MMAP
+
+ /** Sets the path to the module. */
+ fun modulePath(modulePath: String): Builder = apply { this.modulePath = modulePath }
+
+ /** Sets the path to the tokenizer. */
+ fun tokenizerPath(tokenizerPath: String): Builder = apply {
+ this.tokenizerPath = tokenizerPath
+ }
+
+ /** Sets the temperature for sampling generation. */
+ fun temperature(temperature: Float): Builder = apply { this.temperature = temperature }
+
+ /** Sets the path to optional additional data files. */
+ fun dataPath(dataPath: String?): Builder = apply { this.dataPath = dataPath }
+
+ /** Sets the model type (text-only or multimodal). */
+ fun modelType(modelType: Int): Builder = apply { this.modelType = modelType }
+
+ /** Sets the number of BOS tokens to prepend. */
+ fun numBos(numBos: Int): Builder = apply { this.numBos = numBos }
+
+ /** Sets the number of EOS tokens to append. */
+ fun numEos(numEos: Int): Builder = apply { this.numEos = numEos }
+
+ /**
+ * Sets the load mode for the model file. Defaults to [LOAD_MODE_MMAP] (mmap without mlock),
+ * which avoids pinning model pages in RAM.
+ *
+ * @throws IllegalArgumentException if loadMode is not one of the supported constants
+ */
+ fun loadMode(loadMode: Int): Builder {
+ require(
+ loadMode == LOAD_MODE_FILE ||
+ loadMode == LOAD_MODE_MMAP ||
+ loadMode == LOAD_MODE_MMAP_USE_MLOCK ||
+ loadMode == LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) {
+ "Unknown load mode: $loadMode"
+ }
+ return apply { this.loadMode = loadMode }
+ }
+
+ /**
+ * Constructs the LlmModuleConfig instance with validated parameters.
+ *
+ * @throws IllegalArgumentException if required fields are missing
+ */
+ fun build(): LlmModuleConfig {
+ require(modulePath != null && tokenizerPath != null) {
+ "Module path and tokenizer path are required"
+ }
+ return LlmModuleConfig(
+ modulePath!!,
+ tokenizerPath!!,
+ temperature,
+ dataPath,
+ modelType,
+ numBos,
+ numEos,
+ loadMode,
+ )
+ }
+ }
+}