@@ -33,7 +33,27 @@ def _is_view_copy(node: torch.fx.Node) -> bool:
3333 )
3434
3535
36+ def _is_select_copy (node : torch .fx .Node ) -> bool :
37+ return node .op == "call_function" and node .target in (
38+ torch .ops .aten .select_copy .int ,
39+ ops .edge .aten .select_copy .int ,
40+ )
41+
42+
3643_VIEW_OP = memory .view
44+ _SELECT_OP = memory .select
45+
46+
47+ def _is_contiguous_base (spec : TensorSpec ) -> bool :
48+ if not isinstance (spec , TensorSpec ):
49+ return False
50+ if spec .const :
51+ return False
52+ shape = list (spec .shape )
53+ if tuple (spec .dim_order ) != tuple (range (len (shape ))):
54+ return False
55+ expected_stride = contiguous_stride_from_shape (torch .Size (shape ))
56+ return tuple (spec .stride ) == tuple (expected_stride )
3757
3858
3959class _Guard :
@@ -54,7 +74,12 @@ def __call__(self, view_spec) -> None: # pyre-ignore[2]
5474
5575
5676class _ViewSpec (TensorSpec ):
57- def __init__ (self , base : TensorSpec , shape : List [int ]) -> None :
77+ def __init__ (
78+ self ,
79+ base : TensorSpec ,
80+ shape : List [int ],
81+ byte_offset : int = 0 ,
82+ ) -> None :
5883 """
5984 A _ViewSpec is TensorSpec that shares non-size related fields with its base.
6085 The size-related fields are: shape, stride, dim_order, and shape_dynamism.
@@ -65,7 +90,11 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
6590
6691 A _ViewSpec can only be created from a non-sparse, strided TensorSpec.
6792 On creation, a _ViewSpec must be compatible with its base with respect to
68- shape_dynamism, dtype, and nbytes.
93+ shape_dynamism, dtype, and nbytes (when byte_offset is 0).
94+
95+ When byte_offset is non-zero (used for select/slice sub-views), the view
96+ describes a contiguous sub-region of the base at the given byte offset.
97+ In this case, nbytes may differ from the base and rank may change.
6998
7099 A _ViewSpec contains _guards that are evaluated on every __getattribute__ call.
71100 The purpose of the guards is to make sure the _ViewSpec is still compatible
@@ -119,6 +148,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
119148
120149 self ._guards : List [_Guard ] = []
121150 self ._unguarded_access = False
151+ self ._byte_offset : int = byte_offset
122152
123153 # Make sure base is not sparse and add a guard
124154 if base .is_sparse :
@@ -183,15 +213,19 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
183213 _Guard ("dtype" , lambda view_spec : view_spec .dtype , base .dtype )
184214 )
185215
186- # We do not guard nbytes because dynamic symints are replaced by upper bounds.
187- # We do guard on rank, though
188- if self .nbytes () != base .nbytes ():
189- raise Exception (
190- f"_ViewSpec is incompatible with its base on creation. It has nbytes={ self .nbytes ()} , but its base has nbytes={ base .nbytes ()} ."
216+ # For traditional views (same nbytes, zero offset), rank is guarded.
217+ # For sub-views (select/slice), the output is a contiguous subset so
218+ # nbytes will differ and rank may change.
219+ is_full_view = byte_offset == 0 and self .nbytes () == base .nbytes ()
220+ if is_full_view :
221+ self ._guards .append (
222+ _Guard ("rank" , lambda view_spec : len (view_spec .shape ), len (shape ))
191223 )
192- self ._guards .append (
193- _Guard ("rank" , lambda view_spec : len (view_spec .shape ), len (shape ))
194- )
224+ else :
225+ if self .nbytes () + byte_offset > base .nbytes ():
226+ raise Exception (
227+ f"_ViewSpec sub-view extends beyond base. Sub-view needs { self .nbytes ()} bytes at offset { byte_offset } , but base has { base .nbytes ()} bytes."
228+ )
195229
196230 def _run_guards (self ) -> None :
197231 unguarded_access = self ._unguarded_access
@@ -211,6 +245,7 @@ def __getattribute__(self, name: str): # pyre-ignore
211245 "_guards" ,
212246 "_unguarded_access" ,
213247 "_run_guards" ,
248+ "_byte_offset" ,
214249 ]:
215250 return object .__getattribute__ (self , name )
216251
@@ -219,6 +254,11 @@ def __getattribute__(self, name: str): # pyre-ignore
219254 val = object .__getattribute__ (self , name )
220255 elif name in self ._base_fields :
221256 val = object .__getattribute__ (self ._base , name )
257+ # For sub-views (select/slice), adjust mem_offset by byte_offset
258+ if name == "mem_offset" and val is not None :
259+ byte_offset = object .__getattribute__ (self , "_byte_offset" )
260+ if byte_offset != 0 :
261+ val = val + byte_offset
222262 else :
223263 if len (name ) > 0 and name [0 ] != "_" :
224264 logger .warning (
@@ -239,6 +279,7 @@ def __setattr__(self, name: str, val) -> None: # pyre-ignore
239279 "_guards" ,
240280 "_unguarded_access" ,
241281 "_run_guards" ,
282+ "_byte_offset" ,
242283 ]:
243284 object .__setattr__ (self , name , val )
244285 return
@@ -293,11 +334,50 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
293334
294335 n_replaced += 1
295336
337+ elif _is_select_copy (node ) and all (
338+ u .op != "output" for u in node .users
339+ ):
340+ replaced = self ._try_replace_select (node )
341+ if replaced :
342+ n_replaced += 1
343+
296344 module .recompile ()
297345
298346 logger .debug (f"Replaced { n_replaced } view_copy nodes with { _VIEW_OP } nodes." )
347+ logger .debug (
348+ f"Replaced { n_replaced } select_copy nodes with { _SELECT_OP } nodes."
349+ )
299350 return PassResult (graph_module , n_replaced > 0 )
300351
352+ def _try_replace_select (self , node : torch .fx .Node ) -> bool :
353+ base = node .args [0 ]
354+ assert isinstance (base , torch .fx .Node )
355+ base_spec = base .meta ["spec" ]
356+ if not _is_contiguous_base (base_spec ):
357+ return False
358+
359+ dim : int = node .args [1 ]
360+ index : int = node .args [2 ]
361+ base_shape = [int (s ) for s in base_spec .shape ]
362+
363+ if dim < 0 :
364+ dim += len (base_shape )
365+
366+ if any (base_shape [i ] != 1 for i in range (dim )):
367+ return False
368+
369+ if index < 0 :
370+ index += base_shape [dim ]
371+
372+ out_shape = list (node .meta ["spec" ].shape )
373+ base_stride = contiguous_stride_from_shape (torch .Size (base_shape ))
374+ element_size = torch ._utils ._element_size (base_spec .dtype )
375+ byte_offset = index * base_stride [dim ] * element_size
376+
377+ node .target = _SELECT_OP
378+ node .meta ["spec" ] = _ViewSpec (base_spec , out_shape , byte_offset )
379+ return True
380+
301381 def ensures (self , graph_module : torch .fx .GraphModule ) -> None :
302382 for module in graph_module .modules ():
303383 if not isinstance (module , torch .fx .GraphModule ):
@@ -310,6 +390,8 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
310390 )
311391 if node .op == "call_function" and node .target == _VIEW_OP :
312392 assert isinstance (node .meta ["spec" ], _ViewSpec )
393+ if node .op == "call_function" and node .target == _SELECT_OP :
394+ assert isinstance (node .meta ["spec" ], _ViewSpec )
313395
314396 def requires (self , graph_module : torch .fx .GraphModule ) -> None :
315397 """
0 commit comments