Skip to content

Feat (utils): replace weights with quantized ones#1505

Open
Giuseppe5 wants to merge 4 commits into
Xilinx:devfrom
Giuseppe5:merge_ln
Open

Feat (utils): replace weights with quantized ones#1505
Giuseppe5 wants to merge 4 commits into
Xilinx:devfrom
Giuseppe5:merge_ln

Conversation

@Giuseppe5
Copy link
Copy Markdown
Collaborator

@Giuseppe5 Giuseppe5 commented Apr 4, 2026

Reason for this PR

In certain instances, like learned round, the weight tensors has extra parameters attached to it that could make export somewhat complicated.

Changes Made in this PR

We perform a destructive replacement of the original weight tensor with its quantized counterparts. This allows for easier exports in many scenarios.

There is an optional flag to keeep track of the original weights.

After merging, we reset the quantizers, including setting rounding mode to round, which is what is most commonly supported during export process.

What is missing:

  • Integration in the LLM entrypoint.

Testing Summary

New tests added.

@Giuseppe5 Giuseppe5 requested a review from pablomlago April 9, 2026 08:38
@Giuseppe5 Giuseppe5 self-assigned this Apr 9, 2026
Copy link
Copy Markdown
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments

Comment thread src/brevitas/nn/utils.py Outdated
@Giuseppe5 Giuseppe5 added next release PRs which should be merged for the next release labels Apr 20, 2026
Comment thread src/brevitas/nn/utils.py
state_dict[keys_map[old_key]] = state_dict.pop(old_key)


class merge_quant_weights:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see a reason why this should be a context manager, specially considering you need to add a check to verify that one one forward pass is run. For me, this should be a function that accept a sample and runs the forward within the function.

Comment thread src/brevitas/nn/utils.py
class merge_quant_weights:
"""Context manager that merges quantized weights into model weights.

This could be useful for example with Learned Round.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I see, this is the main purpose of the context manager. What else is this context manager intended to cover? Also, this docstring does not reflect the fact that the scales are also updated to PARAMETER_FROM_STATS.

Comment thread src/brevitas/nn/utils.py
self._model = model
self._hooks: List[RemovableHandle] = []
self._module_tensor_id_mapping = {}
self.disable_quant = disable_quant
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why some attributes are public and other private? E.g., _model and disable_quant.

Comment thread src/brevitas/nn/utils.py
for module in self._module_tensor_id_mapping:
self._reset_quantizer(module)

def change_scale_impl_type(self, proxy) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def change_scale_impl_type(self, proxy) -> None:
def change_scale_impl_type(self, proxy: WeightQuantProxyFromInjectorBase) -> None:

Comment thread src/brevitas/nn/utils.py
@staticmethod
def _reset_quantizer(proxy) -> None:
"""Switch a weight quant proxy from LearnedRound back to standard Round."""
reinit_on_state_dict = config.REINIT_ON_STATE_DICT_LOAD
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern of overriding values in config and then restoring to the original values appears multiple times. Can we extract this common functionality? E.g.:

from contextlib import contextmanager
@contextmanager
def override_config(**overrides):
    old = {}
    try:
        for k, v in overrides.items():
            old[k] = getattr(config, k)
            setattr(config, k, v)
        yield
    finally:
        for k, v in old.items():
            setattr(config, k, v)

and then use it like:

with override_config(
        REINIT_ON_STATE_DICT_LOAD=False,
        IGNORE_MISSING_KEYS=True,
    ):

LearnedRoundImplType.HARD_SIGMOID, LearnedRoundImplType.SIGMOID, LearnedRoundImplType.IDENTITY]


def _insert_learned_round(model, learned_round_param):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to create a new function. Use insert_learned_round_quantizers.



@pytest.mark.parametrize("learned_round_param", LEARNED_ROUND_OPTIONS)
def test_merge_quant_weights_reset(learned_round_param):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test could be merged into the previous one.



@pytest.mark.parametrize("learned_round_param", LEARNED_ROUND_OPTIONS)
def test_merge_quant_weights_forward_equivalence(learned_round_param):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this test can also be merged into the previous one.

Comment thread src/brevitas/nn/utils.py
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Dict
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove unused imports.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

next release PRs which should be merged for the next release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants