@@ -33,7 +33,27 @@ def _is_view_copy(node: torch.fx.Node) -> bool:
3333 )
3434
3535
36+ def _is_select_copy (node : torch .fx .Node ) -> bool :
37+ return node .op == "call_function" and node .target in (
38+ torch .ops .aten .select_copy .int ,
39+ ops .edge .aten .select_copy .int ,
40+ )
41+
42+
3643_VIEW_OP = memory .view
44+ _SELECT_OP = memory .select
45+
46+
47+ def _is_contiguous_base (spec : TensorSpec ) -> bool :
48+ if not isinstance (spec , TensorSpec ):
49+ return False
50+ if spec .const or spec .is_dynamic_shape_tensor :
51+ return False
52+ shape = list (spec .shape )
53+ if tuple (spec .dim_order ) != tuple (range (len (shape ))):
54+ return False
55+ expected_stride = contiguous_stride_from_shape (torch .Size (shape ))
56+ return tuple (spec .stride ) == tuple (expected_stride )
3757
3858
3959class _Guard :
@@ -54,7 +74,12 @@ def __call__(self, view_spec) -> None: # pyre-ignore[2]
5474
5575
5676class _ViewSpec (TensorSpec ):
57- def __init__ (self , base : TensorSpec , shape : List [int ]) -> None :
77+ def __init__ (
78+ self ,
79+ base : TensorSpec ,
80+ shape : List [int ],
81+ byte_offset : int = 0 ,
82+ ) -> None :
5883 """
5984 A _ViewSpec is TensorSpec that shares non-size related fields with its base.
6085 The size-related fields are: shape, stride, dim_order, and shape_dynamism.
@@ -65,7 +90,11 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
6590
6691 A _ViewSpec can only be created from a non-sparse, strided TensorSpec.
6792 On creation, a _ViewSpec must be compatible with its base with respect to
68- shape_dynamism, dtype, and nbytes.
93+ shape_dynamism, dtype, and nbytes (when byte_offset is 0).
94+
95+ When byte_offset is non-zero (used for select/slice sub-views), the view
96+ describes a contiguous sub-region of the base at the given byte offset.
97+ In this case, nbytes may differ from the base and rank may change.
6998
7099 A _ViewSpec contains _guards that are evaluated on every __getattribute__ call.
71100 The purpose of the guards is to make sure the _ViewSpec is still compatible
@@ -119,6 +148,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
119148
120149 self ._guards : List [_Guard ] = []
121150 self ._unguarded_access = False
151+ self ._byte_offset : int = byte_offset
122152
123153 # Make sure base is not sparse and add a guard
124154 if base .is_sparse :
@@ -183,15 +213,19 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
183213 _Guard ("dtype" , lambda view_spec : view_spec .dtype , base .dtype )
184214 )
185215
186- # We do not guard nbytes because dynamic symints are replaced by upper bounds.
187- # We do guard on rank, though
188- if self .nbytes () != base .nbytes ():
189- raise Exception (
190- f"_ViewSpec is incompatible with its base on creation. It has nbytes={ self .nbytes ()} , but its base has nbytes={ base .nbytes ()} ."
216+ # For traditional views (same nbytes, zero offset), rank is guarded.
217+ # For sub-views (select/slice), the output is a contiguous subset so
218+ # nbytes will differ and rank may change.
219+ is_full_view = byte_offset == 0 and self .nbytes () == base .nbytes ()
220+ if is_full_view :
221+ self ._guards .append (
222+ _Guard ("rank" , lambda view_spec : len (view_spec .shape ), len (shape ))
191223 )
192- self ._guards .append (
193- _Guard ("rank" , lambda view_spec : len (view_spec .shape ), len (shape ))
194- )
224+ else :
225+ if self .nbytes () + byte_offset > base .nbytes ():
226+ raise Exception (
227+ f"_ViewSpec sub-view extends beyond base. Sub-view needs { self .nbytes ()} bytes at offset { byte_offset } , but base has { base .nbytes ()} bytes."
228+ )
195229
196230 def _run_guards (self ) -> None :
197231 unguarded_access = self ._unguarded_access
@@ -211,6 +245,7 @@ def __getattribute__(self, name: str): # pyre-ignore
211245 "_guards" ,
212246 "_unguarded_access" ,
213247 "_run_guards" ,
248+ "_byte_offset" ,
214249 ]:
215250 return object .__getattribute__ (self , name )
216251
@@ -219,6 +254,11 @@ def __getattribute__(self, name: str): # pyre-ignore
219254 val = object .__getattribute__ (self , name )
220255 elif name in self ._base_fields :
221256 val = object .__getattribute__ (self ._base , name )
257+ # For sub-views (select/slice), adjust mem_offset by byte_offset
258+ if name == "mem_offset" and val is not None :
259+ byte_offset = object .__getattribute__ (self , "_byte_offset" )
260+ if byte_offset != 0 :
261+ val = val + byte_offset
222262 else :
223263 if len (name ) > 0 and name [0 ] != "_" :
224264 logger .warning (
@@ -239,6 +279,7 @@ def __setattr__(self, name: str, val) -> None: # pyre-ignore
239279 "_guards" ,
240280 "_unguarded_access" ,
241281 "_run_guards" ,
282+ "_byte_offset" ,
242283 ]:
243284 object .__setattr__ (self , name , val )
244285 return
@@ -271,6 +312,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
271312 During memory planning, view nodes share the same storage as their base.
272313 """
273314
315+ n_replaced = 0
274316 n_replaced = 0
275317 for module in graph_module .modules ():
276318 if not isinstance (module , torch .fx .GraphModule ):
@@ -293,11 +335,61 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
293335
294336 n_replaced += 1
295337
338+ elif _is_select_copy (node ) and all (
339+ u .op != "output" for u in node .users
340+ ):
341+ replaced = self ._try_replace_select (node )
342+ if replaced :
343+ n_replaced += 1
344+
296345 module .recompile ()
297346
298347 logger .debug (f"Replaced { n_replaced } view_copy nodes with { _VIEW_OP } nodes." )
348+ logger .debug (
349+ f"Replaced { n_replaced } select_copy nodes with { _SELECT_OP } nodes."
350+ )
299351 return PassResult (graph_module , n_replaced > 0 )
300352
353+ def _try_replace_select (self , node : torch .fx .Node ) -> bool :
354+ base = node .args [0 ]
355+ assert isinstance (base , torch .fx .Node )
356+ base_spec = base .meta ["spec" ]
357+ if not _is_contiguous_base (base_spec ):
358+ return False
359+
360+ dim = node .args [1 ]
361+ index = node .args [2 ]
362+ if not isinstance (dim , int ) or not isinstance (index , int ):
363+ return False
364+
365+ base_shape = list (base_spec .shape )
366+ if dim < 0 :
367+ dim += len (base_shape )
368+ if dim < 0 or dim >= len (base_shape ):
369+ return False
370+
371+ # For C-contiguous layout, select(dim=d) produces contiguous output
372+ # only when all dimensions before d have size 1.
373+ if any (base_shape [i ] != 1 for i in range (dim )):
374+ return False
375+
376+ if index < 0 :
377+ index += base_shape [dim ]
378+ if index < 0 or index >= base_shape [dim ]:
379+ return False
380+
381+ out_shape = list (node .meta ["val" ].shape )
382+ if any (not isinstance (s , int ) for s in out_shape ):
383+ return False
384+
385+ base_stride = contiguous_stride_from_shape (torch .Size (base_shape ))
386+ element_size = torch ._utils ._element_size (base_spec .dtype )
387+ byte_offset = index * base_stride [dim ] * element_size
388+
389+ node .target = _SELECT_OP
390+ node .meta ["spec" ] = _ViewSpec (base_spec , out_shape , byte_offset )
391+ return True
392+
301393 def ensures (self , graph_module : torch .fx .GraphModule ) -> None :
302394 for module in graph_module .modules ():
303395 if not isinstance (module , torch .fx .GraphModule ):
@@ -310,6 +402,8 @@ def ensures(self, graph_module: torch.fx.GraphModule) -> None:
310402 )
311403 if node .op == "call_function" and node .target == _VIEW_OP :
312404 assert isinstance (node .meta ["spec" ], _ViewSpec )
405+ if node .op == "call_function" and node .target == _SELECT_OP :
406+ assert isinstance (node .meta ["spec" ], _ViewSpec )
313407
314408 def requires (self , graph_module : torch .fx .GraphModule ) -> None :
315409 """
0 commit comments