@@ -1197,5 +1197,112 @@ def forward(self, a, *args, **kwargs):
11971197 torch .export .export (model , args , kwargs = kwargs , dynamic_shapes = ds )
11981198
11991199
1200+ def test_remove_inputs_kwargs (self ):
1201+ """Test that remove_inputs removes a kwarg from the observer info."""
1202+
1203+ class Model (torch .nn .Module ):
1204+ def forward (self , x , y , z = None ):
1205+ r = x + y
1206+ if z is not None :
1207+ r += z
1208+ return r
1209+
1210+ inputs = [
1211+ dict (x = torch .randn ((5 , 6 )), y = torch .randn ((1 , 6 )), z = torch .randn ((5 , 6 ))),
1212+ dict (x = torch .randn ((7 , 7 )), y = torch .randn ((1 , 7 )), z = torch .randn ((7 , 7 ))),
1213+ dict (x = torch .randn ((7 , 8 )), y = torch .randn ((1 , 8 )), z = torch .randn ((7 , 8 ))),
1214+ ]
1215+
1216+ model = Model ()
1217+ observer = InputObserver ()
1218+ with observer (model ):
1219+ for kwargs in inputs :
1220+ model (** kwargs )
1221+ self .assertEqual (len (observer .info ), 3 )
1222+
1223+ cst = torch .export .Dim .DYNAMIC
1224+ ds = observer .infer_dynamic_shapes ()
1225+ self .assertIn ("z" , ds )
1226+ self .assertIn ("x" , ds )
1227+ self .assertIn ("y" , ds )
1228+
1229+ # Remove z input
1230+ observer .remove_inputs (["z" ])
1231+
1232+ ds_after = observer .infer_dynamic_shapes ()
1233+ self .assertNotIn ("z" , ds_after )
1234+ self .assertIn ("x" , ds_after )
1235+ self .assertIn ("y" , ds_after )
1236+ self .assertEqual (dict (x = {0 : cst , 1 : cst }, y = {1 : cst }), ds_after )
1237+
1238+ args_after = observer .infer_arguments ()
1239+ self .assertIsInstance (args_after , dict )
1240+ self .assertNotIn ("z" , args_after )
1241+ self .assertIn ("x" , args_after )
1242+ self .assertIn ("y" , args_after )
1243+
1244+ def test_remove_inputs_multiple_kwargs (self ):
1245+ """Test that remove_inputs removes multiple kwargs at once."""
1246+
1247+ class Model (torch .nn .Module ):
1248+ def forward (self , x , y , z = None , w = None ):
1249+ r = x + y
1250+ if z is not None :
1251+ r += z
1252+ if w is not None :
1253+ r += w
1254+ return r
1255+
1256+ inputs = [
1257+ dict (
1258+ x = torch .randn ((5 , 6 )),
1259+ y = torch .randn ((1 , 6 )),
1260+ z = torch .randn ((5 , 6 )),
1261+ w = torch .randn ((1 , 6 )),
1262+ ),
1263+ dict (
1264+ x = torch .randn ((6 , 7 )),
1265+ y = torch .randn ((1 , 7 )),
1266+ z = torch .randn ((6 , 7 )),
1267+ w = torch .randn ((1 , 7 )),
1268+ ),
1269+ dict (
1270+ x = torch .randn ((7 , 8 )),
1271+ y = torch .randn ((1 , 8 )),
1272+ z = torch .randn ((7 , 8 )),
1273+ w = torch .randn ((1 , 8 )),
1274+ ),
1275+ ]
1276+
1277+ model = Model ()
1278+ observer = InputObserver ()
1279+ with observer (model ):
1280+ for kwargs in inputs :
1281+ model (** kwargs )
1282+ self .assertEqual (len (observer .info ), 3 )
1283+
1284+ cst = torch .export .Dim .DYNAMIC
1285+ ds = observer .infer_dynamic_shapes ()
1286+ self .assertIn ("z" , ds )
1287+ self .assertIn ("w" , ds )
1288+
1289+ # Remove z and w inputs
1290+ observer .remove_inputs (["z" , "w" ])
1291+
1292+ ds_after = observer .infer_dynamic_shapes ()
1293+ self .assertNotIn ("z" , ds_after )
1294+ self .assertNotIn ("w" , ds_after )
1295+ self .assertIn ("x" , ds_after )
1296+ self .assertIn ("y" , ds_after )
1297+ self .assertEqual (dict (x = {0 : cst , 1 : cst }, y = {1 : cst }), ds_after )
1298+
1299+ args_after = observer .infer_arguments ()
1300+ self .assertIsInstance (args_after , dict )
1301+ self .assertNotIn ("z" , args_after )
1302+ self .assertNotIn ("w" , args_after )
1303+ self .assertIn ("x" , args_after )
1304+ self .assertIn ("y" , args_after )
1305+
1306+
12001307if __name__ == "__main__" :
12011308 unittest .main (verbosity = 2 )
0 commit comments