Skip to content

Commit ae8750c

Browse files
authored
Merge pull request #2 from mikepapadim/main
Change defaults max tokens to 1024 and verbose on interactive mode
2 parents a61a109 + f36972e commit ae8750c

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

src/main/java/com/example/LlamaApp.java

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import com.example.inference.engine.impl.Options;
1111
import com.example.loader.weights.ModelLoader;
1212
import com.example.loader.weights.State;
13-
import com.example.tokenizer.impl.Tokenizer;
1413
import com.example.tornadovm.FloatArrayUtils;
1514
import com.example.tornadovm.TornadoVMMasterPlan;
1615
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
@@ -29,7 +28,8 @@ public class LlamaApp {
2928
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
3029
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
3130
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
32-
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "false")); // Show performance metrics in interactive mode
31+
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
32+
3333
/**
3434
* Creates and configures a sampler for token generation based on specified parameters.
3535
*
@@ -115,7 +115,6 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
115115
return sampler;
116116
}
117117

118-
119118
static void runInteractive(Llama model, Sampler sampler, Options options) {
120119
State state = null;
121120
List<Integer> conversationTokens = new ArrayList<>();
@@ -162,15 +161,12 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
162161
// Choose between GPU and CPU path based on configuration
163162
if (USE_TORNADOVM) {
164163
// GPU path using TornadoVM
165-
responseTokens = Llama.generateTokensGPU(model, state, startPosition,
166-
conversationTokens.subList(startPosition, conversationTokens.size()),
167-
stopTokens, options.maxTokens(), sampler, options.echo(),
168-
options.stream() ? tokenConsumer : null, tornadoVMPlan);
164+
responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
165+
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
169166
} else {
170167
// CPU path
171-
responseTokens = Llama.generateTokens(model, state, startPosition,
172-
conversationTokens.subList(startPosition, conversationTokens.size()),
173-
stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
168+
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
169+
options.echo(), tokenConsumer);
174170
}
175171

176172
// Include stop token in the prompt history, but not in the response displayed to the user.
@@ -211,7 +207,7 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
211207
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
212208
State state = model.createNewState();
213209
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
214-
TornadoVMMasterPlan tornadoVMPlan =null;
210+
TornadoVMMasterPlan tornadoVMPlan = null;
215211

216212
List<Integer> promptTokens = new ArrayList<>();
217213
promptTokens.add(chatFormat.beginOfText);
@@ -233,10 +229,9 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {
233229

234230
Set<Integer> stopTokens = chatFormat.getStopTokens();
235231
if (USE_TORNADOVM) {
236-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
232+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
237233
// Call generateTokensGPU without the token consumer parameter
238-
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
239-
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
234+
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
240235
} else {
241236
// CPU path still uses the token consumer
242237
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);

src/main/java/com/example/inference/engine/impl/Options.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
88
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
99

10-
public static final int DEFAULT_MAX_TOKENS = 512;
10+
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
1313
require(modelPath != null, "Missing argument: --model <path> is required");

0 commit comments

Comments
 (0)