Skip to content

Commit ac5b4c0

Browse files
move quant parsing to base model config
1 parent 508a118 commit ac5b4c0

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

pkg/hfutil/modelconfig/interface.go

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ type BaseModelConfig struct {
6767
TorchDtype string `json:"torch_dtype"`
6868
TransformerVersion string `json:"transformers_version"`
6969

70+
// Quantization config (optional, shared across all model types)
71+
QuantizationConfig *QuantizationConfig `json:"quantization_config,omitempty"`
72+
7073
// Internal fields (not in JSON)
7174
ConfigPath string `json:"-"`
7275
}
@@ -91,6 +94,14 @@ func (c *BaseModelConfig) GetTorchDtype() string {
9194
return c.TorchDtype
9295
}
9396

97+
// GetQuantizationType returns the quantization method used (if any)
98+
func (c *BaseModelConfig) GetQuantizationType() string {
99+
if c.QuantizationConfig != nil && c.QuantizationConfig.QuantMethod != "" {
100+
return c.QuantizationConfig.QuantMethod
101+
}
102+
return ""
103+
}
104+
94105
// Default implementation for HasVision - most models don't have vision capabilities
95106
func (c *BaseModelConfig) HasVision() bool {
96107
return false
@@ -238,9 +249,6 @@ type GenericModelConfig struct {
238249
IntermediateSize int `json:"intermediate_size"`
239250
MaxPositionEmbeddings int `json:"max_position_embeddings"`
240251
VocabSize int `json:"vocab_size"`
241-
242-
// Quantization config (optional)
243-
QuantizationConfig *QuantizationConfig `json:"quantization_config,omitempty"`
244252
}
245253

246254
// GetParameterCount attempts to get parameter count from safetensors, falls back to estimation
@@ -278,13 +286,6 @@ func estimateGenericParams(hiddenSize, numLayers, intermediateSize, vocabSize in
278286
return embeddingParams + totalLayerParams
279287
}
280288

281-
func (c *GenericModelConfig) GetQuantizationType() string {
282-
if c.QuantizationConfig != nil && c.QuantizationConfig.QuantMethod != "" {
283-
return c.QuantizationConfig.QuantMethod
284-
}
285-
return ""
286-
}
287-
288289
func (c *GenericModelConfig) GetContextLength() int {
289290
return c.MaxPositionEmbeddings
290291
}
@@ -380,11 +381,6 @@ func (c *GenericDiffusionModelConfig) GetParameterCount() int64 {
380381
return total
381382
}
382383

383-
func (c *GenericDiffusionModelConfig) GetQuantizationType() string {
384-
// Not supported. Doesn't seem to be standardized in HF.
385-
return ""
386-
}
387-
388384
func (c *GenericDiffusionModelConfig) GetContextLength() int {
389385
if c.ConfigPath == "" {
390386
return 0

0 commit comments

Comments
 (0)