Skip to content

Commit fa888ee

Browse files
authored
[IR] Allow to copy an unfrozen version of the Shape (#2238)
When a shape is frozen, the dims of the shape cannot be modified. Users can call ``` new_shape = shape.copy() new_shape[0] = 1 ``` to assign to the new shape. Added examples and the `frozen` property.
1 parent 02cf905 commit fa888ee

2 files changed

Lines changed: 76 additions & 6 deletions

File tree

onnxscript/ir/_core.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,8 @@ def meta(self) -> _metadata.MetadataStore:
994994

995995

996996
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
997+
"""Immutable symbolic dimension that can be shared across multiple shapes."""
998+
997999
__slots__ = ("_value",)
9981000

9991001
def __init__(self, value: str | None) -> None:
@@ -1054,6 +1056,53 @@ def _maybe_convert_to_symbolic_dim(
10541056

10551057

10561058
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1059+
"""The shape of a tensor, including its dimensions and optional denotations.
1060+
1061+
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1062+
symbolic dimensions.
1063+
1064+
A shape can be compared to another shape or plain Python list.
1065+
1066+
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1067+
unfrozen, making it suitable to be shared across tensors or values.
1068+
Call :method:`freeze` to freeze the shape.
1069+
1070+
To update the dimension of a frozen shape, call :method:`copy` to create a
1071+
new shape with the same dimensions that can be modified.
1072+
1073+
Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1074+
1075+
Example::
1076+
1077+
>>> from onnxscript import ir
1078+
>>> shape = ir.Shape(["B", None, 3])
1079+
>>> shape.rank()
1080+
3
1081+
>>> shape.is_static()
1082+
False
1083+
>>> shape.is_dynamic()
1084+
True
1085+
>>> shape.is_static(dim=2)
1086+
True
1087+
>>> shape[0] = 1
1088+
>>> shape[1] = 2
1089+
>>> shape.dims
1090+
(1, 2, 3)
1091+
>>> shape == [1, 2, 3]
1092+
True
1093+
>>> shape.frozen
1094+
False
1095+
>>> shape.freeze()
1096+
>>> shape.frozen
1097+
True
1098+
1099+
Attributes:
1100+
dims: A tuple of dimensions representing the shape.
1101+
Each dimension can be an integer, None or a :class:`SymbolicDim`.
1102+
frozen: Indicates whether the shape is immutable. When frozen, the shape
1103+
cannot be modified or unfrozen.
1104+
"""
1105+
10571106
__slots__ = ("_dims", "_frozen")
10581107

10591108
def __init__(
@@ -1076,7 +1125,8 @@ def __init__(
10761125
Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
10771126
for pre-defined dimension denotations.
10781127
frozen: If True, the shape is immutable and cannot be modified. This
1079-
is useful when the shape is initialized by a Tensor.
1128+
is useful when the shape is initialized by a Tensor or when the shape
1129+
is shared across multiple tensors. The default is False.
10801130
"""
10811131
self._dims: list[int | SymbolicDim] = [
10821132
_maybe_convert_to_symbolic_dim(dim) for dim in dims
@@ -1090,10 +1140,6 @@ def __init__(
10901140
)
10911141
self._frozen: bool = frozen
10921142

1093-
def copy(self):
1094-
"""Return a copy of the shape."""
1095-
return Shape(self._dims, self._denotations, self._frozen)
1096-
10971143
@property
10981144
def dims(self) -> tuple[int | SymbolicDim, ...]:
10991145
"""All dimensions in the shape.
@@ -1102,8 +1148,29 @@ def dims(self) -> tuple[int | SymbolicDim, ...]:
11021148
"""
11031149
return tuple(self._dims)
11041150

1151+
@property
1152+
def frozen(self) -> bool:
1153+
"""Whether the shape is frozen.
1154+
1155+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1156+
Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1157+
new shape with the same dimensions that can be modified.
1158+
"""
1159+
return self._frozen
1160+
1161+
def freeze(self) -> None:
1162+
"""Freeze the shape.
1163+
1164+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1165+
"""
1166+
self._frozen = True
1167+
1168+
def copy(self, frozen: bool = False):
1169+
"""Return a copy of the shape."""
1170+
return Shape(self._dims, self._denotations, frozen=frozen)
1171+
11051172
def rank(self) -> int:
1106-
"""The rank of the shape."""
1173+
"""The rank of the tensor this shape represents."""
11071174
return len(self._dims)
11081175

11091176
def numpy(self) -> tuple[int, ...]:

onnxscript/ir/_core_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,9 @@ def test_setitem_raises_when_shape_is_frozen(self):
622622
with self.assertRaisesRegex(TypeError, "frozen"):
623623
shape[0] = 1
624624

625+
with self.assertRaisesRegex(TypeError, "frozen"):
626+
shape[0] = "some_string"
627+
625628
def test_getitem(self):
626629
shape = _core.Shape([42], denotations=("DATA_CHANNEL",))
627630
self.assertEqual(shape[0], 42)

0 commit comments

Comments
 (0)