diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 51c6d83502..ae2cfee95d 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -994,6 +994,8 @@ def meta(self) -> _metadata.MetadataStore: class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): + """Immutable symbolic dimension that can be shared across multiple shapes.""" + __slots__ = ("_value",) def __init__(self, value: str | None) -> None: @@ -1054,6 +1056,53 @@ def _maybe_convert_to_symbolic_dim( class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): + """The shape of a tensor, including its dimensions and optional denotations. + + The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or + symbolic dimensions. + + A shape can be compared to another shape or plain Python list. + + A shape can be frozen (made immutable). When the shape is frozen, it cannot be + unfrozen, making it suitable to be shared across tensors or values. + Call :method:`freeze` to freeze the shape. + + To update the dimension of a frozen shape, call :method:`copy` to create a + new shape with the same dimensions that can be modified. + + Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations. + + Example:: + + >>> from onnxscript import ir + >>> shape = ir.Shape(["B", None, 3]) + >>> shape.rank() + 3 + >>> shape.is_static() + False + >>> shape.is_dynamic() + True + >>> shape.is_static(dim=2) + True + >>> shape[0] = 1 + >>> shape[1] = 2 + >>> shape.dims + (1, 2, 3) + >>> shape == [1, 2, 3] + True + >>> shape.frozen + False + >>> shape.freeze() + >>> shape.frozen + True + + Attributes: + dims: A tuple of dimensions representing the shape. + Each dimension can be an integer, None or a :class:`SymbolicDim`. + frozen: Indicates whether the shape is immutable. When frozen, the shape + cannot be modified or unfrozen. + """ + __slots__ = ("_dims", "_frozen") def __init__( @@ -1076,7 +1125,8 @@ def __init__( Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition for pre-defined dimension denotations. frozen: If True, the shape is immutable and cannot be modified. This - is useful when the shape is initialized by a Tensor. + is useful when the shape is initialized by a Tensor or when the shape + is shared across multiple tensors. The default is False. """ self._dims: list[int | SymbolicDim] = [ _maybe_convert_to_symbolic_dim(dim) for dim in dims @@ -1090,10 +1140,6 @@ def __init__( ) self._frozen: bool = frozen - def copy(self): - """Return a copy of the shape.""" - return Shape(self._dims, self._denotations, self._frozen) - @property def dims(self) -> tuple[int | SymbolicDim, ...]: """All dimensions in the shape. @@ -1102,8 +1148,29 @@ def dims(self) -> tuple[int | SymbolicDim, ...]: """ return tuple(self._dims) + @property + def frozen(self) -> bool: + """Whether the shape is frozen. + + When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. + Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a + new shape with the same dimensions that can be modified. + """ + return self._frozen + + def freeze(self) -> None: + """Freeze the shape. + + When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. + """ + self._frozen = True + + def copy(self, frozen: bool = False): + """Return a copy of the shape.""" + return Shape(self._dims, self._denotations, frozen=frozen) + def rank(self) -> int: - """The rank of the shape.""" + """The rank of the tensor this shape represents.""" return len(self._dims) def numpy(self) -> tuple[int, ...]: diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 7068a8da8f..ee2b0f389c 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -622,6 +622,9 @@ def test_setitem_raises_when_shape_is_frozen(self): with self.assertRaisesRegex(TypeError, "frozen"): shape[0] = 1 + with self.assertRaisesRegex(TypeError, "frozen"): + shape[0] = "some_string" + def test_getitem(self): shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) self.assertEqual(shape[0], 42)