Skip to content

Commit 4b3fcab

Browse files
committed
[network] Add *untested* distributed model
1 parent 82bc1b9 commit 4b3fcab

3 files changed

Lines changed: 72 additions & 92 deletions

File tree

accelforge/frontend/arch/spatialable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ def _get_physical_fanout_along(self, dim_name: str, default: int = 1) -> int:
142142
return s.fanout
143143
return default
144144

145-
def _get_physical_stride_along(self, dim_name: str, default: int = 1) -> int:
145+
def _get_physical_stride_along(self, dim_name: str) -> int:
146146
for s in self._physical_spatial:
147147
if s.name == dim_name:
148148
return s.stride
149-
return default
149+
raise ValueError(f"dimension {dim_name} not found")
150150

151151
def _spatial_str(self, include_newline=True) -> str:
152152
if not self.spatial:

accelforge/model/_looptree/reuse/symbolic/_network.py

Lines changed: 52 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
PartiallyRelevant,
1313
)
1414

15-
from accelforge.util._sympy.broadcast_max import Min, Max, MaxGeqZero
15+
from accelforge.util._sympy.broadcast_max import MaxGeqZero, MinGeqZero
1616

1717
from ._common import AnalysisInfo
18-
from ._stats import NetworkStats
18+
from ._stats import NetworkStats, SymbolicAnalysisOutput
1919

2020

2121
class NetworkAnalyzer:
@@ -25,7 +25,7 @@ def __init__(self, network_stats):
2525

2626
def accumulate_child_result(
2727
self,
28-
child_result,
28+
child_result: SymbolicAnalysisOutput,
2929
info: AnalysisInfo,
3030
shape_repeats,
3131
einsum_name,
@@ -35,6 +35,7 @@ def accumulate_child_result(
3535
flattened_arch = info.job.flattened_arch
3636

3737
for network, child_network_stats in child_result.network_stats.items():
38+
src_component = flattened_arch[network.source.level]
3839
if network not in self.network_stats:
3940
self.network_stats[network] = NetworkStats()
4041
accumulated_network_stats = self.network_stats[network]
@@ -64,93 +65,72 @@ def accumulate_child_result(
6465
* actions_per_value
6566
)
6667

67-
if is_component_a_above_b(node.component, network.component, flattened_arch):
68+
if flattened_arch.is_above(node.component, network.component):
6869
continue
6970

7071
relevancy = info.tensor_to_relevancy[network.tensor][node.rank_variable]
7172

73+
# The fanout in this dimension in mapping nodes below, i.e., the stride
7274
last_fanout = child_result.fanout.get((node.component, einsum_name), {})
7375
last_fanout = last_fanout.get(node.name, 1)
7476
if isinstance(relevancy, Irrelevant):
75-
# Cost of multicasting is the cost of delivering along the dimension
76-
multicast_hops = shape_repeats * last_fanout
77-
multicast_cost = multicast_hops * volume
78-
self.overall_max_hops += multicast_hops
79-
80-
accumulated_network_stats.total_hops += multicast_cost
81-
accumulated_network_stats.max_hops = MaxGeqZero(
82-
accumulated_network_stats.max_hops,
83-
self.overall_max_hops + child_network_stats.max_hops,
84-
)
77+
# Distributed or not, the amount of total cost is the same.
78+
# However, the accesses now come from different physical memories
79+
total_cost = multicast_cost(shape_repeats, last_fanout)*volume
80+
max_hops = shape_repeats*last_fanout
8581
elif isinstance(relevancy, Relevant):
86-
# Cost of unicast is the cost of delivering to each point in
87-
# the dimension with shape as stride
88-
# TODO: we should use the actual stride
89-
total_unicast_cost = (
90-
0.5 * (shape_repeats + 1) * shape_repeats * last_fanout * volume
91-
)
92-
max_unicast_hops = shape_repeats * last_fanout
93-
self.overall_max_hops += max_unicast_hops
94-
95-
accumulated_network_stats.total_hops += total_unicast_cost
96-
accumulated_network_stats.max_hops = MaxGeqZero(
97-
accumulated_network_stats.max_hops,
98-
self.overall_max_hops + child_network_stats.max_hops,
99-
)
82+
# If distributed, then we bind data as locally as possible in the
83+
# physical buffers
84+
if src_component._get_physical_fanout_along(node.name) > 1:
85+
physical_stride = src_component._get_physical_stride_along(node.name)
86+
87+
n_dsts_per_physical = MinGeqZero(
88+
# if last_fanout > physical_stride, set n_dst to 1, which results in 0 hops
89+
# later (which is correct because the set of destinations always overlap
90+
# the set of sources).
91+
MaxGeqZero(physical_stride / last_fanout, 1),
92+
shape_repeats
93+
)
94+
n_activated_physical = MaxGeqZero(shape_repeats*last_fanout/physical_stride, 1)
95+
total_cost = (
96+
n_activated_physical
97+
*
98+
unicast_cost(n_dsts_per_physical, last_fanout)
99+
*
100+
volume
101+
)
102+
max_hops = MinGeqZero(shape_repeats*last_fanout, physical_stride)
103+
else:
104+
total_cost = unicast_cost(shape_repeats, last_fanout)*volume
105+
max_hops = shape_repeats * last_fanout
100106
elif isinstance(relevancy, PartiallyRelevant):
101107
raise NotImplementedError()
102108
else:
103109
raise RuntimeError(f"unhandled relevancy type {relevancy}")
104110

105-
return self.overall_max_hops
106-
107-
108-
def reduce_dicts(dict1: dict, dict2: dict, reduce_op):
109-
for key in dict1:
110-
if key not in dict2:
111-
dict2[key] = dict1[key]
112-
else:
113-
dict2[key] = reduce_op(dict1[key], dict2[key])
114-
115-
116-
def get_total_to_per_unit(total, max_per_unit):
117-
if total == 0 and max_per_unit != 0:
118-
raise ValueError(f"total is 0 but max_per_unit is {max_per_unit}")
119-
if total == 0:
120-
return 1
121-
return max_per_unit / total
111+
# TODO: this is sketchy
112+
self.overall_max_hops += max_hops
122113

114+
accumulated_network_stats.total_hops += total_cost
115+
accumulated_network_stats.max_hops = MaxGeqZero(
116+
accumulated_network_stats.max_hops,
117+
self.overall_max_hops + child_network_stats.max_hops,
118+
)
123119

124-
def has_parent_tensor_holder(
125-
tensor: TensorName, node_idx: int, info
126-
) -> bool:
127-
for node in info.mapping[:node_idx]:
128-
if isinstance(node, TensorHolder) and tensor in node.tensors:
129-
return True
130-
return False
120+
return self.overall_max_hops
131121

132122

133-
def find_component_object(
134-
component: str, flattened_arch: list[arch.Leaf]
135-
) -> arch.TensorHolder:
136-
for node in flattened_arch:
137-
if node.name == component:
138-
return node
139-
raise ValueError(f"Component {component} not found in flattened arch")
123+
def multicast_cost(n_dsts, stride):
124+
"""Returns total hops of multicast along a dimension."""
125+
return (n_dsts-1)*stride
140126

141127

142-
def is_component_a_above_b(component_a: str, component_b: str, flattened_arch):
143-
a_found = False
144-
b_found = False
145-
for node in flattened_arch:
146-
if node.name == component_a:
147-
a_found = True
148-
if node.name == component_b:
149-
b_found = True
128+
def unicast_cost(n_dsts, stride):
129+
"""Returns total hops of unicast along a dimension."""
130+
# Cost of unicast is the cost of delivering to each point in
131+
# the dimension with shape as stride
132+
return arithmetic_sum(n_dsts-1)*stride
150133

151-
if a_found and not b_found:
152-
return True
153-
elif b_found and not a_found:
154-
return False
155-
raise ValueError(f"Neither {component_a} nor {component_b} found in flattened arch")
156134

135+
def arithmetic_sum(n):
136+
return 0.5 * (n+1) * n

tests/test_network.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_hierarchical_1d(self):
7878
* (KN / MAC_TILE) # number of used Scratchpad
7979
* M_TILE
8080
* KN # temporal for n1 in mapping
81-
* sum(i+1 for i in range(MAC_TILE)) # unicast along X-axis of MacArray
81+
* sum(i for i in range(MAC_TILE)) # unicast along X-axis of MacArray
8282
* BITS_PER_VALUE,
8383
)
8484
# NOTE: assuming XY routing (as defined in mapping)
@@ -88,7 +88,7 @@ def test_hierarchical_1d(self):
8888
* (KN / MAC_TILE)
8989
* M_TILE
9090
* KN # temporal for n1 in mapping
91-
* MAC_TILE # multicast along X-axis of MacArray
91+
* (MAC_TILE - 1) # multicast along X-axis of MacArray
9292
* BITS_PER_VALUE,
9393
)
9494
self.assertEqual(
@@ -97,14 +97,14 @@ def test_hierarchical_1d(self):
9797
* (KN / MAC_TILE)
9898
* M_TILE
9999
* KN
100-
* sum(i+1 for i in range(MAC_TILE))
100+
* sum(i for i in range(MAC_TILE))
101101
* BITS_PER_VALUE,
102102
)
103103

104104
self.assertEqual(
105105
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>T0<SEP>hops"].iloc[0],
106106
(M / M_TILE)
107-
* sum(i+1 for i in range(KN // MAC_TILE)) # unicast along X-axis of PeArray
107+
* sum(i for i in range(KN // MAC_TILE)) # unicast along X-axis of PeArray
108108
* M_TILE
109109
* MAC_TILE
110110
* BITS_PER_VALUE,
@@ -113,15 +113,15 @@ def test_hierarchical_1d(self):
113113
self.assertEqual(
114114
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>T1<SEP>hops"].iloc[0],
115115
(M / M_TILE)
116-
* KN // MAC_TILE # multicast along X-axis of PeArray
116+
* (KN // MAC_TILE - 1) # multicast along X-axis of PeArray
117117
* M_TILE
118118
* KN
119119
* BITS_PER_VALUE,
120120
)
121121
self.assertEqual(
122122
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>W0<SEP>hops"].iloc[0],
123123
(M / M_TILE)
124-
* sum(i+1 for i in range(KN // MAC_TILE)) # unicast along PeArray
124+
* sum(i for i in range(KN // MAC_TILE)) # unicast along PeArray
125125
* MAC_TILE
126126
* KN
127127
* BITS_PER_VALUE,
@@ -156,9 +156,9 @@ def test_hierarchical(self):
156156
* (KN / MAC_TILE) ** 2
157157
* M_TILE
158158
* (
159-
sum(i+1 for i in range(MAC_TILE)) # unicasting along X
159+
sum(i for i in range(MAC_TILE)) # unicasting along X
160160
+
161-
MAC_TILE * MAC_TILE # multicast along Y for each column
161+
MAC_TILE * (MAC_TILE-1) # multicast along Y for each column
162162
)
163163
* BITS_PER_VALUE,
164164
)
@@ -169,9 +169,9 @@ def test_hierarchical(self):
169169
* (KN / MAC_TILE) ** 2
170170
* M_TILE
171171
* (
172-
MAC_TILE * MAC_TILE # multicast along X (the tile is shape N1, which is MAC_TILE here)
172+
MAC_TILE * (MAC_TILE - 1) # multicast along X (the tile is shape N1, which is MAC_TILE here)
173173
+
174-
MAC_TILE * sum(i+1 for i in range(MAC_TILE)) # unicasting along Y for each row
174+
MAC_TILE * sum(i for i in range(MAC_TILE)) # unicasting along Y for each row
175175
)
176176
* BITS_PER_VALUE,
177177
)
@@ -181,9 +181,9 @@ def test_hierarchical(self):
181181
* (KN / MAC_TILE) ** 2
182182
* M_TILE
183183
* (
184-
MAC_TILE * sum(i+1 for i in range(MAC_TILE)) # unicast along X (the tile is shape N1, which is MAC_TILE here)
184+
MAC_TILE * sum(i for i in range(MAC_TILE)) # unicast along X (the tile is shape N1, which is MAC_TILE here)
185185
+
186-
MAC_TILE * sum(i+1 for i in range(MAC_TILE)) # unicasting along Y for each row
186+
MAC_TILE * sum(i for i in range(MAC_TILE)) # unicasting along Y for each row
187187
)
188188
* BITS_PER_VALUE,
189189
)
@@ -192,9 +192,9 @@ def test_hierarchical(self):
192192
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>T0<SEP>hops"].iloc[0],
193193
(M / M_TILE)
194194
* (
195-
sum(i+1 for i in range(PE_TILE))
195+
sum(i for i in range(PE_TILE))
196196
+
197-
PE_TILE * PE_TILE
197+
PE_TILE * (PE_TILE - 1)
198198
)
199199
# tile shape
200200
* M_TILE
@@ -206,9 +206,9 @@ def test_hierarchical(self):
206206
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>T1<SEP>hops"].iloc[0],
207207
(M / M_TILE)
208208
* (
209-
PE_TILE * PE_TILE
209+
PE_TILE * (PE_TILE - 1)
210210
+
211-
PE_TILE * sum(i+1 for i in range(PE_TILE))
211+
PE_TILE * sum(i for i in range(PE_TILE))
212212
)
213213
* M_TILE
214214
* MAC_TILE
@@ -218,9 +218,9 @@ def test_hierarchical(self):
218218
result.data["Matmul0<SEP>action<SEP>PeArray<SEP>W0<SEP>hops"].iloc[0],
219219
(M / M_TILE)
220220
* (
221-
PE_TILE * sum(i+1 for i in range(PE_TILE))
221+
PE_TILE * sum(i for i in range(PE_TILE))
222222
+
223-
PE_TILE * sum(i+1 for i in range(PE_TILE))
223+
PE_TILE * sum(i for i in range(PE_TILE))
224224
)
225225
* MAC_TILE**2
226226
* BITS_PER_VALUE,

0 commit comments

Comments
 (0)