-
Notifications
You must be signed in to change notification settings - Fork 30
Expand file tree
/
Copy pathOptions.java
More file actions
103 lines (95 loc) · 5.3 KB
/
Options.java
File metadata and controls
103 lines (95 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
package com.example;
import java.io.PrintStream;
import java.nio.file.Path;
import java.nio.file.Paths;
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, boolean guiMode) {
public static final int DEFAULT_MAX_TOKENS = 1024;
public Options {
require(modelPath != null, "Missing argument: --model <path> is required");
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
static void require(boolean condition, String messageFormat, Object... args) {
if (!condition) {
System.out.println("ERROR " + messageFormat.formatted(args));
System.out.println();
printUsage(System.out);
System.exit(-1);
}
}
static void printUsage(PrintStream out) {
out.println("Usage: jbang Llama3.java [options]");
out.println();
out.println("Options:");
out.println(" --model, -m <path> required, path to .gguf file");
out.println(" --interactive, --chat, -i run in chat mode");
out.println(" --instruct run in instruct (once) mode, default mode");
out.println(" --prompt, -p <string> input prompt");
out.println(" --system-prompt, -sp <string> (optional) system prompt (Llama models)");
out.println(" --suffix <string> suffix for fill-in-the-middle request (Codestral)");
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
out.println(" --seed <long> random seed, default System.nanoTime()");
out.println(" --max-tokens, -n <int> number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS);
out.println(" --stream <boolean> print tokens during generation; may cause encoding artifacts for non ASCII text, default true");
out.println(" --echo <boolean> print ALL tokens to stderr, if true, recommended to set --stream=false, default false");
out.println(" --gui <boolean> run the GUI chatbox");
out.println();
}
public static Options parseOptions(String[] args) {
String prompt = "Tell me a story with Java"; // Hardcoded for testing
String systemPrompt = null;
String suffix = null;
float temperature = 0.1f;
float topp = 0.95f;
Path modelPath = null;
long seed = System.nanoTime();
// Keep max context length small for low-memory devices.
int maxTokens = DEFAULT_MAX_TOKENS;
boolean interactive = false;
boolean stream = true;
boolean echo = false;
boolean guiMode = false;
for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);
switch (optionName) {
case "--interactive", "--chat", "-i" -> interactive = true;
case "--instruct" -> interactive = false;
case "--gui" -> guiMode = true;
case "--help", "-h" -> {
printUsage(System.out);
System.exit(0);
}
default -> {
String nextArg;
if (optionName.contains("=")) {
String[] parts = optionName.split("=", 2);
optionName = parts[0];
nextArg = parts[1];
} else {
require(i + 1 < args.length, "Missing argument for option %s", optionName);
nextArg = args[i + 1];
i += 1; // skip arg
}
switch (optionName) {
case "--prompt", "-p" -> prompt = nextArg;
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
case "--suffix" -> suffix = nextArg;
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
case "--top-p" -> topp = Float.parseFloat(nextArg);
case "--model", "-m" -> modelPath = Paths.get(nextArg);
case "--seed", "-s" -> seed = Long.parseLong(nextArg);
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
default -> require(false, "Unknown option: %s", optionName);
}
}
}
}
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, guiMode);
}
}