Skip to content

Commit 52429d6

Browse files
feat: add roadmap primitives for triton quantization and memory-aware gguf export
Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/fc1f2077-187e-4757-9927-2c60ef187666 Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com>
1 parent 6bad9ad commit 52429d6

6 files changed

Lines changed: 154 additions & 7 deletions

File tree

docs/guide/gguf-export.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,24 @@ For very large models:
212212
# Use lower quantization
213213
model.export("gguf", "model.Q3_K_M.gguf", quantization="Q3_K_M")
214214

215-
# Or export with streaming (reduces memory)
216-
model.export("gguf", "model.gguf", quantization="Q4_K_M", streaming=True)
215+
# Enable chunked conversion + smart ordering
216+
model.export(
217+
"gguf",
218+
"model.gguf",
219+
quantization="Q4_K_M",
220+
chunked_conversion=True,
221+
max_shard_size="2GB",
222+
smart_tensor_ordering=True,
223+
)
224+
225+
# Force intermediate files to a dedicated disk offload directory
226+
model.export(
227+
"gguf",
228+
"model.gguf",
229+
quantization="Q4_K_M",
230+
disk_offloading=True,
231+
disk_offload_dir="./quantllm_offload",
232+
)
217233
```
218234

219235
### Windows Issues

quantllm/core/memory.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
"""
1313

1414
import gc
15-
from typing import Optional, Dict, Any, List, Union, Callable
15+
from typing import Optional, Dict, Any, List, Union, Callable, Iterable
1616
from contextlib import contextmanager
17+
from collections import OrderedDict
1718
import torch
1819
import torch.nn as nn
1920

@@ -186,6 +187,26 @@ def estimate_model_memory(
186187
}
187188

188189

190+
def memory_optimized_tensor_order(
191+
state_dict: Dict[str, torch.Tensor],
192+
*,
193+
prioritize_large_tensors: bool = True,
194+
) -> "OrderedDict[str, torch.Tensor]":
195+
"""
196+
Return an ordered state dict to reduce peak memory pressure during serialization.
197+
198+
By default, larger tensors are emitted first to reduce long-lived allocator pressure
199+
in shard-based writes on very large checkpoints.
200+
"""
201+
items: Iterable = state_dict.items()
202+
sorted_items = sorted(
203+
items,
204+
key=lambda kv: kv[1].numel() * kv[1].element_size(),
205+
reverse=prioritize_large_tensors,
206+
)
207+
return OrderedDict(sorted_items)
208+
209+
189210
class DynamicOffloader:
190211
"""
191212
Dynamic layer offloading for large models.

quantllm/core/turbo_model.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,13 +1144,24 @@ def _export_gguf(
11441144
output_path: Output file path for GGUF
11451145
quantization: Quantization type (Q4_K_M, Q5_K_M, Q8_0, etc.)
11461146
fast_mode: Skip intermediate F16 step for faster export (slightly less optimal)
1147+
chunked_conversion: Save model shards during conversion for large checkpoints
1148+
max_shard_size: Max shard size used when chunked conversion is active
1149+
smart_tensor_ordering: Save tensors in memory-optimized order
1150+
disk_offloading: Use a dedicated temp/offload directory for intermediate artifacts
1151+
disk_offload_dir: Directory used when disk_offloading=True
11471152
"""
11481153
from ..quant import convert_to_gguf, quantize_gguf, ensure_llama_cpp_installed, GGUF_QUANT_TYPES
11491154
from ..utils import QuantLLMProgress, format_time, format_size
11501155
import time
11511156

11521157
start_time = time.time()
11531158

1159+
chunked_conversion = bool(kwargs.pop("chunked_conversion", False))
1160+
max_shard_size = kwargs.pop("max_shard_size", "2GB" if chunked_conversion else "50GB")
1161+
smart_tensor_ordering = bool(kwargs.pop("smart_tensor_ordering", False))
1162+
disk_offloading = bool(kwargs.pop("disk_offloading", False))
1163+
disk_offload_dir = kwargs.pop("disk_offload_dir", None)
1164+
11541165
quant_type = quantization or self.config.quant_type or "q4_k_m"
11551166
quant_type_upper = quant_type.upper()
11561167
quant_type_lower = quant_type.lower()
@@ -1163,6 +1174,12 @@ def _export_gguf(
11631174
print_info(f"Target quantization: {quant_type_upper}")
11641175
if fast_mode:
11651176
print_info("Fast mode enabled")
1177+
if chunked_conversion:
1178+
print_info(f"Chunked conversion enabled (max_shard_size={max_shard_size})")
1179+
if smart_tensor_ordering:
1180+
print_info("Smart tensor ordering enabled")
1181+
if disk_offloading:
1182+
print_info(f"Disk offloading enabled ({disk_offload_dir or 'system temp'})")
11661183

11671184
# Ensure llama.cpp
11681185
if self.verbose:
@@ -1188,21 +1205,35 @@ def _export_gguf(
11881205
# Get model name for file naming
11891206
model_name = self.model.config._name_or_path.split('/')[-1]
11901207

1208+
temp_parent = disk_offload_dir if disk_offloading else None
1209+
if temp_parent:
1210+
os.makedirs(temp_parent, exist_ok=True)
1211+
11911212
# Create temp dir for conversion
1192-
with tempfile.TemporaryDirectory() as temp_dir:
1213+
with tempfile.TemporaryDirectory(dir=temp_parent) as temp_dir:
11931214
# Step 1: Save model to temp directory
11941215
if self.verbose:
11951216
print_header("Step 1/3: Saving Model", icon="💾")
11961217
print_info(f"Staging model to {temp_dir}...")
11971218

11981219
with QuantLLMProgress() as progress:
11991220
task = progress.add_task("Saving model weights...", total=None)
1221+
save_kwargs = {
1222+
"safe_serialization": True,
1223+
"max_shard_size": max_shard_size,
1224+
}
1225+
1226+
if smart_tensor_ordering:
1227+
from .memory import memory_optimized_tensor_order
1228+
save_kwargs["state_dict"] = memory_optimized_tensor_order(model_to_save.state_dict())
1229+
12001230
try:
1201-
model_to_save.save_pretrained(temp_dir, safe_serialization=True)
1231+
model_to_save.save_pretrained(temp_dir, **save_kwargs)
12021232
except Exception as e:
12031233
if self.verbose:
12041234
print_warning(f"SafeTensors save failed ({e}), using PyTorch format...")
1205-
model_to_save.save_pretrained(temp_dir, safe_serialization=False)
1235+
save_kwargs["safe_serialization"] = False
1236+
model_to_save.save_pretrained(temp_dir, **save_kwargs)
12061237

12071238
self.tokenizer.save_pretrained(temp_dir)
12081239
progress.update(task, completed=100)

quantllm/kernels/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
from .triton import (
88
TritonQuantizedLinear,
99
fused_dequant_matmul,
10+
int4_matmul,
1011
is_triton_available,
12+
triton_q4_0_quantize,
13+
triton_q8_0_quantize,
1114
)
1215

1316
__all__ = [
1417
"TritonQuantizedLinear",
1518
"fused_dequant_matmul",
19+
"int4_matmul",
1620
"is_triton_available",
21+
"triton_q4_0_quantize",
22+
"triton_q8_0_quantize",
1723
]

quantllm/kernels/triton/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
from .quantized_linear import (
88
TritonQuantizedLinear,
99
fused_dequant_matmul,
10+
int4_matmul,
1011
is_triton_available,
12+
triton_q4_0_quantize,
13+
triton_q8_0_quantize,
1114
)
1215

1316
__all__ = [
1417
"TritonQuantizedLinear",
1518
"fused_dequant_matmul",
19+
"int4_matmul",
1620
"is_triton_available",
21+
"triton_q4_0_quantize",
22+
"triton_q8_0_quantize",
1723
]

quantllm/kernels/triton/quantized_linear.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Performance: ~2-3x faster than separate dequant + matmul
99
"""
1010

11-
from typing import Optional, Tuple
11+
from typing import Callable, Dict, Optional, Tuple
1212
import torch
1313
import torch.nn as nn
1414

@@ -27,6 +27,67 @@ def is_triton_available() -> bool:
2727
return _TRITON_AVAILABLE
2828

2929

30+
def triton_q8_0_quantize(weight: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
31+
"""
32+
Quantize a weight matrix to Q8_0 format (per-column symmetric int8).
33+
34+
Returns:
35+
qweight: int8 tensor [in_features, out_features]
36+
scales: fp tensor [1, out_features]
37+
"""
38+
if weight.dim() != 2:
39+
raise ValueError(f"Q8_0 quantization expects a 2D tensor, got shape={tuple(weight.shape)}")
40+
41+
max_abs = weight.abs().amax(dim=0, keepdim=True).clamp(min=eps)
42+
scale = max_abs / 127.0
43+
qweight = torch.clamp(torch.round(weight / scale), -128, 127).to(torch.int8)
44+
return qweight, scale.to(weight.dtype)
45+
46+
47+
def triton_q4_0_quantize(weight: torch.Tensor, eps: float = 1e-8) -> Tuple[torch.Tensor, torch.Tensor]:
48+
"""
49+
Quantize a weight matrix to Q4_0 format (per-column symmetric 4-bit stored in int8).
50+
51+
Returns:
52+
qweight: int8 tensor [in_features, out_features] with values in [-8, 7]
53+
scales: fp tensor [1, out_features]
54+
"""
55+
if weight.dim() != 2:
56+
raise ValueError(f"Q4_0 quantization expects a 2D tensor, got shape={tuple(weight.shape)}")
57+
58+
max_abs = weight.abs().amax(dim=0, keepdim=True).clamp(min=eps)
59+
scale = max_abs / 7.0
60+
qweight = torch.clamp(torch.round(weight / scale), -8, 7).to(torch.int8)
61+
return qweight, scale.to(weight.dtype)
62+
63+
64+
def int4_matmul(
65+
x: torch.Tensor,
66+
qweight: torch.Tensor,
67+
scales: torch.Tensor,
68+
bias: Optional[torch.Tensor] = None,
69+
) -> torch.Tensor:
70+
"""
71+
INT4 matmul path backed by fused dequant+matmul on CUDA/Triton when available.
72+
73+
Args:
74+
x: Input [..., in_features]
75+
qweight: Quantized int4 values stored in int8, shape [in_features, out_features]
76+
scales: Per-column scales, shape [1, out_features] or [in_features/group, out_features]
77+
bias: Optional bias [out_features]
78+
"""
79+
zeros = torch.zeros_like(scales)
80+
group_size = qweight.shape[0] if scales.shape[0] == 1 else max(qweight.shape[0] // scales.shape[0], 1)
81+
return fused_dequant_matmul(
82+
x=x,
83+
qweight=qweight,
84+
scales=scales,
85+
zeros=zeros,
86+
bias=bias,
87+
group_size=group_size,
88+
)
89+
90+
3091
if _TRITON_AVAILABLE:
3192
@triton.jit
3293
def _fused_dequant_matmul_kernel(
@@ -462,3 +523,9 @@ def extra_repr(self) -> str:
462523
f'group_size={self.group_size}, '
463524
f'triton={self._use_triton}'
464525
)
526+
527+
528+
TRITON_QUANT_KERNELS: Dict[str, Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = {
529+
"q4_0": triton_q4_0_quantize,
530+
"q8_0": triton_q8_0_quantize,
531+
}

0 commit comments

Comments
 (0)