forked from bernardladenthin/java-llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLlamaModelBackend.java
More file actions
105 lines (93 loc) · 4.2 KB
/
Copy pathLlamaModelBackend.java
File metadata and controls
105 lines (93 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// SPDX-FileCopyrightText: 2026 Bernard Ladenthin <bernard.ladenthin@gmail.com>
//
// SPDX-License-Identifier: MIT
package net.ladenthin.llama.server;
import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import net.ladenthin.llama.LlamaModel;
import net.ladenthin.llama.parameters.InferenceParameters;
/**
* Production {@link OpenAiBackend} that runs requests against a loaded {@link LlamaModel}.
*
* <p>Non-streaming chat reuses {@link LlamaModel#chatComplete(InferenceParameters)}, whose return value
* is already a verbatim OpenAI {@code chat.completion} body. Streaming chat uses
* {@link LlamaModel#streamChatCompletion(InferenceParameters, java.util.function.Consumer)}, which emits
* OpenAI {@code chat.completion.chunk} objects (including {@code delta.tool_calls}). Text completions and
* embeddings forward the request body verbatim to {@link LlamaModel#handleCompletionsOai(String)} /
* {@link LlamaModel#handleEmbeddings(String, boolean)}, which already return OpenAI-shaped JSON.
*
* <p>The streaming sink may fail with {@link IOException} (client disconnect); because the underlying
* model API takes a {@link java.util.function.Consumer} (no checked exceptions), that failure is
* relayed across the boundary via {@link java.io.UncheckedIOException} and unwrapped here so the
* in-flight native task is cancelled.
*/
final class LlamaModelBackend implements OpenAiBackend {
private final LlamaModel model;
private final OpenAiRequestMapper mapper;
/**
* Create a backend over the given model.
*
* @param model the loaded model to run completions against
* @param mapper the OpenAI-request to {@link InferenceParameters} mapper
*/
LlamaModelBackend(LlamaModel model, OpenAiRequestMapper mapper) {
this.model = model;
this.mapper = mapper;
}
@Override
public String metrics() {
return model.getMetrics();
}
@Override
public String complete(JsonNode request) {
return model.chatComplete(mapper.toInferenceParameters(request));
}
@Override
public void stream(JsonNode request, ChunkSink sink) throws IOException {
InferenceParameters params = mapper.toInferenceParameters(request);
// Holds an IOException thrown by the sink so it can be rethrown after the model API (which
// only understands unchecked exceptions) unwinds and cancels the native task.
final IOException[] sinkFailure = new IOException[1];
try {
model.streamChatCompletion(params, chunkJson -> {
try {
sink.accept(chunkJson);
} catch (IOException e) {
sinkFailure[0] = e;
throw new java.io.UncheckedIOException(e);
}
});
} catch (java.io.UncheckedIOException e) {
IOException cause = sinkFailure[0];
if (cause != null) {
throw cause;
}
throw e;
}
}
@Override
public String completions(JsonNode request) {
// The native /v1/completions handler parses the OpenAI body itself; forward it verbatim.
return model.handleCompletionsOai(request.toString());
}
@Override
public String embeddings(JsonNode request) {
// oaiCompat=true so the response uses the OpenAI {"object":"list","data":[{embedding}]} shape.
return model.handleEmbeddings(request.toString(), true);
}
@Override
public String infill(JsonNode request) {
// The native /infill handler parses the body itself (input_prefix/input_suffix/...) and applies
// the model's FIM tokens from GGUF metadata; forward verbatim.
return model.handleInfill(request.toString());
}
@Override
public String rerank(JsonNode request) {
final String query = OaiRerankSupport.readQuery(request);
final String[] documents = OaiRerankSupport.readDocuments(request);
final int topN = OaiRerankSupport.readTopN(request);
final String requestModel = request.path("model").asText("");
final String nativeJson = model.handleRerank(query, documents);
return OaiRerankSupport.toOaiResponse(nativeJson, requestModel, topN);
}
}