Skip to content

Commit 0ea6812

Browse files
committed
Add streaming iter_load and tighten quant public API
- serialize.py: add iter_load() generator that streams weights one at a time from safetensors, keeping peak memory proportional to the largest single weight instead of loading all weights into memory at once. - pack_cuda.py: rewrite load_and_pack_for_cuda to use iter_load for streaming — avoids ~40 GB peak memory when loading the 31B checkpoint. - __init__.py: remove low-level CUDA packer internals (pack_int4_for_cuda, pack_int8_for_cuda, pack_linear_for_cuda, pack_embedding_for_cuda) from the public API. Tests import these directly from pack_cuda.py.
1 parent 4279fc4 commit 0ea6812

4 files changed

Lines changed: 132 additions & 14 deletions

File tree

examples/models/gemma4_31b/quant/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .pack import ModulePackerFn, pack_model, pack_one # noqa: F401
8-
from .pack_cuda import ( # noqa: F401
9-
DEFAULT_CUDA_PACKERS,
10-
load_and_pack_for_cuda,
11-
pack_embedding_for_cuda,
12-
pack_int4_for_cuda,
13-
pack_int8_for_cuda,
14-
pack_linear_for_cuda,
15-
)
8+
from .pack_cuda import DEFAULT_CUDA_PACKERS, load_and_pack_for_cuda # noqa: F401
169
from .quantize import dequantize_weight, quantize_model, quantize_weight # noqa: F401
1710
from .recipe import QuantConfig, QuantRecipe, QuantRule # noqa: F401
1811
from .serialize import ( # noqa: F401

examples/models/gemma4_31b/quant/pack_cuda.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818

1919
from .pack import ModulePackerFn, pack_model # noqa: F401
20-
from .serialize import CanonicalQuantizedWeight, load
20+
from .serialize import CanonicalQuantizedWeight
2121

2222

2323
# ---------------------------------------------------------------------------
@@ -202,9 +202,23 @@ def load_and_pack_for_cuda(
202202
model: nn.Module,
203203
packers: dict[type, ModulePackerFn] | None = None,
204204
) -> None:
205-
"""Read a quantized safetensors file and pack into ``model`` for CUDA.
205+
"""Stream weights from a quantized safetensors file and pack for CUDA.
206206
207-
Thin wrapper: ``load`` + ``pack_model``.
207+
Uses ``iter_load`` to process one weight at a time, keeping peak
208+
memory proportional to the largest single weight instead of loading
209+
all weights into memory at once.
208210
"""
209-
quantized, unquantized = load(path)
210-
pack_model(model, quantized, unquantized, packers or DEFAULT_CUDA_PACKERS)
211+
from .pack import pack_one
212+
from .serialize import iter_load
213+
214+
_packers = packers or DEFAULT_CUDA_PACKERS
215+
216+
for fqn, value in iter_load(path):
217+
pack_one(model, fqn, value, _packers)
218+
219+
for fqn, p in model.named_parameters():
220+
if p.device.type == "meta":
221+
raise RuntimeError(
222+
f"Weight '{fqn}' not found in checkpoint "
223+
f"(model/checkpoint version mismatch?)"
224+
)

examples/models/gemma4_31b/quant/serialize.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import json
2424
from dataclasses import dataclass
25-
from typing import Optional
25+
from typing import Iterator, Optional
2626

2727
import torch
2828
from safetensors import safe_open
@@ -233,3 +233,49 @@ def load(
233233
header = f.metadata()
234234
tensors = {k: f.get_tensor(k) for k in f.keys()}
235235
return deserialize(tensors, header)
236+
237+
238+
def iter_load(
239+
path: str,
240+
) -> Iterator[tuple[str, CanonicalQuantizedWeight | torch.Tensor]]:
241+
"""Stream weights from a safetensors file one at a time.
242+
243+
Yields ``(fqn, value)`` where *value* is a ``CanonicalQuantizedWeight``
244+
for quantized weights or a plain ``torch.Tensor`` for unquantized ones.
245+
Only one weight's tensors are resident in memory at a time, keeping peak
246+
memory proportional to the largest single weight.
247+
"""
248+
with safe_open(path, framework="pt", device="cpu") as f:
249+
header = f.metadata()
250+
quant_meta = json.loads(header.get("quant", "{}"))
251+
all_keys = set(f.keys())
252+
consumed: set[str] = set()
253+
254+
for fqn, meta in quant_meta.items():
255+
config = QuantConfig(
256+
bits=meta["bits"],
257+
group_size=meta["group_size"],
258+
symmetric=meta["symmetric"],
259+
method=meta["method"],
260+
)
261+
qdata = f.get_tensor(f"{fqn}.qdata")
262+
consumed.add(f"{fqn}.qdata")
263+
if config.bits == 4:
264+
qdata = _nibble_unpack(qdata, meta["shape"][-1])
265+
266+
scale = f.get_tensor(f"{fqn}.scale")
267+
consumed.add(f"{fqn}.scale")
268+
269+
zero_key = f"{fqn}.zero"
270+
zero = None
271+
if zero_key in all_keys:
272+
zero = f.get_tensor(zero_key)
273+
consumed.add(zero_key)
274+
275+
yield fqn, CanonicalQuantizedWeight(
276+
qdata=qdata, scale=scale, zero=zero, config=config
277+
)
278+
279+
for key in all_keys:
280+
if key not in consumed:
281+
yield key, f.get_tensor(key)

examples/models/gemma4_31b/quant/tests/test_serialize.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_nibble_unpack,
2727
CanonicalQuantizedWeight,
2828
deserialize,
29+
iter_load,
2930
load,
3031
save,
3132
serialize,
@@ -264,5 +265,69 @@ def test_empty_quantized(self):
264265
self.assertTrue(torch.equal(unq["w"], u["w"]))
265266

266267

268+
class TestIterLoad(unittest.TestCase):
269+
"""Streaming load — one weight at a time from disk."""
270+
271+
def test_yields_all_weights(self):
272+
"""iter_load yields every quantized and unquantized weight."""
273+
q4 = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max")
274+
q8 = QuantConfig(bits=8, group_size=32, symmetric=True, method="min_max")
275+
cw4 = _make_cqw((64, 128), q4)
276+
cw8 = _make_cqw((32, 64), q8)
277+
unq = {"norm.weight": torch.randn(64, dtype=torch.bfloat16)}
278+
279+
with tempfile.TemporaryDirectory() as d:
280+
path = os.path.join(d, "m.safetensors")
281+
save({"proj.weight": cw4, "embed.weight": cw8}, unq, path)
282+
items = list(iter_load(path))
283+
284+
fqns = {fqn for fqn, _ in items}
285+
self.assertIn("proj.weight", fqns)
286+
self.assertIn("embed.weight", fqns)
287+
self.assertIn("norm.weight", fqns)
288+
self.assertEqual(len(items), 3)
289+
290+
def test_quantized_matches_load(self):
291+
"""Streaming yields identical CQW to batch load."""
292+
config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max")
293+
cw = _make_cqw((64, 128), config)
294+
295+
with tempfile.TemporaryDirectory() as d:
296+
path = os.path.join(d, "m.safetensors")
297+
save({"w": cw}, {}, path)
298+
299+
q_batch, _ = load(path)
300+
items = dict(iter_load(path))
301+
302+
batch_cw = q_batch["w"]
303+
stream_cw = items["w"]
304+
self.assertIsInstance(stream_cw, CanonicalQuantizedWeight)
305+
self.assertTrue(torch.equal(batch_cw.qdata, stream_cw.qdata))
306+
self.assertTrue(torch.equal(batch_cw.scale, stream_cw.scale))
307+
self.assertTrue(torch.equal(batch_cw.zero, stream_cw.zero))
308+
self.assertEqual(batch_cw.config, stream_cw.config)
309+
310+
def test_unquantized_matches_load(self):
311+
"""Streaming yields identical plain tensors to batch load."""
312+
unq = {"a": torch.randn(8, 16, dtype=torch.bfloat16)}
313+
314+
with tempfile.TemporaryDirectory() as d:
315+
path = os.path.join(d, "m.safetensors")
316+
save({}, unq, path)
317+
318+
_, u_batch = load(path)
319+
items = dict(iter_load(path))
320+
321+
self.assertTrue(torch.equal(u_batch["a"], items["a"]))
322+
323+
def test_empty_file(self):
324+
"""Streaming an empty checkpoint yields nothing."""
325+
with tempfile.TemporaryDirectory() as d:
326+
path = os.path.join(d, "m.safetensors")
327+
save({}, {}, path)
328+
items = list(iter_load(path))
329+
self.assertEqual(len(items), 0)
330+
331+
267332
if __name__ == "__main__":
268333
unittest.main()

0 commit comments

Comments
 (0)