Skip to content

Commit 82d3eed

Browse files
committed
[frontend,mapper,model] Refactor flattened arch into a class
1 parent 9779fd9 commit 82d3eed

13 files changed

Lines changed: 92 additions & 43 deletions

File tree

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
class FlattenedArch:
2+
"""
3+
A flattened arch is an architecture spec that has been
4+
flattened into a hierarchy for the purpose of mapping
5+
a particular Einsum.
6+
7+
Several steps (may not be exhaustive) are applied when
8+
an arch is flattened:
9+
- A compute unit is selected, and a flattened arch
10+
is a path from the root to that compute unit
11+
- Expressions have been evaluated in the context of an
12+
Einsum.
13+
- Non-hierarchical arch nodes (e.g., ones inside the
14+
`nodes` key of an `Array` has been inserted into
15+
a hierarchy.)
16+
17+
This class should only be relevant to the model and
18+
mapper. That is, the user should generally not define
19+
a flattened arch directly. So, unlike other classes in
20+
`frontend`, this one is intentionally *not* an
21+
`EvalableModel`.
22+
"""
23+
def __init__(self, nodes: list["Leaf"]):
24+
self.nodes = nodes
25+
26+
def __getitem__(self, idx: int | str | slice):
27+
if isinstance(idx, (int, slice)):
28+
return self.nodes[idx]
29+
elif isinstance(idx, str):
30+
for node in self.nodes:
31+
if node.name == idx:
32+
return node
33+
raise KeyError(f"arch node with name {idx} not found")
34+
raise ValueError(f"idx should be int or str, but instead {type(idx)}")
35+
36+
def __iter__(self):
37+
for node in self.nodes:
38+
yield node
39+
40+
def index(self, name: str):
41+
for i, node in enumerate(self.nodes):
42+
if node.name == name:
43+
return i
44+
raise ValueError(f"no node found with name {name}")
45+
46+
def is_above(self, name_a: str, name_b: str):
47+
"""
48+
Returns True if node with name_a is above node with name_b.
49+
Raises ValueError if either is not found.
50+
"""
51+
idx_a = self.index(name_a)
52+
idx_b = self.index(name_b)
53+
return idx_a < idx_b

accelforge/frontend/arch/structure.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from accelforge.util.exceptions import EvaluationError
2222

2323
from accelforge.frontend.arch.spatialable import Spatialable
24+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
2425

2526
from pydantic import Discriminator
2627
from accelforge.util._basetypes import _uninstantiable
@@ -328,7 +329,7 @@ def _flatten(
328329
compute_node: str,
329330
fanout: int = 1,
330331
return_fanout: bool = False,
331-
):
332+
) -> FlattenedArch:
332333
from accelforge.frontend.arch.components import Compute
333334

334335
nodes = []
@@ -350,6 +351,7 @@ def _flatten(
350351
e.add_field(node)
351352
raise e
352353

354+
nodes = FlattenedArch(nodes)
353355
if return_fanout:
354356
return nodes, fanout
355357
return nodes
@@ -418,7 +420,7 @@ def _flatten(
418420
compute_node: str,
419421
fanout: int = 1,
420422
return_fanout: bool = False,
421-
):
423+
) -> FlattenedArch:
422424
from accelforge.frontend.arch.components import Compute
423425

424426
nodes = []
@@ -466,6 +468,7 @@ def _flatten(
466468
e.add_field(node)
467469
raise e
468470

471+
nodes = FlattenedArch(nodes)
469472
if return_fanout:
470473
return nodes, fanout
471474
return nodes

accelforge/frontend/mapping/mapping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pydantic import ConfigDict, Discriminator, Tag, computed_field
3434
import sympy
3535

36+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
3637
from accelforge.frontend.renames import EinsumName, TensorName
3738
from accelforge.util._basetypes import (
3839
# Parsing helpers for the input files.
@@ -1285,7 +1286,7 @@ def compact_str(self) -> str:
12851286
def _get_single_tensor_mapping(
12861287
self,
12871288
tensor_name: TensorName,
1288-
flattened_arch: list[arch.Leaf],
1289+
flattened_arch: FlattenedArch,
12891290
tensor_rank_variables: set[str],
12901291
) -> Self:
12911292
"""

accelforge/frontend/spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Container,
1313
Spatialable,
1414
)
15+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
1516

1617
from accelforge.frontend.workload import Workload
1718
from accelforge.frontend.variables import Variables
@@ -276,7 +277,7 @@ def _get_flattened_architecture(
276277
self,
277278
compute_node: str | Compute | None = None,
278279
einsum_name: EinsumName | None = None,
279-
) -> list[list[Leaf]] | list[Leaf]:
280+
) -> list[FlattenedArch] | FlattenedArch:
280281
"""
281282
Return the architecture as paths of ``Leaf`` instances from the highest-level
282283
node to each ``Compute`` node. Parses arithmetic expressions in the

accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import List
44
from accelforge._accelerated_imports import np
5+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
56
from accelforge.frontend._workload_isl._symbolic import PartiallyRelevant, Relevant
67
import accelforge.frontend.arch as arch
78
from accelforge.frontend.arch import (
@@ -196,7 +197,7 @@ def constrained_loops(
196197

197198

198199
def get_constraints(
199-
flattened_arch: list[arch.Leaf],
200+
flattened_arch: FlattenedArch,
200201
mapping: List[MappingNode],
201202
symbol_table: dict[str, InvertibleSet],
202203
einsum_name: EinsumName,

accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import Enum
66

77
import accelforge.frontend.arch as arch
8+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
89
from accelforge.util._frozenset import oset
910
from accelforge.frontend.mapping import (
1011
MappingNode,
@@ -39,15 +40,14 @@ def insert_temporal_loops(
3940
ranks_with_tile_pattern: set,
4041
workload: Workload,
4142
_can_lower_outermost_memory: bool,
42-
flattened_arch: list[arch.Leaf],
43+
flattened_arch: FlattenedArch,
4344
max_fused_loops: int,
4445
fanouts: dict[str, int],
4546
fusable_tensors: set[TensorName],
4647
intermediate_tensors: set[TensorName],
4748
let_non_intermediate_tensors_respawn_in_backing_storage: bool,
4849
explore_loop_orders: bool,
4950
):
50-
arch_node_names = [n.name for n in flattened_arch]
5151
# First establish insertion points. Insertion points are:
5252
# - Below the last instance of the first memory
5353
# - Between any two TensorHolder nodes
@@ -88,7 +88,7 @@ def insert_temporal_loops(
8888
for s in split_mapping:
8989
# Within each split mapping group, sort by arch levels.
9090
# This can help create places to put spatial loops
91-
s.sort(key=lambda tensor_holder: arch_node_names.index(tensor_holder.component))
91+
s.sort(key=lambda tensor_holder: flattened_arch.index(tensor_holder.component))
9292

9393
if sum(map(len, split_mapping)) != len(mapping):
9494
raise RuntimeError("BUG: number of storage nodes post-split != original")
@@ -333,11 +333,10 @@ def _get_next_storages(i: int, toll_allowed: bool = False) -> list[TensorHolder]
333333
def insert_spatial_loops(
334334
mapping: list[MappingNode],
335335
einsum: Einsum,
336-
flattened_arch: list[arch.Memory],
336+
flattened_arch: FlattenedArch,
337337
intermediate_tensors: set[TensorName],
338338
):
339339
nodes_with_fanout = [n for n in flattened_arch if n.get_fanout() > 1]
340-
arch_node_names = [n.name for n in flattened_arch]
341340
tensor2fully_relevant_rank_vars = einsum.tensor2directly_indexing_rank_variables
342341
simple_rank_variables = einsum._simple_rank_variables
343342

@@ -346,7 +345,7 @@ def insert_spatial_loops(
346345
# above the fanout in the arch, and below any temporal loops in the
347346
# same block.
348347
insertion_point = _idx_below_lowest_tensor_holder_with_component_above_fanout(
349-
node, mapping, arch_node_names
348+
node, mapping, flattened_arch
350349
)
351350
while insertion_point < len(mapping) and isinstance(
352351
mapping[insertion_point], Temporal
@@ -386,16 +385,16 @@ def _tensors_seen_above_point(idx, mapping):
386385

387386

388387
def _idx_below_lowest_tensor_holder_with_component_above_fanout(
389-
fanout_node, mapping, arch_node_names
388+
fanout_node, mapping, flattened_arch: FlattenedArch
390389
):
391390
"""Return the index right after the lowest TensorHolder whose component
392391
is above the fanout in the arch. If none found, returns len(mapping)."""
393-
fanout_arch_idx = arch_node_names.index(fanout_node.name)
392+
fanout_arch_idx = flattened_arch.index(fanout_node.name)
394393
result = 0
395394
for i in range(len(mapping)):
396395
if not isinstance(mapping[i], TensorHolder):
397396
continue
398-
if arch_node_names.index(mapping[i].component) < fanout_arch_idx:
397+
if flattened_arch.index(mapping[i].component) < fanout_arch_idx:
399398
result = i + 1
400399
return result
401400

accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tqdm import tqdm
1212

1313
import accelforge.frontend.arch as arch
14+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
1415
from accelforge.util._frozenset import oset
1516
from accelforge.frontend.mapping import (
1617
Compute,
@@ -86,7 +87,7 @@ def unpack_loops_to_rank_variables(mapping: List[MappingNode]):
8687
# Iterate over mappings
8788
# =================================================================================================
8889
def place_missing_temporal_loops(
89-
mapping: List[MappingNode], einsum: Einsum, flattened_arch: list[arch.Leaf]
90+
mapping: List[MappingNode], einsum: Einsum, flattened_arch: FlattenedArch
9091
):
9192
"""
9293
Adds temporal loops to the mapping to fill in any rank variables that are missing.
@@ -132,7 +133,7 @@ def place_missing_temporal_loops(
132133

133134
def remove_unordered_spatial_temporal_loops(
134135
mapping: list[MappingNode],
135-
flattened_arch: list[arch.Leaf],
136+
flattened_arch: FlattenedArch,
136137
einsum: Einsum,
137138
explore_unordered_spatial_loops: bool = True,
138139
):
@@ -258,7 +259,7 @@ def assert_proper_fusion_labeling(
258259
def iterate_mappings_no_constraints(
259260
spec: Spec,
260261
einsum_name: str,
261-
flattened_arch: list[arch.Leaf],
262+
flattened_arch: FlattenedArch,
262263
rank_variable_bounds: dict[RankVariable, int],
263264
job: Job,
264265
) -> Iterator[tuple[Mapping, SymbolTable, arch.Compute, int]]:
@@ -331,7 +332,7 @@ def iterate_mappings_no_constraints(
331332
def iterate_mappings_constraints(
332333
spec: Spec,
333334
einsum_names: list[str] | str,
334-
flattened_arch: list[arch.Leaf],
335+
flattened_arch: FlattenedArch,
335336
rank_variable_bounds: dict[RankVariable, int],
336337
tensor_to_relevancy: dict[
337338
TensorName, dict[RankVariable, Relevant | PartiallyRelevant]

accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any
33

44
import accelforge.frontend.arch as arch
5+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
56
from accelforge.frontend.mapping import MappingNode, Reservation, Storage, TensorHolder
67

78

@@ -22,7 +23,7 @@ def _recursive_iter_fence_positions(
2223

2324
def get_reservation_choices(
2425
mapping: list[TensorHolder],
25-
flattened_arch: list[arch.Leaf],
26+
flattened_arch: FlattenedArch,
2627
) -> Generator[tuple[list[TensorHolder], Any], None, None]:
2728
# Rules:
2829
# - In general, reservations go right under their storage node

accelforge/mapper/FFM/_make_pmappings/pmapper_job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from uuid import UUID, uuid4
77

88
import accelforge.frontend.arch as arch
9+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
910
from accelforge.util._frozenset import oset
1011
from accelforge.frontend.mapping import (
1112
Mapping,
@@ -57,7 +58,7 @@ class Job:
5758
mapping: Mapping | None = None
5859
constraints: MappingConstraints | None = None
5960
fusable_tensors: set[TensorName] | None = None
60-
flattened_arch: list[arch.Leaf] | None = None
61+
flattened_arch: FlattenedArch | None = None
6162

6263
einsum_name: EinsumName | None = None
6364
"""If the Job is for a single einsum, this is the einsum name."""

accelforge/model/_looptree/latency/memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from accelforge.frontend import arch
44
from accelforge.frontend.arch import Leaf, Memory, TensorHolder, Component
5+
from accelforge.frontend.arch._flattened_arch import FlattenedArch
56
from accelforge.frontend.mapping import Compute, Mapping
67
from accelforge.frontend.spec import Spec
78

@@ -39,7 +40,7 @@ def isl_to_summarized(
3940

4041
def component_latency(
4142
looptree_results: SymbolicAnalysisOutput,
42-
flattened_arch: list[Leaf],
43+
flattened_arch: FlattenedArch,
4344
mapping: Mapping,
4445
spec: Spec,
4546
):

0 commit comments

Comments
 (0)