Skip to content

Commit 10ff17b

Browse files
refactor: use shared Turbo config defaults for export and push
Agent-Logs-Url: https://github.com/codewithdark-git/QuantLLM/sessions/aa78d528-be1d-4467-813d-711a55ade22a Co-authored-by: codewithdark-git <144595403+codewithdark-git@users.noreply.github.com>
1 parent 2cd4bd2 commit 10ff17b

2 files changed

Lines changed: 23 additions & 11 deletions

File tree

quantllm/core/turbo_model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ def from_pretrained(
122122
quantize: Whether to quantize the model
123123
config_override: Dict to override any auto-detected settings
124124
config: Shared export/push config (format, quantization, push_format, etc.)
125-
quantize: Whether to quantize the model
126-
config_override: Dict to override any auto-detected settings
127125
verbose: Print loading progress
128126
129127
Returns:
@@ -1009,10 +1007,16 @@ def export(
10091007
>>> model.export("onnx", "./my_model_onnx/")
10101008
>>> model.export("mlx", "./my_model_mlx/", quantization="4bit")
10111009
"""
1012-
format = (format or self.export_push_config["format"]).lower()
1010+
format = (
1011+
format
1012+
if format is not None
1013+
else self.export_push_config.get("format", DEFAULT_EXPORT_PUSH_CONFIG["format"])
1014+
).lower()
10131015
effective_quantization = quantization
10141016
if effective_quantization is None and format == "gguf":
1015-
effective_quantization = self.export_push_config["quantization"]
1017+
effective_quantization = self.export_push_config.get(
1018+
"quantization", DEFAULT_EXPORT_PUSH_CONFIG["quantization"]
1019+
)
10161020

10171021
# Merge LoRA if applied
10181022
if self._lora_applied:
@@ -1025,7 +1029,7 @@ def export(
10251029
if output_path is None:
10261030
model_name = self.model.config._name_or_path.split('/')[-1]
10271031
if format == "gguf":
1028-
quant = effective_quantization or "Q4_K_M"
1032+
quant = effective_quantization
10291033
output_path = f"{model_name}.{quant.upper()}.gguf"
10301034
elif format == "safetensors":
10311035
output_path = f"./{model_name}-quantllm/"
@@ -1086,8 +1090,14 @@ def push_to_hub(
10861090
"""
10871091
from ..hub import QuantLLMHubManager
10881092

1089-
format_lower = (format or self.export_push_config["push_format"]).lower()
1090-
push_quantization = quantization or self.export_push_config["push_quantization"]
1093+
format_lower = (
1094+
format
1095+
if format is not None
1096+
else self.export_push_config.get("push_format", DEFAULT_EXPORT_PUSH_CONFIG["push_format"])
1097+
).lower()
1098+
push_quantization = quantization or self.export_push_config.get(
1099+
"push_quantization", DEFAULT_EXPORT_PUSH_CONFIG["push_quantization"]
1100+
)
10911101

10921102
# Get the original base model name (full path for HuggingFace link)
10931103
base_model_full = self.model.config._name_or_path
@@ -1101,7 +1111,7 @@ def push_to_hub(
11011111

11021112
if format_lower == "gguf":
11031113
# Export GGUF directly to staging
1104-
quant_label = push_quantization or "Q4_K_M"
1114+
quant_label = push_quantization
11051115
filename = f"{model_name}.{quant_label.upper()}.gguf"
11061116
save_path = os.path.join(manager.staging_dir, filename)
11071117

tests/test_export_push_config.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def _stub_turbo(export_push_config):
1111
model = TurboModel.__new__(TurboModel)
1212
model.model = _stub_model()
1313
model.tokenizer = None
14-
model.config = SimpleNamespace(quant_type="Q8_0")
14+
smart_config = SimpleNamespace(quant_type="Q8_0")
15+
model.config = smart_config
1516
model._lora_applied = False
1617
model.verbose = False
1718
model.export_push_config = export_push_config
@@ -36,7 +37,7 @@ def test_build_export_push_config_aligns_push_values_with_export_values():
3637
assert resolved["push_quantization"] == "Q5_K_M"
3738

3839

39-
def test_export_uses_shared_config_when_format_and_quantization_are_omitted():
40+
def test_export_prefers_shared_quantization_over_smart_config_quant_type():
4041
model = _stub_turbo(
4142
{
4243
"format": "gguf",
@@ -60,11 +61,12 @@ def fake_export_gguf(output_path, quantization=None, **kwargs):
6061

6162
output = model.export()
6263

64+
assert model.config.quant_type == "Q8_0"
6365
assert output.endswith(".Q4_K_M.gguf")
6466
assert captured["quantization"] == "Q4_K_M"
6567

6668

67-
def test_push_uses_shared_config_when_omitted(monkeypatch):
69+
def test_gguf_push_uses_shared_config_when_omitted(monkeypatch):
6870
model = _stub_turbo({
6971
"format": "gguf",
7072
"push_format": "gguf",

0 commit comments

Comments
 (0)