Skip to content

Commit 81fba5e

Browse files
authored
Fix Gemma 4 system message and modality order (#211)
* Align Gemma4 messages with chat template * Respect Gemma 4 modality order
1 parent b94661e commit 81fba5e

2 files changed

Lines changed: 91 additions & 1 deletion

File tree

Libraries/MLXVLM/Models/Gemma4.swift

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1769,6 +1769,32 @@ public final class Gemma4: Module, VLMModel, KVCacheDimensionProvider {
17691769

17701770
// MARK: - Processor
17711771

1772+
public struct Gemma4MessageGenerator: MessageGenerator {
1773+
public init() {}
1774+
1775+
public func generate(message: Chat.Message) -> MLXLMCommon.Message {
1776+
if message.role == .system {
1777+
[
1778+
"role": message.role.rawValue,
1779+
"content": message.content,
1780+
]
1781+
} else {
1782+
[
1783+
"role": message.role.rawValue,
1784+
"content": message.images.map { _ in
1785+
["type": "image"]
1786+
}
1787+
+ message.videos.map { _ in
1788+
["type": "video"]
1789+
}
1790+
+ [
1791+
["type": "text", "text": message.content]
1792+
],
1793+
]
1794+
}
1795+
}
1796+
}
1797+
17721798
public struct Gemma4Processor: UserInputProcessor {
17731799
private let config: Gemma4ProcessorConfiguration
17741800
private let tokenizer: any Tokenizer
@@ -1805,7 +1831,7 @@ public struct Gemma4Processor: UserInputProcessor {
18051831
}
18061832

18071833
public func prepare(input: UserInput) async throws -> LMInput {
1808-
let messages = Qwen2VLMessageGenerator().generate(from: input)
1834+
let messages = Gemma4MessageGenerator().generate(from: input)
18091835

18101836
var promptTokens = try tokenizer.applyChatTemplate(
18111837
messages: messages, tools: input.tools,

Tests/MLXLMTests/UserInputTests.swift

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,33 @@ public class UserInputTests: XCTestCase {
9595
assertEqual(expected, messages)
9696
}
9797

98+
public func testGemma4ConversionText() {
99+
let chat: [Chat.Message] = [
100+
.system("You are a useful agent."),
101+
.user("Tell me a story."),
102+
]
103+
104+
let messages = Gemma4MessageGenerator().generate(messages: chat)
105+
106+
let expected: [[String: any Sendable]] = [
107+
[
108+
"role": "system",
109+
"content": "You are a useful agent.",
110+
],
111+
[
112+
"role": "user",
113+
"content": [
114+
[
115+
"type": "text",
116+
"text": "Tell me a story.",
117+
]
118+
],
119+
],
120+
]
121+
122+
assertEqual(expected, messages)
123+
}
124+
98125
// MARK: - Mistral3 Message Generator Tests
99126

100127
public func testMistral3ConversionText() {
@@ -230,4 +257,41 @@ public class UserInputTests: XCTestCase {
230257
XCTAssertEqual(userInput.images.count, 1)
231258
}
232259

260+
public func testGemma4ConversionImage() {
261+
let chat: [Chat.Message] = [
262+
.system("You are a useful agent."),
263+
.user(
264+
"What is this?",
265+
images: [
266+
.url(
267+
URL(
268+
string: "https://opensource.apple.com/images/projects/mlx.f5c59d8b.png")!
269+
)
270+
]),
271+
]
272+
273+
let messages = Gemma4MessageGenerator().generate(messages: chat)
274+
275+
let expected: [[String: any Sendable]] = [
276+
[
277+
"role": "system",
278+
"content": "You are a useful agent.",
279+
],
280+
[
281+
"role": "user",
282+
"content": [
283+
[
284+
"type": "image"
285+
],
286+
[
287+
"type": "text",
288+
"text": "What is this?",
289+
],
290+
],
291+
],
292+
]
293+
294+
assertEqual(expected, messages)
295+
}
296+
233297
}

0 commit comments

Comments
 (0)