11import copy
22import itertools
33import logging
4+ from numbers import Number
45from accelforge .util ._frozenset import oset
56from typing import (
67 Any ,
@@ -90,8 +91,35 @@ class TensorHolderAction(Action):
9091 bits_per_action : EvalsTo [int | float ] = (
9192 "1 if bits_per_action is None else bits_per_action"
9293 )
93- """ The number of bits accessed in this action. For example, setting bits_per_action
94- to 16 means that each call to this action yields 16 bits. """
94+ """
95+ The number of bits accessed in this action. For example, setting bits_per_action to
96+ 16 means that each call to this action yields 16 bits. Overridden by
97+ values_per_action in this action or by the parent component's values per action.
98+ """
99+
100+ values_per_action : EvalsTo [dict ] = {}
101+ """
102+ Sets the number of tensor values that are accessed by each call of this action. Keys
103+ are evaluated as expressions and may reference one or more tensors. Overrides
104+ bits_per_action, and sets bits_per_action to values_per_action[tensor] *
105+ bits_per_value[tensor].
106+ """
107+
108+ def _eval_expressions (self , * args , ** kwargs ):
109+ if getattr (self , "_evaluated" , False ):
110+ return super ()._eval_expressions (* args , ** kwargs )
111+
112+ class MyPostCall (_PostCall ):
113+ def __call__ (self , field , value , evaluated , symbol_table ):
114+ if field == "values_per_action" :
115+ evaluated = _eval_tensor2number (
116+ evaluated ,
117+ location = "values_per_action" ,
118+ symbol_table = symbol_table ,
119+ )
120+ return evaluated
121+
122+ return super ()._eval_expressions (* args , ** kwargs , post_calls = (MyPostCall (),))
95123
96124
97125_COMPONENT_MODEL_CACHE : dict [tuple , "Component" ] = {}
@@ -198,6 +226,13 @@ class Component(Spatialable):
198226 this action's energy. Multiplies the calculated energy of each action.
199227 """
200228
229+ actions_scale : EvalsTo [int | float ] = 1
230+ """
231+ Scales the number of actions performed by this component. Multiplies the action
232+ count for each action of this component, which proportionally increases this
233+ component's energy and latency.
234+ """
235+
201236 total_latency : str | int | float = "sum(*action2latency.values())"
202237 """
203238 An expression representing the total latency of this component in seconds. This is
@@ -232,8 +267,8 @@ class Component(Spatialable):
232267 n_parallel_instances : EvalsTo [int | float ] = 1
233268 """
234269 The number of parallel instances of this component. Increasing parallel instances
235- will proportionally increase area and leakage, while reducing latency (unless
236- latency calculation is overridden).
270+ will proportionally increase area and leakage while reducing latency (unless latency
271+ calculation is overridden).
237272 """
238273
239274 extra_attributes_for_component_model : _ExtraAttrs = _ExtraAttrs ()
@@ -727,7 +762,7 @@ def _copy_for_component_modeling(self) -> Self:
727762)
728763
729764
730- def _eval_tensor2bits (
765+ def _eval_tensor2number (
731766 toeval : dict [str , Any ],
732767 location : str ,
733768 symbol_table : dict [str , Any ],
@@ -936,7 +971,17 @@ class TensorHolder(Component, Leaf):
936971 """
937972 The number of bits accessed in each of this component's actions. Overridden by
938973 bits_per_action in any action of this component. If set here, acts as a default
939- value for the bits_per_action of all actions of this component.
974+ value for the bits_per_action of all actions of this component. Overridden by
975+ values_per_action or by values in each action.
976+ """
977+
978+ values_per_action : EvalsTo [dict ] = {}
979+ """
980+ Sets the number of tensor values that are accessed by each action of this
981+ `TensorHolder`. Keys are evaluated as expressions and may reference one or more
982+ tensors. Overrides bits_per_action, and sets bits_per_action to
983+ values_per_action[tensor] * bits_per_value[tensor]. Overridden by values_per_action
984+ in any action of this component.
940985 """
941986
942987 def model_post_init (self , __context__ = None ) -> None :
@@ -951,15 +996,36 @@ def _eval_expressions(self, *args, **kwargs):
951996 class MyPostCall (_PostCall ):
952997 def __call__ (self , field , value , evaluated , symbol_table ):
953998 if field == "bits_per_value" :
954- evaluated = _eval_tensor2bits (
999+ evaluated = _eval_tensor2number (
9551000 evaluated ,
9561001 location = "bits_per_value" ,
9571002 symbol_table = symbol_table ,
9581003 )
1004+ if field == "values_per_action" :
1005+ evaluated = _eval_tensor2number (
1006+ evaluated ,
1007+ location = "values_per_action" ,
1008+ symbol_table = symbol_table ,
1009+ )
9591010 return evaluated
9601011
9611012 return super ()._eval_expressions (* args , ** kwargs , post_calls = (MyPostCall (),))
9621013
1014+ def _get_values_per_action (
1015+ self , action_name : str , tensor_name : TensorName , bits_per_value_default : Number
1016+ ):
1017+ action = self .actions [action_name ]
1018+
1019+ if tensor_name in action .values_per_action :
1020+ return action .values_per_action [tensor_name ]
1021+ if tensor_name in self .values_per_action :
1022+ return self .values_per_action [tensor_name ]
1023+
1024+ tensor_bpv = self .bits_per_value .get (tensor_name , bits_per_value_default )
1025+ action_bpa = action .bits_per_action
1026+
1027+ return action_bpa / tensor_bpv
1028+
9631029
9641030class Container (Leaf , Spatialable ):
9651031 """
@@ -1049,24 +1115,27 @@ def _eval_expressions(self, *args, **kwargs):
10491115 if getattr (self , "_evaluated" , False ):
10501116 return super ()._eval_expressions (* args , ** kwargs )
10511117
1052- # Override TensorHolder's _PostCall to also handle direction
10531118 class MyPostCall (_PostCall ):
10541119 def __call__ (self_pc , field , value , evaluated , symbol_table ):
10551120 if field == "bits_per_value" :
1056- evaluated = _eval_tensor2bits (
1121+ evaluated = _eval_tensor2number (
10571122 evaluated ,
10581123 location = "bits_per_value" ,
10591124 symbol_table = symbol_table ,
10601125 )
1126+ if field == "values_per_action" :
1127+ evaluated = _eval_tensor2number (
1128+ evaluated ,
1129+ location = "values_per_action" ,
1130+ symbol_table = symbol_table ,
1131+ )
10611132 if field == "direction" :
10621133 evaluated = _eval_direction (
10631134 evaluated ,
10641135 symbol_table = symbol_table ,
10651136 )
10661137 return evaluated
10671138
1068- # Skip TensorHolder's _eval_expressions (which adds its own post_calls
1069- # for bits_per_value) since we handle it here too
10701139 return Component ._eval_expressions (
10711140 self , * args , ** kwargs , post_calls = (MyPostCall (),)
10721141 )
@@ -1135,7 +1204,7 @@ def _eval_expressions(self, *args, **kwargs):
11351204 class MyPostCall (_PostCall ):
11361205 def __call__ (self , field , value , evaluated , symbol_table ):
11371206 if field == "bits_per_value" :
1138- evaluated = _eval_tensor2bits (
1207+ evaluated = _eval_tensor2number (
11391208 evaluated ,
11401209 location = "bits_per_value" ,
11411210 symbol_table = symbol_table ,
0 commit comments