Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vitepress/config/apiReferenceSidebar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const chatWrappersOrder = [
"Llama3ChatWrapper",
"Llama2ChatWrapper",
"MistralChatWrapper",
"Gemma4ChatWrapper",
"GemmaChatWrapper",
"ChatMLChatWrapper",
"FalconChatWrapper",
Expand Down
253 changes: 253 additions & 0 deletions src/chatWrappers/Gemma4ChatWrapper.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import {ChatWrapper, ChatWrapperJinjaMatchConfiguration} from "../ChatWrapper.js";
import {
ChatModelFunctionCall, ChatModelFunctions, ChatModelResponse, ChatWrapperGenerateContextStateOptions, ChatWrapperGeneratedContextState,
ChatWrapperSettings
} from "../types.js";
import {LlamaText, SpecialToken, SpecialTokensText} from "../utils/LlamaText.js";
import {jsonDumps} from "./utils/jsonDumps.js";

// source: https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
export class Gemma4ChatWrapper extends ChatWrapper {
public readonly wrapperName: string = "Gemma 4";

public readonly reasoning: boolean;
public readonly keepOnlyLastThought: boolean;

public override readonly settings: ChatWrapperSettings = {
supportsSystemMessages: true,
functions: {
call: {
optionalPrefixSpace: false,
prefix: LlamaText(new SpecialTokensText("<|tool_call>call:")),
paramsPrefix: "{",
suffix: LlamaText(new SpecialTokensText("}<tool_call|>")),
emptyCallParamsPlaceholder: undefined
},
result: {
prefix: LlamaText(new SpecialTokensText("<tool_response>response:"), "{{functionName}}", "{"),
suffix: LlamaText(new SpecialTokensText("}</tool_response>"))
}
},
segments: {
reiterateStackAfterFunctionCalls: true,
thought: {
prefix: LlamaText(new SpecialTokensText("<|channel>thought\n")),
suffix: LlamaText(new SpecialTokensText("<channel|>"))
}
}
};

public constructor(options: {
/**
* Whether to promote the model to perform reasoning.
*
* Defaults to `true`.
*/
reasoning?: boolean,

/**
* Whether to keep only the chain of thought from the last model response.
*
* Setting this to `false` will keep all the chain of thoughts from the model responses in the context state.
*
* Defaults to `true`.
*/
keepOnlyLastThought?: boolean
} = {}) {
super();

const {
reasoning = true,
keepOnlyLastThought = true
} = options;

this.reasoning = reasoning;
this.keepOnlyLastThought = keepOnlyLastThought;
}

public override generateContextState({
chatHistory, availableFunctions, documentFunctionParams
}: ChatWrapperGenerateContextStateOptions): ChatWrapperGeneratedContextState {
const hasFunctions = Object.keys(availableFunctions ?? {}).length > 0;
const modifiedChatHistory = chatHistory.slice();

let systemMessage: LlamaText = LlamaText();
if (modifiedChatHistory[0]?.type === "system") {
systemMessage = LlamaText.fromJSON(modifiedChatHistory[0].text);
modifiedChatHistory.shift();
}

if (hasFunctions)
systemMessage = LlamaText([
systemMessage,
this.generateAvailableFunctionsSystemText(availableFunctions ?? {}, {documentParams: documentFunctionParams})
]);

if (this.reasoning)
systemMessage = LlamaText([
new SpecialTokensText("<|think|>"),
systemMessage
]);

if (systemMessage.values.length > 0)
modifiedChatHistory.unshift({
type: "system",
text: systemMessage.toJSON()
});

const contextContent: LlamaText[] = [
LlamaText(new SpecialToken("BOS"))
];

for (let i = 0; i < modifiedChatHistory.length; i++) {
const isLastItem = i === modifiedChatHistory.length - 1;
const item = modifiedChatHistory[i];

if (item == null)
continue;

if (item.type === "system")
contextContent.push(
LlamaText([
new SpecialTokensText("<|turn>system\n"),
LlamaText.fromJSON(item.text),
isLastItem
? LlamaText([])
: new SpecialTokensText("<turn|>\n")
])
);
else if (item.type === "user")
contextContent.push(
LlamaText([
new SpecialTokensText("<|turn>user\n"),
item.text,
isLastItem
? LlamaText([])
: new SpecialTokensText("<turn|>\n")
])
);
else if (item.type === "model")
contextContent.push(this._getModelResponse(item.response, true, isLastItem, this.keepOnlyLastThought));
else
void (item satisfies never);
}

return {
contextText: LlamaText(contextContent),
stopGenerationTriggers: [
LlamaText(new SpecialToken("EOS")),
LlamaText(new SpecialToken("EOT")),
LlamaText(new SpecialTokensText("<turn|>")),
LlamaText(new SpecialTokensText("<turn|>\n")),
LlamaText("<|return|>")
]
};
}

public override generateAvailableFunctionsSystemText(availableFunctions: ChatModelFunctions, {documentParams = true}: {
documentParams?: boolean
}): LlamaText {
return LlamaText(
Object.entries(availableFunctions)
.map(([name, definition]) => {
return LlamaText([
new SpecialTokensText("<|tool>"),
"declaration:", name, "{",
jsonDumps({
description: definition.description || undefined,
parameters: documentParams
? (definition.params || {})
: undefined
}),
"}", new SpecialTokensText("<tool|>")
]);
})
);
}

public override generateModelResponseText(modelResponse: ChatModelResponse["response"], useRawValues: boolean = true): LlamaText {
return this._getModelResponse(modelResponse, useRawValues, false, false);
}

/** @internal */
private _getModelResponse(
modelResponse: ChatModelResponse["response"],
useRawValues: boolean,
isLastItem: boolean,
keepOnlyLastThought: boolean
) {
const res: LlamaText[] = [
LlamaText(new SpecialTokensText("<|turn>model\n"))
];
const pendingFunctionCalls: ChatModelFunctionCall[] = [];

const addPendingFunctions = () => {
if (pendingFunctionCalls.length === 0)
return;

res.push(this.generateFunctionCallsAndResults(pendingFunctionCalls, useRawValues));

pendingFunctionCalls.length = 0;
};

for (let index = 0; index < modelResponse.length; index++) {
const isLastResponse = index === modelResponse.length - 1;
const response = modelResponse[index];

if (response == null)
continue;
else if (response === "" && (!isLastResponse || !isLastItem))
continue;

if (typeof response === "string") {
addPendingFunctions();
res.push(LlamaText(response));
} else if (response.type === "segment") {
addPendingFunctions();

if (response.ended && response.raw != null && useRawValues)
res.push(LlamaText.fromJSON(response.raw));
else if (response.segmentType === "thought") {
if (keepOnlyLastThought && !isLastItem)
continue;

res.push(
LlamaText([
new SpecialTokensText("<|channel>thought"),
response.text,
(isLastItem && !response.ended)
? LlamaText([])
: new SpecialTokensText("<channel|>")
])
);
} else if (response.segmentType === "comment")
continue; // unsupported
else
void (response.segmentType satisfies never);
} else if (response.type === "functionCall") {
if (response.startsNewChunk)
addPendingFunctions();

pendingFunctionCalls.push(response);
} else
void (response satisfies never);
}

addPendingFunctions();

return LlamaText(res);
}

/** @internal */
public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): ChatWrapperJinjaMatchConfiguration<typeof this> {
return [
[{}, {}],
[{reasoning: false}, {}],
[
{reasoning: true},
{},
{additionalRenderParameters: {"enable_thinking": true}}
]
];
}
}
20 changes: 16 additions & 4 deletions src/chatWrappers/utils/resolveChatWrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {FalconChatWrapper} from "../FalconChatWrapper.js";
import {FunctionaryChatWrapper} from "../FunctionaryChatWrapper.js";
import {AlpacaChatWrapper} from "../AlpacaChatWrapper.js";
import {GemmaChatWrapper} from "../GemmaChatWrapper.js";
import {Gemma4ChatWrapper} from "../Gemma4ChatWrapper.js";
import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js";
import {TemplateChatWrapper} from "../generic/TemplateChatWrapper.js";
import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js";
Expand All @@ -27,7 +28,7 @@ import type {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js";

export const specializedChatWrapperTypeNames = Object.freeze([
"general", "deepSeek", "qwen", "llama3.2-lightweight", "llama3.1", "llama3", "llama2Chat", "mistral", "alpacaChat", "functionary",
"chatML", "falconChat", "gemma", "harmony", "seed"
"chatML", "falconChat", "gemma4", "gemma", "harmony", "seed"
] as const);
export type SpecializedChatWrapperTypeName = (typeof specializedChatWrapperTypeNames)[number];

Expand Down Expand Up @@ -56,6 +57,7 @@ export const chatWrappers = Object.freeze({
"functionary": FunctionaryChatWrapper,
"chatML": ChatMLChatWrapper,
"falconChat": FalconChatWrapper,
"gemma4": Gemma4ChatWrapper,
"gemma": GemmaChatWrapper,
"harmony": HarmonyChatWrapper,
"seed": SeedChatWrapper,
Expand All @@ -70,7 +72,8 @@ const chatWrapperToConfigType = new Map(
);

const specializedChatWrapperRelatedTexts = {
"harmony": ["gpt", "gpt-oss"]
"harmony": ["gpt", "gpt-oss"],
"gemma4": ["gemma 4", "gemma-4"]
} satisfies Partial<Record<ResolvableChatWrapperTypeName, string[]>>;

export type BuiltInChatWrapperType = InstanceType<typeof chatWrappers[keyof typeof chatWrappers]>;
Expand Down Expand Up @@ -364,6 +367,8 @@ export function resolveChatWrapper(
return createSpecializedChatWrapper(Llama3ChatWrapper);
else if (includesText(modelNames, ["Mistral", "Mistral Large", "Mistral Large Instruct", "Mistral-Large", "Codestral"]))
return createSpecializedChatWrapper(MistralChatWrapper);
else if (includesText(modelNames, ["Gemma 4", "Gemma-4", "gemma-4"]))
return createSpecializedChatWrapper(Gemma4ChatWrapper);
else if (includesText(modelNames, ["Gemma", "Gemma 2"]))
return createSpecializedChatWrapper(GemmaChatWrapper);
else if (includesText(modelNames, ["gpt-oss", "Gpt Oss", "Gpt-Oss", "openai_gpt-oss", "Openai_Gpt Oss", "openai.gpt-oss", "Openai.Gpt Oss"]))
Expand All @@ -381,6 +386,8 @@ export function resolveChatWrapper(
return createSpecializedChatWrapper(SeedChatWrapper);
else if (modelJinjaTemplate.includes("<|start|>") && modelJinjaTemplate.includes("<|channel|>"))
return createSpecializedChatWrapper(HarmonyChatWrapper);
else if (modelJinjaTemplate.includes("<|turn>") && modelJinjaTemplate.includes("<|tool_call>call:"))
return createSpecializedChatWrapper(Gemma4ChatWrapper);
else if (modelJinjaTemplate.includes("<|im_start|>"))
return createSpecializedChatWrapper(ChatMLChatWrapper);
else if (modelJinjaTemplate.includes("[INST]"))
Expand Down Expand Up @@ -430,9 +437,12 @@ export function resolveChatWrapper(
return createSpecializedChatWrapper(FunctionaryChatWrapper);
else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral"))
return createSpecializedChatWrapper(ChatMLChatWrapper);
else if (lowercaseName === "gemma")
else if (lowercaseName === "gemma") {
if (firstSplitLowercaseSubType === "4")
return createSpecializedChatWrapper(Gemma4ChatWrapper);

return createSpecializedChatWrapper(GemmaChatWrapper);
else if (splitLowercaseSubType.includes("chatml"))
} else if (splitLowercaseSubType.includes("chatml"))
return createSpecializedChatWrapper(ChatMLChatWrapper);
}
}
Expand All @@ -454,6 +464,8 @@ export function resolveChatWrapper(
return createSpecializedChatWrapper(FalconChatWrapper);
else if (arch === "gemma" || arch === "gemma2")
return createSpecializedChatWrapper(GemmaChatWrapper);
else if (arch === "gemma4")
return createSpecializedChatWrapper(Gemma4ChatWrapper);
}

return null;
Expand Down
1 change: 1 addition & 0 deletions src/gguf/types/GgufMetadataTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export const enum GgufArchitectureType {
gemma2 = "gemma2",
gemma3 = "gemma3",
gemma3n = "gemma3n",
gemma4 = "gemma4",
gemmaEmbedding = "gemma-embedding",
starcoder2 = "starcoder2",
mamba = "mamba",
Expand Down
2 changes: 2 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import {FalconChatWrapper} from "./chatWrappers/FalconChatWrapper.js";
import {AlpacaChatWrapper} from "./chatWrappers/AlpacaChatWrapper.js";
import {FunctionaryChatWrapper} from "./chatWrappers/FunctionaryChatWrapper.js";
import {GemmaChatWrapper} from "./chatWrappers/GemmaChatWrapper.js";
import {Gemma4ChatWrapper} from "./chatWrappers/Gemma4ChatWrapper.js";
import {HarmonyChatWrapper} from "./chatWrappers/HarmonyChatWrapper.js";
import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./chatWrappers/generic/TemplateChatWrapper.js";
import {
Expand Down Expand Up @@ -231,6 +232,7 @@ export {
AlpacaChatWrapper,
FunctionaryChatWrapper,
GemmaChatWrapper,
Gemma4ChatWrapper,
HarmonyChatWrapper,
TemplateChatWrapper,
type TemplateChatWrapperOptions,
Expand Down
4 changes: 4 additions & 0 deletions src/utils/LlamaText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class LlamaText {
return LlamaTextConstructor.compare(this, other);
}

public trim(): LlamaText {
return this.trimStart().trimEnd();
}

public trimStart(): LlamaText {
const newValues = this.values.slice();

Expand Down
Loading