Skip to content

Commit 799fb33

Browse files
Named tuples are returned from action calls
1 parent 024d124 commit 799fb33

1 file changed

Lines changed: 98 additions & 82 deletions

File tree

hwcomponents/model.py

Lines changed: 98 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,26 @@
22
import inspect
33
from numbers import Number
44
from functools import wraps
5-
from typing import Any, Callable, List, Type, Union, TypeVar
5+
from typing import Any, Callable, List, NamedTuple, Type, Union, TypeVar
66
from hwcomponents._logging import ListLoggable, messages_from_logger, pop_all_messages
77
from hwcomponents._util import parse_float
88

99
T = TypeVar("T", bound="ComponentModel")
1010

1111

12+
class EnergyLatency(NamedTuple):
13+
energy: float
14+
latency: float
15+
16+
def __add__(self, other: "EnergyLatency") -> "EnergyLatency":
17+
return EnergyLatency(self.energy + other.energy, self.latency + other.latency)
18+
19+
1220
def action(
1321
func: Callable[..., Union[float, tuple[float, float]]] = None,
1422
bits_per_action: str = None,
1523
pipelined_subcomponents: bool = False,
16-
) -> Callable[..., tuple]:
24+
) -> Callable[..., EnergyLatency]:
1725
"""
1826
Decorator that adds an action to an energy/area model. If the component has no
1927
subcomponents, then the action is expected to return a tuple of (energy, latency)
@@ -52,92 +60,100 @@ def action(
5260

5361
@wraps(func)
5462
def wrapper(self: "ComponentModel", *args, **kwargs):
55-
self.logger.info("")
56-
self.logger.info(
57-
f"Calling action {self.__class__.__name__}.{func.__name__} with arguments {args} and {kwargs}"
58-
)
59-
for subcomponent in self.subcomponents:
60-
subcomponent._energy_used = 0
61-
subcomponent._latency_used = 0
62-
scale = 1
63-
scalestr = None
64-
if bits_per_action is not None and "bits_per_action" in kwargs:
65-
nominal_bits = None
66-
try:
67-
nominal_bits = getattr(self, bits_per_action)
68-
except:
69-
pass
70-
if nominal_bits is None:
71-
raise ValueError(
72-
f"{self.__name__} has no attribute {bits_per_action}. "
73-
f"Ensure that the attributes referenced in @action "
74-
f"are defined in the class."
75-
)
76-
scale = kwargs["bits_per_action"] / nominal_bits
77-
scalestr = f"Scaling by {kwargs['bits_per_action']=} / {nominal_bits=}"
78-
kwargs = {k: v for k, v in kwargs.items() if k not in additional_kwargs}
79-
returned_value = func(self, *args, **kwargs)
80-
# Normalize return to (energy, latency)
81-
if returned_value is None:
82-
if not self._subcomponents_set and not self.subcomponents:
63+
was_already_calling_action = getattr(self, "_currently_calling_action", False)
64+
self._currently_calling_action = True
65+
try:
66+
self.logger.info("")
67+
self.logger.info(
68+
f"Calling action {self.__class__.__name__}.{func.__name__} with arguments {args} and {kwargs}"
69+
)
70+
for subcomponent in self.subcomponents:
71+
subcomponent._energy_used = 0
72+
subcomponent._latency_used = 0
73+
scale = 1
74+
scalestr = None
75+
if bits_per_action is not None and "bits_per_action" in kwargs:
76+
nominal_bits = None
77+
try:
78+
nominal_bits = getattr(self, bits_per_action)
79+
except:
80+
pass
81+
if nominal_bits is None:
82+
raise ValueError(
83+
f"{self.__name__} has no attribute {bits_per_action}. "
84+
f"Ensure that the attributes referenced in @action "
85+
f"are defined in the class."
86+
)
87+
scale = kwargs["bits_per_action"] / nominal_bits
88+
scalestr = f"Scaling by {kwargs['bits_per_action']=} / {nominal_bits=}"
89+
kwargs = {k: v for k, v in kwargs.items() if k not in additional_kwargs}
90+
returned_value = func(self, *args, **kwargs)
91+
# Normalize return to (energy, latency)
92+
if returned_value is None:
93+
if not self._subcomponents_set and not self.subcomponents:
94+
raise ValueError(
95+
f"@action function {func.__name__} did not return a value. "
96+
f"This is permitted if and only if the component has no "
97+
f"subcomponents. Please either initialize subcomponents or ensure "
98+
f"that the @action function returns a tuple of (energy, latency)."
99+
)
100+
energy_val, latency_val = 0.0, 0.0
101+
elif isinstance(returned_value, (tuple, list)) and len(returned_value) == 2:
102+
energy_val, latency_val = returned_value
103+
else:
83104
raise ValueError(
84-
f"@action function {func.__name__} did not return a value. "
85-
f"This is permitted if and only if the component has no "
86-
f"subcomponents. Please either initialize subcomponents or ensure "
87-
f"that the @action function returns a tuple of (energy, latency)."
105+
f"@action function {func.__name__} returned an invalid value. "
106+
f"Expected a tuple of (energy, latency), got {returned_value}."
88107
)
89-
energy_val, latency_val = 0.0, 0.0
90-
elif isinstance(returned_value, (tuple, list)) and len(returned_value) == 2:
91-
energy_val, latency_val = returned_value
92-
else:
93-
raise ValueError(
94-
f"@action function {func.__name__} returned an invalid value. "
95-
f"Expected a tuple of (energy, latency), got {returned_value}."
96-
)
97108

98-
self.logger.info(
99-
f"Function {func.__name__} returned energy {energy_val} and latency {latency_val}"
100-
)
101-
if scalestr is not None:
102-
self.logger.info(scalestr)
103-
104-
energy_val *= self.energy_scale
105-
if self.energy_scale != 1:
106-
self.logger.info(f"Scaling energy by {self.energy_scale=}")
107-
for subcomponent in self.subcomponents:
108-
self.logger.info(
109-
f"Adding subcomponent {subcomponent.__class__.__name__} energy {subcomponent._energy_used}"
110-
)
111-
energy_val += subcomponent._energy_used
112-
subcomponent._energy_used = 0
113-
energy_val *= scale
114-
self._energy_used += energy_val
115-
116-
latency_val *= self.latency_scale
117-
if self.latency_scale != 1:
118-
self.logger.info(f"Scaling latency by {self.latency_scale=}")
119-
target_func = max if pipelined_subcomponents else sum
120-
x = "Max" if pipelined_subcomponents else "Summ"
121-
for subcomponent in self.subcomponents:
122109
self.logger.info(
123-
f"{x}ing subcomponent {subcomponent.__class__.__name__} latency {subcomponent._latency_used}"
110+
f"Function {func.__name__} returned energy {energy_val} and latency {latency_val}"
124111
)
125-
latency_val = target_func((latency_val, subcomponent._latency_used))
126-
subcomponent._latency_used = 0
127-
latency_val *= scale
128-
self._latency_used += latency_val
129-
130-
for subcomponent in self.subcomponents:
131-
self.logger.info(f"Log for subcomponent {subcomponent.__class__.__name__}:")
132-
for message in pop_all_messages(subcomponent.logger):
133-
if message:
134-
self.logger.info(f"\t{message}")
135-
136-
self.logger.info(
137-
f"** Final return value for {self.__class__.__name__}.{func.__name__}: energy {energy_val} and latency {latency_val}"
138-
)
112+
if scalestr is not None:
113+
self.logger.info(scalestr)
114+
115+
if not was_already_calling_action:
116+
energy_val *= self.energy_scale
117+
if self.energy_scale != 1:
118+
self.logger.info(f"Scaling energy by {self.energy_scale=}")
119+
for subcomponent in self.subcomponents:
120+
self.logger.info(
121+
f"Adding subcomponent {subcomponent.__class__.__name__} energy {subcomponent._energy_used}"
122+
)
123+
energy_val += subcomponent._energy_used
124+
subcomponent._energy_used = 0
125+
energy_val *= scale
126+
self._energy_used += energy_val
127+
128+
latency_val *= self.latency_scale
129+
if self.latency_scale != 1:
130+
self.logger.info(f"Scaling latency by {self.latency_scale=}")
131+
target_func = max if pipelined_subcomponents else sum
132+
x = "Max" if pipelined_subcomponents else "Summ"
133+
for subcomponent in self.subcomponents:
134+
self.logger.info(
135+
f"{x}ing subcomponent {subcomponent.__class__.__name__} latency {subcomponent._latency_used}"
136+
)
137+
latency_val = target_func((latency_val, subcomponent._latency_used))
138+
subcomponent._latency_used = 0
139+
latency_val *= scale
140+
self._latency_used += latency_val
141+
142+
for subcomponent in self.subcomponents:
143+
self.logger.info(f"Log for subcomponent {subcomponent.__class__.__name__}:")
144+
for message in pop_all_messages(subcomponent.logger):
145+
if message:
146+
self.logger.info(f"\t{message}")
147+
148+
self.logger.info(
149+
f"** Final return value for {self.__class__.__name__}.{func.__name__}: energy {energy_val} and latency {latency_val}"
150+
)
139151

140-
return energy_val, latency_val
152+
return EnergyLatency(energy_val, latency_val)
153+
except:
154+
raise
155+
finally:
156+
self._currently_calling_action = False
141157

142158
wrapper._is_component_action = True
143159
wrapper._original_function = func

0 commit comments

Comments
 (0)