forked from foundation-model-stack/fms-model-optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple_mx_example.py
More file actions
110 lines (91 loc) · 3.95 KB
/
simple_mx_example.py
File metadata and controls
110 lines (91 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple example using a toy model to demo how to trigger mx in fms-mo."""
# Third Party
import numpy as np
import torch
import torch.nn.functional as F
class ResidualMLP(torch.nn.Module):
def __init__(self, hidden_size, device="cuda"):
super(ResidualMLP, self).__init__()
self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped
def forward(self, inputs):
norm_outputs = self.layernorm(inputs)
# MLP
proj_outputs = self.dense_4h(norm_outputs)
proj_outputs = F.gelu(proj_outputs)
mlp_outputs = self.dense_h(proj_outputs)
mlp_outputs = self.dummy(mlp_outputs)
# Residual Connection
outputs = inputs + mlp_outputs
return outputs
if __name__ == "__main__":
# Third Party
from tabulate import tabulate
# Local
from fms_mo import qconfig_init, qmodel_prep
HIDDEN_DIM = 128
x = np.random.randn(16, HIDDEN_DIM)
x = torch.tensor(x, dtype=torch.float32, device="cuda")
results = {"dtype": [], "output[0, :5]": [], "||ref - out_dtype||_2": []}
# --- Test 0. Run MLP as is
mlp = ResidualMLP(HIDDEN_DIM)
# mlp.to("cuda")
with torch.no_grad():
out = mlp(x)
results["dtype"].append("fp32")
results["output[0, :5]"].append(out[0, :5].tolist())
results["||ref - out_dtype||_2"].append("-")
print(mlp)
# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
qcfg = qconfig_init()
qcfg["nbits_a"] = 8
qcfg["nbits_w"] = 8
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int8")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
# print(model)
qcfg["nbits_a"] = 4
qcfg["nbits_w"] = 4
mlp = ResidualMLP(HIDDEN_DIM)
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append("fms_int4")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)
# --- Test 2. now change mapping to MX
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically
for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
mlp = ResidualMLP(HIDDEN_DIM) # fresh model
model = qmodel_prep(mlp, x, qcfg)
with torch.no_grad():
out_dtype = model(x)
results["dtype"].append(f"mx{dtype_to_test}")
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
print(model)
print(tabulate(results, headers="keys"))
print("DONE!")