Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 251 additions & 6 deletions components/model/deepseek/deepseek_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,257 @@
assert.Equal(t, "panic error: info, \nstack: stack", err.Error())
}

func TestWithTools(t *testing.T) {
cm := &ChatModel{conf: &ChatModelConfig{Model: "test model"}}
ncm, err := cm.WithTools([]*schema.ToolInfo{{Name: "test tool name"}})
assert.Nil(t, err)
assert.Equal(t, "test model", ncm.(*ChatModel).conf.Model)
assert.Equal(t, "test tool name", ncm.(*ChatModel).rawTools[0].Name)
func TestIsCallbacksEnabled(t *testing.T) {
cm := &ChatModel{}
assert.True(t, cm.IsCallbacksEnabled())
}

func TestConcatTextParts(t *testing.T) {

Check failure on line 266 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

other declaration of TestConcatTextParts

Check failure on line 266 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

other declaration of TestConcatTextParts
// all text parts
result, err := concatTextParts([]schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: "hello"},
{Type: schema.ChatMessagePartTypeText, Text: "world"},
}, func(p schema.MessageInputPart) (schema.ChatMessagePartType, string) {
return p.Type, p.Text
})
assert.NoError(t, err)
assert.Equal(t, "hello\n\nworld", result)

// unsupported type
_, err = concatTextParts([]schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: "hello"},
{Type: schema.ChatMessagePartTypeImageURL, Text: "url"},
}, func(p schema.MessageInputPart) (schema.ChatMessagePartType, string) {
return p.Type, p.Text
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "does not support")
}

func TestToDeepSeekMessage(t *testing.T) {

Check failure on line 288 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

other declaration of TestToDeepSeekMessage

Check failure on line 288 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

other declaration of TestToDeepSeekMessage
t.Run("multi content not supported", func(t *testing.T) {
_, err := toDeepSeekMessage(&schema.Message{
MultiContent: []schema.ChatMessagePart{{}},
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "multi content is not supported")
})

t.Run("user input multi content text only", func(t *testing.T) {
msg, err := toDeepSeekMessage(&schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: "part1"},
{Type: schema.ChatMessagePartTypeText, Text: "part2"},
},
})
assert.NoError(t, err)
assert.Equal(t, "part1\n\npart2", msg.Content)
assert.Equal(t, "user", msg.Role)
})

t.Run("user input multi content unsupported type", func(t *testing.T) {
_, err := toDeepSeekMessage(&schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeImageURL, Text: "url"},
},
})
assert.Error(t, err)
})

t.Run("assistant gen multi content", func(t *testing.T) {
msg, err := toDeepSeekMessage(&schema.Message{
Role: schema.Assistant,
AssistantGenMultiContent: []schema.MessageOutputPart{
{Type: schema.ChatMessagePartTypeText, Text: "gen1"},
{Type: schema.ChatMessagePartTypeText, Text: "gen2"},
},
})
assert.NoError(t, err)
assert.Equal(t, "gen1\n\ngen2", msg.Content)
})

t.Run("assistant gen multi content unsupported type", func(t *testing.T) {
_, err := toDeepSeekMessage(&schema.Message{
Role: schema.Assistant,
AssistantGenMultiContent: []schema.MessageOutputPart{
{Type: schema.ChatMessagePartTypeImageURL, Text: "url"},
},
})
assert.Error(t, err)
})

t.Run("unknown role", func(t *testing.T) {
_, err := toDeepSeekMessage(&schema.Message{Role: schema.RoleType("unknown")})
assert.Error(t, err)
assert.Contains(t, err.Error(), "unknown role type")
})

t.Run("prefix on non-assistant", func(t *testing.T) {
m := schema.UserMessage("hi")
SetPrefix(m)
_, err := toDeepSeekMessage(m)
assert.Error(t, err)
assert.Contains(t, err.Error(), "prefix only supported for assistant")
})

t.Run("reasoning content from extra", func(t *testing.T) {
m := schema.AssistantMessage("hi", nil)
SetReasoningContent(m, "reasoning from extra")
msg, err := toDeepSeekMessage(m)
assert.NoError(t, err)
assert.Equal(t, "reasoning from extra", msg.ReasoningContent)
})

t.Run("tool message with tool call id", func(t *testing.T) {
m := &schema.Message{Role: schema.Tool, ToolCallID: "call-123", Content: "result"}
msg, err := toDeepSeekMessage(m)
assert.NoError(t, err)
assert.Equal(t, "call-123", msg.ToolCallID)
})

t.Run("assistant with tool calls", func(t *testing.T) {
idx := 5
m := &schema.Message{
Role: schema.Assistant,
ToolCalls: []schema.ToolCall{
{Index: &idx, ID: "tc-1", Type: "function", Function: schema.FunctionCall{Name: "fn", Arguments: "{}"}},
},
}
msg, err := toDeepSeekMessage(m)
assert.NoError(t, err)
assert.Len(t, msg.ToolCalls, 1)
assert.Equal(t, "tc-1", msg.ToolCalls[0].ID)
})

t.Run("all role types", func(t *testing.T) {
for _, tc := range []struct {
role schema.RoleType
expected string
}{
{schema.System, "system"},
{schema.User, "user"},
{schema.Assistant, "assistant"},
{schema.Tool, "tool"},
} {
msg, err := toDeepSeekMessage(&schema.Message{Role: tc.role, Content: "hi"})
assert.NoError(t, err)
assert.Equal(t, tc.expected, msg.Role)
}
})
}

func TestToMessageRole(t *testing.T) {

Check failure on line 402 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

other declaration of TestToMessageRole

Check failure on line 402 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

other declaration of TestToMessageRole
assert.Equal(t, schema.User, toMessageRole("user"))
assert.Equal(t, schema.Assistant, toMessageRole("assistant"))
assert.Equal(t, schema.System, toMessageRole("system"))
assert.Equal(t, schema.Tool, toMessageRole("tool"))
assert.Equal(t, schema.RoleType("custom"), toMessageRole("custom"))
}

func TestExtractLogProbs(t *testing.T) {

Check failure on line 410 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

other declaration of TestExtractLogProbs

Check failure on line 410 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

other declaration of TestExtractLogProbs
// non-map input
_, err := extractLogProbs("not a map")
assert.Error(t, err)

// valid map input
lp, err := extractLogProbs(map[string]any{
"content": []any{
map[string]any{
"token": "hello",
"logprob": 0.9,
},
},
})
assert.NoError(t, err)
assert.NotNil(t, lp)
}

func TestToLogProbsNil(t *testing.T) {
assert.Nil(t, toLogProbs(nil))
}

func TestDereferenceOrZero(t *testing.T) {
v := 42
assert.Equal(t, 42, dereferenceOrZero(&v))
assert.Equal(t, 0, dereferenceOrZero[int](nil))
}

func TestToEinoTokenUsageNil(t *testing.T) {
assert.Nil(t, toEinoTokenUsage(nil))
}

func TestToCallbackUsageNil(t *testing.T) {
assert.Nil(t, toCallbackUsage(nil))
}

func TestToModelCallbackUsageNil(t *testing.T) {
assert.Nil(t, toModelCallbackUsage(nil))
assert.Nil(t, toModelCallbackUsage(&schema.ResponseMeta{}))
}

func TestNewChatModelOptions(t *testing.T) {
t.Run("missing model", func(t *testing.T) {
_, err := NewChatModel(context.Background(), &ChatModelConfig{})
assert.Error(t, err)
})

t.Run("with base url no trailing slash", func(t *testing.T) {
cm, err := NewChatModel(context.Background(), &ChatModelConfig{
APIKey: "key",
Model: "model",
BaseURL: "https://example.com/api",
})
assert.NoError(t, err)
assert.NotNil(t, cm)
})

t.Run("with base url trailing slash", func(t *testing.T) {
cm, err := NewChatModel(context.Background(), &ChatModelConfig{
APIKey: "key",
Model: "model",
BaseURL: "https://example.com/api/",
})
assert.NoError(t, err)
assert.NotNil(t, cm)
})

t.Run("with path", func(t *testing.T) {
cm, err := NewChatModel(context.Background(), &ChatModelConfig{
APIKey: "key",
Model: "model",
Path: "/v1/chat",
})
assert.NoError(t, err)
assert.NotNil(t, cm)
})
}

func TestWithToolsEmpty(t *testing.T) {
cm := &ChatModel{conf: &ChatModelConfig{Model: "test"}}
_, err := cm.WithTools(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no tools to bind")
}

func TestBindToolsEmpty(t *testing.T) {
cm := &ChatModel{conf: &ChatModelConfig{Model: "test"}}
err := cm.BindTools(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no tools to bind")
}

func TestBindForcedToolsEmpty(t *testing.T) {
cm := &ChatModel{conf: &ChatModelConfig{Model: "test"}}
err := cm.BindForcedTools(nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no tools to bind")
}

func TestStreamToEinoTokenUsageNil(t *testing.T) {
assert.Nil(t, streamToEinoTokenUsage(nil))
assert.Nil(t, streamToEinoTokenUsage(&deepseek.StreamUsage{}))
}

func TestLogProbs(t *testing.T) {
Expand Down Expand Up @@ -491,7 +736,7 @@
})
}

func TestToDeepSeekMessage(t *testing.T) {

Check failure on line 739 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

TestToDeepSeekMessage redeclared in this block

Check failure on line 739 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

TestToDeepSeekMessage redeclared in this block
t.Run("role mapping", func(t *testing.T) {
cases := []struct {
role schema.RoleType
Expand Down Expand Up @@ -688,7 +933,7 @@
})
}

func TestConcatTextParts(t *testing.T) {

Check failure on line 936 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

TestConcatTextParts redeclared in this block

Check failure on line 936 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

TestConcatTextParts redeclared in this block
t.Run("text parts joined", func(t *testing.T) {
parts := []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: "a"},
Expand All @@ -713,7 +958,7 @@
})
}

func TestExtractLogProbs(t *testing.T) {

Check failure on line 961 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

TestExtractLogProbs redeclared in this block

Check failure on line 961 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

TestExtractLogProbs redeclared in this block
t.Run("nil returns nil", func(t *testing.T) {
result, err := extractLogProbs(nil)
assert.Nil(t, err)
Expand Down Expand Up @@ -868,7 +1113,7 @@
})
}

func TestToMessageRole(t *testing.T) {

Check failure on line 1116 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-test

TestToMessageRole redeclared in this block

Check failure on line 1116 in components/model/deepseek/deepseek_test.go

View workflow job for this annotation

GitHub Actions / unit-benchmark-test

TestToMessageRole redeclared in this block
assert.Equal(t, schema.User, toMessageRole("user"))
assert.Equal(t, schema.Assistant, toMessageRole("assistant"))
assert.Equal(t, schema.System, toMessageRole("system"))
Expand Down
33 changes: 22 additions & 11 deletions components/model/ollama/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"strings"
"time"

"github.com/eino-contrib/jsonschema"
"github.com/eino-contrib/ollama/api"

"github.com/cloudwego/eino/callbacks"
Expand Down Expand Up @@ -493,6 +494,26 @@ func parseJSONToObject(jsonStr string) (map[string]any, error) {
return result, err
}

func schemaToToolProperty(s *jsonschema.Schema) api.ToolProperty {
var tp api.ToolProperty
if s.TypeEnhanced != nil {
tp.Type = s.TypeEnhanced
} else if s.Type != "" {
tp.Type = api.PropertyType{s.Type}
}
tp.Description = s.Description
tp.Enum = s.Enum
if len(s.AnyOf) > 0 {
for _, ao := range s.AnyOf {
tp.AnyOf = append(tp.AnyOf, schemaToToolProperty(ao))
}
}
if s.Items != nil {
tp.Items = schemaToToolProperty(s.Items)
}
return tp
}

func toOllamaTools(einoTools []*schema.ToolInfo) ([]api.Tool, error) {
var ollamaTools []api.Tool
for _, einoTool := range einoTools {
Expand All @@ -508,17 +529,7 @@ func toOllamaTools(einoTools []*schema.ToolInfo) ([]api.Tool, error) {
required = openTool.Required

for pair := openTool.Properties.Oldest(); pair != nil; pair = pair.Next() {
var typ []string
if pair.Value.TypeEnhanced != nil {
typ = pair.Value.TypeEnhanced
} else {
typ = []string{pair.Value.Type}
}
properties[pair.Key] = api.ToolProperty{
Type: typ,
Description: pair.Value.Description,
Enum: pair.Value.Enum,
}
properties[pair.Key] = schemaToToolProperty(pair.Value)
}
}

Expand Down
Loading