Skip to content

Commit d4e3906

Browse files
committed
wip
1 parent 7499f20 commit d4e3906

6 files changed

Lines changed: 215 additions & 6 deletions

File tree

accelforge/frontend/renames.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def _eval_expressions(self, symbol_table: dict[str, Any], *args, **kwargs):
5050
)
5151
raise EvaluationError(
5252
f"Expected count is {evaluated.expected_count}, but got "
53-
f"{len(evaluated.source)}: {evaluated.source}",
53+
f"{len(evaluated.source)}: {evaluated.source}\n"
54+
f"Symbol table: {symbol_table}",
5455
source_field="source",
5556
)
5657
return evaluated, symbol_table

accelforge/frontend/workload.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ def _canonicalize_einsums_and_adapters(self, data):
973973
)
974974

975975
def is_einsum(d):
976-
return not hasattr(d, "tag") or d.tag == "!Einsum"
976+
return _get_tag(d, default="Einsum") == "Einsum"
977977

978978
if has_einsums:
979979
data["einsums_and_adapters"] = data["einsums"]
@@ -1362,3 +1362,79 @@ def get_compute_intensity(self, einsum_name: str) -> float:
13621362
self.get_tensor_size(tensor)
13631363
for tensor in self.einsums[einsum_name].tensor_names
13641364
)
1365+
1366+
def get_adapted_workload(self) -> "Workload":
1367+
"""
1368+
Return a (deep) copy of the workload that has tensor names mangled
1369+
via the adapter.
1370+
1371+
In general, a tensor will be renamed <name of adapter>__<tensor name>.
1372+
1373+
Each adapter is turned into a copy Einsum that reads the original tensor
1374+
and writes the mangled tensor. Every Einsum that comes after the adapter
1375+
and accesses the original tensor is rewired to use the mangled name, so
1376+
the copy sits logically between the tensor's producer (or workload input)
1377+
and its downstream consumers.
1378+
1379+
Returns
1380+
-------
1381+
Workload
1382+
A new workload with adapters lowered into copy Einsums and the
1383+
affected tensor names mangled.
1384+
"""
1385+
1386+
def _projection_for(tensor: TensorName) -> dict[Rank, str]:
1387+
"""Find how `tensor` is accessed so the copy Einsum can mirror it."""
1388+
for item in self.einsums_and_adapters:
1389+
if not isinstance(item, Einsum):
1390+
continue
1391+
for ta in item.tensor_accesses:
1392+
if ta.name == tensor:
1393+
return dict(ta.projection)
1394+
raise ValueError(
1395+
f"Adapter references tensor {tensor}, but no Einsum accesses it."
1396+
)
1397+
1398+
new_einsums: list[Einsum] = []
1399+
# Maps an original tensor name to its mangled name, for every adapter seen
1400+
# so far. Einsums after an adapter use the mangled name.
1401+
mangled: dict[TensorName, TensorName] = {}
1402+
1403+
for item in self.einsums_and_adapters:
1404+
if isinstance(item, Einsum):
1405+
einsum = item.model_copy(deep=True)
1406+
for ta in einsum.tensor_accesses:
1407+
if ta.name in mangled:
1408+
ta.name = mangled[ta.name]
1409+
new_einsums.append(einsum)
1410+
elif isinstance(item, CopyAdapter):
1411+
tensor = TensorName(item.tensor)
1412+
new_name = f"{item.name}__{tensor}"
1413+
# The copy reads whatever name the tensor currently has (it may
1414+
# itself have been mangled by an earlier adapter) and writes the
1415+
# newly mangled name.
1416+
source = mangled.get(tensor, tensor)
1417+
projection = _projection_for(tensor)
1418+
copy_einsum = Einsum(
1419+
name=item.name,
1420+
tensor_accesses=[
1421+
{"name": source, "projection": projection, "output": False},
1422+
{"name": new_name, "projection": projection, "output": True},
1423+
],
1424+
is_copy_operation=True,
1425+
)
1426+
new_einsums.append(copy_einsum)
1427+
mangled[tensor] = new_name
1428+
else:
1429+
raise ValueError(
1430+
f"Unsupported adapter type {type(item).__name__} in workload."
1431+
)
1432+
1433+
return Workload(
1434+
einsums=new_einsums,
1435+
iteration_space_shape=self.iteration_space_shape,
1436+
rank_sizes=self.rank_sizes,
1437+
n_instances=self.n_instances,
1438+
bits_per_value=self.bits_per_value,
1439+
persistent_tensors=self.persistent_tensors,
1440+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from accelforge import Spec
2+
from accelforge.frontend.mapper.metrics import Metrics
3+
from accelforge.mapper.FFM.pmappings import MultiEinsumPmappings
4+
from accelforge.mapper.FFM.mappings import Mappings
5+
import accelforge.mapper.FFM._make_pmappings.make_pmappings as pmapper
6+
from accelforge.frontend.workload import EinsumName
7+
from accelforge.util._frozenset import oset
8+
9+
10+
def make_copy_adapter(spec: Spec) -> MultiEinsumPmappings:
11+
"""
12+
Return a MultiEinsumPmappings that simply allows two pmappings to be
13+
compatible iff they are already compatible.
14+
"""
15+
return MultiEinsumPmappings(
16+
spec=spec,
17+
einsum2pmappings={},
18+
pmapping_objects={},
19+
einsum2jobs={},
20+
can_combine_multiple_runs=True,
21+
einsums_with_pmappings_generated=oset(),
22+
flattened_arches={},
23+
evaluated_specs={},
24+
)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from accelforge.frontend import Spec, EinsumName
2+
3+
4+
def add_adapter_between_einsums(
5+
spec: Spec, einsum_a: EinsumName, einsum_b: EinsumName, adapter
6+
):
7+
"""Insert an adapter between Einsums in a spec, modifying the intermediate tensor names."""
8+
workload = spec.workload
9+
e_a = workload.einsums[einsum_a]
10+
e_b = workload.einsums[einsum_b]
11+
12+
a2b = e_a.output_tensor_names & e_b.input_tensor_names
13+
b2a = e_b.output_tensor_names & e_a.input_tensor_names
14+
15+
if a2b and not b2a:
16+
producer, consumer = e_a, e_b
17+
intermediate = next(iter(a2b))
18+
elif b2a and not a2b:
19+
producer, consumer = e_b, e_a
20+
intermediate = next(iter(b2a))
21+
else:
22+
raise ValueError(
23+
f"Cannot insert adapter: expected exactly one intermediate tensor "
24+
f"flowing between {einsum_a} and {einsum_b}, found a->b={a2b}, b->a={b2a}"
25+
)
26+
27+
new_name = f"{intermediate}_{adapter.name}"
28+
29+
for ta in consumer.tensor_accesses:
30+
if ta.name == intermediate:
31+
ta.name = new_name
32+
33+
adapter_inputs = [ta for ta in adapter.tensor_accesses if not ta.output]
34+
adapter_outputs = [ta for ta in adapter.tensor_accesses if ta.output]
35+
if len(adapter_inputs) != 1 or len(adapter_outputs) != 1:
36+
raise ValueError(
37+
f"Adapter {adapter.name} must have exactly one input and one output "
38+
f"tensor access, found {len(adapter_inputs)} inputs and "
39+
f"{len(adapter_outputs)} outputs"
40+
)
41+
adapter_inputs[0].name = intermediate
42+
adapter_outputs[0].name = new_name
43+
44+
producer_idx = next(
45+
i for i, e in enumerate(workload.einsums) if e.name == producer.name
46+
)
47+
workload.einsums.insert(producer_idx + 1, adapter)

tests/input_files/adapters/gpt3_6.7B.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ workload:
3535
- !Copy
3636
name: copy_I
3737
tensor: I
38-
- "V[b, m, h, e] = I[b, m, d] * WV[h, e, d]"
39-
- "K[b, m, h, e] = I[b, m, d] * WK[h, e, d]"
40-
- "Q[b, m, h, e] = I[b, m, d] * WQ[h, e, d]"
38+
- einsum: "V[b, m, h, e] = I[b, m, d] * WV[h, e, d]"
39+
renames: {input: I}
40+
- einsum: "K[b, m, h, e] = I[b, m, d] * WK[h, e, d]"
41+
renames: {input: I}
42+
- einsum: "Q[b, m, h, e] = I[b, m, d] * WQ[h, e, d]"
43+
renames: {input: I}
4144

4245
- einsum: "QK[b, m, p, h] = Q[b, m, h, e] * K[b, M: p, h, e]"
4346
renames: {input: Q}

tests/test_adapter.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,63 @@ class TestParsing(unittest.TestCase):
1010
def test_gpt3(self):
1111
spec = af.Spec.from_yaml(INPUT_FILES_DIR / "gpt3_6.7B.yaml")
1212
self.assertEqual(
13-
spec.workload.einsum_names, ["V", "K", "Q", "QK_softmax", "Z", "FFA", "FFB"]
13+
spec.workload.einsum_names,
14+
["V", "K", "Q", "QK", "QK_softmax", "AV", "Z", "FFA", "FFB"],
1415
)
16+
17+
18+
class TestMangling(unittest.TestCase):
19+
def setUp(self):
20+
self.spec = af.Spec.from_yaml(INPUT_FILES_DIR / "gpt3_6.7B.yaml")
21+
self.workload = self.spec.workload
22+
self.adapted = self.workload.get_adapted_workload()
23+
24+
def test_gpt3(self):
25+
self.assertIn("copy_I__I", self.adapted.einsums["Q"].input_tensor_names)
26+
27+
def test_all_consumers_of_adapted_tensor_are_mangled(self):
28+
# Every Einsum that read the original input I should now read the mangled
29+
# name instead, and should no longer reference I directly.
30+
for name in ["V", "K", "Q"]:
31+
inputs = self.adapted.einsums[name].input_tensor_names
32+
self.assertIn("copy_I__I", inputs)
33+
self.assertNotIn("I", inputs)
34+
35+
def test_copy_einsum_is_inserted(self):
36+
# The adapter is lowered into a copy Einsum named after the adapter.
37+
self.assertIn("copy_I", self.adapted.einsum_names)
38+
copy_einsum = self.adapted.einsums["copy_I"]
39+
self.assertTrue(copy_einsum.is_copy_operation)
40+
self.assertEqual(copy_einsum.input_tensor_names, {"I"})
41+
self.assertEqual(copy_einsum.output_tensor_names, {"copy_I__I"})
42+
43+
def test_copy_einsum_mirrors_original_projection(self):
44+
# The copy reads/writes the same ranks the original tensor was accessed by.
45+
copy_einsum = self.adapted.einsums["copy_I"]
46+
src = next(t for t in copy_einsum.tensor_accesses if t.name == "I")
47+
dst = next(t for t in copy_einsum.tensor_accesses if t.name == "copy_I__I")
48+
self.assertEqual(set(src.ranks), {"B", "M", "D"})
49+
self.assertEqual(set(dst.ranks), {"B", "M", "D"})
50+
51+
def test_original_tensor_only_remains_on_copy(self):
52+
# After adapting, the original I is produced/consumed only by the copy
53+
# Einsum; downstream Einsums use the mangled name.
54+
einsums_with_I = [e.name for e in self.adapted.einsums_with_tensor("I")]
55+
self.assertEqual(einsums_with_I, ["copy_I"])
56+
57+
def test_downstream_einsums_unaffected(self):
58+
# Tensors unrelated to the adapter keep their names.
59+
qk_inputs = self.adapted.einsums["QK"].input_tensor_names
60+
self.assertEqual(qk_inputs, {"Q", "K"})
61+
62+
def test_einsum_order_preserved(self):
63+
self.assertEqual(
64+
self.adapted.einsum_names,
65+
["copy_I", "V", "K", "Q", "QK", "QK_softmax", "AV", "Z", "FFA", "FFB"],
66+
)
67+
68+
def test_original_workload_unchanged(self):
69+
# get_adapted_workload returns a copy; the source workload is untouched.
70+
self.assertIn("I", self.workload.einsums["Q"].input_tensor_names)
71+
self.assertNotIn("copy_I__I", self.workload.einsums["Q"].input_tensor_names)
72+
self.assertNotIn("copy_I", self.workload.einsum_names)

0 commit comments

Comments
 (0)