@@ -10,5 +10,63 @@ class TestParsing(unittest.TestCase):
1010 def test_gpt3 (self ):
1111 spec = af .Spec .from_yaml (INPUT_FILES_DIR / "gpt3_6.7B.yaml" )
1212 self .assertEqual (
13- spec .workload .einsum_names , ["V" , "K" , "Q" , "QK_softmax" , "Z" , "FFA" , "FFB" ]
13+ spec .workload .einsum_names ,
14+ ["V" , "K" , "Q" , "QK" , "QK_softmax" , "AV" , "Z" , "FFA" , "FFB" ],
1415 )
16+
17+
18+ class TestMangling (unittest .TestCase ):
19+ def setUp (self ):
20+ self .spec = af .Spec .from_yaml (INPUT_FILES_DIR / "gpt3_6.7B.yaml" )
21+ self .workload = self .spec .workload
22+ self .adapted = self .workload .get_adapted_workload ()
23+
24+ def test_gpt3 (self ):
25+ self .assertIn ("copy_I__I" , self .adapted .einsums ["Q" ].input_tensor_names )
26+
27+ def test_all_consumers_of_adapted_tensor_are_mangled (self ):
28+ # Every Einsum that read the original input I should now read the mangled
29+ # name instead, and should no longer reference I directly.
30+ for name in ["V" , "K" , "Q" ]:
31+ inputs = self .adapted .einsums [name ].input_tensor_names
32+ self .assertIn ("copy_I__I" , inputs )
33+ self .assertNotIn ("I" , inputs )
34+
35+ def test_copy_einsum_is_inserted (self ):
36+ # The adapter is lowered into a copy Einsum named after the adapter.
37+ self .assertIn ("copy_I" , self .adapted .einsum_names )
38+ copy_einsum = self .adapted .einsums ["copy_I" ]
39+ self .assertTrue (copy_einsum .is_copy_operation )
40+ self .assertEqual (copy_einsum .input_tensor_names , {"I" })
41+ self .assertEqual (copy_einsum .output_tensor_names , {"copy_I__I" })
42+
43+ def test_copy_einsum_mirrors_original_projection (self ):
44+ # The copy reads/writes the same ranks the original tensor was accessed by.
45+ copy_einsum = self .adapted .einsums ["copy_I" ]
46+ src = next (t for t in copy_einsum .tensor_accesses if t .name == "I" )
47+ dst = next (t for t in copy_einsum .tensor_accesses if t .name == "copy_I__I" )
48+ self .assertEqual (set (src .ranks ), {"B" , "M" , "D" })
49+ self .assertEqual (set (dst .ranks ), {"B" , "M" , "D" })
50+
51+ def test_original_tensor_only_remains_on_copy (self ):
52+ # After adapting, the original I is produced/consumed only by the copy
53+ # Einsum; downstream Einsums use the mangled name.
54+ einsums_with_I = [e .name for e in self .adapted .einsums_with_tensor ("I" )]
55+ self .assertEqual (einsums_with_I , ["copy_I" ])
56+
57+ def test_downstream_einsums_unaffected (self ):
58+ # Tensors unrelated to the adapter keep their names.
59+ qk_inputs = self .adapted .einsums ["QK" ].input_tensor_names
60+ self .assertEqual (qk_inputs , {"Q" , "K" })
61+
62+ def test_einsum_order_preserved (self ):
63+ self .assertEqual (
64+ self .adapted .einsum_names ,
65+ ["copy_I" , "V" , "K" , "Q" , "QK" , "QK_softmax" , "AV" , "Z" , "FFA" , "FFB" ],
66+ )
67+
68+ def test_original_workload_unchanged (self ):
69+ # get_adapted_workload returns a copy; the source workload is untouched.
70+ self .assertIn ("I" , self .workload .einsums ["Q" ].input_tensor_names )
71+ self .assertNotIn ("copy_I__I" , self .workload .einsums ["Q" ].input_tensor_names )
72+ self .assertNotIn ("copy_I" , self .workload .einsum_names )
0 commit comments