@@ -2518,3 +2518,121 @@ def forward(self):
25182518 for j in range (2 ):
25192519 expected_storage .append (j * 16 + i )
25202520 self .assertEqual ([int (v ) for v in storage_values ], expected_storage )
2521+
2522+ def test_emit_device_info_propagated_to_serialized_tensor (self ) -> None :
2523+ """Verify that device info from PropagateDevicePass flows through
2524+ the emitter into ExtraTensorInfo.device_type on serialized tensors."""
2525+ from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
2526+ generate_pattern_op_partitions ,
2527+ )
2528+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2529+ from executorch .exir .backend .partitioner import (
2530+ DelegationSpec ,
2531+ Partitioner ,
2532+ PartitionResult ,
2533+ )
2534+ from executorch .exir .backend .test .backend_with_compiler_demo import (
2535+ BackendWithCompilerDemo ,
2536+ )
2537+ from executorch .exir .passes .propagate_device_pass import (
2538+ TARGET_DEVICE_COMPILE_SPEC_KEY ,
2539+ )
2540+ from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2541+
2542+ class AddSupport (OperatorSupportBase ):
2543+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
2544+ return node .op == "call_function" and node .target in [
2545+ exir_ops .edge .aten .add .Tensor ,
2546+ ]
2547+
2548+ class DevicePartitioner (Partitioner ):
2549+ def __init__ (self ):
2550+ super ().__init__ ()
2551+ self .delegation_spec = DelegationSpec (
2552+ BackendWithCompilerDemo .__name__ ,
2553+ [
2554+ CompileSpec ("max_value" , bytes ([4 ])),
2555+ CompileSpec (TARGET_DEVICE_COMPILE_SPEC_KEY , b"cuda:0" ),
2556+ ],
2557+ )
2558+
2559+ def partition (self , exported_program ) -> PartitionResult :
2560+ partition_tags = {}
2561+ partition_list = generate_pattern_op_partitions (
2562+ exported_program .graph_module ,
2563+ op_support = any_chain (AddSupport ()),
2564+ )
2565+ for partition in partition_list :
2566+ for node in partition .nodes :
2567+ tag = f"tag{ partition .id } "
2568+ node .meta ["delegation_tag" ] = tag
2569+ partition_tags [tag ] = self .delegation_spec
2570+ return PartitionResult (
2571+ tagged_exported_program = exported_program ,
2572+ partition_tags = partition_tags ,
2573+ )
2574+
2575+ class Model (torch .nn .Module ):
2576+ def forward (self , a , b ):
2577+ return torch .add (a , b )
2578+
2579+ model = Model ()
2580+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2581+
2582+ edge = to_edge (
2583+ export (model , inputs ),
2584+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2585+ )
2586+ lowered = edge .to_backend (DevicePartitioner ())
2587+ et_prog = lowered .to_executorch ()
2588+ program = et_prog ._emitter_output .program
2589+
2590+ plan = program .execution_plan [0 ]
2591+ self .assertGreater (len (plan .delegates ), 0 )
2592+
2593+ tensor_values = [v .val for v in plan .values if isinstance (v .val , Tensor )]
2594+ cuda_tensors = [
2595+ t
2596+ for t in tensor_values
2597+ if t .extra_tensor_info is not None
2598+ and t .extra_tensor_info .device_type == schema .DeviceType .CUDA
2599+ ]
2600+ # add(a, b) produces 1 delegate output tensor that should be CUDA
2601+ self .assertEqual (
2602+ len (cuda_tensors ),
2603+ 1 ,
2604+ f"Expected exactly 1 CUDA tensor for delegated add, got { len (cuda_tensors )} " ,
2605+ )
2606+
2607+ def test_emit_cpu_tensors_no_extra_device_info (self ) -> None :
2608+ """When all tensors are on CPU (default), ExtraTensorInfo should NOT be
2609+ created solely for device info — it should remain None for activation tensors.
2610+ """
2611+
2612+ class Model (torch .nn .Module ):
2613+ def forward (self , a , b ):
2614+ return torch .add (a , b )
2615+
2616+ model = Model ()
2617+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2618+
2619+ edge = to_edge (
2620+ export (model , inputs ),
2621+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2622+ )
2623+ et_prog = edge .to_executorch ()
2624+ program = et_prog ._emitter_output .program
2625+
2626+ plan = program .execution_plan [0 ]
2627+ tensor_values = [v .val for v in plan .values if isinstance (v .val , Tensor )]
2628+ cuda_tensors = [
2629+ t
2630+ for t in tensor_values
2631+ if t .extra_tensor_info is not None
2632+ and t .extra_tensor_info .device_type == schema .DeviceType .CUDA
2633+ ]
2634+ self .assertEqual (
2635+ len (cuda_tensors ),
2636+ 0 ,
2637+ "No tensor should have CUDA device when model runs entirely on CPU" ,
2638+ )
0 commit comments