Feat (utils): replace weights with quantized ones#1505
Conversation
| state_dict[keys_map[old_key]] = state_dict.pop(old_key) | ||
|
|
||
|
|
||
| class merge_quant_weights: |
There was a problem hiding this comment.
I do not see a reason why this should be a context manager, specially considering you need to add a check to verify that one one forward pass is run. For me, this should be a function that accept a sample and runs the forward within the function.
| class merge_quant_weights: | ||
| """Context manager that merges quantized weights into model weights. | ||
|
|
||
| This could be useful for example with Learned Round. |
There was a problem hiding this comment.
From what I see, this is the main purpose of the context manager. What else is this context manager intended to cover? Also, this docstring does not reflect the fact that the scales are also updated to PARAMETER_FROM_STATS.
| self._model = model | ||
| self._hooks: List[RemovableHandle] = [] | ||
| self._module_tensor_id_mapping = {} | ||
| self.disable_quant = disable_quant |
There was a problem hiding this comment.
Is there a reason why some attributes are public and other private? E.g., _model and disable_quant.
| for module in self._module_tensor_id_mapping: | ||
| self._reset_quantizer(module) | ||
|
|
||
| def change_scale_impl_type(self, proxy) -> None: |
There was a problem hiding this comment.
| def change_scale_impl_type(self, proxy) -> None: | |
| def change_scale_impl_type(self, proxy: WeightQuantProxyFromInjectorBase) -> None: |
| @staticmethod | ||
| def _reset_quantizer(proxy) -> None: | ||
| """Switch a weight quant proxy from LearnedRound back to standard Round.""" | ||
| reinit_on_state_dict = config.REINIT_ON_STATE_DICT_LOAD |
There was a problem hiding this comment.
This pattern of overriding values in config and then restoring to the original values appears multiple times. Can we extract this common functionality? E.g.:
from contextlib import contextmanager
@contextmanager
def override_config(**overrides):
old = {}
try:
for k, v in overrides.items():
old[k] = getattr(config, k)
setattr(config, k, v)
yield
finally:
for k, v in old.items():
setattr(config, k, v)
and then use it like:
with override_config(
REINIT_ON_STATE_DICT_LOAD=False,
IGNORE_MISSING_KEYS=True,
):
| LearnedRoundImplType.HARD_SIGMOID, LearnedRoundImplType.SIGMOID, LearnedRoundImplType.IDENTITY] | ||
|
|
||
|
|
||
| def _insert_learned_round(model, learned_round_param): |
There was a problem hiding this comment.
I think there is no need to create a new function. Use insert_learned_round_quantizers.
|
|
||
|
|
||
| @pytest.mark.parametrize("learned_round_param", LEARNED_ROUND_OPTIONS) | ||
| def test_merge_quant_weights_reset(learned_round_param): |
There was a problem hiding this comment.
This test could be merged into the previous one.
|
|
||
|
|
||
| @pytest.mark.parametrize("learned_round_param", LEARNED_ROUND_OPTIONS) | ||
| def test_merge_quant_weights_forward_equivalence(learned_round_param): |
There was a problem hiding this comment.
Probably this test can also be merged into the previous one.
| # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from typing import Dict |
There was a problem hiding this comment.
Please remove unused imports.
Reason for this PR
In certain instances, like learned round, the weight tensors has extra parameters attached to it that could make export somewhat complicated.
Changes Made in this PR
We perform a destructive replacement of the original weight tensor with its quantized counterparts. This allows for easier exports in many scenarios.
There is an optional flag to keeep track of the original weights.
After merging, we reset the quantizers, including setting rounding mode to round, which is what is most commonly supported during export process.
What is missing:
Testing Summary
New tests added.