Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aigateway/types/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ func (m Model) MarshalJSON() ([]byte, error) {
Public bool `json:"public,omitempty"`
Endpoint string `json:"endpoint"`
Metadata map[string]any `json:"metadata"`
CSGHubModelID *string `json:"csghub_model_id,omitempty"`
OwnerUUID *string `json:"owner_uuid,omitempty"`
ClusterID *string `json:"cluster_id,omitempty"`
SvcName *string `json:"svc_name,omitempty"`
SvcType *int `json:"svc_type,omitempty"`
ImageID *string `json:"image_id,omitempty"`
AuthHead *string `json:"auth_head,omitempty"`
Provider *string `json:"provider,omitempty"`
Expand All @@ -102,6 +105,12 @@ func (m Model) MarshalJSON() ([]byte, error) {
supportFC := m.SupportFunctionCall
resp.SupportFunctionCall = &supportFC
}
if m.CSGHubModelID != "" {
Comment thread
Rader marked this conversation as resolved.
resp.CSGHubModelID = &m.CSGHubModelID
}
if m.OwnerUUID != "" {
resp.OwnerUUID = &m.OwnerUUID
}
if m.Provider != "" {
resp.Provider = &m.Provider
}
Expand All @@ -114,6 +123,9 @@ func (m Model) MarshalJSON() ([]byte, error) {
if m.SvcName != "" {
resp.SvcName = &m.SvcName
}
if m.SvcType != 0 {
resp.SvcType = &m.SvcType
}
if m.ImageID != "" {
resp.ImageID = &m.ImageID
}
Expand All @@ -136,8 +148,11 @@ func (m *Model) UnmarshalJSON(data []byte) error {
Public bool `json:"public,omitempty"`
Endpoint string `json:"endpoint"`
Metadata map[string]any `json:"metadata"`
CSGHubModelID string `json:"csghub_model_id,omitempty"`
OwnerUUID string `json:"owner_uuid,omitempty"`
ClusterID string `json:"cluster_id,omitempty"`
SvcName string `json:"svc_name,omitempty"`
SvcType int `json:"svc_type,omitempty"`
ImageID string `json:"image_id,omitempty"`
AuthHead string `json:"auth_head,omitempty"`
Provider string `json:"provider,omitempty"`
Expand All @@ -157,8 +172,11 @@ func (m *Model) UnmarshalJSON(data []byte) error {
m.SupportFunctionCall = aux.SupportFunctionCall
m.Endpoint = aux.Endpoint
m.Metadata = aux.Metadata
m.CSGHubModelID = aux.CSGHubModelID
m.OwnerUUID = aux.OwnerUUID
m.ClusterID = aux.ClusterID
m.SvcName = aux.SvcName
m.SvcType = aux.SvcType
m.ImageID = aux.ImageID
m.AuthHead = aux.AuthHead
m.Provider = aux.Provider
Expand Down
51 changes: 51 additions & 0 deletions aigateway/types/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ func TestModelSerialization(t *testing.T) {
},
InternalModelInfo: InternalModelInfo{
CSGHubModelID: "test/repo/path",
OwnerUUID: "test-owner-uuid",
ClusterID: "test-cluster-id",
SvcName: "test-service",
SvcType: 1,
Expand Down Expand Up @@ -79,6 +80,17 @@ func TestModelSerialization(t *testing.T) {
if !contains(jsonStr, "image_id") || !contains(jsonStr, "test-image-id") {
t.Errorf("Internal response should contain expanded InternalModelInfo fields, got: %s", jsonStr)
}
// csghub_model_id, owner_uuid, and svc_type must survive the Redis round-trip so
Comment thread
Rader marked this conversation as resolved.
// that RecordUsage can populate resource_id for inference models.
if !contains(jsonStr, "csghub_model_id") || !contains(jsonStr, "test/repo/path") {
t.Errorf("Internal response should contain csghub_model_id, got: %s", jsonStr)
}
if !contains(jsonStr, "owner_uuid") {
t.Errorf("Internal response should contain owner_uuid, got: %s", jsonStr)
}
if !contains(jsonStr, "svc_type") {
t.Errorf("Internal response should contain svc_type, got: %s", jsonStr)
}
})

// case3: mode switching
Expand Down Expand Up @@ -168,6 +180,45 @@ func TestModelUnmarshal(t *testing.T) {
}
}

// TestInferenceModelRoundTrip verifies that CSGHubModelID, OwnerUUID, and SvcType
Comment thread
Rader marked this conversation as resolved.
// survive a Redis marshal→unmarshal cycle so that RecordUsage can always populate
// resource_id for inference (llm_type=inference) models.
func TestInferenceModelRoundTrip(t *testing.T) {
original := &Model{
BaseModel: BaseModel{
ID: "Qwen/Qwen3Guard-Gen-0.6B:fgufi9nytc00",
Object: "model",
Created: 1633046400,
OwnedBy: "Qwen",
},
InternalModelInfo: InternalModelInfo{
CSGHubModelID: "Qwen/Qwen3Guard-Gen-0.6B",
OwnerUUID: "uuid-owner-123",
ClusterID: "cluster-abc",
SvcName: "fgufi9nytc00",
SvcType: 2,
ImageID: "img-xyz",
},
Endpoint: "http://inference.internal/v1",
InternalUse: true,
}

data, err := json.Marshal(original)
require.NoError(t, err, "marshal should not error")

var restored Model
require.NoError(t, json.Unmarshal(data, &restored), "unmarshal should not error")

require.Equal(t, original.CSGHubModelID, restored.CSGHubModelID, "CSGHubModelID must round-trip")
require.Equal(t, original.OwnerUUID, restored.OwnerUUID, "OwnerUUID must round-trip")
require.Equal(t, original.SvcType, restored.SvcType, "SvcType must round-trip")
require.Equal(t, original.ClusterID, restored.ClusterID, "ClusterID must round-trip")
require.Equal(t, original.SvcName, restored.SvcName, "SvcName must round-trip")
require.Equal(t, original.ImageID, restored.ImageID, "ImageID must round-trip")
require.Equal(t, original.ID, restored.ID, "ID must round-trip")
require.Equal(t, original.Endpoint, restored.Endpoint, "Endpoint must round-trip")
}

func TestModel_SkipBalance(t *testing.T) {
Comment thread
Rader marked this conversation as resolved.
tests := []struct {
name string
Expand Down
Loading