|
2 | 2 | import inspect |
3 | 3 | from numbers import Number |
4 | 4 | from functools import wraps |
5 | | -from typing import Any, Callable, List, Type, Union, TypeVar |
| 5 | +from typing import Any, Callable, List, NamedTuple, Type, Union, TypeVar |
6 | 6 | from hwcomponents._logging import ListLoggable, messages_from_logger, pop_all_messages |
7 | 7 | from hwcomponents._util import parse_float |
8 | 8 |
|
9 | 9 | T = TypeVar("T", bound="ComponentModel") |
10 | 10 |
|
11 | 11 |
|
| 12 | +class EnergyLatency(NamedTuple): |
| 13 | + energy: float |
| 14 | + latency: float |
| 15 | + |
| 16 | + def __add__(self, other: "EnergyLatency") -> "EnergyLatency": |
| 17 | + return EnergyLatency(self.energy + other.energy, self.latency + other.latency) |
| 18 | + |
| 19 | + |
12 | 20 | def action( |
13 | 21 | func: Callable[..., Union[float, tuple[float, float]]] = None, |
14 | 22 | bits_per_action: str = None, |
15 | 23 | pipelined_subcomponents: bool = False, |
16 | | -) -> Callable[..., tuple]: |
| 24 | +) -> Callable[..., EnergyLatency]: |
17 | 25 | """ |
18 | 26 | Decorator that adds an action to an energy/area model. If the component has no |
19 | 27 | subcomponents, then the action is expected to return a tuple of (energy, latency) |
@@ -52,92 +60,100 @@ def action( |
52 | 60 |
|
53 | 61 | @wraps(func) |
54 | 62 | def wrapper(self: "ComponentModel", *args, **kwargs): |
55 | | - self.logger.info("") |
56 | | - self.logger.info( |
57 | | - f"Calling action {self.__class__.__name__}.{func.__name__} with arguments {args} and {kwargs}" |
58 | | - ) |
59 | | - for subcomponent in self.subcomponents: |
60 | | - subcomponent._energy_used = 0 |
61 | | - subcomponent._latency_used = 0 |
62 | | - scale = 1 |
63 | | - scalestr = None |
64 | | - if bits_per_action is not None and "bits_per_action" in kwargs: |
65 | | - nominal_bits = None |
66 | | - try: |
67 | | - nominal_bits = getattr(self, bits_per_action) |
68 | | - except: |
69 | | - pass |
70 | | - if nominal_bits is None: |
71 | | - raise ValueError( |
72 | | - f"{self.__name__} has no attribute {bits_per_action}. " |
73 | | - f"Ensure that the attributes referenced in @action " |
74 | | - f"are defined in the class." |
75 | | - ) |
76 | | - scale = kwargs["bits_per_action"] / nominal_bits |
77 | | - scalestr = f"Scaling by {kwargs['bits_per_action']=} / {nominal_bits=}" |
78 | | - kwargs = {k: v for k, v in kwargs.items() if k not in additional_kwargs} |
79 | | - returned_value = func(self, *args, **kwargs) |
80 | | - # Normalize return to (energy, latency) |
81 | | - if returned_value is None: |
82 | | - if not self._subcomponents_set and not self.subcomponents: |
| 63 | + was_already_calling_action = getattr(self, "_currently_calling_action", False) |
| 64 | + self._currently_calling_action = True |
| 65 | + try: |
| 66 | + self.logger.info("") |
| 67 | + self.logger.info( |
| 68 | + f"Calling action {self.__class__.__name__}.{func.__name__} with arguments {args} and {kwargs}" |
| 69 | + ) |
| 70 | + for subcomponent in self.subcomponents: |
| 71 | + subcomponent._energy_used = 0 |
| 72 | + subcomponent._latency_used = 0 |
| 73 | + scale = 1 |
| 74 | + scalestr = None |
| 75 | + if bits_per_action is not None and "bits_per_action" in kwargs: |
| 76 | + nominal_bits = None |
| 77 | + try: |
| 78 | + nominal_bits = getattr(self, bits_per_action) |
| 79 | + except: |
| 80 | + pass |
| 81 | + if nominal_bits is None: |
| 82 | + raise ValueError( |
| 83 | + f"{self.__name__} has no attribute {bits_per_action}. " |
| 84 | + f"Ensure that the attributes referenced in @action " |
| 85 | + f"are defined in the class." |
| 86 | + ) |
| 87 | + scale = kwargs["bits_per_action"] / nominal_bits |
| 88 | + scalestr = f"Scaling by {kwargs['bits_per_action']=} / {nominal_bits=}" |
| 89 | + kwargs = {k: v for k, v in kwargs.items() if k not in additional_kwargs} |
| 90 | + returned_value = func(self, *args, **kwargs) |
| 91 | + # Normalize return to (energy, latency) |
| 92 | + if returned_value is None: |
| 93 | + if not self._subcomponents_set and not self.subcomponents: |
| 94 | + raise ValueError( |
| 95 | + f"@action function {func.__name__} did not return a value. " |
| 96 | + f"This is permitted if and only if the component has no " |
| 97 | + f"subcomponents. Please either initialize subcomponents or ensure " |
| 98 | + f"that the @action function returns a tuple of (energy, latency)." |
| 99 | + ) |
| 100 | + energy_val, latency_val = 0.0, 0.0 |
| 101 | + elif isinstance(returned_value, (tuple, list)) and len(returned_value) == 2: |
| 102 | + energy_val, latency_val = returned_value |
| 103 | + else: |
83 | 104 | raise ValueError( |
84 | | - f"@action function {func.__name__} did not return a value. " |
85 | | - f"This is permitted if and only if the component has no " |
86 | | - f"subcomponents. Please either initialize subcomponents or ensure " |
87 | | - f"that the @action function returns a tuple of (energy, latency)." |
| 105 | + f"@action function {func.__name__} returned an invalid value. " |
| 106 | + f"Expected a tuple of (energy, latency), got {returned_value}." |
88 | 107 | ) |
89 | | - energy_val, latency_val = 0.0, 0.0 |
90 | | - elif isinstance(returned_value, (tuple, list)) and len(returned_value) == 2: |
91 | | - energy_val, latency_val = returned_value |
92 | | - else: |
93 | | - raise ValueError( |
94 | | - f"@action function {func.__name__} returned an invalid value. " |
95 | | - f"Expected a tuple of (energy, latency), got {returned_value}." |
96 | | - ) |
97 | 108 |
|
98 | | - self.logger.info( |
99 | | - f"Function {func.__name__} returned energy {energy_val} and latency {latency_val}" |
100 | | - ) |
101 | | - if scalestr is not None: |
102 | | - self.logger.info(scalestr) |
103 | | - |
104 | | - energy_val *= self.energy_scale |
105 | | - if self.energy_scale != 1: |
106 | | - self.logger.info(f"Scaling energy by {self.energy_scale=}") |
107 | | - for subcomponent in self.subcomponents: |
108 | | - self.logger.info( |
109 | | - f"Adding subcomponent {subcomponent.__class__.__name__} energy {subcomponent._energy_used}" |
110 | | - ) |
111 | | - energy_val += subcomponent._energy_used |
112 | | - subcomponent._energy_used = 0 |
113 | | - energy_val *= scale |
114 | | - self._energy_used += energy_val |
115 | | - |
116 | | - latency_val *= self.latency_scale |
117 | | - if self.latency_scale != 1: |
118 | | - self.logger.info(f"Scaling latency by {self.latency_scale=}") |
119 | | - target_func = max if pipelined_subcomponents else sum |
120 | | - x = "Max" if pipelined_subcomponents else "Summ" |
121 | | - for subcomponent in self.subcomponents: |
122 | 109 | self.logger.info( |
123 | | - f"{x}ing subcomponent {subcomponent.__class__.__name__} latency {subcomponent._latency_used}" |
| 110 | + f"Function {func.__name__} returned energy {energy_val} and latency {latency_val}" |
124 | 111 | ) |
125 | | - latency_val = target_func((latency_val, subcomponent._latency_used)) |
126 | | - subcomponent._latency_used = 0 |
127 | | - latency_val *= scale |
128 | | - self._latency_used += latency_val |
129 | | - |
130 | | - for subcomponent in self.subcomponents: |
131 | | - self.logger.info(f"Log for subcomponent {subcomponent.__class__.__name__}:") |
132 | | - for message in pop_all_messages(subcomponent.logger): |
133 | | - if message: |
134 | | - self.logger.info(f"\t{message}") |
135 | | - |
136 | | - self.logger.info( |
137 | | - f"** Final return value for {self.__class__.__name__}.{func.__name__}: energy {energy_val} and latency {latency_val}" |
138 | | - ) |
| 112 | + if scalestr is not None: |
| 113 | + self.logger.info(scalestr) |
| 114 | + |
| 115 | + if not was_already_calling_action: |
| 116 | + energy_val *= self.energy_scale |
| 117 | + if self.energy_scale != 1: |
| 118 | + self.logger.info(f"Scaling energy by {self.energy_scale=}") |
| 119 | + for subcomponent in self.subcomponents: |
| 120 | + self.logger.info( |
| 121 | + f"Adding subcomponent {subcomponent.__class__.__name__} energy {subcomponent._energy_used}" |
| 122 | + ) |
| 123 | + energy_val += subcomponent._energy_used |
| 124 | + subcomponent._energy_used = 0 |
| 125 | + energy_val *= scale |
| 126 | + self._energy_used += energy_val |
| 127 | + |
| 128 | + latency_val *= self.latency_scale |
| 129 | + if self.latency_scale != 1: |
| 130 | + self.logger.info(f"Scaling latency by {self.latency_scale=}") |
| 131 | + target_func = max if pipelined_subcomponents else sum |
| 132 | + x = "Max" if pipelined_subcomponents else "Summ" |
| 133 | + for subcomponent in self.subcomponents: |
| 134 | + self.logger.info( |
| 135 | + f"{x}ing subcomponent {subcomponent.__class__.__name__} latency {subcomponent._latency_used}" |
| 136 | + ) |
| 137 | + latency_val = target_func((latency_val, subcomponent._latency_used)) |
| 138 | + subcomponent._latency_used = 0 |
| 139 | + latency_val *= scale |
| 140 | + self._latency_used += latency_val |
| 141 | + |
| 142 | + for subcomponent in self.subcomponents: |
| 143 | + self.logger.info(f"Log for subcomponent {subcomponent.__class__.__name__}:") |
| 144 | + for message in pop_all_messages(subcomponent.logger): |
| 145 | + if message: |
| 146 | + self.logger.info(f"\t{message}") |
| 147 | + |
| 148 | + self.logger.info( |
| 149 | + f"** Final return value for {self.__class__.__name__}.{func.__name__}: energy {energy_val} and latency {latency_val}" |
| 150 | + ) |
139 | 151 |
|
140 | | - return energy_val, latency_val |
| 152 | + return EnergyLatency(energy_val, latency_val) |
| 153 | + except: |
| 154 | + raise |
| 155 | + finally: |
| 156 | + self._currently_calling_action = False |
141 | 157 |
|
142 | 158 | wrapper._is_component_action = True |
143 | 159 | wrapper._original_function = func |
|
0 commit comments