Skip to content

Commit 6c70174

Browse files
committed
Merge branch 'main' into pipeline
2 parents 45f7626 + 94d7ebc commit 6c70174

36 files changed

Lines changed: 2689 additions & 1823 deletions

accelforge/frontend/arch/components.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import itertools
33
import logging
4+
from numbers import Number
45
from accelforge.util._frozenset import oset
56
from typing import (
67
Any,
@@ -90,8 +91,35 @@ class TensorHolderAction(Action):
9091
bits_per_action: EvalsTo[int | float] = (
9192
"1 if bits_per_action is None else bits_per_action"
9293
)
93-
""" The number of bits accessed in this action. For example, setting bits_per_action
94-
to 16 means that each call to this action yields 16 bits. """
94+
"""
95+
The number of bits accessed in this action. For example, setting bits_per_action to
96+
16 means that each call to this action yields 16 bits. Overridden by
97+
values_per_action in this action or by the parent component's values per action.
98+
"""
99+
100+
values_per_action: EvalsTo[dict] = {}
101+
"""
102+
Sets the number of tensor values that are accessed by each call of this action. Keys
103+
are evaluated as expressions and may reference one or more tensors. Overrides
104+
bits_per_action, and sets bits_per_action to values_per_action[tensor] *
105+
bits_per_value[tensor].
106+
"""
107+
108+
def _eval_expressions(self, *args, **kwargs):
109+
if getattr(self, "_evaluated", False):
110+
return super()._eval_expressions(*args, **kwargs)
111+
112+
class MyPostCall(_PostCall):
113+
def __call__(self, field, value, evaluated, symbol_table):
114+
if field == "values_per_action":
115+
evaluated = _eval_tensor2number(
116+
evaluated,
117+
location="values_per_action",
118+
symbol_table=symbol_table,
119+
)
120+
return evaluated
121+
122+
return super()._eval_expressions(*args, **kwargs, post_calls=(MyPostCall(),))
95123

96124

97125
_COMPONENT_MODEL_CACHE: dict[tuple, "Component"] = {}
@@ -198,6 +226,13 @@ class Component(Spatialable):
198226
this action's energy. Multiplies the calculated energy of each action.
199227
"""
200228

229+
actions_scale: EvalsTo[int | float] = 1
230+
"""
231+
Scales the number of actions performed by this component. Multiplies the action
232+
count for each action of this component, which proportionally increases this
233+
component's energy and latency.
234+
"""
235+
201236
total_latency: str | int | float = "sum(*action2latency.values())"
202237
"""
203238
An expression representing the total latency of this component in seconds. This is
@@ -232,8 +267,8 @@ class Component(Spatialable):
232267
n_parallel_instances: EvalsTo[int | float] = 1
233268
"""
234269
The number of parallel instances of this component. Increasing parallel instances
235-
will proportionally increase area and leakage, while reducing latency (unless
236-
latency calculation is overridden).
270+
will proportionally increase area and leakage while reducing latency (unless latency
271+
calculation is overridden).
237272
"""
238273

239274
extra_attributes_for_component_model: _ExtraAttrs = _ExtraAttrs()
@@ -727,7 +762,7 @@ def _copy_for_component_modeling(self) -> Self:
727762
)
728763

729764

730-
def _eval_tensor2bits(
765+
def _eval_tensor2number(
731766
toeval: dict[str, Any],
732767
location: str,
733768
symbol_table: dict[str, Any],
@@ -936,7 +971,17 @@ class TensorHolder(Component, Leaf):
936971
"""
937972
The number of bits accessed in each of this component's actions. Overridden by
938973
bits_per_action in any action of this component. If set here, acts as a default
939-
value for the bits_per_action of all actions of this component.
974+
value for the bits_per_action of all actions of this component. Overridden by
975+
values_per_action or by values in each action.
976+
"""
977+
978+
values_per_action: EvalsTo[dict] = {}
979+
"""
980+
Sets the number of tensor values that are accessed by each action of this
981+
`TensorHolder`. Keys are evaluated as expressions and may reference one or more
982+
tensors. Overrides bits_per_action, and sets bits_per_action to
983+
values_per_action[tensor] * bits_per_value[tensor]. Overridden by values_per_action
984+
in any action of this component.
940985
"""
941986

942987
def model_post_init(self, __context__=None) -> None:
@@ -951,15 +996,36 @@ def _eval_expressions(self, *args, **kwargs):
951996
class MyPostCall(_PostCall):
952997
def __call__(self, field, value, evaluated, symbol_table):
953998
if field == "bits_per_value":
954-
evaluated = _eval_tensor2bits(
999+
evaluated = _eval_tensor2number(
9551000
evaluated,
9561001
location="bits_per_value",
9571002
symbol_table=symbol_table,
9581003
)
1004+
if field == "values_per_action":
1005+
evaluated = _eval_tensor2number(
1006+
evaluated,
1007+
location="values_per_action",
1008+
symbol_table=symbol_table,
1009+
)
9591010
return evaluated
9601011

9611012
return super()._eval_expressions(*args, **kwargs, post_calls=(MyPostCall(),))
9621013

1014+
def _get_values_per_action(
1015+
self, action_name: str, tensor_name: TensorName, bits_per_value_default: Number
1016+
):
1017+
action = self.actions[action_name]
1018+
1019+
if tensor_name in action.values_per_action:
1020+
return action.values_per_action[tensor_name]
1021+
if tensor_name in self.values_per_action:
1022+
return self.values_per_action[tensor_name]
1023+
1024+
tensor_bpv = self.bits_per_value.get(tensor_name, bits_per_value_default)
1025+
action_bpa = action.bits_per_action
1026+
1027+
return action_bpa / tensor_bpv
1028+
9631029

9641030
class Container(Leaf, Spatialable):
9651031
"""
@@ -1049,24 +1115,27 @@ def _eval_expressions(self, *args, **kwargs):
10491115
if getattr(self, "_evaluated", False):
10501116
return super()._eval_expressions(*args, **kwargs)
10511117

1052-
# Override TensorHolder's _PostCall to also handle direction
10531118
class MyPostCall(_PostCall):
10541119
def __call__(self_pc, field, value, evaluated, symbol_table):
10551120
if field == "bits_per_value":
1056-
evaluated = _eval_tensor2bits(
1121+
evaluated = _eval_tensor2number(
10571122
evaluated,
10581123
location="bits_per_value",
10591124
symbol_table=symbol_table,
10601125
)
1126+
if field == "values_per_action":
1127+
evaluated = _eval_tensor2number(
1128+
evaluated,
1129+
location="values_per_action",
1130+
symbol_table=symbol_table,
1131+
)
10611132
if field == "direction":
10621133
evaluated = _eval_direction(
10631134
evaluated,
10641135
symbol_table=symbol_table,
10651136
)
10661137
return evaluated
10671138

1068-
# Skip TensorHolder's _eval_expressions (which adds its own post_calls
1069-
# for bits_per_value) since we handle it here too
10701139
return Component._eval_expressions(
10711140
self, *args, **kwargs, post_calls=(MyPostCall(),)
10721141
)
@@ -1135,7 +1204,7 @@ def _eval_expressions(self, *args, **kwargs):
11351204
class MyPostCall(_PostCall):
11361205
def __call__(self, field, value, evaluated, symbol_table):
11371206
if field == "bits_per_value":
1138-
evaluated = _eval_tensor2bits(
1207+
evaluated = _eval_tensor2number(
11391208
evaluated,
11401209
location="bits_per_value",
11411210
symbol_table=symbol_table,

accelforge/frontend/arch/spatialable.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,6 @@ class Spatial(EvalableModel):
6464
will be power gated if not used by a particular Einsum.
6565
"""
6666

67-
allow_imperfect_spatial_loops: EvalsTo[bool] = False
68-
"""
69-
If True, spatial loops over this fanout are allowed to not-perfectly divide the full
70-
rank shape, which may let us find mappings with better utilization. For example, if
71-
the full rank shape is 7, then allow_imperfect_spatial_loops=False would only permit
72-
a spatial loop of size 7, while allow_imperfect_spatial_loops=True would allow
73-
spatial loops of size 1, 2, 3, 4, and 7. If our spatial fanout is of size 4, then we
74-
could do one tile of size 4 and another tile of size 3, with one unit of padding
75-
that is skipped.
76-
77-
Only "simple" rank variables-- those that appear alone and not as part of an
78-
expression-- may have imperfect loops.
79-
"""
80-
8167

8268
class Spatialable(EvalableModel):
8369
"""Something that can be duplicated to create an array of."""

accelforge/frontend/mapper/ffm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,38 @@ class FFM(EvalableModel):
124124
are so many templates being generated?).
125125
"""
126126

127+
explore_imperfect_spatial_loops: bool = False
128+
"""
129+
If True, spatial loop bounds may not perfectly divide the full rank shape. This
130+
takes longer to explore and requires more RAM, but mappings found may have better
131+
spatial utilization. This is especially helpful when the rank shapes have few prime
132+
factors.
133+
134+
For example, if the rank shape is 7, then explore_imperfect_spatial_loops=False
135+
would explore loop bounds of 1, 7 and explore_imperfect_spatial_loops=True would
136+
explore loop bounds of 1, 2, 3, 4, 7. This would be helfpul for a size-4 PE array,
137+
where we could get full utilization using 4 PEs in one timestep and 3 PEs in another
138+
timestep.
139+
140+
Only "simple" rank variables (those appearing alone and not inside an expression in
141+
any tensor access) may have imperfect loop bounds.
142+
"""
143+
144+
explore_imperfect_temporal_loops: bool = False
145+
"""
146+
If True, temporal loop bounds may not perfectly divide the full rank shape. This
147+
takes longer to explore and requires more RAM, but mappings found may have lower
148+
memory usage. This is especially helpful when the rank shapes have few prime
149+
factors.
150+
151+
For example, if the rank shape is 7, then explore_imperfect_temporal_loops=False
152+
would explore loop bounds of 1, 7 and explore_imperfect_temporal_loops=True would
153+
explore loop bounds of 1, 2, 3, 4, 7.
154+
155+
Only "simple" rank variables (those appearing alone and not inside an expression in
156+
any tensor access) may have imperfect loop bounds.
157+
"""
158+
127159
prioritize_reuse_of_unfused_tensors: bool = False
128160
"""
129161
If set to True, then for all memory levels, the mapper will place the storage nodes

accelforge/frontend/mapping/mapping.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ class Loop(MappingNode):
382382
""" Whether this Loop is shared with another Einsum. """
383383

384384
_may_cause_imperfect: bool = False
385+
"""
386+
This means that the tile shape of this loop may not perfectly factorize the rank
387+
shape.
388+
"""
385389

386390
model_config = ConfigDict(arbitrary_types_allowed=True)
387391

accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _loop_bound_constraint_from_no_refetch_and_resend(
334334
if (
335335
isinstance(mapping[end_index], TensorHolder)
336336
and n in mapping[end_index].tensors
337-
and mapping[end_index].component == m.name
337+
and mapping[end_index].component == arch_node.name
338338
):
339339
break
340340
end_index += 1
@@ -363,7 +363,7 @@ def _loop_bound_constraint_from_no_refetch_and_resend(
363363
if (
364364
isinstance(mapping[start_index], TensorHolder)
365365
and n in mapping[start_index].tensors
366-
and mapping[start_index].component == m.name
366+
and mapping[start_index].component == arch_node.name
367367
):
368368
break
369369
start_index += 1
@@ -375,7 +375,7 @@ def _loop_bound_constraint_from_no_refetch_and_resend(
375375
and n in mapping[end_index].tensors
376376
):
377377
# Can't have two tensor holders for the same tensor + component
378-
assert mapping[end_index].component != m.name
378+
assert mapping[end_index].component != arch_node.name
379379
break
380380
end_index += 1
381381

0 commit comments

Comments
 (0)