Skip to content

Allow access to integer parameters in shape constraints #42

@uri-granta

Description

@uri-granta

Feature request

Allow access to integer parameters in shape constraints.

Motivation

Defining shapes of synthetic functions in trieste. E.g. something like:

@check_shapes(
	"a: [batch..., $d]",
	"return: [batch..., 1]"
)
def levy(x: TensorType, d: int) -> TensorType:
        ...

Proposal

There are at least two ways to handle this.

One is to allow references to int parameters (similar to the proposal in #6) using a syntax like the one above.

Another is to support value constraints for ints, maybe something like:

@check_shapes(
   "a: [batch..., dim]",
   "d: dim",
   "return: [batch..., 1]"
)
def levy(x: TensorType, d: int) -> TensorType:
       ...

Workarounds

One workaround is to move the dynamic check inside the function, but this splits the spec and doesn't support docstring rewriting.

@check_shapes("return: [batch..., 1]")
def levy(x: TensorType, d: int) -> TensorType:
        check_shape(x, "[batch..., d]")

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions