diff --git a/extension/android/BUCK b/extension/android/BUCK index c7e275805e2..a0f3a6411d1 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -47,13 +47,13 @@ non_fbcode_target(_kind = fb_android_library, name = "executorch_llama", warnings_as_errors = False, srcs = [ - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java", - "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt", + "executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt", ], autoglob = False, - language = "JAVA", + language = "KOTLIN", deps = [ ":executorch", "//fbandroid/java/com/facebook/jni:jni", diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt similarity index 53% rename from extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java rename to extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt index 4e834d06721..3b56986bf14 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.kt @@ -6,45 +6,42 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.executorch.extension.llm; +package org.pytorch.executorch.extension.llm -import com.facebook.jni.annotations.DoNotStrip; -import org.pytorch.executorch.annotations.Experimental; +import com.facebook.jni.annotations.DoNotStrip +import org.pytorch.executorch.annotations.Experimental /** - * Callback interface for Llama model. Users can implement this interface to receive the generated + * Callback interface for Llm model. Users can implement this interface to receive the generated * tokens and statistics. * - *

Warning: These APIs are experimental and subject to change without notice + * Warning: These APIs are experimental and subject to change without notice */ @Experimental -public interface LlmCallback { +interface LlmCallback { /** * Called when a new result is available from JNI. Users will keep getting onResult() invocations * until generate() finishes. * * @param result Last generated token */ - @DoNotStrip - public void onResult(String result); + @DoNotStrip fun onResult(result: String) /** * Called when the statistics for the generate() is available. * - *

The result will be a JSON string. See extension/llm/stats.h for the field definitions. + * The result will be a JSON string. See extension/llm/stats.h for the field definitions. * * @param stats JSON string containing the statistics for the generate() */ - @DoNotStrip - default void onStats(String stats) {} + @DoNotStrip fun onStats(stats: String) {} /** * Called when an error occurs during generate(). * - * @param errorCode Error code from the ExecuTorch runtime (see {@link - * org.pytorch.executorch.ExecutorchRuntimeException}) + * @param errorCode Error code from the ExecuTorch runtime (see + * [org.pytorch.executorch.ExecutorchRuntimeException]) * @param message Human-readable error description */ - @DoNotStrip - default void onError(int errorCode, String message) {} + @DoNotStrip fun onError(errorCode: Int, message: String) {} } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java deleted file mode 100644 index db7941aadad..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for controlling text generation parameters in LLM operations. - * - *

This class provides settings for text generation behavior including output formatting, - * generation limits, and sampling parameters. Instances should be created using the {@link - * #create()} method and the fluent builder pattern. - */ -public class LlmGenerationConfig { - private final boolean echo; - private final int maxNewTokens; - private final boolean warming; - private final int seqLen; - private final float temperature; - private final int numBos; - private final int numEos; - - private LlmGenerationConfig(Builder builder) { - this.echo = builder.echo; - this.maxNewTokens = builder.maxNewTokens; - this.warming = builder.warming; - this.seqLen = builder.seqLen; - this.temperature = builder.temperature; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - } - - /** - * Creates a new Builder instance for constructing generation configurations. - * - * @return a new Builder with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - /** - * @return true if input prompt should be included in the output - */ - public boolean isEcho() { - return echo; - } - - /** - * @return maximum number of tokens to generate (-1 for unlimited) - */ - public int getMaxNewTokens() { - return maxNewTokens; - } - - /** - * @return true if model warming is enabled - */ - public boolean isWarming() { - return warming; - } - - /** - * @return maximum sequence length for generation (-1 for default) - */ - public int getSeqLen() { - return seqLen; - } - - /** - * @return temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * Builder class for constructing LlmGenerationConfig instances. - * - *

Provides a fluent interface for configuring generation parameters with sensible defaults. - * All methods return the builder instance to enable method chaining. - */ - public static class Builder { - private boolean echo = true; - private int maxNewTokens = -1; - private boolean warming = false; - private int seqLen = -1; - private float temperature = 0.8f; - private int numBos = 0; - private int numEos = 0; - - Builder() {} - - /** - * Sets whether to include the input prompt in the generated output. - * - * @param echo true to include input prompt, false to return only new tokens - * @return this builder instance - */ - public Builder echo(boolean echo) { - this.echo = echo; - return this; - } - - /** - * Sets the maximum number of new tokens to generate. - * - * @param maxNewTokens the token limit (-1 for unlimited generation) - * @return this builder instance - */ - public Builder maxNewTokens(int maxNewTokens) { - this.maxNewTokens = maxNewTokens; - return this; - } - - /** - * Enables or disables model warming. - * - * @param warming true to generate initial tokens for model warmup - * @return this builder instance - */ - public Builder warming(boolean warming) { - this.warming = warming; - return this; - } - - /** - * Sets the maximum sequence length for generation. - * - * @param seqLen maximum sequence length (-1 for default behavior) - * @return this builder instance - */ - public Builder seqLen(int seqLen) { - this.seqLen = seqLen; - return this; - } - - /** - * Sets the temperature for random sampling. - * - * @param temperature sampling temperature (typical range 0.0-1.0) - * @return this builder instance - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return this builder instance - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return this builder instance - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Constructs the LlmGenerationConfig instance with the configured parameters. - * - * @return new LlmGenerationConfig instance with current builder settings - */ - public LlmGenerationConfig build() { - return new LlmGenerationConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt new file mode 100644 index 00000000000..c0f8956fb7f --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmGenerationConfig.kt @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for controlling text generation parameters in LLM operations. + * + * This class provides settings for text generation behavior including output formatting, generation + * limits, and sampling parameters. Instances should be created using the [create] method and the + * fluent builder pattern. + */ +class LlmGenerationConfig +private constructor( + @get:JvmName("isEcho") val echo: Boolean, + val maxNewTokens: Int, + @get:JvmName("isWarming") val warming: Boolean, + val seqLen: Int, + val temperature: Float, + val numBos: Int, + val numEos: Int, +) { + + companion object { + /** + * Creates a new Builder instance for constructing generation configurations. + * + * @return a new Builder with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmGenerationConfig instances. + * + * Provides a fluent interface for configuring generation parameters with sensible defaults. All + * methods return the builder instance to enable method chaining. + */ + class Builder internal constructor() { + private var echo: Boolean = true + private var maxNewTokens: Int = -1 + private var warming: Boolean = false + private var seqLen: Int = -1 + private var temperature: Float = 0.8f + private var numBos: Int = 0 + private var numEos: Int = 0 + + /** Sets whether to include the input prompt in the generated output. */ + fun echo(echo: Boolean): Builder = apply { this.echo = echo } + + /** Sets the maximum number of new tokens to generate. */ + fun maxNewTokens(maxNewTokens: Int): Builder = apply { this.maxNewTokens = maxNewTokens } + + /** Enables or disables model warming. */ + fun warming(warming: Boolean): Builder = apply { this.warming = warming } + + /** Sets the maximum sequence length for generation. */ + fun seqLen(seqLen: Int): Builder = apply { this.seqLen = seqLen } + + /** Sets the temperature for random sampling. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** Constructs the LlmGenerationConfig instance with the configured parameters. */ + fun build(): LlmGenerationConfig = + LlmGenerationConfig(echo, maxNewTokens, warming, seqLen, temperature, numBos, numEos) + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java deleted file mode 100644 index a563dc6bcc7..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ /dev/null @@ -1,668 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; -import java.nio.ByteBuffer; -import java.util.List; -import org.pytorch.executorch.ExecuTorchRuntime; -import org.pytorch.executorch.annotations.Experimental; - -/** - * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text - * from the model. - * - *

Warning: These APIs are experimental and subject to change without notice - */ -@Experimental -public class LlmModule { - - public static final int MODEL_TYPE_TEXT = 1; - public static final int MODEL_TYPE_TEXT_VISION = 2; - public static final int MODEL_TYPE_MULTIMODAL = 2; - - private final HybridData mHybridData; - private static final int DEFAULT_SEQ_LEN = 128; - private static final boolean DEFAULT_ECHO = true; - private static final float DEFAULT_TEMPERATURE = -1.0f; - private static final int DEFAULT_BOS = 0; - private static final int DEFAULT_EOS = 0; - private static final int DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP; - - @DoNotStrip - private static native HybridData initHybrid( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode); - - private LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos, - int loadMode) { - ExecuTorchRuntime.getRuntime(); - ExecuTorchRuntime.validateFilePath(modulePath, "model path"); - ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path"); - - mHybridData = - initHybrid( - modelType, modulePath, tokenizerPath, temperature, dataFiles, numBos, numEos, loadMode); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - numBos, - numEos, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * dataFiles. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - List dataFiles) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataFiles, - DEFAULT_BOS, - DEFAULT_EOS, - DEFAULT_LOAD_MODE); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, - String modulePath, - String tokenizerPath, - float temperature, - String dataPath, - int numBos, - int numEos) { - this( - modelType, - modulePath, - tokenizerPath, - temperature, - dataPath != null ? List.of(dataPath) : List.of(), - numBos, - numEos); - } - - /** - * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. - */ - public LlmModule( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { - this(modelType, modulePath, tokenizerPath, temperature, dataPath, DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ - public LlmModule(String modulePath, String tokenizerPath, float temperature) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data - * path. - */ - public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - this( - MODEL_TYPE_TEXT, - modulePath, - tokenizerPath, - temperature, - List.of(dataPath), - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ - public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - this(modelType, modulePath, tokenizerPath, temperature, List.of(), DEFAULT_BOS, DEFAULT_EOS); - } - - /** Constructs a LLM Module for a model with the given LlmModuleConfig */ - public LlmModule(LlmModuleConfig config) { - this( - config.getModelType(), - config.getModulePath(), - config.getTokenizerPath(), - config.getTemperature(), - config.getDataPath() != null ? List.of(config.getDataPath()) : List.of(), - config.getNumBos(), - config.getNumEos(), - config.getLoadMode()); - } - - public void resetNative() { - mHybridData.resetNative(); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results. - */ - public int generate(String prompt, LlmCallback llmCallback) { - return generate( - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback) { - return generate( - null, - 0, - 0, - 0, - prompt, - seqLen, - llmCallback, - DEFAULT_ECHO, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public int generate(String prompt, LlmCallback llmCallback, boolean echo) { - return generate( - null, - 0, - 0, - 0, - prompt, - DEFAULT_SEQ_LEN, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate( - prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public native int generate( - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos); - - /** - * Start generating tokens from the module. - * - * @param prompt Input prompt - * @param config the config for generation - * @param llmCallback callback object to receive results - */ - public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCallback) { - int seqLen = config.getSeqLen(); - boolean echo = config.isEcho(); - float temperature = config.getTemperature(); - int numBos = config.getNumBos(); - int numEos = config.getNumEos(); - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - */ - public int generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo) { - return generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - DEFAULT_TEMPERATURE, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - */ - public int generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature) { - return generate( - image, - width, - height, - channels, - prompt, - seqLen, - llmCallback, - echo, - temperature, - DEFAULT_BOS, - DEFAULT_EOS); - } - - /** - * Start generating tokens from the module. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @param prompt Input prompt - * @param seqLen sequence length - * @param llmCallback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param temperature temperature for sampling (use negative value to use module default) - * @param numBos number of BOS tokens to prepend - * @param numEos number of EOS tokens to append - */ - public int generate( - int[] image, - int width, - int height, - int channels, - String prompt, - int seqLen, - LlmCallback llmCallback, - boolean echo, - float temperature, - int numBos, - int numEos) { - if (image != null) { - prefillImages(image, width, height, channels); - } - return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos); - } - - /** - * Prefill the KV cache with the given image input. - * - * @param image Input image as a byte array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillImages(int[] image, int width, int height, int channels) { - int nativeResult = prefillImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - /** - * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data - * is accessed directly without JNI array copies, unlike {@link #prefillImages(int[], int, int, - * int)}. The ByteBuffer must contain raw uint8 pixel data in CHW format with at least channels * - * height * width bytes remaining. Only the first channels * height * width bytes from the - * buffer's current position are read; the position of the original ByteBuffer is not modified. - * - * @param image Input image as a direct ByteBuffer containing uint8 pixel data - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining - * bytes - * @throws RuntimeException if the prefill failed - */ - @Experimental - public void prefillImages(ByteBuffer image, int width, int height, int channels) { - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - long expectedBytes; - try { - long pixels = Math.multiplyExact((long) width, (long) height); - expectedBytes = Math.multiplyExact(pixels, (long) channels); - } catch (ArithmeticException ex) { - throw new IllegalArgumentException( - "width*height*channels is too large and overflows the allowed range.", ex); - } - if (width <= 0 - || height <= 0 - || channels <= 0 - || expectedBytes > Integer.MAX_VALUE - || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels (" - + expectedBytes - + ")."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - } - - /** - * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The - * buffer data is accessed directly without JNI array copies, unlike {@link - * #prefillImages(float[], int, int, int)}. The ByteBuffer must contain normalized float pixel - * data in CHW format with at least channels * height * width * 4 bytes remaining. Only the first - * channels * height * width floats from the buffer's current position are consumed. The buffer - * must use the platform's native byte order (set via {@code - * buffer.order(ByteOrder.nativeOrder())}). - * - * @param image Input normalized image as a direct ByteBuffer containing float pixel data in - * native byte order - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining - * bytes, is not float-aligned, or does not use native byte order - * @throws RuntimeException if the prefill failed - */ - @Experimental - public void prefillNormalizedImage(ByteBuffer image, int width, int height, int channels) { - if (!image.isDirect()) { - throw new IllegalArgumentException("Input ByteBuffer must be direct."); - } - if (image.order() != java.nio.ByteOrder.nativeOrder()) { - throw new IllegalArgumentException( - "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())."); - } - if (image.position() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "Input ByteBuffer position (" + image.position() + ") must be 4-byte aligned."); - } - final long expectedBytes; - try { - int wh = Math.multiplyExact(width, height); - long whc = Math.multiplyExact((long) wh, (long) channels); - long totalBytes = Math.multiplyExact(whc, (long) Float.BYTES); - if (totalBytes > Integer.MAX_VALUE) { - throw new IllegalArgumentException( - "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: " - + totalBytes); - } - expectedBytes = totalBytes; - } catch (ArithmeticException e) { - throw new IllegalArgumentException( - "Overflow while computing width*height*channels*4 for ByteBuffer size.", e); - } - if (width <= 0 || height <= 0 || channels <= 0 || image.remaining() < expectedBytes) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" - + image.remaining() - + ") must be at least width*height*channels*4 (" - + expectedBytes - + ")."); - } - if (image.remaining() % Float.BYTES != 0) { - throw new IllegalArgumentException( - "ByteBuffer remaining (" + image.remaining() + ") must be a multiple of 4 (float size)."); - } - // slice() so that getDirectBufferAddress on the native side returns a pointer - // starting at the current position, not the base address. - int nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - } - - private native int prefillImagesInput(int[] image, int width, int height, int channels); - - private native int prefillImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - private native int prefillNormalizedImagesInputBuffer( - ByteBuffer image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given normalized image input. - * - * @param image Input normalized image as a float array - * @param width Input image width - * @param height Input image height - * @param channels Input image number of channels - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillImages(float[] image, int width, int height, int channels) { - int nativeResult = prefillNormalizedImagesInput(image, width, height, channels); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - private native int prefillNormalizedImagesInput( - float[] image, int width, int height, int channels); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a byte array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = prefillAudioInput(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - private native int prefillAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given preprocessed audio input. - * - * @param audio Input preprocessed audio as a float array - * @param batch_size Input batch size - * @param n_bins Input number of bins - * @param n_frames Input number of frames - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = prefillAudioInputFloat(audio, batch_size, n_bins, n_frames); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - private native int prefillAudioInputFloat( - float[] audio, int batch_size, int n_bins, int n_frames); - - /** - * Prefill the KV cache with the given raw audio input. - * - * @param audio Input raw audio as a byte array - * @param batch_size Input batch size - * @param n_channels Input number of channels - * @param n_samples Input number of samples - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - int nativeResult = prefillRawAudioInput(audio, batch_size, n_channels, n_samples); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - private native int prefillRawAudioInput( - byte[] audio, int batch_size, int n_channels, int n_samples); - - /** - * Prefill the KV cache with the given text prompt. - * - * @param prompt The text prompt to prefill. - * @return 0 on success - * @throws RuntimeException if the prefill failed - */ - @Experimental - public long prefillPrompt(String prompt) { - int nativeResult = prefillTextInput(prompt); - if (nativeResult != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult); - } - return 0; - } - - // returns status - private native int prefillTextInput(String prompt); - - /** - * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. - * - *

The startPos will be reset to 0. - */ - public native void resetContext(); - - /** Stop current generate() before it finishes. */ - @DoNotStrip - public native void stop(); - - /** Force loading the module. Otherwise the model is loaded during first generate(). */ - @DoNotStrip - public native int load(); -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt new file mode 100644 index 00000000000..c1e9e62e5e1 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.kt @@ -0,0 +1,733 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +import com.facebook.jni.HybridData +import com.facebook.jni.annotations.DoNotStrip +import java.nio.ByteBuffer +import java.nio.ByteOrder +import org.pytorch.executorch.ExecuTorchRuntime +import org.pytorch.executorch.annotations.Experimental + +/** + * LlmModule is a wrapper around the Executorch LLM. It provides a simple interface to generate text + * from the model. + * + * Warning: These APIs are experimental and subject to change without notice + */ +@Experimental +class LlmModule +private constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, +) { + + private val mHybridData: HybridData + + init { + ExecuTorchRuntime.getRuntime() + requireNotNull(modulePath) { "model path must not be null" } + requireNotNull(tokenizerPath) { "tokenizer path must not be null" } + ExecuTorchRuntime.validateFilePath(modulePath, "model path") + ExecuTorchRuntime.validateFilePath(tokenizerPath, "tokenizer path") + mHybridData = + initHybrid( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + loadMode, + ) + } + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + numBos, + numEos, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * dataFiles. + */ + constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataFiles: List, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataFiles, + DEFAULT_BOS, + DEFAULT_EOS, + DEFAULT_LOAD_MODE, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataPath: String?, + numBos: Int, + numEos: Int, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + listOfNotNull(dataPath), + numBos, + numEos, + ) + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataPath: String?, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + dataPath, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ + constructor( + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Constructs a LLM Module for a model with given model path, tokenizer, temperature and data + * path. + */ + constructor( + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + dataPath: String, + ) : this( + MODEL_TYPE_TEXT, + modulePath, + tokenizerPath, + temperature, + listOf(dataPath), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ + constructor( + modelType: Int, + modulePath: String?, + tokenizerPath: String?, + temperature: Float, + ) : this( + modelType, + modulePath, + tokenizerPath, + temperature, + emptyList(), + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** Constructs a LLM Module for a model with the given LlmModuleConfig */ + constructor(config: LlmModuleConfig) : this( + config.modelType, + config.modulePath, + config.tokenizerPath, + config.temperature, + listOfNotNull(config.dataPath), + config.numBos, + config.numEos, + config.loadMode, + ) + + fun resetNative() { + mHybridData.resetNative() + } + + // --- generate overloads --- + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, llmCallback: LlmCallback): Int = + generate( + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback): Int = + generate( + null, + 0, + 0, + 0, + prompt, + seqLen, + llmCallback, + DEFAULT_ECHO, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, llmCallback: LlmCallback, echo: Boolean): Int = + generate( + null, + 0, + 0, + 0, + prompt, + DEFAULT_SEQ_LEN, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate(prompt: String, seqLen: Int, llmCallback: LlmCallback, echo: Boolean): Int = + generate(prompt, seqLen, llmCallback, echo, DEFAULT_TEMPERATURE, DEFAULT_BOS, DEFAULT_EOS) + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + external fun generate( + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ): Int + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param config the config for generation + * @param llmCallback callback object to receive results + */ + fun generate(prompt: String, config: LlmGenerationConfig, llmCallback: LlmCallback): Int = + generate( + null, + 0, + 0, + 0, + prompt, + config.seqLen, + llmCallback, + config.echo, + config.temperature, + config.numBos, + config.numEos, + ) + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + ): Int = + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + DEFAULT_TEMPERATURE, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + ): Int = + generate( + image, + width, + height, + channels, + prompt, + seqLen, + llmCallback, + echo, + temperature, + DEFAULT_BOS, + DEFAULT_EOS, + ) + + /** + * Start generating tokens from the module. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param prompt Input prompt + * @param seqLen sequence length + * @param llmCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param temperature temperature for sampling (use negative value to use module default) + * @param numBos number of BOS tokens to prepend + * @param numEos number of EOS tokens to append + */ + fun generate( + image: IntArray?, + width: Int, + height: Int, + channels: Int, + prompt: String, + seqLen: Int, + llmCallback: LlmCallback, + echo: Boolean, + temperature: Float, + numBos: Int, + numEos: Int, + ): Int { + if (image != null) { + prefillImages(image, width, height, channels) + } + return generate(prompt, seqLen, llmCallback, echo, temperature, numBos, numEos) + } + + // --- prefill methods --- + + /** + * Prefill the KV cache with the given image input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: IntArray, width: Int, height: Int, channels: Int): Long { + val nativeResult = prefillImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + /** + * Prefill a multimodal Module with the given image input via a direct ByteBuffer. The buffer data + * is accessed directly without JNI array copies, unlike [prefillImages]. The ByteBuffer must + * contain raw uint8 pixel data in CHW format with at least channels * height * width bytes + * remaining. Only the first channels * height * width bytes from the buffer's current position + * are read; the position of the original ByteBuffer is not modified. + * + * @param image Input image as a direct ByteBuffer containing uint8 pixel data + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct or has insufficient remaining + * bytes + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: ByteBuffer, width: Int, height: Int, channels: Int) { + require(image.isDirect) { "Input ByteBuffer must be direct." } + val expectedBytes: Long + try { + val pixels = Math.multiplyExact(width.toLong(), height.toLong()) + expectedBytes = Math.multiplyExact(pixels, channels.toLong()) + } catch (ex: ArithmeticException) { + throw IllegalArgumentException( + "width*height*channels is too large and overflows the allowed range.", ex) + } + require( + width > 0 && + height > 0 && + channels > 0 && + expectedBytes <= Int.MAX_VALUE && + image.remaining() >= expectedBytes) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels ($expectedBytes)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + } + + /** + * Prefill a multimodal Module with the given normalized image input via a direct ByteBuffer. The + * buffer data is accessed directly without JNI array copies, unlike [prefillImages]. The + * ByteBuffer must contain normalized float pixel data in CHW format with at least channels * + * height * width * 4 bytes remaining. Only the first channels * height * width floats from the + * buffer's current position are consumed. The buffer must use the platform's native byte order + * (set via `buffer.order(ByteOrder.nativeOrder())`). + * + * @param image Input normalized image as a direct ByteBuffer containing float pixel data in + * native byte order + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @throws IllegalArgumentException if the ByteBuffer is not direct, has insufficient remaining + * bytes, is not float-aligned, or does not use native byte order + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillNormalizedImage(image: ByteBuffer, width: Int, height: Int, channels: Int) { + require(image.isDirect) { "Input ByteBuffer must be direct." } + require(image.order() == ByteOrder.nativeOrder()) { + "Input ByteBuffer must use native byte order (ByteOrder.nativeOrder())." + } + require(image.position() % Float.SIZE_BYTES == 0) { + "Input ByteBuffer position (${image.position()}) must be 4-byte aligned." + } + val expectedBytes: Long + try { + val wh = Math.multiplyExact(width, height) + val whc = Math.multiplyExact(wh.toLong(), channels.toLong()) + val totalBytes = Math.multiplyExact(whc, Float.SIZE_BYTES.toLong()) + if (totalBytes > Int.MAX_VALUE) { + throw IllegalArgumentException( + "ByteBuffer size (width*height*channels*4) exceeds Integer.MAX_VALUE bytes: $totalBytes") + } + expectedBytes = totalBytes + } catch (e: ArithmeticException) { + throw IllegalArgumentException( + "Overflow while computing width*height*channels*4 for ByteBuffer size.", e) + } + require(width > 0 && height > 0 && channels > 0 && image.remaining() >= expectedBytes) { + "ByteBuffer remaining (${image.remaining()}) must be at least width*height*channels*4 ($expectedBytes)." + } + require(image.remaining() % Float.SIZE_BYTES == 0) { + "ByteBuffer remaining (${image.remaining()}) must be a multiple of 4 (float size)." + } + // slice() so that getDirectBufferAddress on the native side returns a pointer + // starting at the current position, not the base address. + val nativeResult = prefillNormalizedImagesInputBuffer(image.slice(), width, height, channels) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + } + + private external fun prefillImagesInput( + image: IntArray, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + private external fun prefillNormalizedImagesInputBuffer( + image: ByteBuffer, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given normalized image input. + * + * @param image Input normalized image as a float array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillImages(image: FloatArray, width: Int, height: Int, channels: Int): Long { + val nativeResult = prefillNormalizedImagesInput(image, width, height, channels) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + private external fun prefillNormalizedImagesInput( + image: FloatArray, + width: Int, + height: Int, + channels: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a byte array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: ByteArray, batchSize: Int, nBins: Int, nFrames: Int): Long { + val nativeResult = prefillAudioInput(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + private external fun prefillAudioInput( + audio: ByteArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given preprocessed audio input. + * + * @param audio Input preprocessed audio as a float array + * @param batchSize Input batch size + * @param nBins Input number of bins + * @param nFrames Input number of frames + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillAudio(audio: FloatArray, batchSize: Int, nBins: Int, nFrames: Int): Long { + val nativeResult = prefillAudioInputFloat(audio, batchSize, nBins, nFrames) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + private external fun prefillAudioInputFloat( + audio: FloatArray, + batchSize: Int, + nBins: Int, + nFrames: Int, + ): Int + + /** + * Prefill the KV cache with the given raw audio input. + * + * @param audio Input raw audio as a byte array + * @param batchSize Input batch size + * @param nChannels Input number of channels + * @param nSamples Input number of samples + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillRawAudio(audio: ByteArray, batchSize: Int, nChannels: Int, nSamples: Int): Long { + val nativeResult = prefillRawAudioInput(audio, batchSize, nChannels, nSamples) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + private external fun prefillRawAudioInput( + audio: ByteArray, + batchSize: Int, + nChannels: Int, + nSamples: Int, + ): Int + + /** + * Prefill the KV cache with the given text prompt. + * + * @param prompt The text prompt to prefill + * @return 0 on success + * @throws RuntimeException if the prefill failed + */ + @Experimental + fun prefillPrompt(prompt: String): Long { + val nativeResult = prefillTextInput(prompt) + if (nativeResult != 0) { + throw RuntimeException("Prefill failed with error code: $nativeResult") + } + return 0 + } + + // returns status + private external fun prefillTextInput(prompt: String): Int + + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. The + * startPos will be reset to 0. + */ + external fun resetContext() + + /** Stop current generate() before it finishes. */ + @DoNotStrip external fun stop() + + /** Force loading the module. Otherwise the model is loaded during first generate(). */ + @DoNotStrip external fun load(): Int + + companion object { + const val MODEL_TYPE_TEXT = 1 + const val MODEL_TYPE_TEXT_VISION = 2 + const val MODEL_TYPE_MULTIMODAL = 2 + + private const val DEFAULT_SEQ_LEN = 128 + private const val DEFAULT_ECHO = true + private const val DEFAULT_TEMPERATURE = -1.0f + private const val DEFAULT_BOS = 0 + private const val DEFAULT_EOS = 0 + private const val DEFAULT_LOAD_MODE = LlmModuleConfig.LOAD_MODE_MMAP + + @DoNotStrip + @JvmStatic + private external fun initHybrid( + modelType: Int, + modulePath: String, + tokenizerPath: String, + temperature: Float, + dataFiles: List, + numBos: Int, + numEos: Int, + loadMode: Int, + ): HybridData + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java deleted file mode 100644 index feb52a2b34b..00000000000 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.java +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch.extension.llm; - -/** - * Configuration class for initializing a LlmModule. - * - *

{@link #create()} method and the fluent builder pattern. - */ -public class LlmModuleConfig { - private final String modulePath; - private final String tokenizerPath; - private final float temperature; - private final String dataPath; - private final int modelType; - private final int numBos; - private final int numEos; - private final int loadMode; - - /** Load entire model file into a buffer (no mmap). */ - public static final int LOAD_MODE_FILE = 0; - - /** Load model via mmap without mlock (default). Pages faulted in on demand. */ - public static final int LOAD_MODE_MMAP = 1; - - /** Load model via mmap and pin all pages with mlock. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK = 2; - - /** Load model via mmap and attempt mlock, ignoring mlock failures. */ - public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - - private LlmModuleConfig(Builder builder) { - this.modulePath = builder.modulePath; - this.tokenizerPath = builder.tokenizerPath; - this.temperature = builder.temperature; - this.dataPath = builder.dataPath; - this.modelType = builder.modelType; - this.numBos = builder.numBos; - this.numEos = builder.numEos; - this.loadMode = builder.loadMode; - } - - /** Model type constant for text-only models. */ - public static final int MODEL_TYPE_TEXT = 1; - - /** Model type constant for text-and-vision multimodal models. */ - public static final int MODEL_TYPE_TEXT_VISION = 2; - - /** Model type constant for generic multimodal models. */ - public static final int MODEL_TYPE_MULTIMODAL = 2; - - /** - * Creates a new Builder instance for constructing LlmModuleConfig objects. - * - * @return a new Builder instance with default configuration values - */ - public static Builder create() { - return new Builder(); - } - - // Getters with documentation - /** - * @return Path to the compiled model module (.pte file) - */ - public String getModulePath() { - return modulePath; - } - - /** - * @return Path to the tokenizer file or directory - */ - public String getTokenizerPath() { - return tokenizerPath; - } - - /** - * @return Temperature value for sampling (higher = more random) - */ - public float getTemperature() { - return temperature; - } - - /** - * @return Optional path to additional data files - */ - public String getDataPath() { - return dataPath; - } - - /** - * @return Type of model (text-only or text-vision) - */ - public int getModelType() { - return modelType; - } - - /** - * @return Number of BOS tokens to prepend - */ - public int getNumBos() { - return numBos; - } - - /** - * @return Number of EOS tokens to append - */ - public int getNumEos() { - return numEos; - } - - /** - * @return Load mode for the model file (one of LOAD_MODE_* constants) - */ - public int getLoadMode() { - return loadMode; - } - - /** - * Builder class for constructing LlmModuleConfig instances with optional parameters. - * - *

The builder provides a fluent interface for configuring model parameters and validates - * required fields before construction. - */ - public static class Builder { - private String modulePath; - private String tokenizerPath; - private float temperature = 0.8f; - private String dataPath = ""; - private int modelType = MODEL_TYPE_TEXT; - private int numBos = 0; - private int numEos = 0; - private int loadMode = LOAD_MODE_MMAP; - - Builder() {} - - /** - * Sets the path to the module. - * - * @param modulePath Path to module - * @return This builder instance for method chaining - */ - public Builder modulePath(String modulePath) { - this.modulePath = modulePath; - return this; - } - - /** - * Sets the path to the tokenizer. - * - * @param tokenizerPath Path to tokenizer - * @return This builder instance for method chaining - */ - public Builder tokenizerPath(String tokenizerPath) { - this.tokenizerPath = tokenizerPath; - return this; - } - - /** - * Sets the temperature for sampling generation. - * - * @param temperature Temperature value (typical range 0.0-1.0) - * @return This builder instance for method chaining - */ - public Builder temperature(float temperature) { - this.temperature = temperature; - return this; - } - - /** - * Sets the path to optional additional data files. - * - * @param dataPath Path to supplementary data resources - * @return This builder instance for method chaining - */ - public Builder dataPath(String dataPath) { - this.dataPath = dataPath; - return this; - } - - /** - * Sets the model type (text-only or multimodal). - * - * @param modelType One of MODEL_TYPE_TEXT, MODEL_TYPE_TEXT_VISION, MODEL_TYPE_MULTIMODAL - * @return This builder instance for method chaining - */ - public Builder modelType(int modelType) { - this.modelType = modelType; - return this; - } - - /** - * Sets the number of BOS tokens to prepend. - * - * @param numBos number of BOS tokens - * @return This builder instance for method chaining - */ - public Builder numBos(int numBos) { - this.numBos = numBos; - return this; - } - - /** - * Sets the number of EOS tokens to append. - * - * @param numEos number of EOS tokens - * @return This builder instance for method chaining - */ - public Builder numEos(int numEos) { - this.numEos = numEos; - return this; - } - - /** - * Sets the load mode for the model file. Defaults to {@link #LOAD_MODE_MMAP} (mmap without - * mlock), which avoids pinning model pages in RAM. - * - * @param loadMode One of LOAD_MODE_FILE, LOAD_MODE_MMAP, LOAD_MODE_MMAP_USE_MLOCK, - * LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS - * @return This builder instance for method chaining - * @throws IllegalArgumentException if {@code loadMode} is not one of the supported constants - */ - public Builder loadMode(int loadMode) { - if (loadMode != LOAD_MODE_FILE - && loadMode != LOAD_MODE_MMAP - && loadMode != LOAD_MODE_MMAP_USE_MLOCK - && loadMode != LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) { - throw new IllegalArgumentException("Unknown load mode: " + loadMode); - } - this.loadMode = loadMode; - return this; - } - - /** - * Constructs the LlmModuleConfig instance with validated parameters. - * - * @return New LlmModuleConfig instance with configured values - * @throws IllegalArgumentException if required fields are missing - */ - public LlmModuleConfig build() { - if (modulePath == null || tokenizerPath == null) { - throw new IllegalArgumentException("Module path and tokenizer path are required"); - } - return new LlmModuleConfig(this); - } - } -} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt new file mode 100644 index 00000000000..a8a9d6065a8 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModuleConfig.kt @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch.extension.llm + +/** + * Configuration class for initializing a LlmModule. + * + * Use [create] method and the fluent builder pattern. + */ +class LlmModuleConfig +private constructor( + val modulePath: String, + val tokenizerPath: String, + val temperature: Float, + val dataPath: String?, + val modelType: Int, + val numBos: Int, + val numEos: Int, + val loadMode: Int, +) { + + companion object { + /** Load entire model file into a buffer (no mmap). */ + const val LOAD_MODE_FILE = 0 + + /** Load model via mmap without mlock (default). Pages faulted in on demand. */ + const val LOAD_MODE_MMAP = 1 + + /** Load model via mmap and pin all pages with mlock. */ + const val LOAD_MODE_MMAP_USE_MLOCK = 2 + + /** Load model via mmap and attempt mlock, ignoring mlock failures. */ + const val LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3 + + /** Model type constant for text-only models. */ + const val MODEL_TYPE_TEXT = 1 + + /** Model type constant for text-and-vision multimodal models. */ + const val MODEL_TYPE_TEXT_VISION = 2 + + /** Model type constant for generic multimodal models. */ + const val MODEL_TYPE_MULTIMODAL = 2 + + /** + * Creates a new Builder instance for constructing LlmModuleConfig objects. + * + * @return a new Builder instance with default configuration values + */ + @JvmStatic fun create(): Builder = Builder() + } + + /** + * Builder class for constructing LlmModuleConfig instances with optional parameters. + * + * The builder provides a fluent interface for configuring model parameters and validates required + * fields before construction. + */ + class Builder internal constructor() { + internal var modulePath: String? = null + internal var tokenizerPath: String? = null + internal var temperature: Float = 0.8f + internal var dataPath: String? = "" + internal var modelType: Int = MODEL_TYPE_TEXT + internal var numBos: Int = 0 + internal var numEos: Int = 0 + internal var loadMode: Int = LOAD_MODE_MMAP + + /** Sets the path to the module. */ + fun modulePath(modulePath: String): Builder = apply { this.modulePath = modulePath } + + /** Sets the path to the tokenizer. */ + fun tokenizerPath(tokenizerPath: String): Builder = apply { + this.tokenizerPath = tokenizerPath + } + + /** Sets the temperature for sampling generation. */ + fun temperature(temperature: Float): Builder = apply { this.temperature = temperature } + + /** Sets the path to optional additional data files. */ + fun dataPath(dataPath: String?): Builder = apply { this.dataPath = dataPath } + + /** Sets the model type (text-only or multimodal). */ + fun modelType(modelType: Int): Builder = apply { this.modelType = modelType } + + /** Sets the number of BOS tokens to prepend. */ + fun numBos(numBos: Int): Builder = apply { this.numBos = numBos } + + /** Sets the number of EOS tokens to append. */ + fun numEos(numEos: Int): Builder = apply { this.numEos = numEos } + + /** + * Sets the load mode for the model file. Defaults to [LOAD_MODE_MMAP] (mmap without mlock), + * which avoids pinning model pages in RAM. + * + * @throws IllegalArgumentException if loadMode is not one of the supported constants + */ + fun loadMode(loadMode: Int): Builder { + require( + loadMode == LOAD_MODE_FILE || + loadMode == LOAD_MODE_MMAP || + loadMode == LOAD_MODE_MMAP_USE_MLOCK || + loadMode == LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS) { + "Unknown load mode: $loadMode" + } + return apply { this.loadMode = loadMode } + } + + /** + * Constructs the LlmModuleConfig instance with validated parameters. + * + * @throws IllegalArgumentException if required fields are missing + */ + fun build(): LlmModuleConfig { + require(modulePath != null && tokenizerPath != null) { + "Module path and tokenizer path are required" + } + return LlmModuleConfig( + modulePath!!, + tokenizerPath!!, + temperature, + dataPath, + modelType, + numBos, + numEos, + loadMode, + ) + } + } +}