Skip to content

Commit 6f094d7

Browse files
authored
[For RL] Keep attrs after folding weight and fix empty extra state for Megatron (#779)
## What does this PR do? **Type of change:** improvement **Overview:** - For Quantization aware reinforcement learning, after folding weight of rollout, we want to keep the quantization attrs for next step. - Minor fix for empty extra state - Support getting dataloader from jsonl file, useful for using training data as calibration data. I can separate this to another PR if necessary. ## Usage `mtq.fold_weight(keep_attrs=True)` will keep quantizer attrs after folding weight, ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**: NA - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for loading dataset samples directly from JSONL/JSONL.GZ files * Added optional parameter to skip logits return in generation prefill operations * Enhanced weight folding operations to optionally preserve quantization attributes during model optimization * **Bug Fixes** * Fixed handling of empty tensor states to prevent deserialization errors in Megatron module <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Meng Xin <mxin@nvidia.com>
1 parent a538f2e commit 6f094d7

8 files changed

Lines changed: 78 additions & 25 deletions

File tree

modelopt/torch/opt/plugins/megatron.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def _modelopt_set_extra_state(self, state: Any):
9999
return
100100

101101
if isinstance(state, torch.Tensor):
102+
if state.numel() == 0:
103+
return
102104
# Default format: byte tensor with pickled data
103105
#
104106
# TODO: possible deserialization improvement

modelopt/torch/quantization/model_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,8 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
530530
print("\n".join(lines))
531531

532532

533-
def fold_weight(model: nn.Module):
533+
def fold_weight(model: nn.Module, keep_attrs: bool = False):
534534
"""Fold weight quantizer for fast evaluation."""
535535
for name, module in model.named_modules():
536536
if isinstance(module, QuantModule):
537-
module.fold_weight()
537+
module.fold_weight(keep_attrs)

modelopt/torch/quantization/nn/modules/quant_linear.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@ def forward(self, input, *args, **kwargs):
162162
output = super().forward(input, *args, **kwargs)
163163
return output
164164

165-
def fold_weight(self):
165+
def fold_weight(self, keep_attrs: bool = False):
166166
"""Fold the weight for faster eval."""
167-
super().fold_weight()
167+
super().fold_weight(keep_attrs)
168168
if (
169169
hasattr(self, "weight_quantizer")
170170
and hasattr(self, "weight")
@@ -179,13 +179,14 @@ def fold_weight(self):
179179
self.weight
180180
+ self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
181181
)
182-
_attrs = [
183-
"_svdquant_lora_a",
184-
"_svdquant_lora_b",
185-
]
186-
for attr in _attrs:
187-
if hasattr(self.weight_quantizer, attr):
188-
delattr(self.weight_quantizer, attr)
182+
if not keep_attrs:
183+
_attrs = [
184+
"_svdquant_lora_a",
185+
"_svdquant_lora_b",
186+
]
187+
for attr in _attrs:
188+
if hasattr(self.weight_quantizer, attr):
189+
delattr(self.weight_quantizer, attr)
189190

190191

191192
class RealQuantLinear(QuantModule):

modelopt/torch/quantization/nn/modules/quant_module.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def modelopt_post_restore(self, prefix: str = ""):
119119
if isinstance(module, TensorQuantizer):
120120
module.to(non_tq_param_or_buffer.device)
121121

122-
def fold_weight(self):
122+
def fold_weight(self, keep_attrs: bool = False):
123123
"""Fold the weight for faster eval."""
124124
# Handle all attributes that end with _weight_quantizer
125125
for name in dir(self):
@@ -138,13 +138,14 @@ def fold_weight(self):
138138
weight = getattr(self, weight_name)
139139
weight.data.copy_(attr(weight.float()).to(weight.dtype))
140140
attr.disable()
141-
_attrs = [
142-
"_pre_quant_scale",
143-
"_amax",
144-
]
145-
for attr_name in _attrs:
146-
if hasattr(attr, attr_name):
147-
delattr(attr, attr_name)
141+
if not keep_attrs:
142+
_attrs = [
143+
"_pre_quant_scale",
144+
"_amax",
145+
]
146+
for attr_name in _attrs:
147+
if hasattr(attr, attr_name):
148+
delattr(attr, attr_name)
148149

149150

150151
QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule)

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ class HFRowParallelLinear(HFParallelLinear):
363363
class _QuantHFParallelLinear(_ParallelLinear):
364364
_functionals_to_replace = [(torch.nn.functional, "linear")]
365365

366-
def fold_weight(self):
366+
def fold_weight(self, keep_attrs: bool = False):
367367
with self.enable_weight_access_and_writeback():
368-
super().fold_weight()
368+
super().fold_weight(keep_attrs)
369369

370370
@contextmanager
371371
def enable_weight_access_and_writeback(self):

modelopt/torch/quantization/plugins/vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
228228
)
229229

230230
@torch.no_grad()
231-
def fold_weight(self):
231+
def fold_weight(self, keep_attrs: bool = False):
232232
# the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
233233
for i in range(self.w13_weight.shape[0]):
234234
self.w13_weight[i].copy_(

modelopt/torch/utils/dataset_utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Utility functions for getting samples and forward loop function for different datasets."""
1717

1818
import copy
19+
import json
1920
from collections.abc import Callable
2021
from typing import TYPE_CHECKING, Any
2122
from warnings import warn
@@ -110,6 +111,47 @@
110111
]
111112

112113

114+
def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
115+
"""Load up to ``num_samples`` entries from a JSONL file using the ``text`` field.
116+
117+
Each non-empty line must be a JSON object containing a ``text`` field.
118+
"""
119+
if num_samples <= 0:
120+
return []
121+
122+
samples: list[str] = []
123+
124+
with open(jsonl_path, encoding="utf-8") as f:
125+
for line_idx, line in enumerate(f, start=1):
126+
if len(samples) >= num_samples:
127+
break
128+
line = line.strip()
129+
if not line:
130+
continue
131+
132+
try:
133+
obj = json.loads(line)
134+
except json.JSONDecodeError as e:
135+
raise ValueError(
136+
f"Invalid JSON in JSONL file {jsonl_path} at line {line_idx}: {e}"
137+
) from e
138+
139+
if not isinstance(obj, dict):
140+
raise ValueError(
141+
f"Expected a JSON object in JSONL file {jsonl_path} at line {line_idx}, "
142+
f"got {type(obj)}."
143+
)
144+
145+
if "text" not in obj:
146+
raise ValueError(
147+
f"Missing required field 'text' in JSONL file {jsonl_path} at line {line_idx}."
148+
)
149+
150+
samples.append(str(obj["text"]))
151+
152+
return samples
153+
154+
113155
def _normalize_splits(split: str | list[str]) -> list[str]:
114156
"""Ensure split is always a list."""
115157
return [split] if isinstance(split, str) else list(split)
@@ -181,7 +223,7 @@ def get_dataset_samples(
181223
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.
182224
183225
Args:
184-
dataset_name: Name or HuggingFace path of the dataset to load.
226+
dataset_name: Name or HuggingFace path of the dataset to load, or a path to a ``.jsonl``/``.jsonl.gz`` file.
185227
num_samples: Number of samples to load from the dataset.
186228
apply_chat_template: Whether to apply the chat template to the samples
187229
(if supported by the dataset). For unregistered datasets with a
@@ -196,6 +238,10 @@ def get_dataset_samples(
196238
Returns:
197239
Samples: The list of samples.
198240
"""
241+
# Local JSONL file path support (each line is a JSON object with a `text` field).
242+
if dataset_name.endswith(".jsonl"):
243+
return _get_jsonl_text_samples(dataset_name, num_samples)
244+
199245
from datasets import load_dataset
200246

201247
is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
@@ -284,7 +330,8 @@ def get_dataset_dataloader(
284330
"""Get a dataloader with the dataset name and tokenizer of the target model.
285331
286332
Args:
287-
dataset_name: Name of the dataset to load.
333+
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
334+
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
288335
tokenizer: Instance of HuggingFace tokenizer.
289336
batch_size: Batch size of the returned dataloader.
290337
num_samples: Number of samples from the dataset.

modelopt/torch/utils/plugins/megatron_generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def megatron_prefill(
4646
pixel_values: torch.FloatTensor | None = None,
4747
image_grid_thw: torch.LongTensor | None = None,
4848
image_sizes: torch.LongTensor | None = None,
49+
skip_return_logits: bool = False,
4950
) -> torch.Tensor:
5051
"""A simple prefill function for Megatron Core V(LM) models."""
5152
if not isinstance(model, MegatronModule):
@@ -112,6 +113,8 @@ def _forward_step_func(data, model):
112113
forward_only=True,
113114
collect_non_loss_data=True,
114115
)
116+
if skip_return_logits:
117+
return None
115118

116119
if mpu.is_pipeline_last_stage():
117120
logits = list_of_logits[0][:, :seq_length, :].detach()
@@ -124,7 +127,6 @@ def _forward_step_func(data, model):
124127
logits_dtype = torch.float16
125128
else:
126129
logits_dtype = torch.float32
127-
128130
logits = broadcast_from_last_pipeline_stage(
129131
[max_batch_size, seq_length, model.vocab_size], logits_dtype, logits
130132
)

0 commit comments

Comments
 (0)