66
77# pyre-strict
88
9+ import copy
910import logging
11+ import operator
1012from typing import Optional
1113
14+ # Import to register the et_copy ops so torch.ops.et_copy is available.
15+ import executorch .exir .passes ._device_copy_ops_registry # noqa: F401
16+
1217import executorch .exir .schema as schema
1318
1419import torch
@@ -124,23 +129,150 @@ def _tag_specs_with_device(
124129 return False
125130
126131
132+ def _clone_spec_with_device (
133+ spec : TensorSpec ,
134+ device_type : schema .DeviceType ,
135+ device_index : int = 0 ,
136+ ) -> TensorSpec :
137+ """Create a copy of a TensorSpec with a different device."""
138+ new_spec = copy .copy (spec )
139+ new_spec .init_mem_planning_fields ()
140+ _set_device_on_spec (new_spec , device_type , device_index )
141+ return new_spec
142+
143+
127144class PropagateDevicePass (PassBase ):
128145 """
129- After to_backend, walk the graph and set device metadata on TensorSpecs
130- based on partitioner-assigned delegation info.
131-
132- Rules:
133- 1. Delegated nodes: Input and output tensors of a delegate call are marked
134- with the target device derived from the delegate's CompileSpec
135- (key="target_device").
136- 2. Non-delegated nodes: Remain on CPU (default).
137- 3. Getitem nodes that extract from a delegate call inherit the device from
138- the delegate call's output spec at the corresponding index.
146+ After to_backend, walk the graph and insert H2D/D2H copy ops at delegate
147+ boundaries based on partitioner-assigned device info.
148+
149+ When a delegate has a target_device CompileSpec (e.g., "cuda:0"):
150+ - For each delegate input: insert et_copy._h2d_copy before the delegate call.
151+ The original input node stays CPU; the h2d_copy output is tagged as device.
152+ - For each delegate output: insert et_copy._d2h_copy after each getitem.
153+ The getitem stays device; the d2h_copy output is tagged as CPU.
154+ - Getitem nodes that extract from a delegate call inherit the device.
155+
156+ Skip-copy optimizations:
157+ - skip_h2d_for_method_inputs: If the input is a graph-level placeholder
158+ feeding directly to a delegate, don't insert H2D — tag the placeholder
159+ as device instead (user provides GPU tensor at runtime).
160+ - skip_d2h_for_method_outputs: If the getitem feeds directly to graph
161+ output, don't insert D2H — the output stays on device.
139162 """
140163
164+ def __init__ (
165+ self ,
166+ ) -> None :
167+ super ().__init__ ()
168+
169+ def _is_placeholder (self , node : torch .fx .Node ) -> bool :
170+ """Check if a node is a graph-level input (placeholder)."""
171+ return node .op == "placeholder"
172+
173+ def _feeds_directly_to_output (self , node : torch .fx .Node ) -> bool :
174+ """Check if all users of a node are output nodes."""
175+ return all (user .op == "output" for user in node .users )
176+
177+ def _insert_h2d_copies (
178+ self ,
179+ graph_module : torch .fx .GraphModule ,
180+ node : torch .fx .Node ,
181+ target_device_type : schema .DeviceType ,
182+ device_index : int ,
183+ ) -> bool :
184+ """Insert H2D copy nodes for each tensor input to a delegate call."""
185+ changed = False
186+ new_args = list (node .args )
187+ for i , arg in enumerate (node .args [1 :], start = 1 ):
188+ if not isinstance (arg , torch .fx .Node ):
189+ continue
190+ arg_spec = arg .meta .get ("spec" )
191+ if not isinstance (arg_spec , TensorSpec ):
192+ continue
193+
194+ with graph_module .graph .inserting_before (node ):
195+ h2d_node = graph_module .graph .call_function (
196+ torch .ops .et_copy ._h2d_copy .default ,
197+ (arg ,),
198+ )
199+ h2d_spec = _clone_spec_with_device (
200+ arg_spec , target_device_type , device_index
201+ )
202+ h2d_node .meta ["spec" ] = h2d_spec
203+ h2d_node .meta ["val" ] = arg .meta .get ("val" )
204+ if "tensor_meta" in arg .meta :
205+ h2d_node .meta ["tensor_meta" ] = arg .meta ["tensor_meta" ]
206+ new_args [i ] = h2d_node
207+ changed = True
208+
209+ node .args = tuple (new_args )
210+ return changed
211+
212+ def _insert_d2h_for_getitem (
213+ self ,
214+ graph_module : torch .fx .GraphModule ,
215+ node : torch .fx .Node ,
216+ ) -> bool :
217+ """If *node* is a getitem extracting from a delegate call, tag its spec
218+ with the delegate device and insert a D2H copy after it."""
219+ source_node = node .args [0 ]
220+ if not (
221+ isinstance (source_node , torch .fx .Node )
222+ and source_node .op == "call_function"
223+ and source_node .target == executorch_call_delegate
224+ ):
225+ return False
226+
227+ spec = node .meta .get ("spec" )
228+ source_specs = source_node .meta .get ("spec" )
229+ idx = node .args [1 ]
230+ if not (
231+ isinstance (spec , TensorSpec )
232+ and isinstance (source_specs , (tuple , list ))
233+ and isinstance (idx , int )
234+ and idx < len (source_specs )
235+ ):
236+ return False
237+
238+ source_spec = source_specs [idx ]
239+ if not isinstance (source_spec , TensorSpec ):
240+ return False
241+
242+ _set_device_on_spec (spec , source_spec .device , source_spec .device_index )
243+
244+ with graph_module .graph .inserting_after (node ):
245+ d2h_node = graph_module .graph .call_function (
246+ torch .ops .et_copy ._d2h_copy .default ,
247+ (node ,),
248+ )
249+ d2h_spec = _clone_spec_with_device (spec , schema .DeviceType .CPU , 0 )
250+ d2h_node .meta ["spec" ] = d2h_spec
251+ d2h_node .meta ["val" ] = node .meta .get ("val" )
252+ if "tensor_meta" in node .meta :
253+ d2h_node .meta ["tensor_meta" ] = node .meta ["tensor_meta" ]
254+
255+ node .replace_all_uses_with (
256+ d2h_node ,
257+ delete_user_cb = lambda user , _d2h = d2h_node : user != _d2h ,
258+ )
259+ return True
260+
141261 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
262+ # Two-pass approach:
263+ # Pass 1 – For each delegate with a target_device CompileSpec, insert
264+ # H2D copy nodes before delegate inputs and tag the delegate
265+ # output specs with the target device. Delegates without a
266+ # target_device are left untouched (no copies, specs stay CPU).
267+ # Pass 2 – For each getitem that extracts from a device-tagged delegate
268+ # (tracked in device_delegates), propagate the device onto the
269+ # getitem spec and insert a D2H copy after it so downstream
270+ # non-delegated ops receive CPU tensors.
142271 changed = False
143- for node in graph_module .graph .nodes :
272+ device_delegates : set [torch .fx .Node ] = set ()
273+
274+ # Pass 1: insert H2D copies and tag delegate output specs.
275+ for node in list (graph_module .graph .nodes ):
144276 if node .op == "call_function" and node .target == executorch_call_delegate :
145277 lowered_module = _get_lowered_module (graph_module , node )
146278 if lowered_module is None :
@@ -151,18 +283,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
151283 continue
152284
153285 target_device_type , device_index = result
286+ device_delegates .add (node )
287+
288+ changed |= self ._insert_h2d_copies (
289+ graph_module , node , target_device_type , device_index
290+ )
154291
155- # Tag delegate input tensors.
156- # args[0] is the get_attr node for the lowered module; skip it.
157- for arg in node .args [1 :]:
158- if isinstance (arg , torch .fx .Node ):
159- changed |= _tag_specs_with_device (
160- arg .meta .get ("spec" ),
161- target_device_type ,
162- device_index ,
163- )
164-
165- # Tag delegate output tensors.
166292 changed |= _tag_specs_with_device (
167293 node .meta .get ("spec" ),
168294 target_device_type ,
@@ -177,34 +303,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
177303 lowered_module .backend_id ,
178304 )
179305
180- # Second pass: propagate device through getitem nodes that extract
181- # individual outputs from a delegate call.
182- for node in graph_module .graph .nodes :
183- if node .op == "call_function" and node .target .__name__ == "getitem" :
184- source_node = node .args [0 ]
185- if (
186- isinstance (source_node , torch .fx .Node )
187- and source_node .op == "call_function"
188- and source_node .target == executorch_call_delegate
189- ):
190- spec = node .meta .get ("spec" )
191- source_specs = source_node .meta .get ("spec" )
192- idx = node .args [1 ]
193- if (
194- spec is not None
195- and isinstance (spec , TensorSpec )
196- and source_specs is not None
197- and isinstance (source_specs , (tuple , list ))
198- and isinstance (idx , int )
199- and idx < len (source_specs )
200- ):
201- source_spec = source_specs [idx ]
202- if isinstance (source_spec , TensorSpec ):
203- _set_device_on_spec (
204- spec ,
205- source_spec .device ,
206- source_spec .device_index ,
207- )
208- changed = True
306+ # Second pass: propagate device through getitem nodes and insert D2H
307+ # only for delegates that have a target_device.
308+ for node in list (graph_module .graph .nodes ):
309+ if node .op == "call_function" and node .target == operator .getitem :
310+ source = node .args [0 ]
311+ if isinstance (source , torch .fx .Node ) and source in device_delegates :
312+ changed |= self ._insert_d2h_for_getitem (graph_module , node )
209313
314+ graph_module .recompile ()
210315 return PassResult (graph_module , changed )
0 commit comments