77# pyre-strict
88
99import inspect
10+ import operator
1011import unittest
1112from typing import Callable
1213
@@ -483,12 +484,18 @@ def _build_addmm_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
483484 self .assertEqual (len (addmm_nodes ), 1 , "Should find exactly one addmm node" )
484485 return gm , addmm_nodes [0 ]
485486
486- def _build_max_pool2d_graph (self ) -> tuple [torch .fx .GraphModule , torch .fx .Node ]:
487- """Build a simple graph with a max_pool2d_with_indices operation."""
487+ def _build_max_pool2d_graph (
488+ self ,
489+ ) -> tuple [torch .fx .GraphModule , torch .fx .Node , torch .fx .Node ]:
490+ """Build a graph with max_pool2d_with_indices followed by getitem[0].
491+
492+ Returns:
493+ A tuple of (graph_module, getitem_node, max_pool_node).
494+ The getitem_node is where the output annotation is placed.
495+ The max_pool_node is where the input annotation is placed.
496+ """
488497 builder = GraphBuilder ()
489- # Input shape: (batch, channels, height, width)
490498 x = builder .placeholder ("x" , torch .randn (1 , 3 , 8 , 8 ))
491- # max_pool2d_with_indices args: (input, kernel_size, stride, padding, dilation, ceil_mode)
492499 max_pool = builder .call_operator (
493500 op = torch .ops .aten .max_pool2d_with_indices .default ,
494501 args = (x , [2 , 2 ], [2 , 2 ], [0 , 0 ], [1 , 1 ], False ),
@@ -503,19 +510,24 @@ def _build_max_pool2d_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
503510 }
504511 ),
505512 )
506- builder .output ([max_pool ])
513+ getitem = builder .call_operator (
514+ op = operator .getitem ,
515+ args = (max_pool , 0 ),
516+ )
517+ builder .output ([getitem ])
507518 gm = builder .get_graph_module ()
508519
509520 max_pool_nodes = gm .graph .find_nodes (
510521 op = "call_function" ,
511522 target = torch .ops .aten .max_pool2d_with_indices .default ,
512523 )
513- self .assertEqual (
514- len ( max_pool_nodes ),
515- 1 ,
516- "Should find exactly one max_pool2d_with_indices node" ,
524+ self .assertEqual (len ( max_pool_nodes ), 1 )
525+ getitem_nodes = gm . graph . find_nodes (
526+ op = "call_function" ,
527+ target = operator . getitem ,
517528 )
518- return gm , max_pool_nodes [0 ]
529+ self .assertEqual (len (getitem_nodes ), 1 )
530+ return gm , getitem_nodes [0 ], max_pool_nodes [0 ]
519531
520532 def _build_add_relu_graph (
521533 self ,
0 commit comments