Skip to content

Commit ea838e6

Browse files
committed
fix: support tool call syntax with optional whitespace prefix
1 parent e252ec3 commit ea838e6

5 files changed

Lines changed: 125 additions & 44 deletions

File tree

src/chatWrappers/generic/utils/extractFunctionCallSettingsFromJinjaTemplate.ts

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
398398
const callPrefixText = func1ParamsToFunc2Name.text.slice(-callPrefixLength);
399399
const parallelismCallPrefix = modelMessage1ToFunc1Name.text.slice(0, -callPrefixLength);
400400

401-
const callSuffixLength = findCommandStartLength(func1ParamsToFunc2Name.text, func2ParamsToFunc1Result.text);
401+
const callSuffixLength = findCommonStartLength(func1ParamsToFunc2Name.text, func2ParamsToFunc1Result.text);
402402
const callSuffixText = func1ParamsToFunc2Name.text.slice(0, callSuffixLength);
403403

404404
const parallelismBetweenCallsText = func1ParamsToFunc2Name.text.slice(callSuffixLength, -callPrefixLength);
@@ -407,7 +407,7 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
407407
const resultPrefixLength = findCommonEndLength(func2ParamsToFunc1Result.text, func1ResultToFunc2Result.text);
408408
const resultPrefixText = func2ParamsToFunc1Result.text.slice(-resultPrefixLength);
409409

410-
const resultSuffixLength = findCommandStartLength(func1ResultToFunc2Result.text, func2ResultToModelMessage2.text);
410+
const resultSuffixLength = findCommonStartLength(func1ResultToFunc2Result.text, func2ResultToModelMessage2.text);
411411
const resultSuffixText = func1ResultToFunc2Result.text.slice(0, resultSuffixLength);
412412
const parallelismResultBetweenResultsText = func1ResultToFunc2Result.text.slice(resultSuffixLength, -resultPrefixLength);
413413
const parallelismResultSuffixText = func2ResultToModelMessage2.text.slice(resultSuffixLength);
@@ -452,7 +452,7 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
452452
const onlyCallUserMessage1ToFunc1Name = getTextBetweenIds(renderedOnlyCall, userMessage1, func1name);
453453

454454
if (userMessage1ToModelMessage1Start.text != null && onlyCallUserMessage1ToFunc1Name.text != null) {
455-
const onlyCallModelMessagePrefixLength = findCommandStartLength(
455+
const onlyCallModelMessagePrefixLength = findCommonStartLength(
456456
userMessage1ToModelMessage1Start.text,
457457
onlyCallUserMessage1ToFunc1Name.text
458458
);
@@ -470,14 +470,29 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
470470
}
471471
}
472472

473+
const {
474+
whitespacePrefix: revivedCallWhitespacePrefix,
475+
newTarget: cleanRevivedCallPrefix
476+
} = extractWhitespacePrefixFromRevivedText(revivedCallPrefix);
477+
478+
const {
479+
whitespacePrefix: revivedParallelismWhitespacePrefix,
480+
newTarget: cleanRevivedParallelismCallSectionPrefix
481+
} = extractWhitespacePrefixFromRevivedText(
482+
LlamaText([
483+
revivedParallelismCallSectionPrefix,
484+
revivedCallWhitespacePrefix
485+
])
486+
);
487+
473488
return {
474489
stringifyParams,
475490
stringifyResult,
476491
combineModelMessageAndToolCalls,
477492
settings: {
478493
call: {
479494
optionalPrefixSpace: true,
480-
prefix: revivedCallPrefix,
495+
prefix: cleanRevivedCallPrefix,
481496
paramsPrefix: reviveSeparatorText(callParamsPrefixText, idToStaticContent, contentIds),
482497
suffix: reviveSeparatorText(callSuffixText, idToStaticContent, contentIds),
483498
emptyCallParamsPlaceholder: {}
@@ -504,8 +519,17 @@ export function extractFunctionCallSettingsFromJinjaTemplate({
504519
},
505520
parallelism: {
506521
call: {
507-
sectionPrefix: revivedParallelismCallSectionPrefix,
508-
betweenCalls: revivedParallelismCallBetweenCalls,
522+
sectionPrefix: LlamaText([
523+
revivedParallelismWhitespacePrefix,
524+
cleanRevivedParallelismCallSectionPrefix
525+
]),
526+
sectionPrefixAlternateMatches: revivedParallelismWhitespacePrefix.values.length === 0
527+
? undefined
528+
: [cleanRevivedParallelismCallSectionPrefix],
529+
betweenCalls: LlamaText([
530+
revivedParallelismCallBetweenCalls,
531+
revivedCallWhitespacePrefix
532+
]),
509533
sectionSuffix: reviveSeparatorText(parallelismCallSuffixText, idToStaticContent, contentIds)
510534
},
511535
result: {
@@ -567,7 +591,7 @@ function removeCommonRevivedPrefix(target: LlamaText, matchStart: LlamaText) {
567591
if (targetValue === matchStartValue)
568592
continue;
569593
} else if (targetValue instanceof SpecialTokensText && matchStartValue instanceof SpecialTokensText) {
570-
const commonLength = findCommandStartLength(targetValue.value, matchStartValue.value);
594+
const commonLength = findCommonStartLength(targetValue.value, matchStartValue.value);
571595
if (commonLength === targetValue.value.length && commonLength === matchStartValue.value.length)
572596
continue;
573597

@@ -620,7 +644,56 @@ function removeCommonRevivedSuffix(target: LlamaText, matchEnd: LlamaText) {
620644
return LlamaText(target.values.slice(0, target.values.length - matchEnd.values.length));
621645
}
622646

623-
function findCommandStartLength(text1: string, text2: string) {
647+
function extractWhitespacePrefixFromRevivedText(target: LlamaText) {
648+
for (let i = 0; i < target.values.length; i++) {
649+
const value = target.values[i];
650+
if (typeof value === "string") {
651+
const trimmedValueLength = value.trimStart().length;
652+
if (trimmedValueLength === 0)
653+
continue;
654+
655+
const whitespaceLength = value.length - trimmedValueLength;
656+
return {
657+
whitespacePrefix: LlamaText([
658+
...target.values.slice(0, i),
659+
value.slice(0, whitespaceLength)
660+
]),
661+
newTarget: LlamaText([
662+
value.slice(whitespaceLength),
663+
...target.values.slice(i + 1)
664+
])
665+
};
666+
} else if (value instanceof SpecialTokensText) {
667+
const trimmedValue = value.value.trimStart();
668+
if (trimmedValue.length === 0)
669+
continue;
670+
671+
const whitespaceLength = value.value.length - trimmedValue.length;
672+
return {
673+
whitespacePrefix: LlamaText([
674+
...target.values.slice(0, i),
675+
new SpecialTokensText(value.value.slice(0, whitespaceLength))
676+
]),
677+
newTarget: LlamaText([
678+
new SpecialTokensText(value.value.slice(whitespaceLength)),
679+
...target.values.slice(i + 1)
680+
])
681+
};
682+
}
683+
684+
return {
685+
whitespacePrefix: LlamaText(target.values.slice(0, i)),
686+
newTarget: LlamaText(target.values.slice(i))
687+
};
688+
}
689+
690+
return {
691+
whitespacePrefix: target,
692+
newTarget: LlamaText([])
693+
};
694+
}
695+
696+
function findCommonStartLength(text1: string, text2: string) {
624697
let commonStartLength = 0;
625698
while (commonStartLength < text1.length && commonStartLength < text2.length) {
626699
if (text1[commonStartLength] !== text2[commonStartLength])

src/evaluator/LlamaChat/LlamaChat.ts

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,16 +1861,21 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
18611861
StopGenerationDetector.resolveStopTriggers(this.grammar.stopGenerationTriggers, this.llamaChat.model.tokenizer)
18621862
.map((stopTrigger) => this.stopGenerationDetector.addStopTrigger(stopTrigger));
18631863

1864-
if (this.functions != null && Object.keys(this.functions).length > 0 && !this.abortOnNonText)
1865-
this.functionSyntaxStartDetector.addStopTrigger(
1866-
StopGenerationDetector.resolveLlamaTextTrigger(
1867-
LlamaText([
1868-
this.chatWrapper.settings.functions?.parallelism?.call?.sectionPrefix ?? "",
1869-
this.chatWrapper.settings.functions.call.prefix
1870-
]),
1871-
this.llamaChat.model.tokenizer
1872-
)
1873-
);
1864+
if (this.functions != null && Object.keys(this.functions).length > 0 && !this.abortOnNonText) {
1865+
for (const sectionPrefix of [
1866+
this.chatWrapper.settings.functions?.parallelism?.call?.sectionPrefix ?? "",
1867+
...(this.chatWrapper.settings.functions?.parallelism?.call.sectionPrefixAlternateMatches ?? [])
1868+
])
1869+
this.functionSyntaxStartDetector.addStopTrigger(
1870+
StopGenerationDetector.resolveLlamaTextTrigger(
1871+
LlamaText([
1872+
sectionPrefix,
1873+
this.chatWrapper.settings.functions.call.prefix
1874+
]),
1875+
this.llamaChat.model.tokenizer
1876+
)
1877+
);
1878+
}
18741879

18751880
const segmentDefinitions: ConstructorParameters<typeof SegmentHandler>[0]["segmentDefinitions"] = new Map();
18761881
for (const segmentType of allSegmentTypes) {
@@ -1895,15 +1900,19 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
18951900
});
18961901

18971902
if (this.abortOnNonText) {
1898-
this.stopGenerationDetector.addStopTrigger(
1899-
StopGenerationDetector.resolveLlamaTextTrigger(
1900-
LlamaText([
1901-
this.chatWrapper.settings.functions?.parallelism?.call?.sectionPrefix ?? "",
1902-
this.chatWrapper.settings.functions.call.prefix
1903-
]),
1904-
this.llamaChat.model.tokenizer
1905-
)
1906-
);
1903+
for (const sectionPrefix of [
1904+
this.chatWrapper.settings.functions?.parallelism?.call?.sectionPrefix ?? "",
1905+
...(this.chatWrapper.settings.functions?.parallelism?.call.sectionPrefixAlternateMatches ?? [])
1906+
])
1907+
this.stopGenerationDetector.addStopTrigger(
1908+
StopGenerationDetector.resolveLlamaTextTrigger(
1909+
LlamaText([
1910+
sectionPrefix,
1911+
this.chatWrapper.settings.functions.call.prefix
1912+
]),
1913+
this.llamaChat.model.tokenizer
1914+
)
1915+
);
19071916

19081917
for (const segmentType of allSegmentTypes) {
19091918
const segmentDefinition = getChatWrapperSegmentDefinition(this.chatWrapper.settings, segmentType);

src/types.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@ export type ChatWrapperSettings = {
8080
readonly parallelism?: {
8181
readonly call: {
8282
readonly sectionPrefix: string | LlamaText,
83+
84+
/**
85+
* Alternate section prefixes that can be used to detect a function call section,
86+
* but won't be used to construct the context when building it from scratch.
87+
*/
88+
readonly sectionPrefixAlternateMatches?: Array<string | LlamaText>,
89+
8390
readonly betweenCalls?: string | LlamaText,
8491
readonly sectionSuffix?: string | LlamaText
8592
},

test/modelDependent/qwen3-0.6b/functions.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ describe("qwen3 0.6b", () => {
9999
}
100100
} as const;
101101

102-
const res = await chatSession.prompt("What is the second word? No yapping, no formatting", {
102+
const res = await chatSession.prompt("What is the second word? No yapping, no formatting, use the function", {
103103
...promptOptions,
104104
maxTokens: 250,
105105
budgets: {

vitest.config.ts

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,20 @@ export default defineConfig({
88
],
99
pool: "forks",
1010
maxWorkers: 1,
11-
minWorkers: 1,
1211
maxConcurrency: 1,
13-
poolOptions: {
14-
forks: {
15-
minForks: 1,
16-
maxForks: 1,
17-
singleFork: true
18-
19-
// uncomment for profiling
20-
// execArgv: [
21-
// "--cpu-prof",
22-
// "--cpu-prof-dir=test-runner-profile",
23-
// "--heap-prof",
24-
// "--heap-prof-dir=test-runner-profile"
25-
// ]
26-
}
27-
},
2812
snapshotSerializers: [
2913
"./test/utils/helpers/llamaTextSerializer.ts",
3014
"./test/utils/helpers/SpecialTokensTextSerializer.ts",
3115
"./test/utils/helpers/SpecialTokenSerializer.ts"
3216
],
3317
setupFiles: ["./test/utils/helpers/testSetup.ts"]
18+
19+
// uncomment for profiling
20+
// execArgv: [
21+
// "--cpu-prof",
22+
// "--cpu-prof-dir=test-runner-profile",
23+
// "--heap-prof",
24+
// "--heap-prof-dir=test-runner-profile"
25+
// ]
3426
}
3527
});

0 commit comments

Comments
 (0)