Pyrefly's tensor shape tracking is designed so that coverage of the PyTorch library can be extended without understanding Pyrefly's internals. This page explains the three mechanisms for specifying shape transforms and how to add new ones.
Shape tracking uses three complementary mechanisms:
- Fixture stubs —
.pyifiles with shape-generic type signatures. Covers modules likenn.Linear,nn.Conv2d, and functions liketorch.mm. - DSL functions — shape transform specifications written in a tiny
Python subset, registered in
tensor_ops_registry.rs. Covers operations with complex shape logic likereshape,cat,transpose, andF.interpolate. - Special handlers — built into Pyrefly for constructs that need
deeper type system integration, like
nn.Sequentialchaining,.shapeattribute access, and.size().
Most contributions involve fixture stubs or DSL functions. Special handlers require changes to Pyrefly's Rust source.
test/tensor_shapes/fixtures/torch/
├── __init__.pyi
├── nn/
│ ├── __init__.pyi # nn.Linear, nn.Conv2d, nn.LSTM, etc.
│ └── functional.pyi # F.relu, F.softmax, F.conv2d, etc.
├── distributions/
│ └── ... # torch.distributions
└── ...
The search_path config option tells Pyrefly to look here for type
information, overriding the real torch stubs.
A fixture stub provides a shape-generic type signature. For example,
nn.Linear:
class Linear[N, M](Module):
def __init__(self, in_features: Dim[N], out_features: Dim[M],
bias: bool = True) -> None: ...
def forward[*Xs](self, input: Tensor[*Xs, N]) -> Tensor[*Xs, M]: ...The constructor captures the input and output dimensions as type parameters.
The forward method uses those parameters plus a variadic *Xs for batch
dimensions.
- Identify the shape signature. What are the input dimensions, output dimensions, and how do they relate?
- Make constructor parameters
Dim[X]for parameters that determine tensor dimensions. Non-shape parameters (bias,dropout) stay as their original types. - Write the
forwardsignature expressing the shape transform. Use*Xsor*Bsfor batch dimensions that pass through unchanged. - Add the stub to the appropriate
.pyifile in the fixtures directory. - Test it by writing a small model that uses the op and running the checker.
Suppose you want to add nn.GroupNorm, which preserves spatial dimensions:
class GroupNorm[NumGroups, NumChannels](Module):
def __init__(
self,
num_groups: Dim[NumGroups],
num_channels: Dim[NumChannels],
eps: float = 1e-5,
affine: bool = True,
) -> None: ...
def forward[*S](self, input: Tensor[*S]) -> Tensor[*S]: ...Since GroupNorm doesn't change the shape, the forward signature is simply
Tensor[*S] -> Tensor[*S].
DSL functions are registered in:
tensor_ops_registry.rs
Each entry maps a qualified PyTorch function name to a shape transform specification written in a tiny Python subset.
The DSL supports:
- Lists and list comprehensions
- Arithmetic (
+,-,*,//) zip,len, indexingTensor(shape=[...])to construct result shapesself.shapeto access input shapes- Conditionals (
if/else)
def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor:
return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)])This says: the output shape is the element-wise product of the input shape
and the sizes argument.
def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor:
shapes = [t.shape for t in tensors]
result = list(shapes[0])
for s in shapes[1:]:
result[dim] = result[dim] + s[dim]
return Tensor(shape=result)This sums the shapes along the concatenation dimension and preserves all others.
- Write the shape transform in the DSL subset. Focus on the relationship between input and output shapes.
- Register it in
tensor_ops_registry.rswith the qualified PyTorch name (e.g.,"torch.nn.functional.adaptive_avg_pool2d"). - Test it by using the op in a model and checking that
reveal_typeproduces the expected shape.
test/tensor_shapes/models/
Each file is a fully annotated port of a real-world PyTorch model with
assert_type checkpoints and smoke tests.
- Choose a model from TorchBench or another source.
- Port it using the tutorials or the agent skill.
- Add
assert_typeafter every shape-changing operation. - Add smoke tests at the bottom of the file.
- Run
verify_port.shto check for issues.
This script checks a ported model for common issues:
.claude/skills/port-model/verify_port.sh test/tensor_shapes/models/<model>.pyIt reports:
| Metric | Description |
|---|---|
ig |
type: ignore count |
bs |
Bare Tensor in signatures |
bv |
Bare Tensor in variable annotations |
sh |
Shaped assert_type count |
ba |
Bare assert_type count |
sm |
Smoke test count |
After adding stubs, DSL functions, or ported models, run the test suite:
# Run a specific test
buck test pyrefly:pyrefly_library -- tensor_shape
# Run all tests
buck test pyrefly:pyrefly_libraryFor external contributors using cargo:
cargo test tensor_shape