Skip to content

Commit af2146d

Browse files
committed
feat: Gemma 4 support
1 parent 57bea3d commit af2146d

6 files changed

Lines changed: 277 additions & 4 deletions

File tree

.vitepress/config/apiReferenceSidebar.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ const chatWrappersOrder = [
5353
"Llama3ChatWrapper",
5454
"Llama2ChatWrapper",
5555
"MistralChatWrapper",
56+
"Gemma4ChatWrapper",
5657
"GemmaChatWrapper",
5758
"ChatMLChatWrapper",
5859
"FalconChatWrapper",
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
import {ChatWrapper, ChatWrapperJinjaMatchConfiguration} from "../ChatWrapper.js";
2+
import {
3+
ChatModelFunctionCall, ChatModelFunctions, ChatModelResponse, ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState,
4+
ChatWrapperSettings
5+
} from "../types.js";
6+
import {LlamaText, SpecialToken, SpecialTokensText} from "../utils/LlamaText.js";
7+
import {jsonDumps} from "./utils/jsonDumps.js";
8+
9+
// source: https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
10+
export class Gemma4ChatWrapper extends ChatWrapper {
11+
public readonly wrapperName: string = "Gemma 4";
12+
13+
public readonly reasoning: boolean;
14+
public readonly keepOnlyLastThought: boolean;
15+
16+
public override readonly settings: ChatWrapperSettings = {
17+
supportsSystemMessages: true,
18+
functions: {
19+
call: {
20+
optionalPrefixSpace: false,
21+
prefix: LlamaText(new SpecialTokensText("<|tool_call>call:")),
22+
paramsPrefix: "{",
23+
suffix: LlamaText(new SpecialTokensText("}<tool_call|>")),
24+
emptyCallParamsPlaceholder: undefined
25+
},
26+
result: {
27+
prefix: LlamaText(new SpecialTokensText("<tool_response>response:"), "{{functionName}}", "{"),
28+
suffix: LlamaText(new SpecialTokensText("}</tool_response>"))
29+
}
30+
},
31+
segments: {
32+
reiterateStackAfterFunctionCalls: true,
33+
thought: {
34+
prefix: LlamaText(new SpecialTokensText("<|channel>thought\n")),
35+
suffix: LlamaText(new SpecialTokensText("<channel|>"))
36+
}
37+
}
38+
};
39+
40+
public constructor(options: {
41+
/**
42+
* Whether to promote the model to perform reasoning.
43+
*
44+
* Defaults to `true`.
45+
*/
46+
reasoning?: boolean,
47+
48+
/**
49+
* Whether to keep only the chain of thought from the last model response.
50+
*
51+
* Setting this to `false` will keep all the chain of thoughts from the model responses in the context state.
52+
*
53+
* Defaults to `true`.
54+
*/
55+
keepOnlyLastThought?: boolean
56+
} = {}) {
57+
super();
58+
59+
const {
60+
reasoning = true,
61+
keepOnlyLastThought = true
62+
} = options;
63+
64+
this.reasoning = reasoning;
65+
this.keepOnlyLastThought = keepOnlyLastThought;
66+
}
67+
68+
public override generateContextState({
69+
chatHistory, availableFunctions, documentFunctionParams
70+
}: ChatWrapperGenerateContextStateOptions): ChatWrapperGeneratedContextState {
71+
const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0;
72+
const modifiedChatHistory = chatHistory.slice();
73+
74+
let systemMessage: LlamaText = LlamaText();
75+
if (modifiedChatHistory[0]?.type === "system") {
76+
systemMessage = LlamaText.fromJSON(modifiedChatHistory[0].text);
77+
modifiedChatHistory.shift();
78+
}
79+
80+
if (hasFunctions)
81+
systemMessage = LlamaText([
82+
systemMessage,
83+
this.generateAvailableFunctionsSystemText(availableFunctions ?? {}, {documentParams: documentFunctionParams})
84+
]);
85+
86+
if (this.reasoning)
87+
systemMessage = LlamaText([
88+
new SpecialTokensText("<|think|>"),
89+
systemMessage
90+
]);
91+
92+
if (systemMessage.values.length > 0)
93+
modifiedChatHistory.unshift({
94+
type: "system",
95+
text: systemMessage.toJSON()
96+
});
97+
98+
const contextContent: LlamaText[] = [
99+
LlamaText(new SpecialToken("BOS"))
100+
];
101+
102+
for (let i = 0; i < modifiedChatHistory.length; i++) {
103+
const isLastItem = i === modifiedChatHistory.length - 1;
104+
const item = modifiedChatHistory[i];
105+
106+
if (item == null)
107+
continue;
108+
109+
if (item.type === "system")
110+
contextContent.push(
111+
LlamaText([
112+
new SpecialTokensText("<|turn>system\n"),
113+
LlamaText.fromJSON(item.text),
114+
isLastItem
115+
? LlamaText([])
116+
: new SpecialTokensText("<turn|>\n")
117+
])
118+
);
119+
else if (item.type === "user")
120+
contextContent.push(
121+
LlamaText([
122+
new SpecialTokensText("<|turn>user\n"),
123+
item.text,
124+
isLastItem
125+
? LlamaText([])
126+
: new SpecialTokensText("<turn|>\n")
127+
])
128+
);
129+
else if (item.type === "model")
130+
contextContent.push(this._getModelResponse(item.response, true, isLastItem, this.keepOnlyLastThought));
131+
else
132+
void (item satisfies never);
133+
}
134+
135+
return {
136+
contextText: LlamaText(contextContent),
137+
stopGenerationTriggers: [
138+
LlamaText(new SpecialToken("EOS")),
139+
LlamaText(new SpecialToken("EOT")),
140+
LlamaText(new SpecialTokensText("<turn|>")),
141+
LlamaText(new SpecialTokensText("<turn|>\n")),
142+
LlamaText("<|return|>")
143+
]
144+
};
145+
}
146+
147+
public override generateAvailableFunctionsSystemText(availableFunctions: ChatModelFunctions, {documentParams = true}: {
148+
documentParams?: boolean
149+
}): LlamaText {
150+
return LlamaText(
151+
Object.entries(availableFunctions)
152+
.map(([name, definition]) => {
153+
return LlamaText([
154+
new SpecialTokensText("<|tool>"),
155+
"declaration:", name, "{",
156+
jsonDumps({
157+
description: definition.description || undefined,
158+
parameters: documentParams
159+
? (definition.params || {})
160+
: undefined
161+
}),
162+
"}", new SpecialTokensText("<tool|>")
163+
]);
164+
})
165+
);
166+
}
167+
168+
public override generateModelResponseText(modelResponse: ChatModelResponse["response"], useRawValues: boolean = true): LlamaText {
169+
return this._getModelResponse(modelResponse, useRawValues, false, false);
170+
}
171+
172+
/** @internal */
173+
private _getModelResponse(
174+
modelResponse: ChatModelResponse["response"],
175+
useRawValues: boolean,
176+
isLastItem: boolean,
177+
keepOnlyLastThought: boolean
178+
) {
179+
const res: LlamaText[] = [
180+
LlamaText(new SpecialTokensText("<|turn>model\n"))
181+
];
182+
const pendingFunctionCalls: ChatModelFunctionCall[] = [];
183+
184+
const addPendingFunctions = () => {
185+
if (pendingFunctionCalls.length === 0)
186+
return;
187+
188+
res.push(this.generateFunctionCallsAndResults(pendingFunctionCalls, useRawValues));
189+
190+
pendingFunctionCalls.length = 0;
191+
};
192+
193+
for (let index = 0; index < modelResponse.length; index++) {
194+
const isLastResponse = index === modelResponse.length - 1;
195+
const response = modelResponse[index];
196+
197+
if (response == null)
198+
continue;
199+
else if (response === "" && (!isLastResponse || !isLastItem))
200+
continue;
201+
202+
if (typeof response === "string") {
203+
addPendingFunctions();
204+
res.push(LlamaText(response));
205+
} else if (response.type === "segment") {
206+
addPendingFunctions();
207+
208+
if (response.ended && response.raw != null && useRawValues)
209+
res.push(LlamaText.fromJSON(response.raw));
210+
else if (response.segmentType === "thought") {
211+
if (keepOnlyLastThought && !isLastItem)
212+
continue;
213+
214+
res.push(
215+
LlamaText([
216+
new SpecialTokensText("<|channel>thought"),
217+
response.text,
218+
(isLastItem && !response.ended)
219+
? LlamaText([])
220+
: new SpecialTokensText("<channel|>")
221+
])
222+
);
223+
} else if (response.segmentType === "comment")
224+
continue; // unsupported
225+
else
226+
void (response.segmentType satisfies never);
227+
} else if (response.type === "functionCall") {
228+
if (response.startsNewChunk)
229+
addPendingFunctions();
230+
231+
pendingFunctionCalls.push(response);
232+
} else
233+
void (response satisfies never);
234+
}
235+
236+
addPendingFunctions();
237+
238+
return LlamaText(res);
239+
}
240+
241+
/** @internal */
242+
public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): ChatWrapperJinjaMatchConfiguration<typeof this> {
243+
return [
244+
[{}, {}],
245+
[{reasoning: false}, {}],
246+
[
247+
{reasoning: true},
248+
{},
249+
{additionalRenderParameters: {"enable_thinking": true}}
250+
]
251+
];
252+
}
253+
}

src/chatWrappers/utils/resolveChatWrapper.ts

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {FalconChatWrapper} from "../FalconChatWrapper.js";
77
import {FunctionaryChatWrapper} from "../FunctionaryChatWrapper.js";
88
import {AlpacaChatWrapper} from "../AlpacaChatWrapper.js";
99
import {GemmaChatWrapper} from "../GemmaChatWrapper.js";
10+
import {Gemma4ChatWrapper} from "../Gemma4ChatWrapper.js";
1011
import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js";
1112
import {TemplateChatWrapper} from "../generic/TemplateChatWrapper.js";
1213
import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js";
@@ -27,7 +28,7 @@ import type {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js";
2728

2829
export const specializedChatWrapperTypeNames = Object.freeze([
2930
"general", "deepSeek", "qwen", "llama3.2-lightweight", "llama3.1", "llama3", "llama2Chat", "mistral", "alpacaChat", "functionary",
30-
"chatML", "falconChat", "gemma", "harmony", "seed"
31+
"chatML", "falconChat", "gemma4", "gemma", "harmony", "seed"
3132
] as const);
3233
export type SpecializedChatWrapperTypeName = (typeof specializedChatWrapperTypeNames)[number];
3334

@@ -56,6 +57,7 @@ export const chatWrappers = Object.freeze({
5657
"functionary": FunctionaryChatWrapper,
5758
"chatML": ChatMLChatWrapper,
5859
"falconChat": FalconChatWrapper,
60+
"gemma4": Gemma4ChatWrapper,
5961
"gemma": GemmaChatWrapper,
6062
"harmony": HarmonyChatWrapper,
6163
"seed": SeedChatWrapper,
@@ -70,7 +72,8 @@ const chatWrapperToConfigType = new Map(
7072
);
7173

7274
const specializedChatWrapperRelatedTexts = {
73-
"harmony": ["gpt", "gpt-oss"]
75+
"harmony": ["gpt", "gpt-oss"],
76+
"gemma4": ["gemma 4", "gemma-4"]
7477
} satisfies Partial<Record<ResolvableChatWrapperTypeName, string[]>>;
7578

7679
export type BuiltInChatWrapperType = InstanceType<typeof chatWrappers[keyof typeof chatWrappers]>;
@@ -364,6 +367,8 @@ export function resolveChatWrapper(
364367
return createSpecializedChatWrapper(Llama3ChatWrapper);
365368
else if (includesText(modelNames, ["Mistral", "Mistral Large", "Mistral Large Instruct", "Mistral-Large", "Codestral"]))
366369
return createSpecializedChatWrapper(MistralChatWrapper);
370+
else if (includesText(modelNames, ["Gemma 4", "Gemma-4", "gemma-4"]))
371+
return createSpecializedChatWrapper(Gemma4ChatWrapper);
367372
else if (includesText(modelNames, ["Gemma", "Gemma 2"]))
368373
return createSpecializedChatWrapper(GemmaChatWrapper);
369374
else if (includesText(modelNames, ["gpt-oss", "Gpt Oss", "Gpt-Oss", "openai_gpt-oss", "Openai_Gpt Oss", "openai.gpt-oss", "Openai.Gpt Oss"]))
@@ -381,6 +386,8 @@ export function resolveChatWrapper(
381386
return createSpecializedChatWrapper(SeedChatWrapper);
382387
else if (modelJinjaTemplate.includes("<|start|>") && modelJinjaTemplate.includes("<|channel|>"))
383388
return createSpecializedChatWrapper(HarmonyChatWrapper);
389+
else if (modelJinjaTemplate.includes("<|turn>") && modelJinjaTemplate.includes("<|tool_call>call:"))
390+
return createSpecializedChatWrapper(Gemma4ChatWrapper);
384391
else if (modelJinjaTemplate.includes("<|im_start|>"))
385392
return createSpecializedChatWrapper(ChatMLChatWrapper);
386393
else if (modelJinjaTemplate.includes("[INST]"))
@@ -430,9 +437,12 @@ export function resolveChatWrapper(
430437
return createSpecializedChatWrapper(FunctionaryChatWrapper);
431438
else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral"))
432439
return createSpecializedChatWrapper(ChatMLChatWrapper);
433-
else if (lowercaseName === "gemma")
440+
else if (lowercaseName === "gemma") {
441+
if (firstSplitLowercaseSubType === "4")
442+
return createSpecializedChatWrapper(Gemma4ChatWrapper);
443+
434444
return createSpecializedChatWrapper(GemmaChatWrapper);
435-
else if (splitLowercaseSubType.includes("chatml"))
445+
} else if (splitLowercaseSubType.includes("chatml"))
436446
return createSpecializedChatWrapper(ChatMLChatWrapper);
437447
}
438448
}
@@ -454,6 +464,8 @@ export function resolveChatWrapper(
454464
return createSpecializedChatWrapper(FalconChatWrapper);
455465
else if (arch === "gemma" || arch === "gemma2")
456466
return createSpecializedChatWrapper(GemmaChatWrapper);
467+
else if (arch === "gemma4")
468+
return createSpecializedChatWrapper(Gemma4ChatWrapper);
457469
}
458470

459471
return null;

src/gguf/types/GgufMetadataTypes.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export const enum GgufArchitectureType {
4747
gemma2 = "gemma2",
4848
gemma3 = "gemma3",
4949
gemma3n = "gemma3n",
50+
gemma4 = "gemma4",
5051
gemmaEmbedding = "gemma-embedding",
5152
starcoder2 = "starcoder2",
5253
mamba = "mamba",

src/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ import {FalconChatWrapper} from "./chatWrappers/FalconChatWrapper.js";
6262
import {AlpacaChatWrapper} from "./chatWrappers/AlpacaChatWrapper.js";
6363
import {FunctionaryChatWrapper} from "./chatWrappers/FunctionaryChatWrapper.js";
6464
import {GemmaChatWrapper} from "./chatWrappers/GemmaChatWrapper.js";
65+
import {Gemma4ChatWrapper} from "./chatWrappers/Gemma4ChatWrapper.js";
6566
import {HarmonyChatWrapper} from "./chatWrappers/HarmonyChatWrapper.js";
6667
import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./chatWrappers/generic/TemplateChatWrapper.js";
6768
import {
@@ -231,6 +232,7 @@ export {
231232
AlpacaChatWrapper,
232233
FunctionaryChatWrapper,
233234
GemmaChatWrapper,
235+
Gemma4ChatWrapper,
234236
HarmonyChatWrapper,
235237
TemplateChatWrapper,
236238
type TemplateChatWrapperOptions,

src/utils/LlamaText.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ class LlamaText {
122122
return LlamaTextConstructor.compare(this, other);
123123
}
124124

125+
public trim(): LlamaText {
126+
return this.trimStart().trimEnd();
127+
}
128+
125129
public trimStart(): LlamaText {
126130
const newValues = this.values.slice();
127131

0 commit comments

Comments
 (0)