Skip to content

Commit ce760d6

Browse files
authored
Fix GLM-OCR multimodal prompt rendering (#157)
1 parent e9bc29a commit ce760d6

5 files changed

Lines changed: 252 additions & 29 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
## Unreleased
22

3+
* **Fixes**:
4+
* Fixed GLM-OCR and other multimodal chat-template workarounds so image and
5+
audio content parts are preserved when tool-call normalization runs.
36
* **Testing**:
47
* Added `tool/testing/run_local_e2e.dart` as a discovery and orchestration
58
entry point for heavyweight local-only Dart E2E, Flutter device, and

lib/src/core/template/template_workarounds.dart

Lines changed: 94 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,37 @@ class TemplateWorkarounds {
4343
List<LlamaChatMessage> messages,
4444
ChatFormat format,
4545
) {
46+
final needsFuncArgsNormalization = _formatsNeedFuncArgsNormalization
47+
.contains(format);
48+
final needsGenericSchema = _formatsNeedGenericSchema.contains(format);
49+
final needsMoveToolCallsToContent = _formatsNeedMoveToolCallsToContent
50+
.contains(format);
51+
52+
if (!needsFuncArgsNormalization &&
53+
!needsGenericSchema &&
54+
!needsMoveToolCallsToContent) {
55+
return messages;
56+
}
57+
58+
if (!_hasTypedToolCalls(messages)) {
59+
return messages;
60+
}
61+
4662
final jsonMessages = messages.map((m) => m.toJson()).toList();
47-
var changed = false;
4863

49-
if (_formatsNeedFuncArgsNormalization.contains(format)) {
64+
if (needsFuncArgsNormalization) {
5065
normalizeToolCallArgs(jsonMessages);
51-
changed = true;
5266
}
5367

54-
if (_formatsNeedGenericSchema.contains(format)) {
68+
if (needsGenericSchema) {
5569
useGenericSchema(jsonMessages);
56-
changed = true;
5770
}
5871

59-
if (_formatsNeedMoveToolCallsToContent.contains(format)) {
72+
if (needsMoveToolCallsToContent) {
6073
moveToolCallsToContent(jsonMessages);
61-
changed = true;
62-
}
63-
64-
if (!changed) {
65-
return messages;
6674
}
6775

68-
return _messagesFromJson(jsonMessages);
76+
return _messagesFromJson(jsonMessages, messages);
6977
}
7078

7179
/// Ensures tool call arguments are JSON objects, not strings.
@@ -172,6 +180,12 @@ class TemplateWorkarounds {
172180
}
173181
}
174182

183+
static bool _hasTypedToolCalls(List<LlamaChatMessage> messages) {
184+
return messages.any(
185+
(message) => message.parts.any((part) => part is LlamaToolCallContent),
186+
);
187+
}
188+
175189
static Map<String, dynamic> _argumentsToObject(Object? args) {
176190
final map = ToolCallParsingUtils.decodeJsonMapValue(args);
177191
if (map != null) {
@@ -193,11 +207,18 @@ class TemplateWorkarounds {
193207

194208
static List<LlamaChatMessage> _messagesFromJson(
195209
List<Map<String, dynamic>> messages,
210+
List<LlamaChatMessage> originals,
196211
) {
197-
return messages.map(_messageFromJson).toList();
212+
return [
213+
for (var i = 0; i < messages.length; i++)
214+
_messageFromJson(messages[i], original: originals[i]),
215+
];
198216
}
199217

200-
static LlamaChatMessage _messageFromJson(Map<String, dynamic> message) {
218+
static LlamaChatMessage _messageFromJson(
219+
Map<String, dynamic> message, {
220+
required LlamaChatMessage original,
221+
}) {
201222
final role = _parseRole(message['role'] as String? ?? 'user');
202223
final parts = <LlamaContentPart>[];
203224

@@ -244,10 +265,7 @@ class TemplateWorkarounds {
244265
),
245266
);
246267
} else {
247-
final text = _extractTextContent(content);
248-
if (text.isNotEmpty) {
249-
parts.add(LlamaTextContent(text));
250-
}
268+
parts.addAll(_extractContentParts(content, original: original));
251269
}
252270

253271
if (parts.isEmpty) {
@@ -257,21 +275,68 @@ class TemplateWorkarounds {
257275
return LlamaChatMessage.withContent(role: role, content: parts);
258276
}
259277

260-
static String _extractTextContent(Object? content) {
261-
if (content == null) return '';
262-
if (content is String) return content;
263-
if (content is! List) return content.toString();
278+
static List<LlamaContentPart> _extractContentParts(
279+
Object? content, {
280+
required LlamaChatMessage original,
281+
}) {
282+
if (content == null) return const [];
283+
if (content is String) {
284+
return content.isEmpty ? const [] : [LlamaTextContent(content)];
285+
}
286+
if (content is! List) {
287+
final text = content.toString();
288+
return text.isEmpty ? const [] : [LlamaTextContent(text)];
289+
}
290+
291+
final originalImages = original.parts
292+
.whereType<LlamaImageContent>()
293+
.toList();
294+
final originalAudio = original.parts
295+
.whereType<LlamaAudioContent>()
296+
.toList();
297+
var imageIndex = 0;
298+
var audioIndex = 0;
299+
final parts = <LlamaContentPart>[];
264300

265-
final buffer = StringBuffer();
266301
for (final item in content) {
267-
if (item is Map<String, dynamic> && item['type'] == 'text') {
268-
final text = item['text'];
269-
if (text is String) {
270-
buffer.write(text);
271-
}
302+
if (item is! Map<String, dynamic>) continue;
303+
switch (item['type']) {
304+
case 'text':
305+
final text = item['text'];
306+
if (text is String && text.isNotEmpty) {
307+
parts.add(LlamaTextContent(text));
308+
}
309+
break;
310+
case 'image':
311+
case 'image_url':
312+
if (imageIndex < originalImages.length) {
313+
parts.add(originalImages[imageIndex++]);
314+
} else {
315+
parts.add(_imageContentFromJson(item));
316+
}
317+
break;
318+
case 'input_audio':
319+
case 'audio':
320+
if (audioIndex < originalAudio.length) {
321+
parts.add(originalAudio[audioIndex++]);
322+
}
323+
break;
272324
}
273325
}
274-
return buffer.toString();
326+
327+
return parts;
328+
}
329+
330+
static LlamaImageContent _imageContentFromJson(Map<String, dynamic> item) {
331+
final imageUrl = item['image_url'];
332+
final url = imageUrl is Map<String, dynamic> ? imageUrl['url'] : null;
333+
if (url is String && url.startsWith('file://')) {
334+
return LlamaImageContent(path: url.substring('file://'.length));
335+
}
336+
if (url is String && url.isNotEmpty) {
337+
return LlamaImageContent(url: url);
338+
}
339+
return const LlamaImageContent();
275340
}
276341

277342
static LlamaChatRole _parseRole(String role) {

test/unit/core/template/chat_template_engine_test.dart

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import 'dart:convert';
22

33
import 'package:llamadart/src/core/models/chat/chat_message.dart';
44
import 'package:llamadart/src/core/models/chat/chat_role.dart';
5+
import 'package:llamadart/src/core/models/chat/content_part.dart';
56
import 'package:llamadart/src/core/models/inference/tool_choice.dart';
67
import 'package:llamadart/src/core/models/tools/tool_definition.dart';
78
import 'package:llamadart/src/core/models/tools/tool_param.dart';
@@ -27,6 +28,41 @@ void main() {
2728
expect(result.prompt, contains('CUSTOM:hello'));
2829
expect(result.prompt, isNot(contains('BASE:hello')));
2930
});
31+
32+
test('preserves GLM-OCR image markers through format workarounds', () {
33+
const template = '''[gMASK]<sop>
34+
{# GLM detection marker: <arg_key>name</arg_key><arg_value>value</arg_value> #}
35+
{% for m in messages %}
36+
{% if m.role == 'user' %}<|user|>
37+
{% for item in m.content %}
38+
{% if item.type == 'image' %}<|begin_of_image|><|image|><|end_of_image|>{% elif item.type == 'text' %}{{ item.text }}{% endif %}
39+
{% endfor %}
40+
{% endif %}
41+
{% endfor %}
42+
{% if add_generation_prompt %}<|assistant|>{% endif %}''';
43+
const multimodalMessages = [
44+
LlamaChatMessage.withContent(
45+
role: LlamaChatRole.user,
46+
content: [
47+
LlamaImageContent(path: '/tmp/page.png'),
48+
LlamaTextContent('Extract text.'),
49+
],
50+
),
51+
];
52+
53+
final result = ChatTemplateEngine.render(
54+
templateSource: template,
55+
messages: multimodalMessages,
56+
metadata: const {},
57+
);
58+
59+
expect(result.format, equals(ChatFormat.glm45.index));
60+
expect(
61+
result.prompt,
62+
contains('<|begin_of_image|><__media__><|end_of_image|>'),
63+
);
64+
expect(result.prompt, contains('Extract text.'));
65+
});
3066
});
3167

3268
group('ChatTemplateEngine grammar routing', () {

test/unit/core/template/template_workarounds_test.dart

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import 'dart:typed_data';
2+
13
import 'package:llamadart/src/core/models/chat/chat_message.dart';
24
import 'package:llamadart/src/core/models/chat/chat_role.dart';
35
import 'package:llamadart/src/core/models/chat/content_part.dart';
@@ -123,6 +125,121 @@ void main() {
123125
expect(message['content'], contains('"weather"'));
124126
});
125127

128+
test(
129+
'applyFormatWorkarounds returns before serializing byte-backed multimodal content without tool calls',
130+
() {
131+
final imageBytes = Uint8List.fromList([1, 2, 3, 4]);
132+
final input = [
133+
LlamaChatMessage.withContent(
134+
role: LlamaChatRole.user,
135+
content: [
136+
LlamaImageContent(bytes: imageBytes, width: 1, height: 1),
137+
const LlamaTextContent('Extract text.'),
138+
],
139+
),
140+
];
141+
142+
final output = TemplateWorkarounds.applyFormatWorkarounds(
143+
input,
144+
ChatFormat.glm45,
145+
);
146+
147+
expect(identical(output, input), isTrue);
148+
final image = output.first.parts.whereType<LlamaImageContent>().single;
149+
expect(identical(image.bytes, imageBytes), isTrue);
150+
expect(
151+
output.first.parts.whereType<LlamaTextContent>().single.text,
152+
equals('Extract text.'),
153+
);
154+
},
155+
);
156+
157+
test(
158+
'applyFormatWorkarounds preserves multimodal content when tool calls are normalized',
159+
() {
160+
const input = [
161+
LlamaChatMessage.withContent(
162+
role: LlamaChatRole.user,
163+
content: [
164+
LlamaImageContent(path: '/tmp/page.png'),
165+
LlamaTextContent('Extract text.'),
166+
],
167+
),
168+
LlamaChatMessage.withContent(
169+
role: LlamaChatRole.assistant,
170+
content: [
171+
LlamaToolCallContent(
172+
id: 'call_1',
173+
name: 'lookup',
174+
arguments: {'query': 'ocr'},
175+
rawJson: '{"query":"ocr"}',
176+
),
177+
],
178+
),
179+
];
180+
181+
final output = TemplateWorkarounds.applyFormatWorkarounds(
182+
input,
183+
ChatFormat.glm45,
184+
);
185+
186+
expect(output.first.parts[0], isA<LlamaImageContent>());
187+
expect(
188+
output.first.parts.whereType<LlamaTextContent>().single.text,
189+
equals('Extract text.'),
190+
);
191+
final toolCall = output.last.parts
192+
.whereType<LlamaToolCallContent>()
193+
.single;
194+
expect(toolCall.name, equals('lookup'));
195+
expect(toolCall.arguments, equals({'query': 'ocr'}));
196+
},
197+
);
198+
199+
test(
200+
'applyFormatWorkarounds preserves audio content when tool calls are normalized',
201+
() {
202+
final audioBytes = Uint8List.fromList([82, 73, 70, 70]);
203+
final input = [
204+
LlamaChatMessage.withContent(
205+
role: LlamaChatRole.user,
206+
content: [
207+
LlamaAudioContent(bytes: audioBytes),
208+
const LlamaTextContent('Transcribe audio.'),
209+
],
210+
),
211+
const LlamaChatMessage.withContent(
212+
role: LlamaChatRole.assistant,
213+
content: [
214+
LlamaToolCallContent(
215+
id: 'call_1',
216+
name: 'lookup',
217+
arguments: {'query': 'audio'},
218+
rawJson: '{"query":"audio"}',
219+
),
220+
],
221+
),
222+
];
223+
224+
final output = TemplateWorkarounds.applyFormatWorkarounds(
225+
input,
226+
ChatFormat.glm45,
227+
);
228+
229+
final audio = output.first.parts.whereType<LlamaAudioContent>().single;
230+
expect(identical(audio.bytes, audioBytes), isTrue);
231+
expect(
232+
output.first.parts.whereType<LlamaTextContent>().single.text,
233+
equals('Transcribe audio.'),
234+
);
235+
final toolCall = output.last.parts
236+
.whereType<LlamaToolCallContent>()
237+
.single;
238+
expect(toolCall.name, equals('lookup'));
239+
expect(toolCall.arguments, equals({'query': 'audio'}));
240+
},
241+
);
242+
126243
test('applyFormatWorkarounds applies Granite chain', () {
127244
final input = [
128245
LlamaChatMessage.withContent(

website/docs/changelog/recent-releases.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ For canonical full release notes, use:
99

1010
## Unreleased
1111

12+
- Fixed GLM-OCR and other multimodal chat-template workarounds so image and
13+
audio content parts are preserved when tool-call normalization runs.
1214
- Added `tool/testing/run_local_e2e.dart` as a discovery and orchestration
1315
entry point for heavyweight local-only Dart E2E, Flutter device, and
1416
Web/Playwright smoke scenarios.

0 commit comments

Comments
 (0)