Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ func (c *Client) Run(ctx context.Context, model string, args []string, opts RunO
if err := schema.ValidateAllOf(reqSchema, payload); err != nil {
return nil, err
}
if dm := schema.ResolveDeliveryMethod(opts.DeliveryMethod, payload, reqSchema); dm != "" {
payload[fieldDeliveryMethod] = dm
}
if err := schema.ValidateEnums(reqSchema, payload); err != nil {
return nil, err
}
Comment thread
danmrichards marked this conversation as resolved.
}

// Resolve delivery method: payload value > opts override > schema default.
Expand All @@ -232,10 +238,7 @@ func (c *Client) Run(ctx context.Context, model string, args []string, opts RunO
opts.OnSubmit(taskUUID)
}
interval := opts.PollInterval
minResults := extractInt(payload, "numberResults")
if minResults < 1 {
minResults = 1
}
minResults := max(extractInt(payload, "numberResults"), 1)
return c.Poll(ctx, taskUUID, interval, minResults, opts.OnProgress)
}

Expand Down
111 changes: 107 additions & 4 deletions internal/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"slices"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -172,10 +173,8 @@ func ParseKV(arg string, node Node) (path []string, value any, err error) {
}

segments := strings.Split(k, ".")
for _, seg := range segments {
if seg == "" {
return nil, nil, fmt.Errorf("key contains empty segment (got %q)", arg)
}
if slices.Contains(segments, "") {
return nil, nil, fmt.Errorf("key contains empty segment (got %q)", arg)
}

segments = normalisePathSegments(segments, node)
Expand Down Expand Up @@ -438,6 +437,110 @@ func ValidateAllOf(node Node, payload map[string]any) error {
return nil
}

// ValidateEnums checks that every string value in payload that maps to a
// schema property with a string-enum constraint (oneOf consts or direct enum)
// is one of the allowed values. It recurses into nested objects and arrays.
func ValidateEnums(node Node, payload map[string]any) error {
return validateEnumsInObject(node, payload, "")
}

func validateEnumsInObject(node Node, obj map[string]any, prefix string) error {
keys := make([]string, 0, len(obj))
for key := range obj {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
prop, ok := node.Properties[key]
if !ok {
continue
}
path := key
if prefix != "" {
path = prefix + "." + key
}
if err := validateEnumsInValue(prop, obj[key], path); err != nil {
return err
}
}
return nil
}

func validateEnumsInValue(prop Node, val any, path string) error {
switch v := val.(type) {
case string:
if allowed := stringEnumValues(prop); len(allowed) > 0 {
if !slices.Contains(allowed, v) {
return fmt.Errorf("invalid value for %q: must be one of: %s", path, strings.Join(allowed, ", "))
}
}
case map[string]any:
if err := validateEnumsInObject(prop, v, path); err != nil {
return err
}
case []any:
if prop.Items != nil {
for i, item := range v {
if err := validateEnumsInValue(*prop.Items, item, fmt.Sprintf("%s.%d", path, i)); err != nil {
return err
}
}
}
}
return nil
}

// stringEnumValues returns the allowed string values for a string-typed schema
// node. It collects from three sources in priority order:
//
// 1. A bare top-level const (single allowed value, e.g. const: "async")
// 2. String const values from oneOf branches
// 3. Direct enum entries
//
// Returns nil when the node has no string enum constraint.
func stringEnumValues(prop Node) []string {
if len(prop.Const) > 0 {
var s string
if err := json.Unmarshal(prop.Const, &s); err == nil {
return []string{s}
}
}

var vals []string
for i := range prop.OneOf {
if len(prop.OneOf[i].Const) == 0 {
continue
}
var s string
if err := json.Unmarshal(prop.OneOf[i].Const, &s); err == nil {
vals = append(vals, s)
}
}
if len(vals) == 0 {
for _, raw := range prop.Enum {
var s string
if err := json.Unmarshal(raw, &s); err == nil {
vals = append(vals, s)
}
}
}
if len(vals) == 0 {
return nil
}

// De-dupe while preserving schema order.
seen := make(map[string]struct{}, len(vals))
out := make([]string, 0, len(vals))
for _, s := range vals {
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
out = append(out, s)
}
return out
}
Comment thread
danmrichards marked this conversation as resolved.

func checkDependentRequired(depReq map[string][]string, payload map[string]any) error {
triggers := make([]string, 0, len(depReq))
for k := range depReq {
Expand Down
185 changes: 183 additions & 2 deletions internal/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ const (
testFieldText = "text"
testFieldRole = "role"
testFieldContent = "content"
testFieldSettings = "settings"
testFieldStyle = "style"
testFieldThinking = "thinking"
testValHello = "Hello"
testValUser = "user"
testMsg0Role = "messages.0.role"
Expand Down Expand Up @@ -376,7 +379,7 @@ func TestParseKV_DotNotation_ArrayAutoIndex(t *testing.T) {
func TestParseKV_DotNotation_NestedObjectCoercesType(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
"settings": {
testFieldSettings: {
Type: schema.TypeObject,
Properties: map[string]schema.Node{
"maxTokens": {Type: schema.TypeInteger},
Expand All @@ -388,14 +391,192 @@ func TestParseKV_DotNotation_NestedObjectCoercesType(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(path) != 2 || path[0] != "settings" || path[1] != "maxTokens" {
if len(path) != 2 || path[0] != testFieldSettings || path[1] != "maxTokens" {
t.Errorf("path: want [settings maxTokens], got %v", path)
}
if v != int64(512) {
t.Errorf("value: want int64(512), got %v (%T)", v, v)
}
}

func TestParseKV_StringNoEnum_AnyValueAccepted(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldPositivePrompt: {Type: schema.TypeString},
},
}
_, v, err := schema.ParseKV("positivePrompt=any value goes", node)
if err != nil {
t.Fatalf("unexpected error for string field with no enum: %v", err)
}
if v != "any value goes" {
t.Errorf("want 'any value goes', got %v", v)
}
}

func TestParseKV_StringEnum_PassthroughWithoutValidate(t *testing.T) {
// ParseKV does not validate enum constraints — that is deferred to
// ValidateEnums (--validate path). An invalid value is coerced as-is.
node := schema.Node{
Properties: map[string]schema.Node{
testFieldThinking: {
Type: schema.TypeString,
OneOf: []schema.Node{
{Const: json.RawMessage(`"enabled"`)},
{Const: json.RawMessage(`"disabled"`)},
},
},
},
}
_, v, err := schema.ParseKV(testFieldThinking+"=true", node)
if err != nil {
t.Fatalf("ParseKV must not validate enums: %v", err)
}
if v != "true" {
t.Errorf("want string 'true' passed through, got %v", v)
}
}

// ---- ValidateEnums tests ----

func thinkingNode() schema.Node {
return schema.Node{
Properties: map[string]schema.Node{
testFieldSettings: {
Type: schema.TypeObject,
Properties: map[string]schema.Node{
testFieldThinking: {
Type: schema.TypeString,
OneOf: []schema.Node{
{Const: json.RawMessage(`"enabled"`)},
{Const: json.RawMessage(`"disabled"`)},
{Const: json.RawMessage(`"auto"`)},
},
},
},
},
},
}
}

func TestValidateEnums_ValidNestedOneOf(t *testing.T) {
payload := map[string]any{
testFieldSettings: map[string]any{testFieldThinking: "enabled"},
}
if err := schema.ValidateEnums(thinkingNode(), payload); err != nil {
t.Errorf("unexpected error for valid enum value: %v", err)
}
}

func TestValidateEnums_InvalidNestedOneOf(t *testing.T) {
payload := map[string]any{
testFieldSettings: map[string]any{testFieldThinking: "true"},
}
err := schema.ValidateEnums(thinkingNode(), payload)
if err == nil {
t.Fatal("expected error for invalid enum value 'true'")
}
for _, want := range []string{"enabled", "disabled", "auto"} {
if !containsString(err.Error(), want) {
t.Errorf("error should list allowed value %q; got: %v", want, err)
}
}
}

func TestValidateEnums_ValidDirectEnum(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldStyle: {
Type: schema.TypeString,
Enum: []json.RawMessage{
json.RawMessage(`"anime"`),
json.RawMessage(`"cyberpunk"`),
},
},
},
}
payload := map[string]any{testFieldStyle: "anime"}
if err := schema.ValidateEnums(node, payload); err != nil {
t.Errorf("unexpected error: %v", err)
}
}

func TestValidateEnums_InvalidDirectEnum(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldStyle: {
Type: schema.TypeString,
Enum: []json.RawMessage{
json.RawMessage(`"anime"`),
json.RawMessage(`"cyberpunk"`),
},
},
},
}
payload := map[string]any{testFieldStyle: "realistic"}
err := schema.ValidateEnums(node, payload)
if err == nil {
t.Fatal("expected error for value not in enum")
}
if !containsString(err.Error(), "anime") || !containsString(err.Error(), "cyberpunk") {
t.Errorf("error should list allowed values; got: %v", err)
}
}

func TestValidateEnums_NoConstraint_Passes(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldPositivePrompt: {Type: schema.TypeString},
},
}
if err := schema.ValidateEnums(node, map[string]any{testFieldPositivePrompt: "any text"}); err != nil {
t.Errorf("unexpected error for unconstrained string: %v", err)
}
}

func TestValidateEnums_BareConst_ValidValue(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldDeliveryMethod: {
Type: schema.TypeString,
Const: json.RawMessage(`"async"`),
},
},
}
if err := schema.ValidateEnums(node, map[string]any{testFieldDeliveryMethod: "async"}); err != nil {
t.Errorf("unexpected error for matching const: %v", err)
}
}

func TestValidateEnums_BareConst_InvalidValue(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldDeliveryMethod: {
Type: schema.TypeString,
Const: json.RawMessage(`"async"`),
},
},
}
err := schema.ValidateEnums(node, map[string]any{testFieldDeliveryMethod: "sync"})
if err == nil {
t.Fatal("expected error for value not matching const")
}
if !containsString(err.Error(), "async") {
t.Errorf("error should show allowed const value; got: %v", err)
}
}

func TestValidateEnums_UnknownKeyIgnored(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
"known": {Type: schema.TypeString},
},
}
if err := schema.ValidateEnums(node, map[string]any{"unknown": "whatever"}); err != nil {
t.Errorf("unexpected error for key not in schema: %v", err)
}
}

func TestParseKV_EmptySegment(t *testing.T) {
_, _, err := schema.ParseKV("speech..text=Hello", schema.Node{})
if err == nil {
Expand Down
Loading