44# LICENSE file in the root directory of this source tree.
55"""Provide TOSA support checks for ``aten.index.Tensor``.
66
7- Reject unsupported patterns such as high-rank index tensors, front-positioned
8- slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.
7+ Reject unsupported patterns such as front-positioned slice/ellipsis/None
8+ markers and cases that exceed ``int32`` element limits.
99
1010"""
1111
@@ -30,18 +30,12 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
3030 This support check is intended to prevent the partitioning of
3131 currently unsupported usages of the index.Tensor operator.
3232
33- 1. Usages where indexing tensors are of rank 4 or higher.
34- This is due to the AnnotateChannelsLastDimOrder pass and
35- the rarity of such operation.
36- Support is possible but would require further changes to the above
37- pass which can be added at such a time as is necessary.
38-
39- 2. Usages where slice, ellipsis or None are present before an indexing tensor:
33+ 1. Usages where slice, ellipsis or None are present before an indexing tensor:
4034 t[{start}:{end}, indexTensor] - slicing
4135 t[None, indexTensor] - unsqueeze
4236 t[..., indexTensor] - ellipsis
4337
44- 3 . Usages where the value tensor contains more than int32.max elements
38+ 2 . Usages where the value tensor contains more than int32.max elements
4539 This is due to int32 TOSA limitation and the fact that we flatten out
4640 and accumulate all index tensors.
4741 As such to avoid overflow we reject lowering of this operator if it is
@@ -115,13 +109,12 @@ def is_node_tosa_supported(
115109
116110 Enforces the following constraints:
117111 - No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
118- - Indexing tensors have rank <= 3.
119112 - The value tensor element count fits in ``int32``.
120113
121114 """
122115 indices = node .args [1 ]
123116 for index in indices : # type: ignore[union-attr]
124- # Usage 2 guard
117+ # Usage 1 guard
125118 if index is None :
126119 self .reporter .report_reject (
127120 node ,
@@ -132,17 +125,7 @@ def is_node_tosa_supported(
132125 )
133126 return False
134127
135- # Usage 1 guard
136- index = ensure_type (torch .fx .Node , index )
137- fake_tensor = get_first_fake_tensor (index )
138- if len (fake_tensor .size ()) > 3 :
139- self .reporter .report_reject (
140- node ,
141- ("Indexing tensors of rank >= 4 is not supported." ),
142- )
143- return False
144-
145- # Usage 3 guard
128+ # Usage 2 guard
146129 input_node = ensure_type (torch .fx .Node , node .args [0 ])
147130 total_vals = math .prod (get_first_fake_tensor (input_node ).shape )
148131 if total_vals > torch .iinfo (torch .int32 ).max :
0 commit comments