Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

167 changes: 118 additions & 49 deletions aigateway/component/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type OpenAIComponent interface {
GetAvailableModels(c context.Context, user string) ([]types.Model, error)
ListModels(c context.Context, user string, req types.ListModelsReq) (types.ModelList, error)
GetModelByID(c context.Context, username, modelID string) (*types.Model, error)
RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter, sceneValue string) error
RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter) error
CheckBalance(ctx context.Context, username, userUUID string) error
}

Expand Down Expand Up @@ -442,88 +442,157 @@ func getSceneFromSvcType(svcType int) int {
}
}

func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, model *types.Model, counter token.Counter, sceneValue string) error {
usage, err := counter.Usage(c)
if err != nil {
return fmt.Errorf("failed to get token usage from counter,error:%w", err)
// csghubMeteringLLMTypeFromModel returns metadata llm_type (e.g. serverless, inference) used as the path component in csghub://… metering URIs.
func csghubMeteringLLMTypeFromModel(m *types.Model) (string, error) {
if m == nil {
return "", fmt.Errorf("model is nil")
}
if m.Metadata == nil {
return "", fmt.Errorf("model metadata is nil: cannot resolve %s for resource path", types.MetaKeyLLMType)
}
llmType, ok := m.Metadata[types.MetaKeyLLMType].(string)
if !ok {
return "", fmt.Errorf("model metadata %s missing or not a string", types.MetaKeyLLMType)
}
return llmType, nil
}

scene := parseScene(sceneValue)
slog.DebugContext(c, "token usage", slog.Any("usage", usage), slog.Any("scene", scene))
var tokenUsageExtra = struct {
PromptTokenNum string `json:"prompt_token_num"`
CompletionTokenNum string `json:"completion_token_num"`
// 0: external, 1: owner is user, 2: other user is inference, 3: serverless
OwnerType commontypes.TokenUsageType `json:"owner_type"`
}{
PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens),
CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens),
// meteringResourceFromModel builds a MeteringResource from an OpenAI gateway model (see types.MeteringResource).
func meteringResourceFromModel(model *types.Model) (types.MeteringResource, error) {
if model == nil {
return types.MeteringResource{}, fmt.Errorf("model is nil")
}
if model.CSGHubModelID != "" {
llmType, err := csghubMeteringLLMTypeFromModel(model)
if err != nil {
return types.MeteringResource{}, err
}
id := fmt.Sprintf(types.CSGHubResourceFmt, llmType, model.CSGHubModelID)
return types.MeteringResource{
ResourceID: id,
ResourceName: id,
CustomerID: model.SvcName,
}, nil
}
if model.Provider != "" {
id := fmt.Sprintf(types.ExternalLLMResourceFmt, model.Provider, model.ID)
return types.MeteringResource{
ResourceID: id,
ResourceName: id,
CustomerID: id,
}, nil
}
return types.MeteringResource{}, nil
}

// tokenUsageMeteringExtra is serialized into MeteringEvent.Extra for token billing breakdown.
type tokenUsageMeteringExtra struct {
PromptTokenNum string `json:"prompt_token_num"`
CompletionTokenNum string `json:"completion_token_num"`
OwnerType commontypes.TokenUsageType `json:"owner_type"`
}

func validateModelForUsageRecord(c context.Context, model *types.Model) error {
if model == nil {
return fmt.Errorf("record usage: model is nil")
}
if model.CSGHubModelID != "" && model.Provider != "" {
slog.WarnContext(c, "bad model info, both csghub model id and external model provider is set",
slog.Any("model info", model))
slog.Any("model", model))
return fmt.Errorf("record usage: conflicting csghub model id and external provider")
}
if model.CSGHubModelID == "" && model.Provider == "" {
slog.WarnContext(c, "bad model info, both csghub model id and external model provider is not set",
slog.Any("model info", model))
slog.Any("model", model))
return fmt.Errorf("record usage: model missing resource identifiers")
}
return nil
}

func (m *openaiComponentImpl) tokenUsageMeteringExtraAndScene(c context.Context, userUUID string, model *types.Model, usage *token.Usage) (tokenUsageMeteringExtra, commontypes.SceneType, error) {
scene := commontypes.SceneModelServerless
extra := tokenUsageMeteringExtra{
PromptTokenNum: fmt.Sprintf("%d", usage.PromptTokens),
CompletionTokenNum: fmt.Sprintf("%d", usage.CompletionTokens),
}
if model.CSGHubModelID != "" {
switch model.SvcType {
case commontypes.ServerlessType:
tokenUsageExtra.OwnerType = commontypes.CSGHubServerlessInference
extra.OwnerType = commontypes.CSGHubServerlessInference
case commontypes.InferenceType:
if model.OwnerUUID == userUUID {
tokenUsageExtra.OwnerType = commontypes.CSGHubUserDeployedInference
extra.OwnerType = commontypes.CSGHubUserDeployedInference
} else {
belong, err := m.checkOrganization(c, userUUID, model.OwnerUUID)
if err != nil {
return fmt.Errorf("failed to check organization,error:%w", err)
return tokenUsageMeteringExtra{}, 0, fmt.Errorf("failed to check organization: %w", err)
}
if belong {
tokenUsageExtra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference
extra.OwnerType = commontypes.CSGHubOrganFellowDeployedInference
} else {
tokenUsageExtra.OwnerType = commontypes.CSGHubOtherDeployedInference
extra.OwnerType = commontypes.CSGHubOtherDeployedInference
}
}
scene = commontypes.SceneModelInference
default:
slog.WarnContext(c, "bad model info, csghub model missing service type",
slog.Any("model info", model))
slog.ErrorContext(c, "bad model info, csghub model missing service type", slog.Any("model", model))
return tokenUsageMeteringExtra{}, 0, fmt.Errorf("record usage: csghub model has invalid or missing service type")
}
} else if model.Provider != "" {
extra.OwnerType = commontypes.ExternalInference
}
if model.Provider != "" {
tokenUsageExtra.OwnerType = commontypes.ExternalInference
}
return extra, scene, nil
}

extraData, _ := json.Marshal(tokenUsageExtra)
event := commontypes.MeteringEvent{
Uuid: uuid.New(),
UserUUID: userUUID,
Value: usage.TotalTokens,
ValueType: commontypes.TokenNumberType, // count by token
Scene: int(scene),
OpUID: "aigateway",
CreatedAt: time.Now(),
Extra: string(extraData),
func (m *openaiComponentImpl) RecordUsage(c context.Context, userUUID string, model *types.Model, counter token.Counter) error {
usage, err := counter.Usage(c)
if err != nil {
return fmt.Errorf("failed to get token usage from counter: %w", err)
}
if model.CSGHubModelID != "" {
event.ResourceID = model.CSGHubModelID
event.ResourceName = model.CSGHubModelID
event.CustomerID = model.SvcName
if err := validateModelForUsageRecord(c, model); err != nil {
return err
}
if model.Provider != "" {
extendModelKey := fmt.Sprintf("%s:%s", model.Provider, model.ID)
event.ResourceID = extendModelKey
event.ResourceName = extendModelKey
event.CustomerID = extendModelKey
res, ridErr := meteringResourceFromModel(model)
if ridErr != nil {
slog.ErrorContext(c, "cannot record usage: invalid model for resource id", slog.Any("error", ridErr), slog.Any("model", model))
return fmt.Errorf("cannot record usage: %w", ridErr)
}
if res.ResourceID == "" {
slog.ErrorContext(c, "cannot record usage: empty resource id for model", slog.Any("model", model))
return fmt.Errorf("cannot record usage: empty resource id")
}
extra, scene, err := m.tokenUsageMeteringExtraAndScene(c, userUUID, model, usage)
if err != nil {
return err
}
extraData, err := json.Marshal(extra)
if err != nil {
return fmt.Errorf("failed to marshal token usage extra: %w", err)
}
event := commontypes.MeteringEvent{
Uuid: uuid.New(),
UserUUID: userUUID,
Value: usage.TotalTokens,
ValueType: commontypes.TokenNumberType,
Scene: int(scene),
OpUID: "aigateway",
CreatedAt: time.Now(),
Extra: string(extraData),
ResourceID: res.ResourceID,
ResourceName: res.ResourceName,
CustomerID: res.CustomerID,
}
eventData, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal metering event: %w", err)
}
eventData, _ := json.Marshal(event)
err = m.eventPub.PublishMeteringEvent(eventData)
if err != nil {
slog.ErrorContext(c, "failed to publish token usage event", slog.Any("event", event), slog.Any("error", err))
return fmt.Errorf("failed to publish token usage event,error:%w", err)
return fmt.Errorf("failed to publish token usage event: %w", err)
}

slog.InfoContext(c, "public token usage event success", slog.Any("event", event))
slog.InfoContext(c, "published token usage event success", slog.Any("event", event))
return nil
}

Expand Down
7 changes: 0 additions & 7 deletions aigateway/component/openai_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"opencsg.com/csghub-server/builder/store/cache"
"opencsg.com/csghub-server/builder/store/database"
"opencsg.com/csghub-server/common/config"
common_types "opencsg.com/csghub-server/common/types"
)

type extendOpenai struct{}
Expand Down Expand Up @@ -39,12 +38,6 @@ func (e *openaiComponentImpl) userPreference(ctx context.Context, req *types.Use
return req.Models, nil
}

// parseScene parses the scene value from the HTTP header
// return SceneModelServerless
func parseScene(sceneValue string) common_types.SceneType {
return common_types.SceneModelServerless
}

func (e *extendOpenai) CheckBalance(ctx context.Context, username, userUUID string) error {
return nil
}
Expand Down
Loading
Loading