Skip to content

Commit f303270

Browse files
refactor phi to base on base model config
1 parent a167825 commit f303270

1 file changed

Lines changed: 2 additions & 34 deletions

File tree

pkg/hfutil/modelconfig/phi.go

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import (
88

99
// PhiModelConfig represents the configuration for a Phi model
1010
type PhiModelConfig struct {
11-
ConfigPath string `json:"-"`
12-
Architectures []string `json:"architectures"`
13-
ModelType string `json:"model_type"`
11+
BaseModelConfig
12+
1413
AttentionDropout float64 `json:"attention_dropout"`
1514
AttentionProbsDropoutProb float64 `json:"attention_probs_dropout_prob"`
1615
BosTokenId int `json:"bos_token_id"`
@@ -33,8 +32,6 @@ type PhiModelConfig struct {
3332
RopeScaling *struct{} `json:"rope_scaling"`
3433
RopeTheta float64 `json:"rope_theta"`
3534
TieWordEmbeddings bool `json:"tie_word_embeddings"`
36-
TorchDtype string `json:"torch_dtype"`
37-
TransformersVersion string `json:"transformers_version"`
3835
TypeVocabSize int `json:"type_vocab_size"`
3936
UseCache bool `json:"use_cache"`
4037
VocabSize int `json:"vocab_size"`
@@ -72,30 +69,6 @@ func (c *PhiModelConfig) GetParameterCount() int64 {
7269
return 0
7370
}
7471

75-
// GetTransformerVersion returns the transformers library version
76-
func (c *PhiModelConfig) GetTransformerVersion() string {
77-
return c.TransformersVersion
78-
}
79-
80-
// GetQuantizationType returns the quantization method used (if any)
81-
// Phi models typically don't have quantization config directly in the config file
82-
func (c *PhiModelConfig) GetQuantizationType() string {
83-
return ""
84-
}
85-
86-
// GetArchitecture returns the model architecture
87-
func (c *PhiModelConfig) GetArchitecture() string {
88-
if len(c.Architectures) > 0 {
89-
return c.Architectures[0]
90-
}
91-
return ""
92-
}
93-
94-
// GetModelType returns the model type
95-
func (c *PhiModelConfig) GetModelType() string {
96-
return c.ModelType
97-
}
98-
9972
// GetContextLength returns the maximum context length
10073
func (c *PhiModelConfig) GetContextLength() int {
10174
return c.MaxPositionEmbeddings
@@ -106,11 +79,6 @@ func (c *PhiModelConfig) GetModelSizeBytes() int64 {
10679
return EstimateModelSizeBytes(c.GetParameterCount(), c.GetTorchDtype())
10780
}
10881

109-
// GetTorchDtype returns the torch data type used by the model
110-
func (c *PhiModelConfig) GetTorchDtype() string {
111-
return c.TorchDtype
112-
}
113-
11482
// HasVision returns false since this is not a multimodal vision model
11583
func (c *PhiModelConfig) HasVision() bool {
11684
return false

0 commit comments

Comments
 (0)