@@ -94,10 +94,10 @@ def per_loop_transfer_cost(
9494class MeshTopologyModel (TopologyModel ):
9595 """Cost model for a mesh network.
9696
97- Data travels link-by-link along one axis of the mesh. Multicast delivers a
98- value to every point along the dimension; unicast delivers a distinct value
99- to each point. When the source is physically distributed, data is bound as
100- locally as possible across the physical buffers.
97+ Data travels along one axis of the mesh. Multicast delivers a value to every
98+ point along the dimension; unicast delivers a distinct value to each point.
99+ When the source is physically distributed, data is bound as locally as
100+ possible across the physical buffers.
101101 """
102102
103103 def per_loop_transfer_cost (
@@ -153,23 +153,13 @@ def per_loop_transfer_cost(
153153
154154
155155class AllToAllTopologyModel (TopologyModel ):
156- """Cost model for an all-to-all network built around a switch (e.g. NVLink /
157- NVSwitch).
158-
159- Every node connects to every other node through a central switch, so any
160- source reaches any destination in a constant number of hops regardless of
161- how far apart they are in the logical fanout. This differs from a mesh in
162- two ways:
163-
164- - **Uniform latency.** The longest route is a single switch traversal, so
165- ``max_hops`` is constant rather than growing with the distance
166- (``shape_repeats * stride``) between source and destination.
167- - **No store-and-forward accumulation.** Each destination is reached
168- directly, so the total (energy) cost is linear in the number of
169- destinations rather than quadratic as in a mesh unicast.
170-
171- The physical stride is irrelevant here (all nodes are equidistant from the
172- switch), so ``last_fanout`` and physical distribution are not consulted.
156+ """Cost model for an all-to-all network using a switch (e.g. NVLink).
157+
158+ Every node connects to every other node through a switch, so any
159+ source reaches any destination in one hop regardless of
160+
161+ Physical stride is irrelevant, so ``last_fanout`` and physical distribution
162+ are not used.
173163 """
174164
175165 HOPS_PER_TRANSFER = 1
@@ -219,9 +209,7 @@ def per_loop_transfer_cost(
219209 )
220210
221211
222- # Registry mapping each topology to the model class that costs its data
223- # movement. Classes (not instances) are stored because models are stateful and
224- # each NetworkAnalyzer needs its own.
212+ # Registry of topology models
225213TOPOLOGY_MODELS : dict [TopologySpec , type [TopologyModel ]] = {
226214 TopologySpec .MESH : MeshTopologyModel ,
227215 TopologySpec .ALL_TO_ALL : AllToAllTopologyModel ,
0 commit comments