Skip to content

Commit d5a18d5

Browse files
qbmm context manager auto check
Signed-off-by: cliu-us <cliu@us.ibm.com>
1 parent d6c0fe9 commit d5a18d5

3 files changed

Lines changed: 29 additions & 0 deletions

File tree

fms_mo/fx/dynamo_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,32 @@ def call_seq_hook(mod, *_args, **_kwargs):
12421242
)
12431243
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
12441244

1245+
# add auto QBmm check to last layer if any QBmms in model (only for transformers)
1246+
def qbmm_auto_check(_mod, *_args, **_kwargs):
1247+
"""Automatic QBmm check. This hook will be attached to the last module and check once
1248+
only at the end of first forward() call. Throw a "warning" if a model has QBmm attached
1249+
but not called (as it could be intentional.)
1250+
"""
1251+
num_called_qbmms = []
1252+
for lay, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
1253+
for ln in line_nums:
1254+
qbmm_i = model.get_submodule(f"{lay}.QBmm{ln}")
1255+
num_called_qbmms.append(qbmm_i.num_module_called == 1)
1256+
1257+
if not all(num_called_qbmms):
1258+
err_msg = (
1259+
"QBmms were attached but not called during forward()."
1260+
"Possibly patch_torch_bmm() context manager is missing."
1261+
)
1262+
if qcfg["force_stop_if_qbmm_auto_check_failed"]:
1263+
raise RuntimeError(err_msg)
1264+
logger.warning(err_msg)
1265+
1266+
qcfg["hook_qbmm_auto_check"].remove()
1267+
1268+
last_mod = model.get_submodule(qcfg["mod_call_seq"][-1])
1269+
qcfg["hook_qbmm_auto_check"] = last_mod.register_forward_hook(qbmm_auto_check)
1270+
12451271
# c) identify RPN/FPN
12461272
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
12471273

fms_mo/utils/qconfig_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def qconfig_init(recipe: str = None, args: Any = None):
198198
qcfg["which2patch_contextmanager"] = (
199199
None # an internal var that should not be set by user
200200
)
201+
qcfg["force_stop_if_qbmm_auto_check_failed"] = False
201202

202203
# LSTM related, if any of these is not None, then last layer (FC) will not be skipped.
203204
qcfg["nbits_w_lstm"] = None

tests/models/test_qmodelprep.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def test_bert_dynamo_wi_qbmm(
300300
other_qmodules.append(m)
301301

302302
# check 2: model call without our "patch" context manager, will not reach QBmm
303+
# we have an auto check in place, but it will only log warning, unless this flag
304+
# qcfg["force_stop_if_qbmm_auto_check_failed"] = True
303305
with torch.no_grad():
304306
model_bert_eager(**input_bert)
305307
assert all(

0 commit comments

Comments
 (0)