Skip to content

Commit 67abcb2

Browse files
committed
Add x-cog-accept MIME type annotation for Path/File input fields
Add an 'accept' parameter to Input() that specifies allowed MIME types or file extensions for Path/File inputs. The Go static schema generator extracts this and emits it as 'x-cog-accept' in the OpenAPI schema, giving schema consumers (UIs, validators, API clients) visibility into what file types an input expects. Usage: Input(accept="image/*"), Input(accept="audio/wav,audio/mp3"), or Input(accept=".safetensors,.bin"). Using accept on non-Path/File types is a hard build error.
1 parent 84f2c89 commit 67abcb2

9 files changed

Lines changed: 295 additions & 28 deletions

File tree

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Test that the accept parameter on Path/File inputs produces the
2+
# x-cog-accept annotation in the generated OpenAPI schema.
3+
#
4+
# Verifies:
5+
# - accept="image/*" on a Path input emits x-cog-accept in the schema
6+
# - accept with multiple MIME types works
7+
# - accept with file extensions works
8+
# - Fields without accept do not have x-cog-accept
9+
# - Prediction still works end-to-end
10+
11+
cog build -t $TEST_IMAGE
12+
13+
# Extract the schema from the image label
14+
exec docker inspect $TEST_IMAGE --format '{{index .Config.Labels "run.cog.openapi_schema"}}'
15+
16+
# x-cog-accept annotations are present
17+
stdout '"x-cog-accept":"image/\*"'
18+
stdout '"x-cog-accept":"audio/wav,audio/mp3"'
19+
stdout '"x-cog-accept":".safetensors,.bin"'
20+
21+
# The prompt field (str) should NOT have x-cog-accept
22+
# (we check the schema has prompt but confirm no extra x-cog-accept entries)
23+
stdout '"prompt":'
24+
25+
# Path fields still have uri format
26+
stdout '"format":"uri"'
27+
28+
# Prediction works end-to-end
29+
cog predict $TEST_IMAGE -i prompt=hello -i image=@test.png
30+
stdout 'hello-png'
31+
32+
-- cog.yaml --
33+
build:
34+
python_version: "3.12"
35+
predict: "predict.py:Predictor"
36+
37+
-- predict.py --
38+
from typing import Optional
39+
40+
from cog import BasePredictor, Input, Path
41+
42+
43+
class Predictor(BasePredictor):
44+
def predict(
45+
self,
46+
prompt: str = Input(description="Text prompt", default="test"),
47+
image: Path = Input(description="Input image", accept="image/*"),
48+
audio: Optional[Path] = Input(description="Audio clip", accept="audio/wav,audio/mp3", default=None),
49+
weights: Optional[Path] = Input(description="Model weights", accept=".safetensors,.bin", default=None),
50+
) -> str:
51+
ext = str(image).split(".")[-1]
52+
return f"{prompt}-{ext}"
53+
54+
-- test.png --
55+
fake image content
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Test that using accept on a non-Path/File input type causes a build error.
2+
#
3+
# The accept parameter is only valid on Path or File inputs. Using it on
4+
# str, int, float, etc. should produce a clear error at build time.
5+
6+
! cog build -t $TEST_IMAGE
7+
stderr 'accept is only valid on Path or File inputs'
8+
9+
-- cog.yaml --
10+
build:
11+
python_version: "3.12"
12+
predict: "predict.py:Predictor"
13+
14+
-- predict.py --
15+
from cog import BasePredictor, Input
16+
17+
18+
class Predictor(BasePredictor):
19+
def predict(
20+
self,
21+
name: str = Input(description="User name", accept="text/plain"),
22+
) -> str:
23+
return f"hello {name}"

pkg/schema/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const (
2828
ErrChoicesNotResolvable
2929
ErrDefaultNotResolvable
3030
ErrUnresolvableType
31+
ErrAcceptOnNonFileType
3132
ErrOther
3233
)
3334

pkg/schema/openapi.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ func buildInputSchema(info *PredictorInfo) (map[string]any, []enumSchema) {
365365
prop.Set("deprecated", true)
366366
}
367367

368+
// MIME type constraint for Path/File inputs
369+
if field.Accept != nil {
370+
prop.Set("x-cog-accept", *field.Accept)
371+
}
372+
368373
properties.Set(name, prop)
369374
})
370375

pkg/schema/openapi_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,80 @@ func TestMultipleInputTypes(t *testing.T) {
682682
assert.NotContains(t, required, "secret_key")
683683
}
684684

685+
// ---------------------------------------------------------------------------
686+
// Tests: Accept (MIME type) annotation
687+
// ---------------------------------------------------------------------------
688+
689+
func TestAcceptAnnotation(t *testing.T) {
690+
accept := "image/*"
691+
inputs := NewOrderedMap[string, InputField]()
692+
inputs.Set("image", InputField{
693+
Name: "image",
694+
Order: 0,
695+
FieldType: FieldType{Primitive: TypePath, Repetition: Required},
696+
Accept: &accept,
697+
})
698+
699+
info := &PredictorInfo{
700+
Inputs: inputs,
701+
Output: SchemaPrim(TypeString),
702+
Mode: ModePredict,
703+
}
704+
705+
spec := parseSpec(t, info)
706+
props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any)
707+
708+
imageField := props["image"].(map[string]any)
709+
assert.Equal(t, "string", imageField["type"])
710+
assert.Equal(t, "uri", imageField["format"])
711+
assert.Equal(t, "image/*", imageField["x-cog-accept"])
712+
}
713+
714+
func TestAcceptAnnotationMultipleMimeTypes(t *testing.T) {
715+
accept := "audio/wav,audio/mp3,audio/flac"
716+
inputs := NewOrderedMap[string, InputField]()
717+
inputs.Set("audio", InputField{
718+
Name: "audio",
719+
Order: 0,
720+
FieldType: FieldType{Primitive: TypePath, Repetition: Required},
721+
Accept: &accept,
722+
})
723+
724+
info := &PredictorInfo{
725+
Inputs: inputs,
726+
Output: SchemaPrim(TypeString),
727+
Mode: ModePredict,
728+
}
729+
730+
spec := parseSpec(t, info)
731+
props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any)
732+
733+
audioField := props["audio"].(map[string]any)
734+
assert.Equal(t, "audio/wav,audio/mp3,audio/flac", audioField["x-cog-accept"])
735+
}
736+
737+
func TestAcceptAnnotationNotPresentWhenNil(t *testing.T) {
738+
inputs := NewOrderedMap[string, InputField]()
739+
inputs.Set("image", InputField{
740+
Name: "image",
741+
Order: 0,
742+
FieldType: FieldType{Primitive: TypePath, Repetition: Required},
743+
})
744+
745+
info := &PredictorInfo{
746+
Inputs: inputs,
747+
Output: SchemaPrim(TypeString),
748+
Mode: ModePredict,
749+
}
750+
751+
spec := parseSpec(t, info)
752+
props := getPath(spec, "components", "schemas", "Input", "properties").(map[string]any)
753+
754+
imageField := props["image"].(map[string]any)
755+
_, hasAccept := imageField["x-cog-accept"]
756+
assert.False(t, hasAccept, "x-cog-accept should not be present when Accept is nil")
757+
}
758+
685759
// ---------------------------------------------------------------------------
686760
// Tests: Edge cases
687761
// ---------------------------------------------------------------------------

pkg/schema/python/parser.go

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ type inputCallInfo struct {
731731
Regex *string
732732
Choices []schema.DefaultValue
733733
Deprecated *bool
734+
Accept *string
734735
}
735736

736737
type inputMethodInfo struct {
@@ -1148,38 +1149,20 @@ func parseTypedDefaultParameter(
11481149
if err != nil {
11491150
return schema.InputField{}, err
11501151
}
1151-
return schema.InputField{
1152-
Name: name,
1153-
Order: order,
1154-
FieldType: fieldType,
1155-
Default: info.Default,
1156-
Description: info.Description,
1157-
GE: info.GE,
1158-
LE: info.LE,
1159-
MinLength: info.MinLength,
1160-
MaxLength: info.MaxLength,
1161-
Regex: info.Regex,
1162-
Choices: info.Choices,
1163-
Deprecated: info.Deprecated,
1164-
}, nil
1152+
field, err := inputCallInfoToField(name, order, fieldType, info)
1153+
if err != nil {
1154+
return schema.InputField{}, err
1155+
}
1156+
return field, nil
11651157
}
11661158

11671159
// 2. Reference to Input() via class attribute or static method
11681160
if info, ok := resolveInputReference(valNode, source, registry); ok {
1169-
return schema.InputField{
1170-
Name: name,
1171-
Order: order,
1172-
FieldType: fieldType,
1173-
Default: info.Default,
1174-
Description: info.Description,
1175-
GE: info.GE,
1176-
LE: info.LE,
1177-
MinLength: info.MinLength,
1178-
MaxLength: info.MaxLength,
1179-
Regex: info.Regex,
1180-
Choices: info.Choices,
1181-
Deprecated: info.Deprecated,
1182-
}, nil
1161+
field, err := inputCallInfoToField(name, order, fieldType, info)
1162+
if err != nil {
1163+
return schema.InputField{}, err
1164+
}
1165+
return field, nil
11831166
}
11841167

11851168
// 3. Plain default — must be statically resolvable
@@ -1339,6 +1322,30 @@ func isInputCall(node *sitter.Node, source []byte, imports *schema.ImportContext
13391322
return false
13401323
}
13411324

1325+
// inputCallInfoToField converts parsed Input() kwargs into an InputField,
1326+
// validating that accept is only used on Path/File types.
1327+
func inputCallInfoToField(name string, order int, fieldType schema.FieldType, info inputCallInfo) (schema.InputField, error) {
1328+
if info.Accept != nil && fieldType.Primitive != schema.TypePath && fieldType.Primitive != schema.TypeFile {
1329+
return schema.InputField{}, schema.NewError(schema.ErrAcceptOnNonFileType,
1330+
fmt.Sprintf("accept is only valid on Path or File inputs (parameter '%s')", name))
1331+
}
1332+
return schema.InputField{
1333+
Name: name,
1334+
Order: order,
1335+
FieldType: fieldType,
1336+
Default: info.Default,
1337+
Description: info.Description,
1338+
GE: info.GE,
1339+
LE: info.LE,
1340+
MinLength: info.MinLength,
1341+
MaxLength: info.MaxLength,
1342+
Regex: info.Regex,
1343+
Choices: info.Choices,
1344+
Deprecated: info.Deprecated,
1345+
Accept: info.Accept,
1346+
}, nil
1347+
}
1348+
13421349
func parseInputCall(node *sitter.Node, source []byte, paramName string, scope moduleScope) (inputCallInfo, error) {
13431350
var info inputCallInfo
13441351

@@ -1406,6 +1413,10 @@ func parseInputCall(node *sitter.Node, source []byte, paramName string, scope mo
14061413
if b, ok := parseBoolLiteral(valNode, source); ok {
14071414
info.Deprecated = &b
14081415
}
1416+
case "accept":
1417+
if s, ok := parseStringLiteral(valNode, source); ok {
1418+
info.Accept = &s
1419+
}
14091420
}
14101421
}
14111422

pkg/schema/python/parser_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,97 @@ class Predictor(BasePredictor):
922922
require.True(t, *old.Deprecated)
923923
}
924924

925+
// ---------------------------------------------------------------------------
926+
// Accept MIME type
927+
// ---------------------------------------------------------------------------
928+
929+
func TestAcceptMimeType(t *testing.T) {
930+
source := `
931+
from cog import BasePredictor, Input, Path
932+
933+
class Predictor(BasePredictor):
934+
def predict(self, image: Path = Input(description="An image", accept="image/*")) -> str:
935+
pass
936+
`
937+
info := parse(t, source, "Predictor")
938+
image, ok := info.Inputs.Get("image")
939+
require.True(t, ok)
940+
require.NotNil(t, image.Accept)
941+
require.Equal(t, "image/*", *image.Accept)
942+
}
943+
944+
func TestAcceptMultipleMimeTypes(t *testing.T) {
945+
source := `
946+
from cog import BasePredictor, Input, Path
947+
948+
class Predictor(BasePredictor):
949+
def predict(self, audio: Path = Input(accept="audio/wav,audio/mp3")) -> str:
950+
pass
951+
`
952+
info := parse(t, source, "Predictor")
953+
audio, ok := info.Inputs.Get("audio")
954+
require.True(t, ok)
955+
require.NotNil(t, audio.Accept)
956+
require.Equal(t, "audio/wav,audio/mp3", *audio.Accept)
957+
}
958+
959+
func TestAcceptFileExtensions(t *testing.T) {
960+
source := `
961+
from cog import BasePredictor, Input, Path
962+
963+
class Predictor(BasePredictor):
964+
def predict(self, weights: Path = Input(accept=".safetensors,.bin")) -> str:
965+
pass
966+
`
967+
info := parse(t, source, "Predictor")
968+
weights, ok := info.Inputs.Get("weights")
969+
require.True(t, ok)
970+
require.NotNil(t, weights.Accept)
971+
require.Equal(t, ".safetensors,.bin", *weights.Accept)
972+
}
973+
974+
func TestAcceptOnFileType(t *testing.T) {
975+
source := `
976+
from cog import BasePredictor, Input, File
977+
978+
class Predictor(BasePredictor):
979+
def predict(self, f: File = Input(accept="image/png")) -> str:
980+
pass
981+
`
982+
info := parse(t, source, "Predictor")
983+
f, ok := info.Inputs.Get("f")
984+
require.True(t, ok)
985+
require.NotNil(t, f.Accept)
986+
require.Equal(t, "image/png", *f.Accept)
987+
}
988+
989+
func TestAcceptOnNonFileTypeErrors(t *testing.T) {
990+
source := `
991+
from cog import BasePredictor, Input
992+
993+
class Predictor(BasePredictor):
994+
def predict(self, name: str = Input(accept="image/*")) -> str:
995+
pass
996+
`
997+
se := parseErr(t, source, "Predictor", schema.ModePredict)
998+
require.Equal(t, schema.ErrAcceptOnNonFileType, se.Kind)
999+
require.Contains(t, se.Error(), "name")
1000+
}
1001+
1002+
func TestAcceptNotSetWhenOmitted(t *testing.T) {
1003+
source := `
1004+
from cog import BasePredictor, Input, Path
1005+
1006+
class Predictor(BasePredictor):
1007+
def predict(self, image: Path = Input(description="An image")) -> str:
1008+
pass
1009+
`
1010+
info := parse(t, source, "Predictor")
1011+
image, ok := info.Inputs.Get("image")
1012+
require.True(t, ok)
1013+
require.Nil(t, image.Accept)
1014+
}
1015+
9251016
// ---------------------------------------------------------------------------
9261017
// File type (deprecated alias for Path)
9271018
// ---------------------------------------------------------------------------

pkg/schema/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ type InputField struct {
187187
Regex *string
188188
Choices []DefaultValue
189189
Deprecated *bool
190+
Accept *string // MIME types / file extensions for Path/File inputs (e.g. "image/*")
190191
}
191192

192193
// IsRequired returns true if this field is required in the schema.

0 commit comments

Comments
 (0)