@@ -2518,3 +2518,128 @@ def forward(self):
25182518 for j in range (2 ):
25192519 expected_storage .append (j * 16 + i )
25202520 self .assertEqual ([int (v ) for v in storage_values ], expected_storage )
2521+
2522+ def test_emit_device_info_propagated_to_serialized_tensor (self ) -> None :
2523+ """Verify that device info from PropagateDevicePass flows through
2524+ the emitter into ExtraTensorInfo.device_type on serialized tensors."""
2525+ from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
2526+ generate_pattern_op_partitions ,
2527+ )
2528+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2529+ from executorch .exir .backend .partitioner import (
2530+ DelegationSpec ,
2531+ Partitioner ,
2532+ PartitionResult ,
2533+ )
2534+ from executorch .exir .backend .test .backend_with_compiler_demo import (
2535+ BackendWithCompilerDemo ,
2536+ )
2537+ from executorch .exir .passes .propagate_device_pass import (
2538+ TARGET_DEVICE_COMPILE_SPEC_KEY ,
2539+ )
2540+ from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2541+
2542+ class AddSupport (OperatorSupportBase ):
2543+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
2544+ return node .op == "call_function" and node .target in [
2545+ exir_ops .edge .aten .add .Tensor ,
2546+ ]
2547+
2548+ class DevicePartitioner (Partitioner ):
2549+ def __init__ (self ):
2550+ super ().__init__ ()
2551+ self .delegation_spec = DelegationSpec (
2552+ BackendWithCompilerDemo .__name__ ,
2553+ [
2554+ CompileSpec ("max_value" , bytes ([4 ])),
2555+ CompileSpec (TARGET_DEVICE_COMPILE_SPEC_KEY , b"cuda:0" ),
2556+ ],
2557+ )
2558+
2559+ def partition (self , exported_program ) -> PartitionResult :
2560+ partition_tags = {}
2561+ partition_list = generate_pattern_op_partitions (
2562+ exported_program .graph_module ,
2563+ op_support = any_chain (AddSupport ()),
2564+ )
2565+ for partition in partition_list :
2566+ for node in partition .nodes :
2567+ tag = f"tag{ partition .id } "
2568+ node .meta ["delegation_tag" ] = tag
2569+ partition_tags [tag ] = self .delegation_spec
2570+ return PartitionResult (
2571+ tagged_exported_program = exported_program ,
2572+ partition_tags = partition_tags ,
2573+ )
2574+
2575+ class Model (torch .nn .Module ):
2576+ def forward (self , a , b ):
2577+ return torch .add (a , b )
2578+
2579+ model = Model ()
2580+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2581+
2582+ edge = to_edge (
2583+ export (model , inputs ),
2584+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2585+ )
2586+ lowered = edge .to_backend (DevicePartitioner ())
2587+ et_prog = lowered .to_executorch ()
2588+ program = et_prog ._emitter_output .program
2589+
2590+ plan = program .execution_plan [0 ]
2591+ self .assertGreater (len (plan .delegates ), 0 )
2592+
2593+ tensor_values = [v .val for v in plan .values if isinstance (v .val , Tensor )]
2594+ cuda_tensors = [
2595+ t
2596+ for t in tensor_values
2597+ if t .extra_tensor_info is not None
2598+ and t .extra_tensor_info .device_type == schema .DeviceType .CUDA
2599+ ]
2600+ # add(a, b) has 2 delegate inputs + 1 delegate output = 3 CUDA tensors
2601+ self .assertEqual (
2602+ len (cuda_tensors ),
2603+ 3 ,
2604+ f"Expected exactly 3 CUDA tensors (2 inputs + 1 output for delegated add), got { len (cuda_tensors )} " ,
2605+ )
2606+ # Verify device_index is also correctly serialized (cuda:0 → index 0)
2607+ for t in cuda_tensors :
2608+ self .assertEqual (
2609+ t .extra_tensor_info .device_index ,
2610+ 0 ,
2611+ "CUDA tensor device_index should be 0 for cuda:0" ,
2612+ )
2613+
2614+ def test_emit_cpu_tensors_no_extra_device_info (self ) -> None :
2615+ """When all tensors are on CPU (default), ExtraTensorInfo should NOT be
2616+ created solely for device info — it should remain None for activation tensors.
2617+ """
2618+
2619+ class Model (torch .nn .Module ):
2620+ def forward (self , a , b ):
2621+ return torch .add (a , b )
2622+
2623+ model = Model ()
2624+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2625+
2626+ edge = to_edge (
2627+ export (model , inputs ),
2628+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2629+ )
2630+ et_prog = edge .to_executorch ()
2631+ program = et_prog ._emitter_output .program
2632+
2633+ plan = program .execution_plan [0 ]
2634+ tensor_values = [v .val for v in plan .values if isinstance (v .val , Tensor )]
2635+ non_cpu_tensors = [
2636+ t
2637+ for t in tensor_values
2638+ if t .extra_tensor_info is not None
2639+ and t .extra_tensor_info .device_type is not None
2640+ ]
2641+ self .assertEqual (
2642+ len (non_cpu_tensors ),
2643+ 0 ,
2644+ "No tensor should have extra device info when model runs entirely on CPU" ,
2645+ )
0 commit comments