Skip to content

Commit 957173c

Browse files
authored
Merge pull request #126 from BrandonGroth/tiny_models
test: Add tests for save_for_aiu functionality w/ tiny models
2 parents b2a3c85 + e7aed34 commit 957173c

11 files changed

Lines changed: 612 additions & 91 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ htmlcov/
1414
durations/*
1515
coverage*.xml
1616
qcfg.json
17-
models
1817
configs
18+
pytest.out
1919

2020
# IDEs
2121
.vscode/

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3029,7 +3029,7 @@ def __init__(
30293029
self.register_buffer("clip_valn", torch.zeros(perGp[0]))
30303030
else:
30313031
self.register_buffer(
3032-
"clip_val", torch.zeros(perCh) if perCh else torch.Tensor([1.0])
3032+
"clip_val", torch.zeros(perCh) if perCh else torch.Tensor([0.0])
30333033
)
30343034
self.register_buffer(
30353035
"clip_valn", torch.zeros(perCh) if perCh else torch.Tensor([0.0])

fms_mo/utils/aiu_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def process_weight(
232232
is_w_recomputed = False
233233
if layer_name + ".quantize_weight.clip_val" in model.state_dict():
234234
w_cv = model.state_dict()[layer_name + ".quantize_weight.clip_val"]
235+
236+
# Check that clip values are initialized
237+
if torch.any(w_cv.isclose(torch.tensor(0.0))):
238+
raise ValueError(
239+
f"Quantization clip values for {layer_name=} have near-zero values and "
240+
"are likely uninitialized."
241+
)
242+
235243
if w_cv.numel() > 1:
236244
w_cv = w_cv.unsqueeze(dim=1)
237245
weight_int_as_fp = torch.clamp(127 / w_cv * weight_pre_quant, -127, 127).round()

fms_mo/utils/qconfig_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,7 @@ def qconfig_save(
816816

817817
# Save config as json
818818
if os.path.isfile(fname):
819-
message = f"{fname} already exist, will overwrite."
820-
warnings.warn(message, UserWarning)
819+
logger.info(f"{fname} already exist, will overwrite.")
821820
with open(fname, "w", encoding="utf-8") as outfile:
822821
json.dump(temp_qcfg, outfile, indent=4)
823822

0 commit comments

Comments
 (0)