Skip to content

Commit 488f4b1

Browse files
committed
[network] Refactor network cost to handle different topologies
1 parent 7ad792a commit 488f4b1

4 files changed

Lines changed: 207 additions & 63 deletions

File tree

accelforge/frontend/arch/components.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,7 @@ def _render_node_color(self) -> str:
12941294
return "#E0EEFF"
12951295

12961296

1297-
class TopologySpec(str, enum.Enum):
1297+
class TopologySpec(enum.StrEnum):
12981298
MESH = "mesh"
12991299

13001300

accelforge/model/_looptree/latency/memory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def component_latency(
109109

110110
network_to_max_link_traffic = defaultdict(lambda: defaultdict(lambda: 0))
111111
network_to_max_hops = defaultdict(lambda: [])
112+
# Aggregates across tensors
112113
for network, network_stats in looptree_results.network_stats.items():
113114
component = network.component
114115
if component not in name2component:
Lines changed: 201 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any
4+
15
from accelforge.frontend.mapping import (
26
Spatial
37
)
8+
from accelforge.frontend.arch.components import TopologySpec
49
from accelforge.frontend._workload_isl._symbolic import (
510
compute_dense_tile_occupancy,
611
Irrelevant,
@@ -14,22 +19,178 @@
1419
from ._stats import NetworkStats, SymbolicAnalysisOutput
1520

1621

17-
class NetworkAnalyzer:
18-
def __init__(self, network_stats):
22+
@dataclass
23+
class PerLoopTransferCost:
24+
"""The per-spatial-loop cost contributed by a single network, as computed
25+
by a :class:`TopologyModel`."""
26+
27+
total_cost: Any
28+
"""Total hops contributed by data movement over this spatial loop."""
29+
max_hops: Any
30+
"""Hops added to the longest route by this spatial loop."""
31+
max_traffic: Any
32+
"""Maximum traffic (in actions) on any single link along this dimension."""
33+
34+
35+
class TopologyModel(ABC):
36+
"""Computes the cost of moving data across a network of a given topology.
37+
38+
Subclasses encapsulate everything topology-specific about how a tensor's
39+
data is delivered across a spatial fanout. :class:`NetworkAnalyzer` selects
40+
the model for each network from its component's
41+
:class:`~accelforge.frontend.arch.components.TopologySpec` and remains
42+
agnostic to the topology itself.
43+
44+
Instances are stateful: they accumulate per-network max hops across the
45+
repeated spatial-loop iterations of a single :class:`NetworkAnalyzer`, so a
46+
fresh model is constructed for each analyzer (see :func:`get_topology_model`).
47+
"""
48+
49+
def __init__(self):
50+
# Running total of max hops per network, accumulated across the
51+
# repeated spatial-loop iterations handled by one NetworkAnalyzer.
1952
self.overall_max_hops: dict = {}
53+
54+
def accumulate_max_hops(self, network, max_hops):
55+
"""Add this loop's ``max_hops`` to ``network``'s running total and
56+
return the updated total.
57+
58+
Each call to :meth:`NetworkAnalyzer.accumulate_child_result` (i.e., over
59+
a different iteration of a spatial loop) adds more to the max hops.
60+
"""
61+
self.overall_max_hops[network] = (
62+
self.overall_max_hops.get(network, 0) + max_hops
63+
)
64+
return self.overall_max_hops[network]
65+
66+
@abstractmethod
67+
def per_loop_transfer_cost(
68+
self,
69+
relevancy,
70+
*,
71+
shape_repeats,
72+
last_fanout,
73+
volume,
74+
src_component,
75+
dim_name: str,
76+
) -> PerLoopTransferCost:
77+
"""Return the :class:`PerLoopTransferCost` for moving ``volume`` of data across one
78+
spatial loop.
79+
80+
Args:
81+
relevancy: The relevancy of the spatial loop's rank variable to the
82+
tensor (``Irrelevant``, ``Relevant``, or ``PartiallyRelevant``).
83+
shape_repeats: The number of iterations of this spatial loop.
84+
last_fanout: The fanout in this dimension among mapping nodes below
85+
(i.e., the stride).
86+
volume: The data volume (in actions) moved per destination.
87+
src_component: The flattened-arch component sourcing the data, used
88+
to query physical fanout/stride.
89+
dim_name: The name of the spatial dimension (e.g., ``X`` or ``Y``).
90+
"""
91+
raise NotImplementedError
92+
93+
94+
class MeshTopologyModel(TopologyModel):
95+
"""Cost model for a mesh network.
96+
97+
Data travels link-by-link along one axis of the mesh. Multicast delivers a
98+
value to every point along the dimension; unicast delivers a distinct value
99+
to each point. When the source is physically distributed, data is bound as
100+
locally as possible across the physical buffers.
101+
"""
102+
103+
def per_loop_transfer_cost(
104+
self,
105+
relevancy,
106+
*,
107+
shape_repeats,
108+
last_fanout,
109+
volume,
110+
src_component,
111+
dim_name,
112+
) -> PerLoopTransferCost:
113+
if isinstance(relevancy, Irrelevant):
114+
# The volume travels through link by link in one axis of the mesh
115+
# Distributed or not, the amount of total cost is the same.
116+
# However, the accesses now come from different physical memories
117+
total_cost = multicast_cost(shape_repeats, last_fanout) * volume
118+
max_hops = shape_repeats * last_fanout
119+
max_traffic = volume
120+
elif isinstance(relevancy, Relevant):
121+
# If distributed, then we bind data as locally as possible in the
122+
# physical buffers
123+
if src_component._get_physical_fanout_along(dim_name) > 1:
124+
physical_stride = src_component._get_physical_stride_along(dim_name)
125+
126+
n_dsts_per_physical = MinGeqZero(
127+
# if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
128+
# later (which is correct because the set of destinations always overlap
129+
# the set of sources).
130+
MaxGeqZero(physical_stride / last_fanout, 1),
131+
shape_repeats
132+
)
133+
n_activated_physical = MaxGeqZero(shape_repeats * last_fanout / physical_stride, 1)
134+
total_cost = (
135+
n_activated_physical
136+
*
137+
unicast_cost(n_dsts_per_physical, last_fanout)
138+
*
139+
volume
140+
)
141+
max_hops = MinGeqZero((n_dsts_per_physical - 1) * last_fanout, physical_stride)
142+
max_traffic = (n_dsts_per_physical - 1) * volume
143+
else:
144+
total_cost = unicast_cost(shape_repeats, last_fanout) * volume
145+
max_hops = shape_repeats * last_fanout
146+
max_traffic = (shape_repeats - 1) * volume
147+
elif isinstance(relevancy, PartiallyRelevant):
148+
raise NotImplementedError()
149+
else:
150+
raise RuntimeError(f"unhandled relevancy type {relevancy}")
151+
152+
return PerLoopTransferCost(total_cost=total_cost, max_hops=max_hops, max_traffic=max_traffic)
153+
154+
155+
# Registry mapping each topology to the model class that costs its data
156+
# movement. Classes (not instances) are stored because models are stateful and
157+
# each NetworkAnalyzer needs its own.
158+
TOPOLOGY_MODELS: dict[TopologySpec, type[TopologyModel]] = {
159+
TopologySpec.MESH: MeshTopologyModel,
160+
}
161+
162+
163+
def get_topology_model(topology) -> TopologyModel:
164+
"""Construct a fresh :class:`TopologyModel` for the given topology."""
165+
return TOPOLOGY_MODELS[topology]()
166+
167+
168+
class NetworkAnalyzer:
169+
def __init__(self, network_stats, info: AnalysisInfo, einsum_name, node: Spatial):
20170
self.network_stats = network_stats
171+
# These don't change across calls to accumulate_child_result.
172+
self.info = info
173+
self.einsum_name = einsum_name
174+
self.node = node
175+
# Each network gets its own topology model, since different networks may
176+
# have different topologies. Models are constructed lazily, the first
177+
# time a network needs costing, and reused for the analyzer's lifetime so
178+
# their accumulated max hops persist.
179+
self.topology_models: dict = {}
180+
181+
def _get_topology_model(self, network, topology) -> TopologyModel:
182+
if network not in self.topology_models:
183+
self.topology_models[network] = get_topology_model(topology)
184+
return self.topology_models[network]
21185

22186
def accumulate_child_result(
23187
self,
24188
child_result: SymbolicAnalysisOutput,
25-
info: AnalysisInfo,
26189
shape_repeats,
27-
einsum_name,
28190
child_shape,
29-
node: Spatial,
30191
):
31192
"""This function is called for every repeated shape."""
32-
flattened_arch = info.job.flattened_arch
193+
flattened_arch = self.info.job.flattened_arch
33194

34195
for network, child_network_stats in child_result.network_stats.items():
35196
src_component = flattened_arch[network.source.level]
@@ -39,7 +200,7 @@ def accumulate_child_result(
39200

40201
# We only need to update the summary if the spatial loop is for
41202
# a component higher than the network of interest
42-
if flattened_arch.is_above(node.component, network.component):
203+
if flattened_arch.is_above(self.node.component, network.component):
43204
accumulated_network_stats.total_hops += (
44205
child_network_stats.total_hops * shape_repeats
45206
)
@@ -54,71 +215,51 @@ def accumulate_child_result(
54215
)
55216
continue
56217

57-
volume = self._get_data_volume(network, einsum_name, info, child_shape)
218+
volume = self._get_data_volume(network, child_shape)
58219

59-
relevancy = info.tensor_to_relevancy[network.tensor][node.rank_variable]
220+
relevancy = self.info.tensor_to_relevancy[network.tensor][self.node.rank_variable]
60221

61222
# The fanout in this dimension in mapping nodes below, i.e., the stride
62-
last_fanout = child_result.fanout.get((node.component, einsum_name), {})
63-
last_fanout = last_fanout.get(node.name, 1)
64-
if isinstance(relevancy, Irrelevant):
65-
# The volume travels through link by link in one axis of the mesh
66-
# Distributed or not, the amount of total cost is the same.
67-
# However, the accesses now come from different physical memories
68-
total_cost = multicast_cost(shape_repeats, last_fanout)*volume
69-
max_hops = shape_repeats*last_fanout
70-
max_traffic = volume
71-
elif isinstance(relevancy, Relevant):
72-
# If distributed, then we bind data as locally as possible in the
73-
# physical buffers
74-
if src_component._get_physical_fanout_along(node.name) > 1:
75-
physical_stride = src_component._get_physical_stride_along(node.name)
76-
77-
n_dsts_per_physical = MinGeqZero(
78-
# if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
79-
# later (which is correct because the set of destinations always overlap
80-
# the set of sources).
81-
MaxGeqZero(physical_stride / last_fanout, 1),
82-
shape_repeats
83-
)
84-
n_activated_physical = MaxGeqZero(shape_repeats*last_fanout/physical_stride, 1)
85-
total_cost = (
86-
n_activated_physical
87-
*
88-
unicast_cost(n_dsts_per_physical, last_fanout)
89-
*
90-
volume
91-
)
92-
max_hops = MinGeqZero((n_dsts_per_physical-1)*last_fanout, physical_stride)
93-
max_traffic = (n_dsts_per_physical-1)*volume
94-
else:
95-
total_cost = unicast_cost(shape_repeats, last_fanout)*volume
96-
max_hops = shape_repeats * last_fanout
97-
max_traffic = (shape_repeats-1)*volume
98-
elif isinstance(relevancy, PartiallyRelevant):
99-
raise NotImplementedError()
100-
else:
101-
raise RuntimeError(f"unhandled relevancy type {relevancy}")
223+
last_fanout = child_result.fanout.get((self.node.component, self.einsum_name), {})
224+
last_fanout = last_fanout.get(self.node.name, 1)
102225

103-
# Each subsequent call to this function (i.e., over different iterations of a spatial loop)
104-
# adds more to the max hops
105-
self.overall_max_hops[network] = self.overall_max_hops.get(network, 0) + max_hops
226+
topology_model = self._get_topology_model(
227+
network, flattened_arch[network.component].topology
228+
)
229+
per_loop_transfer_cost = topology_model.per_loop_transfer_cost(
230+
relevancy,
231+
shape_repeats=shape_repeats,
232+
last_fanout=last_fanout,
233+
volume=volume,
234+
src_component=src_component,
235+
dim_name=self.node.name,
236+
)
237+
238+
overall_max_hops = topology_model.accumulate_max_hops(
239+
network, per_loop_transfer_cost.max_hops
240+
)
106241

107242
accumulated_network_stats.total_hops += (
108-
total_cost + child_network_stats.total_hops*shape_repeats
243+
per_loop_transfer_cost.total_cost
244+
+ child_network_stats.total_hops * shape_repeats
109245
)
110246
accumulated_network_stats.max_hops = MaxGeqZero(
111247
accumulated_network_stats.max_hops,
112-
self.overall_max_hops[network] + child_network_stats.max_hops,
248+
overall_max_hops + child_network_stats.max_hops,
113249
)
114-
accumulated_network_stats.max_traffic[node.name] = MaxGeqZero(
115-
accumulated_network_stats.max_traffic.get(node.name, 0),
116-
max_traffic + child_network_stats.max_traffic.get(node.name, 0)
250+
accumulated_network_stats.max_traffic[self.node.name] = MaxGeqZero(
251+
accumulated_network_stats.max_traffic.get(self.node.name, 0),
252+
per_loop_transfer_cost.max_traffic + child_network_stats.max_traffic.get(self.node.name, 0)
117253
)
118254

119-
return self.overall_max_hops
255+
overall_max_hops = {}
256+
for model in self.topology_models.values():
257+
overall_max_hops.update(model.overall_max_hops)
258+
return overall_max_hops
120259

121-
def _get_data_volume(self, network, einsum_name, info, child_shape):
260+
def _get_data_volume(self, network, child_shape):
261+
info = self.info
262+
einsum_name = self.einsum_name
122263
flattened_arch = info.job.flattened_arch
123264
projection = info.einsum_tensor_to_projection[(einsum_name, network.tensor)]
124265
component_object = flattened_arch[network.component]
@@ -153,4 +294,4 @@ def unicast_cost(n_dsts, stride):
153294

154295

155296
def arithmetic_sum(n):
156-
return 0.5 * (n+1) * n
297+
return 0.5 * (n+1) * n

accelforge/model/_looptree/reuse/symbolic/_symbolic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,9 @@ def analyze_spatial(node_idx, current_shape, info: AnalysisInfo):
592592

593593
result_accumulator = SymbolicAnalysisOutput()
594594

595-
network_analyzer = NetworkAnalyzer(result_accumulator.network_stats)
595+
network_analyzer = NetworkAnalyzer(
596+
result_accumulator.network_stats, info, einsum_name, node
597+
)
596598

597599
def handle_repeated_value(repeated_shape):
598600
shape_value = repeated_shape.value
@@ -633,7 +635,7 @@ def handle_repeated_value(repeated_shape):
633635
)
634636

635637
network_analyzer.accumulate_child_result(
636-
child_result, info, shape_repeats, einsum_name, child_shape, node
638+
child_result, shape_repeats, child_shape
637639
)
638640

639641
for einsum, child_steps in child_result.temporal_steps.items():

0 commit comments

Comments
 (0)