Skip to content

Commit 795fc3b

Browse files
Copilotxadupre
andauthored
Add unit tests for remove_inputs in InputObserver (#423)
* Initial plan * Add unit tests for remove_inputs in InputObserver Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent 48aec15 commit 795fc3b

1 file changed

Lines changed: 107 additions & 0 deletions

File tree

_unittests/ut_investigate/test_input_observer.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,5 +1197,112 @@ def forward(self, a, *args, **kwargs):
11971197
torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds)
11981198

11991199

1200+
def test_remove_inputs_kwargs(self):
1201+
"""Test that remove_inputs removes a kwarg from the observer info."""
1202+
1203+
class Model(torch.nn.Module):
1204+
def forward(self, x, y, z=None):
1205+
r = x + y
1206+
if z is not None:
1207+
r += z
1208+
return r
1209+
1210+
inputs = [
1211+
dict(x=torch.randn((5, 6)), y=torch.randn((1, 6)), z=torch.randn((5, 6))),
1212+
dict(x=torch.randn((7, 7)), y=torch.randn((1, 7)), z=torch.randn((7, 7))),
1213+
dict(x=torch.randn((7, 8)), y=torch.randn((1, 8)), z=torch.randn((7, 8))),
1214+
]
1215+
1216+
model = Model()
1217+
observer = InputObserver()
1218+
with observer(model):
1219+
for kwargs in inputs:
1220+
model(**kwargs)
1221+
self.assertEqual(len(observer.info), 3)
1222+
1223+
cst = torch.export.Dim.DYNAMIC
1224+
ds = observer.infer_dynamic_shapes()
1225+
self.assertIn("z", ds)
1226+
self.assertIn("x", ds)
1227+
self.assertIn("y", ds)
1228+
1229+
# Remove z input
1230+
observer.remove_inputs(["z"])
1231+
1232+
ds_after = observer.infer_dynamic_shapes()
1233+
self.assertNotIn("z", ds_after)
1234+
self.assertIn("x", ds_after)
1235+
self.assertIn("y", ds_after)
1236+
self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after)
1237+
1238+
args_after = observer.infer_arguments()
1239+
self.assertIsInstance(args_after, dict)
1240+
self.assertNotIn("z", args_after)
1241+
self.assertIn("x", args_after)
1242+
self.assertIn("y", args_after)
1243+
1244+
def test_remove_inputs_multiple_kwargs(self):
1245+
"""Test that remove_inputs removes multiple kwargs at once."""
1246+
1247+
class Model(torch.nn.Module):
1248+
def forward(self, x, y, z=None, w=None):
1249+
r = x + y
1250+
if z is not None:
1251+
r += z
1252+
if w is not None:
1253+
r += w
1254+
return r
1255+
1256+
inputs = [
1257+
dict(
1258+
x=torch.randn((5, 6)),
1259+
y=torch.randn((1, 6)),
1260+
z=torch.randn((5, 6)),
1261+
w=torch.randn((1, 6)),
1262+
),
1263+
dict(
1264+
x=torch.randn((6, 7)),
1265+
y=torch.randn((1, 7)),
1266+
z=torch.randn((6, 7)),
1267+
w=torch.randn((1, 7)),
1268+
),
1269+
dict(
1270+
x=torch.randn((7, 8)),
1271+
y=torch.randn((1, 8)),
1272+
z=torch.randn((7, 8)),
1273+
w=torch.randn((1, 8)),
1274+
),
1275+
]
1276+
1277+
model = Model()
1278+
observer = InputObserver()
1279+
with observer(model):
1280+
for kwargs in inputs:
1281+
model(**kwargs)
1282+
self.assertEqual(len(observer.info), 3)
1283+
1284+
cst = torch.export.Dim.DYNAMIC
1285+
ds = observer.infer_dynamic_shapes()
1286+
self.assertIn("z", ds)
1287+
self.assertIn("w", ds)
1288+
1289+
# Remove z and w inputs
1290+
observer.remove_inputs(["z", "w"])
1291+
1292+
ds_after = observer.infer_dynamic_shapes()
1293+
self.assertNotIn("z", ds_after)
1294+
self.assertNotIn("w", ds_after)
1295+
self.assertIn("x", ds_after)
1296+
self.assertIn("y", ds_after)
1297+
self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), ds_after)
1298+
1299+
args_after = observer.infer_arguments()
1300+
self.assertIsInstance(args_after, dict)
1301+
self.assertNotIn("z", args_after)
1302+
self.assertNotIn("w", args_after)
1303+
self.assertIn("x", args_after)
1304+
self.assertIn("y", args_after)
1305+
1306+
12001307
if __name__ == "__main__":
12011308
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)