1+ from abc import ABC , abstractmethod
2+ from dataclasses import dataclass
3+ from typing import Any
4+
15from accelforge .frontend .mapping import (
26 Spatial
37)
8+ from accelforge .frontend .arch .components import TopologySpec
49from accelforge .frontend ._workload_isl ._symbolic import (
510 compute_dense_tile_occupancy ,
611 Irrelevant ,
1419from ._stats import NetworkStats , SymbolicAnalysisOutput
1520
1621
17- class NetworkAnalyzer :
18- def __init__ (self , network_stats ):
22+ @dataclass
23+ class PerLoopTransferCost :
24+ """The per-spatial-loop cost contributed by a single network, as computed
25+ by a :class:`TopologyModel`."""
26+
27+ total_cost : Any
28+ """Total hops contributed by data movement over this spatial loop."""
29+ max_hops : Any
30+ """Hops added to the longest route by this spatial loop."""
31+ max_traffic : Any
32+ """Maximum traffic (in actions) on any single link along this dimension."""
33+
34+
35+ class TopologyModel (ABC ):
36+ """Computes the cost of moving data across a network of a given topology.
37+
38+ Subclasses encapsulate everything topology-specific about how a tensor's
39+ data is delivered across a spatial fanout. :class:`NetworkAnalyzer` selects
40+ the model for each network from its component's
41+ :class:`~accelforge.frontend.arch.components.TopologySpec` and remains
42+ agnostic to the topology itself.
43+
44+ Instances are stateful: they accumulate per-network max hops across the
45+ repeated spatial-loop iterations of a single :class:`NetworkAnalyzer`, so a
46+ fresh model is constructed for each analyzer (see :func:`get_topology_model`).
47+ """
48+
49+ def __init__ (self ):
50+ # Running total of max hops per network, accumulated across the
51+ # repeated spatial-loop iterations handled by one NetworkAnalyzer.
1952 self .overall_max_hops : dict = {}
53+
54+ def accumulate_max_hops (self , network , max_hops ):
55+ """Add this loop's ``max_hops`` to ``network``'s running total and
56+ return the updated total.
57+
58+ Each call to :meth:`NetworkAnalyzer.accumulate_child_result` (i.e., over
59+ a different iteration of a spatial loop) adds more to the max hops.
60+ """
61+ self .overall_max_hops [network ] = (
62+ self .overall_max_hops .get (network , 0 ) + max_hops
63+ )
64+ return self .overall_max_hops [network ]
65+
66+ @abstractmethod
67+ def per_loop_transfer_cost (
68+ self ,
69+ relevancy ,
70+ * ,
71+ shape_repeats ,
72+ last_fanout ,
73+ volume ,
74+ src_component ,
75+ dim_name : str ,
76+ ) -> PerLoopTransferCost :
77+ """Return the :class:`PerLoopTransferCost` for moving ``volume`` of data across one
78+ spatial loop.
79+
80+ Args:
81+ relevancy: The relevancy of the spatial loop's rank variable to the
82+ tensor (``Irrelevant``, ``Relevant``, or ``PartiallyRelevant``).
83+ shape_repeats: The number of iterations of this spatial loop.
84+ last_fanout: The fanout in this dimension among mapping nodes below
85+ (i.e., the stride).
86+ volume: The data volume (in actions) moved per destination.
87+ src_component: The flattened-arch component sourcing the data, used
88+ to query physical fanout/stride.
89+ dim_name: The name of the spatial dimension (e.g., ``X`` or ``Y``).
90+ """
91+ raise NotImplementedError
92+
93+
94+ class MeshTopologyModel (TopologyModel ):
95+ """Cost model for a mesh network.
96+
97+ Data travels link-by-link along one axis of the mesh. Multicast delivers a
98+ value to every point along the dimension; unicast delivers a distinct value
99+ to each point. When the source is physically distributed, data is bound as
100+ locally as possible across the physical buffers.
101+ """
102+
103+ def per_loop_transfer_cost (
104+ self ,
105+ relevancy ,
106+ * ,
107+ shape_repeats ,
108+ last_fanout ,
109+ volume ,
110+ src_component ,
111+ dim_name ,
112+ ) -> PerLoopTransferCost :
113+ if isinstance (relevancy , Irrelevant ):
114+ # The volume travels through link by link in one axis of the mesh
115+ # Distributed or not, the amount of total cost is the same.
116+ # However, the accesses now come from different physical memories
117+ total_cost = multicast_cost (shape_repeats , last_fanout ) * volume
118+ max_hops = shape_repeats * last_fanout
119+ max_traffic = volume
120+ elif isinstance (relevancy , Relevant ):
121+ # If distributed, then we bind data as locally as possible in the
122+ # physical buffers
123+ if src_component ._get_physical_fanout_along (dim_name ) > 1 :
124+ physical_stride = src_component ._get_physical_stride_along (dim_name )
125+
126+ n_dsts_per_physical = MinGeqZero (
127+ # if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
128+ # later (which is correct because the set of destinations always overlap
129+ # the set of sources).
130+ MaxGeqZero (physical_stride / last_fanout , 1 ),
131+ shape_repeats
132+ )
133+ n_activated_physical = MaxGeqZero (shape_repeats * last_fanout / physical_stride , 1 )
134+ total_cost = (
135+ n_activated_physical
136+ *
137+ unicast_cost (n_dsts_per_physical , last_fanout )
138+ *
139+ volume
140+ )
141+ max_hops = MinGeqZero ((n_dsts_per_physical - 1 ) * last_fanout , physical_stride )
142+ max_traffic = (n_dsts_per_physical - 1 ) * volume
143+ else :
144+ total_cost = unicast_cost (shape_repeats , last_fanout ) * volume
145+ max_hops = shape_repeats * last_fanout
146+ max_traffic = (shape_repeats - 1 ) * volume
147+ elif isinstance (relevancy , PartiallyRelevant ):
148+ raise NotImplementedError ()
149+ else :
150+ raise RuntimeError (f"unhandled relevancy type { relevancy } " )
151+
152+ return PerLoopTransferCost (total_cost = total_cost , max_hops = max_hops , max_traffic = max_traffic )
153+
154+
155+ # Registry mapping each topology to the model class that costs its data
156+ # movement. Classes (not instances) are stored because models are stateful and
157+ # each NetworkAnalyzer needs its own.
158+ TOPOLOGY_MODELS : dict [TopologySpec , type [TopologyModel ]] = {
159+ TopologySpec .MESH : MeshTopologyModel ,
160+ }
161+
162+
163+ def get_topology_model (topology ) -> TopologyModel :
164+ """Construct a fresh :class:`TopologyModel` for the given topology."""
165+ return TOPOLOGY_MODELS [topology ]()
166+
167+
168+ class NetworkAnalyzer :
169+ def __init__ (self , network_stats , info : AnalysisInfo , einsum_name , node : Spatial ):
20170 self .network_stats = network_stats
171+ # These don't change across calls to accumulate_child_result.
172+ self .info = info
173+ self .einsum_name = einsum_name
174+ self .node = node
175+ # Each network gets its own topology model, since different networks may
176+ # have different topologies. Models are constructed lazily, the first
177+ # time a network needs costing, and reused for the analyzer's lifetime so
178+ # their accumulated max hops persist.
179+ self .topology_models : dict = {}
180+
181+ def _get_topology_model (self , network , topology ) -> TopologyModel :
182+ if network not in self .topology_models :
183+ self .topology_models [network ] = get_topology_model (topology )
184+ return self .topology_models [network ]
21185
22186 def accumulate_child_result (
23187 self ,
24188 child_result : SymbolicAnalysisOutput ,
25- info : AnalysisInfo ,
26189 shape_repeats ,
27- einsum_name ,
28190 child_shape ,
29- node : Spatial ,
30191 ):
31192 """This function is called for every repeated shape."""
32- flattened_arch = info .job .flattened_arch
193+ flattened_arch = self . info .job .flattened_arch
33194
34195 for network , child_network_stats in child_result .network_stats .items ():
35196 src_component = flattened_arch [network .source .level ]
@@ -39,7 +200,7 @@ def accumulate_child_result(
39200
40201 # We only need to update the summary if the spatial loop is for
41202 # a component higher than the network of interest
42- if flattened_arch .is_above (node .component , network .component ):
203+ if flattened_arch .is_above (self . node .component , network .component ):
43204 accumulated_network_stats .total_hops += (
44205 child_network_stats .total_hops * shape_repeats
45206 )
@@ -54,71 +215,51 @@ def accumulate_child_result(
54215 )
55216 continue
56217
57- volume = self ._get_data_volume (network , einsum_name , info , child_shape )
218+ volume = self ._get_data_volume (network , child_shape )
58219
59- relevancy = info .tensor_to_relevancy [network .tensor ][node .rank_variable ]
220+ relevancy = self . info .tensor_to_relevancy [network .tensor ][self . node .rank_variable ]
60221
61222 # The fanout in this dimension in mapping nodes below, i.e., the stride
62- last_fanout = child_result .fanout .get ((node .component , einsum_name ), {})
63- last_fanout = last_fanout .get (node .name , 1 )
64- if isinstance (relevancy , Irrelevant ):
65- # The volume travels through link by link in one axis of the mesh
66- # Distributed or not, the amount of total cost is the same.
67- # However, the accesses now come from different physical memories
68- total_cost = multicast_cost (shape_repeats , last_fanout )* volume
69- max_hops = shape_repeats * last_fanout
70- max_traffic = volume
71- elif isinstance (relevancy , Relevant ):
72- # If distributed, then we bind data as locally as possible in the
73- # physical buffers
74- if src_component ._get_physical_fanout_along (node .name ) > 1 :
75- physical_stride = src_component ._get_physical_stride_along (node .name )
76-
77- n_dsts_per_physical = MinGeqZero (
78- # if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
79- # later (which is correct because the set of destinations always overlap
80- # the set of sources).
81- MaxGeqZero (physical_stride / last_fanout , 1 ),
82- shape_repeats
83- )
84- n_activated_physical = MaxGeqZero (shape_repeats * last_fanout / physical_stride , 1 )
85- total_cost = (
86- n_activated_physical
87- *
88- unicast_cost (n_dsts_per_physical , last_fanout )
89- *
90- volume
91- )
92- max_hops = MinGeqZero ((n_dsts_per_physical - 1 )* last_fanout , physical_stride )
93- max_traffic = (n_dsts_per_physical - 1 )* volume
94- else :
95- total_cost = unicast_cost (shape_repeats , last_fanout )* volume
96- max_hops = shape_repeats * last_fanout
97- max_traffic = (shape_repeats - 1 )* volume
98- elif isinstance (relevancy , PartiallyRelevant ):
99- raise NotImplementedError ()
100- else :
101- raise RuntimeError (f"unhandled relevancy type { relevancy } " )
223+ last_fanout = child_result .fanout .get ((self .node .component , self .einsum_name ), {})
224+ last_fanout = last_fanout .get (self .node .name , 1 )
102225
103- # Each subsequent call to this function (i.e., over different iterations of a spatial loop)
104- # adds more to the max hops
105- self .overall_max_hops [network ] = self .overall_max_hops .get (network , 0 ) + max_hops
226+ topology_model = self ._get_topology_model (
227+ network , flattened_arch [network .component ].topology
228+ )
229+ per_loop_transfer_cost = topology_model .per_loop_transfer_cost (
230+ relevancy ,
231+ shape_repeats = shape_repeats ,
232+ last_fanout = last_fanout ,
233+ volume = volume ,
234+ src_component = src_component ,
235+ dim_name = self .node .name ,
236+ )
237+
238+ overall_max_hops = topology_model .accumulate_max_hops (
239+ network , per_loop_transfer_cost .max_hops
240+ )
106241
107242 accumulated_network_stats .total_hops += (
108- total_cost + child_network_stats .total_hops * shape_repeats
243+ per_loop_transfer_cost .total_cost
244+ + child_network_stats .total_hops * shape_repeats
109245 )
110246 accumulated_network_stats .max_hops = MaxGeqZero (
111247 accumulated_network_stats .max_hops ,
112- self . overall_max_hops [ network ] + child_network_stats .max_hops ,
248+ overall_max_hops + child_network_stats .max_hops ,
113249 )
114- accumulated_network_stats .max_traffic [node .name ] = MaxGeqZero (
115- accumulated_network_stats .max_traffic .get (node .name , 0 ),
116- max_traffic + child_network_stats .max_traffic .get (node .name , 0 )
250+ accumulated_network_stats .max_traffic [self . node .name ] = MaxGeqZero (
251+ accumulated_network_stats .max_traffic .get (self . node .name , 0 ),
252+ per_loop_transfer_cost . max_traffic + child_network_stats .max_traffic .get (self . node .name , 0 )
117253 )
118254
119- return self .overall_max_hops
255+ overall_max_hops = {}
256+ for model in self .topology_models .values ():
257+ overall_max_hops .update (model .overall_max_hops )
258+ return overall_max_hops
120259
121- def _get_data_volume (self , network , einsum_name , info , child_shape ):
260+ def _get_data_volume (self , network , child_shape ):
261+ info = self .info
262+ einsum_name = self .einsum_name
122263 flattened_arch = info .job .flattened_arch
123264 projection = info .einsum_tensor_to_projection [(einsum_name , network .tensor )]
124265 component_object = flattened_arch [network .component ]
@@ -153,4 +294,4 @@ def unicast_cost(n_dsts, stride):
153294
154295
155296def arithmetic_sum (n ):
156- return 0.5 * (n + 1 ) * n
297+ return 0.5 * (n + 1 ) * n
0 commit comments