@@ -994,6 +994,8 @@ def meta(self) -> _metadata.MetadataStore:
994994
995995
996996class SymbolicDim (_protocols .SymbolicDimProtocol , _display .PrettyPrintable ):
997+ """Immutable symbolic dimension that can be shared across multiple shapes."""
998+
997999 __slots__ = ("_value" ,)
9981000
9991001 def __init__ (self , value : str | None ) -> None :
@@ -1054,6 +1056,53 @@ def _maybe_convert_to_symbolic_dim(
10541056
10551057
10561058class Shape (_protocols .ShapeProtocol , _display .PrettyPrintable ):
1059+ """The shape of a tensor, including its dimensions and optional denotations.
1060+
1061+ The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1062+ symbolic dimensions.
1063+
1064+ A shape can be compared to another shape or plain Python list.
1065+
1066+ A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1067+ unfrozen, making it suitable to be shared across tensors or values.
1068+ Call :method:`freeze` to freeze the shape.
1069+
1070+ To update the dimension of a frozen shape, call :method:`copy` to create a
1071+ new shape with the same dimensions that can be modified.
1072+
1073+ Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1074+
1075+ Example::
1076+
1077+ >>> from onnxscript import ir
1078+ >>> shape = ir.Shape(["B", None, 3])
1079+ >>> shape.rank()
1080+ 3
1081+ >>> shape.is_static()
1082+ False
1083+ >>> shape.is_dynamic()
1084+ True
1085+ >>> shape.is_static(dim=2)
1086+ True
1087+ >>> shape[0] = 1
1088+ >>> shape[1] = 2
1089+ >>> shape.dims
1090+ (1, 2, 3)
1091+ >>> shape == [1, 2, 3]
1092+ True
1093+ >>> shape.frozen
1094+ False
1095+ >>> shape.freeze()
1096+ >>> shape.frozen
1097+ True
1098+
1099+ Attributes:
1100+ dims: A tuple of dimensions representing the shape.
1101+ Each dimension can be an integer, None or a :class:`SymbolicDim`.
1102+ frozen: Indicates whether the shape is immutable. When frozen, the shape
1103+ cannot be modified or unfrozen.
1104+ """
1105+
10571106 __slots__ = ("_dims" , "_frozen" )
10581107
10591108 def __init__ (
@@ -1076,7 +1125,8 @@ def __init__(
10761125 Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
10771126 for pre-defined dimension denotations.
10781127 frozen: If True, the shape is immutable and cannot be modified. This
1079- is useful when the shape is initialized by a Tensor.
1128+ is useful when the shape is initialized by a Tensor or when the shape
1129+ is shared across multiple tensors. The default is False.
10801130 """
10811131 self ._dims : list [int | SymbolicDim ] = [
10821132 _maybe_convert_to_symbolic_dim (dim ) for dim in dims
@@ -1090,10 +1140,6 @@ def __init__(
10901140 )
10911141 self ._frozen : bool = frozen
10921142
1093- def copy (self ):
1094- """Return a copy of the shape."""
1095- return Shape (self ._dims , self ._denotations , self ._frozen )
1096-
10971143 @property
10981144 def dims (self ) -> tuple [int | SymbolicDim , ...]:
10991145 """All dimensions in the shape.
@@ -1102,8 +1148,29 @@ def dims(self) -> tuple[int | SymbolicDim, ...]:
11021148 """
11031149 return tuple (self ._dims )
11041150
1151+ @property
1152+ def frozen (self ) -> bool :
1153+ """Whether the shape is frozen.
1154+
1155+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1156+ Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1157+ new shape with the same dimensions that can be modified.
1158+ """
1159+ return self ._frozen
1160+
1161+ def freeze (self ) -> None :
1162+ """Freeze the shape.
1163+
1164+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1165+ """
1166+ self ._frozen = True
1167+
1168+ def copy (self , frozen : bool = False ):
1169+ """Return a copy of the shape."""
1170+ return Shape (self ._dims , self ._denotations , frozen = frozen )
1171+
11051172 def rank (self ) -> int :
1106- """The rank of the shape."""
1173+ """The rank of the tensor this shape represents ."""
11071174 return len (self ._dims )
11081175
11091176 def numpy (self ) -> tuple [int , ...]:
0 commit comments