Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ func (c *Client) Run(ctx context.Context, model string, args []string, opts RunO
if err := schema.ValidateEnums(reqSchema, payload); err != nil {
return nil, err
}
if err := schema.ValidateNumericConstraints(reqSchema, payload); err != nil {
return nil, err
}
}

// Resolve delivery method: payload value > opts override > schema default.
Expand Down
103 changes: 103 additions & 0 deletions internal/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"math"
"slices"
"sort"
"strconv"
Expand Down Expand Up @@ -71,6 +72,12 @@ type Node struct {
// Items holds the schema for elements of an array-typed property.
// Used to coerce values when dot-notation paths descend into arrays.
Items *Node `json:"items"`
// Minimum, Maximum, and MultipleOf carry numeric value constraints. Pointers
// so an absent keyword is distinguishable from a real zero bound (e.g. a
// minimum of 0 is meaningful and must not be treated as "no minimum").
Minimum *float64 `json:"minimum"`
Maximum *float64 `json:"maximum"`
MultipleOf *float64 `json:"multipleOf"`
// AllOf, OneOf, and DependentRequired support structural constraints used by
// model schemas to express mutually-exclusive option sets (e.g. dimension
// combinations) and co-dependent fields.
Expand Down Expand Up @@ -780,3 +787,99 @@ func ResolveDeliveryMethod(flagVal string, payload map[string]any, node Node) st
}
return ""
}

// ValidateNumericConstraints checks every numeric value in payload against the
// minimum, maximum, and multipleOf bounds declared on its schema property. It
// recurses into nested objects and array items, mirroring ValidateEnums.
func ValidateNumericConstraints(node Node, payload map[string]any) error {
return validateNumericInObject(node, payload, "")
}

func validateNumericInObject(node Node, obj map[string]any, prefix string) error {
keys := make([]string, 0, len(obj))
for key := range obj {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
prop, ok := node.Properties[key]
if !ok {
continue
}
path := key
if prefix != "" {
path = prefix + "." + key
}
if err := validateNumericInValue(prop, obj[key], path); err != nil {
return err
}
}
return nil
}

func validateNumericInValue(prop Node, val any, path string) error {
switch v := val.(type) {
case map[string]any:
return validateNumericInObject(prop, v, path)
case []any:
if prop.Items != nil {
for i, item := range v {
if err := validateNumericInValue(*prop.Items, item, fmt.Sprintf("%s.%d", path, i)); err != nil {
return err
}
}
}
default:
if f, ok := asFloat(v); ok {
return checkNumericBounds(prop, f, path)
}
}
return nil
}

// checkNumericBounds reports the first violated bound on a single numeric value.
func checkNumericBounds(prop Node, val float64, path string) error {
if prop.Minimum != nil && val < *prop.Minimum {
Comment thread
danmrichards marked this conversation as resolved.
return fmt.Errorf("invalid value for %q: %s is below the minimum of %s", path, formatNumber(val), formatNumber(*prop.Minimum))
}
if prop.Maximum != nil && val > *prop.Maximum {
return fmt.Errorf("invalid value for %q: %s is above the maximum of %s", path, formatNumber(val), formatNumber(*prop.Maximum))
}
if prop.MultipleOf != nil && *prop.MultipleOf > 0 {
ratio := val / *prop.MultipleOf
// Tolerance absorbs float representation error (e.g. 0.29/0.01) while
// still catching genuine off-grid values.
if math.Abs(ratio-math.Round(ratio)) > 1e-9 {
return fmt.Errorf("invalid value for %q: %s must be a multiple of %s", path, formatNumber(val), formatNumber(*prop.MultipleOf))
}
}
return nil
}

// asFloat extracts a float64 from the Go types coerceValue and JSON decoding
// produce for numbers: int64 for integers, float64 for numbers. Non-numeric
// values return ok=false and are skipped.
func asFloat(v any) (float64, bool) {
switch n := v.(type) {
case float64:
return n, true
case int64:
return float64(n), true
case int:
return float64(n), true
case json.Number:
f, err := n.Float64()
return f, err == nil
}
return 0, false
}

// formatNumber renders a bound for an error message: integral values print
// without a decimal point (512, -4), fractional values use their shortest
// representation (0.01).
func formatNumber(f float64) string {
if f == math.Trunc(f) && math.Abs(f) < 1e15 {
return strconv.FormatInt(int64(f), 10)
}
return strconv.FormatFloat(f, 'g', -1, 64)
}
143 changes: 143 additions & 0 deletions internal/schema/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1244,3 +1244,146 @@ func TestNode_UnmarshalJSON_StringTypeNull(t *testing.T) {
t.Errorf("string null should normalise to empty, got %q", node.Type)
}
}

// ---- ValidateNumericConstraints tests ----

const (
testFieldCFGScale = "CFGScale"
testFieldWeight = "weight"
testFieldLora = "lora"
testFieldAccel = "acceleratorOptions"
testFieldCachePct = "cacheEndStepPercentage"
testTypeNumber = "number"
)

func f64(v float64) *float64 { return &v }

// boundedWidthNode mirrors a real image model's width: [512, 2048], step 16.
func boundedWidthNode() schema.Node {
return schema.Node{
Properties: map[string]schema.Node{
testFieldWidth: {Type: schema.TypeInteger, Minimum: f64(512), Maximum: f64(2048), MultipleOf: f64(16)},
},
}
}

func TestValidateNumericConstraints_BelowMinimum(t *testing.T) {
err := schema.ValidateNumericConstraints(boundedWidthNode(), map[string]any{testFieldWidth: int64(-1)})
if err == nil {
t.Fatal("expected error for width below minimum")
}
if !containsString(err.Error(), "minimum") || !containsString(err.Error(), "512") {
t.Errorf("error should report the minimum bound; got: %v", err)
}
}

func TestValidateNumericConstraints_AboveMaximum(t *testing.T) {
err := schema.ValidateNumericConstraints(boundedWidthNode(), map[string]any{testFieldWidth: int64(4096)})
if err == nil {
t.Fatal("expected error for width above maximum")
}
if !containsString(err.Error(), "maximum") {
t.Errorf("error should report the maximum bound; got: %v", err)
}
}

func TestValidateNumericConstraints_NotMultipleOf(t *testing.T) {
err := schema.ValidateNumericConstraints(boundedWidthNode(), map[string]any{testFieldWidth: int64(1000)})
if err == nil {
t.Fatal("expected error for width not a multiple of 16")
}
if !containsString(err.Error(), "multiple") {
t.Errorf("error should report the multiple-of bound; got: %v", err)
}
}

func TestValidateNumericConstraints_ValidValue(t *testing.T) {
if err := schema.ValidateNumericConstraints(boundedWidthNode(), map[string]any{testFieldWidth: int64(1024)}); err != nil {
t.Errorf("1024 satisfies [512,2048] step 16; got: %v", err)
}
}

func TestValidateNumericConstraints_NoConstraint_Passes(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldWidth: {Type: schema.TypeInteger},
},
}
if err := schema.ValidateNumericConstraints(node, map[string]any{testFieldWidth: int64(-1)}); err != nil {
t.Errorf("unconstrained integer should pass, got: %v", err)
}
}

func TestValidateNumericConstraints_NegativeMinimumAllowed(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldWeight: {Type: testTypeNumber, Minimum: f64(-4), Maximum: f64(4), MultipleOf: f64(0.01)},
},
}
if err := schema.ValidateNumericConstraints(node, map[string]any{testFieldWeight: -3.5}); err != nil {
t.Errorf("a value inside a negative-minimum range should pass, got: %v", err)
}
}

func TestValidateNumericConstraints_NestedObject(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldAccel: {
Type: schema.TypeObject,
Properties: map[string]schema.Node{
testFieldCachePct: {Type: schema.TypeInteger, Minimum: f64(1), Maximum: f64(100)},
},
},
},
}
payload := map[string]any{testFieldAccel: map[string]any{testFieldCachePct: int64(150)}}
err := schema.ValidateNumericConstraints(node, payload)
if err == nil {
t.Fatal("expected error for nested value above maximum")
}
if !containsString(err.Error(), testFieldAccel+"."+testFieldCachePct) {
t.Errorf("error should include the nested path; got: %v", err)
}
}

func TestValidateNumericConstraints_ArrayItems(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldLora: {
Type: schema.TypeArray,
Items: &schema.Node{
Type: schema.TypeObject,
Properties: map[string]schema.Node{
testFieldWeight: {Type: testTypeNumber, Minimum: f64(-4), Maximum: f64(4)},
},
},
},
},
}
payload := map[string]any{testFieldLora: []any{map[string]any{testFieldWeight: -10.0}}}
err := schema.ValidateNumericConstraints(node, payload)
if err == nil {
t.Fatal("expected error for array-item value below minimum")
}
if !containsString(err.Error(), testFieldLora+".0."+testFieldWeight) {
t.Errorf("error should include the array-item path; got: %v", err)
}
}

func TestValidateNumericConstraints_FloatMultipleOf(t *testing.T) {
node := schema.Node{
Properties: map[string]schema.Node{
testFieldCFGScale: {Type: testTypeNumber, Minimum: f64(1), Maximum: f64(20), MultipleOf: f64(0.01)},
},
}
if err := schema.ValidateNumericConstraints(node, map[string]any{testFieldCFGScale: 4.5}); err != nil {
t.Errorf("4.5 is a valid multiple of 0.01, got: %v", err)
}
err := schema.ValidateNumericConstraints(node, map[string]any{testFieldCFGScale: 4.567})
if err == nil {
t.Fatal("expected error for 4.567, not a multiple of 0.01")
}
if !containsString(err.Error(), "multiple") {
t.Errorf("error should report the multiple-of bound; got: %v", err)
}
}
Loading