Skip to content

Commit c4de50d

Browse files
Dont zero initialize in export_llama (#16886)
Summary: Zero initialization is non standard with pytorch models, and in particular with ET is frustrating because ET looks to greedily deduplicate weights. That means if you zero initialize a transformer model the pte size will be a lot smaller then you would expect if you didnt know about the deduplication. Differential Revision: D91518961
1 parent f2438a9 commit c4de50d

1 file changed

Lines changed: 19 additions & 16 deletions

File tree

examples/models/llama/model.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -245,24 +245,27 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
245245
for param in self.model_.parameters():
246246
if isinstance(param, TorchAOBaseTensor):
247247
param.requires_grad = False
248+
if missing:
249+
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
250+
if missing_weights:
251+
raise ValueError(
252+
f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match."
253+
)
254+
if unexpected:
255+
if self.verbose:
256+
print(f"Unexpected keys: {unexpected}")
248257
else:
249-
print("Checkpoint not provided, defaulting weights to zeros.")
258+
print("Checkpoint not provided, using default initialization.")
259+
# Because we loaded onto meta device, it is annoying to now load onto cpu
260+
# with the standard random initialization.
250261
self.model_.to_empty(device="cpu")
251-
# Need to provide concrete values for meta-initialized tensors for quantization.
252-
# otherwise it is just filled with nan's.
253-
for p in self.model_.parameters():
254-
p.data.fill_(0)
255-
for b in self.model_.buffers():
256-
b.data.fill_(0)
257-
if missing:
258-
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
259-
if missing_weights:
260-
raise ValueError(
261-
f"The provided checkpoint is missing the following weights that are expected by the model: {missing_weights}. Please fix the fqn's in your checkpoint to match."
262-
)
263-
if unexpected:
264-
if self.verbose:
265-
print(f"Unexpected keys: {unexpected}")
262+
263+
def weight_reset(m):
264+
reset_parameters = getattr(m, "reset_parameters", None)
265+
if callable(reset_parameters):
266+
m.reset_parameters()
267+
268+
self.model_.apply(weight_reset)
266269

267270
# Prune the input layer if input_prune_map is provided
268271
if input_prune_map is not None:

0 commit comments

Comments
 (0)