1212 PartiallyRelevant ,
1313)
1414
15- from accelforge .util ._sympy .broadcast_max import Min , Max , MaxGeqZero
15+ from accelforge .util ._sympy .broadcast_max import MaxGeqZero , MinGeqZero
1616
1717from ._common import AnalysisInfo
18- from ._stats import NetworkStats
18+ from ._stats import NetworkStats , SymbolicAnalysisOutput
1919
2020
2121class NetworkAnalyzer :
@@ -25,7 +25,7 @@ def __init__(self, network_stats):
2525
2626 def accumulate_child_result (
2727 self ,
28- child_result ,
28+ child_result : SymbolicAnalysisOutput ,
2929 info : AnalysisInfo ,
3030 shape_repeats ,
3131 einsum_name ,
@@ -35,6 +35,7 @@ def accumulate_child_result(
3535 flattened_arch = info .job .flattened_arch
3636
3737 for network , child_network_stats in child_result .network_stats .items ():
38+ src_component = flattened_arch [network .source .level ]
3839 if network not in self .network_stats :
3940 self .network_stats [network ] = NetworkStats ()
4041 accumulated_network_stats = self .network_stats [network ]
@@ -64,93 +65,72 @@ def accumulate_child_result(
6465 * actions_per_value
6566 )
6667
67- if is_component_a_above_b (node .component , network .component , flattened_arch ):
68+ if flattened_arch . is_above (node .component , network .component ):
6869 continue
6970
7071 relevancy = info .tensor_to_relevancy [network .tensor ][node .rank_variable ]
7172
73+ # The fanout in this dimension in mapping nodes below, i.e., the stride
7274 last_fanout = child_result .fanout .get ((node .component , einsum_name ), {})
7375 last_fanout = last_fanout .get (node .name , 1 )
7476 if isinstance (relevancy , Irrelevant ):
75- # Cost of multicasting is the cost of delivering along the dimension
76- multicast_hops = shape_repeats * last_fanout
77- multicast_cost = multicast_hops * volume
78- self .overall_max_hops += multicast_hops
79-
80- accumulated_network_stats .total_hops += multicast_cost
81- accumulated_network_stats .max_hops = MaxGeqZero (
82- accumulated_network_stats .max_hops ,
83- self .overall_max_hops + child_network_stats .max_hops ,
84- )
77+ # Distributed or not, the amount of total cost is the same.
78+ # However, the accesses now come from different physical memories
79+ total_cost = multicast_cost (shape_repeats , last_fanout )* volume
80+ max_hops = shape_repeats * last_fanout
8581 elif isinstance (relevancy , Relevant ):
86- # Cost of unicast is the cost of delivering to each point in
87- # the dimension with shape as stride
88- # TODO: we should use the actual stride
89- total_unicast_cost = (
90- 0.5 * (shape_repeats + 1 ) * shape_repeats * last_fanout * volume
91- )
92- max_unicast_hops = shape_repeats * last_fanout
93- self .overall_max_hops += max_unicast_hops
94-
95- accumulated_network_stats .total_hops += total_unicast_cost
96- accumulated_network_stats .max_hops = MaxGeqZero (
97- accumulated_network_stats .max_hops ,
98- self .overall_max_hops + child_network_stats .max_hops ,
99- )
82+ # If distributed, then we bind data as locally as possible in the
83+ # physical buffers
84+ if src_component ._get_physical_fanout_along (node .name ) > 1 :
85+ physical_stride = src_component ._get_physical_stride_along (node .name )
86+
87+ n_dsts_per_physical = MinGeqZero (
88+ # if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
89+ # later (which is correct because the set of destinations always overlap
90+ # the set of sources).
91+ MaxGeqZero (physical_stride / last_fanout , 1 ),
92+ shape_repeats
93+ )
94+ n_activated_physical = MaxGeqZero (shape_repeats * last_fanout / physical_stride , 1 )
95+ total_cost = (
96+ n_activated_physical
97+ *
98+ unicast_cost (n_dsts_per_physical , last_fanout )
99+ *
100+ volume
101+ )
102+ max_hops = MinGeqZero (shape_repeats * last_fanout , physical_stride )
103+ else :
104+ total_cost = unicast_cost (shape_repeats , last_fanout )* volume
105+ max_hops = shape_repeats * last_fanout
100106 elif isinstance (relevancy , PartiallyRelevant ):
101107 raise NotImplementedError ()
102108 else :
103109 raise RuntimeError (f"unhandled relevancy type { relevancy } " )
104110
105- return self .overall_max_hops
106-
107-
108- def reduce_dicts (dict1 : dict , dict2 : dict , reduce_op ):
109- for key in dict1 :
110- if key not in dict2 :
111- dict2 [key ] = dict1 [key ]
112- else :
113- dict2 [key ] = reduce_op (dict1 [key ], dict2 [key ])
114-
115-
116- def get_total_to_per_unit (total , max_per_unit ):
117- if total == 0 and max_per_unit != 0 :
118- raise ValueError (f"total is 0 but max_per_unit is { max_per_unit } " )
119- if total == 0 :
120- return 1
121- return max_per_unit / total
111+ # TODO: this is sketchy
112+ self .overall_max_hops += max_hops
122113
114+ accumulated_network_stats .total_hops += total_cost
115+ accumulated_network_stats .max_hops = MaxGeqZero (
116+ accumulated_network_stats .max_hops ,
117+ self .overall_max_hops + child_network_stats .max_hops ,
118+ )
123119
124- def has_parent_tensor_holder (
125- tensor : TensorName , node_idx : int , info
126- ) -> bool :
127- for node in info .mapping [:node_idx ]:
128- if isinstance (node , TensorHolder ) and tensor in node .tensors :
129- return True
130- return False
120+ return self .overall_max_hops
131121
132122
133- def find_component_object (
134- component : str , flattened_arch : list [arch .Leaf ]
135- ) -> arch .TensorHolder :
136- for node in flattened_arch :
137- if node .name == component :
138- return node
139- raise ValueError (f"Component { component } not found in flattened arch" )
123+ def multicast_cost (n_dsts , stride ):
124+ """Returns total hops of multicast along a dimension."""
125+ return (n_dsts - 1 )* stride
140126
141127
142- def is_component_a_above_b (component_a : str , component_b : str , flattened_arch ):
143- a_found = False
144- b_found = False
145- for node in flattened_arch :
146- if node .name == component_a :
147- a_found = True
148- if node .name == component_b :
149- b_found = True
128+ def unicast_cost (n_dsts , stride ):
129+ """Returns total hops of unicast along a dimension."""
130+ # Cost of unicast is the cost of delivering to each point in
131+ # the dimension with shape as stride
132+ return arithmetic_sum (n_dsts - 1 )* stride
150133
151- if a_found and not b_found :
152- return True
153- elif b_found and not a_found :
154- return False
155- raise ValueError (f"Neither { component_a } nor { component_b } found in flattened arch" )
156134
135+ def arithmetic_sum (n ):
136+ return 0.5 * (n + 1 ) * n
0 commit comments