Skip to content

Commit 6893058

Browse files
committed
test: Add FP8 model tests and tiny model generator
- Add fp8_aware_dense layer unit tests - Add FP8 Qwen3 model loading test using roulis/tiny-fp8-qwen3 - Include Python script to generate tiny FP8 test models
1 parent cb36413 commit 6893058

3 files changed

Lines changed: 203 additions & 2 deletions

File tree

lib/bumblebee/conversion/pytorch_params.ex

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,16 @@ defmodule Bumblebee.Conversion.PyTorchParams do
6565

6666
params_expr = model_state.data
6767
preserve_source_types = opts[:preserve_source_types] || false
68-
{params, diff} = init_params(model, params_expr, pytorch_state, opts[:params_mapping], preserve_source_types)
68+
69+
{params, diff} =
70+
init_params(
71+
model,
72+
params_expr,
73+
pytorch_state,
74+
opts[:params_mapping],
75+
preserve_source_types
76+
)
77+
6978
model_state = %{model_state | data: params}
7079

7180
params_complete? = diff.missing == [] and diff.mismatched == []
@@ -110,7 +119,12 @@ defmodule Bumblebee.Conversion.PyTorchParams do
110119

111120
prefixes = infer_prefixes(layers, pytorch_state, params_mapping)
112121

113-
diff = %{missing: [], mismatched: [], used_keys: [], preserve_source_types: preserve_source_types}
122+
diff = %{
123+
missing: [],
124+
mismatched: [],
125+
used_keys: [],
126+
preserve_source_types: preserve_source_types
127+
}
114128

115129
{params, diff} =
116130
layers

test/bumblebee/layers_test.exs

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
defmodule Bumblebee.LayersTest do
2+
use ExUnit.Case, async: true
3+
4+
import Bumblebee.TestHelpers
5+
6+
describe "fp8_aware_dense/3" do
7+
test "dequantizes FP8 kernel with scale_inv" do
8+
# Create a simple model with fp8_aware_dense
9+
model =
10+
Axon.input("input", shape: {nil, 4})
11+
|> Bumblebee.Layers.fp8_aware_dense(8, name: "dense", block_size: 2)
12+
13+
# Create params with known values
14+
# kernel: [4, 8] - input_features x output_features
15+
# scale_inv: [2, 4] - ceil(4/2) x ceil(8/2) blocks
16+
kernel =
17+
Nx.tensor(
18+
[
19+
[1, 2, 3, 4, 5, 6, 7, 8],
20+
[1, 2, 3, 4, 5, 6, 7, 8],
21+
[1, 2, 3, 4, 5, 6, 7, 8],
22+
[1, 2, 3, 4, 5, 6, 7, 8]
23+
],
24+
type: {:f, 32}
25+
)
26+
27+
# Scale of 2.0 for all blocks means output should be 2x what it would be without scaling
28+
scale_inv =
29+
Nx.tensor(
30+
[
31+
[2.0, 2.0, 2.0, 2.0],
32+
[2.0, 2.0, 2.0, 2.0]
33+
],
34+
type: {:f, 32}
35+
)
36+
37+
params = %{
38+
"dense" => %{
39+
"kernel" => kernel,
40+
"scale_inv" => scale_inv
41+
}
42+
}
43+
44+
input = Nx.tensor([[1.0, 1.0, 1.0, 1.0]])
45+
46+
output = Axon.predict(model, params, %{"input" => input})
47+
48+
# Without scaling: input [1,1,1,1] dot kernel gives [4, 8, 12, 16, 20, 24, 28, 32]
49+
# With scale_inv of 2.0: [8, 16, 24, 32, 40, 48, 56, 64]
50+
expected = Nx.tensor([[8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 64.0]])
51+
52+
assert_all_close(output, expected)
53+
end
54+
55+
test "dequantizes with identity scale (1.0)" do
56+
model =
57+
Axon.input("input", shape: {nil, 4})
58+
|> Bumblebee.Layers.fp8_aware_dense(4, name: "dense", block_size: 2)
59+
60+
kernel =
61+
Nx.tensor(
62+
[
63+
[1, 0, 0, 0],
64+
[0, 1, 0, 0],
65+
[0, 0, 1, 0],
66+
[0, 0, 0, 1]
67+
],
68+
type: {:f, 32}
69+
)
70+
71+
# Identity scale
72+
scale_inv =
73+
Nx.tensor(
74+
[
75+
[1.0, 1.0],
76+
[1.0, 1.0]
77+
],
78+
type: {:f, 32}
79+
)
80+
81+
params = %{
82+
"dense" => %{
83+
"kernel" => kernel,
84+
"scale_inv" => scale_inv
85+
}
86+
}
87+
88+
input = Nx.tensor([[2.0, 3.0, 4.0, 5.0]])
89+
output = Axon.predict(model, params, %{"input" => input})
90+
91+
# Identity matrix with scale 1.0 should return input unchanged
92+
assert_all_close(output, input)
93+
end
94+
95+
test "handles non-block-aligned dimensions" do
96+
# 3 input features, 5 output features with block_size 2
97+
# This tests the slicing logic for non-aligned dimensions
98+
model =
99+
Axon.input("input", shape: {nil, 3})
100+
|> Bumblebee.Layers.fp8_aware_dense(5, name: "dense", block_size: 2)
101+
102+
# kernel: [3, 5]
103+
kernel = Nx.broadcast(1.0, {3, 5})
104+
105+
# scale_inv: [ceil(3/2), ceil(5/2)] = [2, 3]
106+
scale_inv = Nx.broadcast(1.0, {2, 3})
107+
108+
params = %{
109+
"dense" => %{
110+
"kernel" => kernel,
111+
"scale_inv" => scale_inv
112+
}
113+
}
114+
115+
input = Nx.tensor([[1.0, 1.0, 1.0]])
116+
output = Axon.predict(model, params, %{"input" => input})
117+
118+
# Sum of 3 ones = 3.0 for each output
119+
expected = Nx.tensor([[3.0, 3.0, 3.0, 3.0, 3.0]])
120+
121+
assert_all_close(output, expected)
122+
end
123+
124+
test "includes bias when use_bias is true" do
125+
model =
126+
Axon.input("input", shape: {nil, 2})
127+
|> Bumblebee.Layers.fp8_aware_dense(2, name: "dense", block_size: 2, use_bias: true)
128+
129+
kernel =
130+
Nx.tensor(
131+
[
132+
[1, 0],
133+
[0, 1]
134+
],
135+
type: {:f, 32}
136+
)
137+
138+
scale_inv = Nx.tensor([[1.0]], type: {:f, 32})
139+
bias = Nx.tensor([10.0, 20.0], type: {:f, 32})
140+
141+
params = %{
142+
"dense" => %{
143+
"kernel" => kernel,
144+
"scale_inv" => scale_inv,
145+
"bias" => bias
146+
}
147+
}
148+
149+
input = Nx.tensor([[1.0, 2.0]])
150+
output = Axon.predict(model, params, %{"input" => input})
151+
152+
# [1, 2] with identity kernel = [1, 2], plus bias [10, 20] = [11, 22]
153+
expected = Nx.tensor([[11.0, 22.0]])
154+
155+
assert_all_close(output, expected)
156+
end
157+
end
158+
end

test/bumblebee/text/qwen3_test.exs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,33 @@ defmodule Bumblebee.Text.Qwen3Test do
7575
Nx.tensor([[-0.1487, -0.0071]])
7676
)
7777
end
78+
79+
test ":for_causal_language_modeling with FP8 weights" do
80+
assert {:ok,
81+
%{model: model, params: %Axon.ModelState{data: params_data} = params, spec: spec}} =
82+
Bumblebee.load_model(
83+
{:hf, "roulis/tiny-fp8-qwen3"},
84+
preserve_source_types: true
85+
)
86+
87+
assert %Bumblebee.Text.Qwen3{architecture: :for_causal_language_modeling} = spec
88+
89+
# Verify FP8 weights are preserved
90+
q_proj_kernel = params_data["decoder.blocks.0.self_attention.query"]["kernel"]
91+
assert Nx.type(q_proj_kernel) == {:f8_e4m3fn, 8}
92+
93+
# Verify scale_inv is loaded
94+
q_proj_scale = params_data["decoder.blocks.0.self_attention.query"]["scale_inv"]
95+
assert Nx.type(q_proj_scale) == {:f, 32}
96+
97+
inputs = %{
98+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
99+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
100+
}
101+
102+
# Model should run without error (dequantization happens internally)
103+
outputs = Axon.predict(model, params, inputs)
104+
105+
assert Nx.shape(outputs.logits) == {1, 10, 1024}
106+
end
78107
end

0 commit comments

Comments
 (0)