forked from beehive-lab/GPULlama3.java
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDevstral.java
More file actions
73 lines (59 loc) · 2.74 KB
/
Devstral.java
File metadata and controls
73 lines (59 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package org.beehive.gpullama3.model.devstral;
import org.beehive.gpullama3.inference.InferenceCore;
import org.beehive.gpullama3.inference.InferenceEngine;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.inference.state.DevstralState;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.AbstractModel;
import org.beehive.gpullama3.model.ModelType;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.tokenizer.DevstralTokenizer;
import org.beehive.gpullama3.tokenizer.Tokenizer;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import java.util.List;
import java.util.Set;
import java.util.function.IntConsumer;
public class Devstral extends AbstractModel {
DevstralConfiguration configuration;
public Devstral(DevstralConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
super(tokenizer, weights, chatFormat, null);
this.configuration = configuration;
}
@Override
public DevstralConfiguration configuration() {
return configuration;
}
@Override
public DevstralTokenizer tokenizer() {
return (DevstralTokenizer) tokenizer;
}
@Override
public ModelType getModelType() {
return ModelType.DEVSTRAL_2;
}
public State createNewState() {
State state = new DevstralState(configuration(), -1);
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
return state;
}
public State createNewState(int batchsize) {
State state = new DevstralState(configuration(), batchsize);
state.latestToken = tokenizer.getSpecialTokens().get("<s>");
return state;
}
@Override
public void forward(State state, int token, int position) {
InferenceCore.forwardJavaDevstral(this, state, token, position);
}
@Override
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
return InferenceEngine.generateTokensLlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
}
@Override
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
return InferenceEngine.generateTokensGPULlama(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
}
}