Skip to content

Commit dc23d9b

Browse files
private _return_estimation_object
1 parent b5d4621 commit dc23d9b

6 files changed

Lines changed: 69 additions & 44 deletions

File tree

hwcomponents/_model_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from types import ModuleType
55
from typing import Any, Callable, Dict, List, Optional, Set, Union
66
from .model import EnergyAreaModel
7-
from ._logging import move_queue_from_one_logger_to_another, ListLoggable, pop_all_messages
7+
from ._logging import (
8+
move_queue_from_one_logger_to_another,
9+
ListLoggable,
10+
pop_all_messages,
11+
)
812

913

1014
class EstimatorError(Exception):

hwcomponents/_version.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
commit_id: COMMIT_ID
2929
__commit_id__: COMMIT_ID
3030

31-
__version__ = version = '5.0.7'
32-
__version_tuple__ = version_tuple = (5, 0, 7)
31+
__version__ = version = "5.0.12"
32+
__version_tuple__ = version_tuple = (5, 0, 12)
3333

34-
__commit_id__ = commit_id = 'ga2a56804b'
34+
__commit_id__ = commit_id = "gb5d4621f3"

hwcomponents/_version_scheme.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Version scheme for setuptools-scm - creates post-release versions."""
2+
23
from setuptools_scm.version import guess_next_version
34

45

56
def post_version(version):
67
"""Create post-release versions instead of dev versions."""
78
if version.exact:
8-
return version.format_with("{tag}").lstrip('v')
9+
return version.format_with("{tag}").lstrip("v")
910

10-
base = str(version.tag).lstrip('v') if version.tag else (guess_next_version(version) or "1.0")
11+
base = (
12+
str(version.tag).lstrip("v")
13+
if version.tag
14+
else (guess_next_version(version) or "1.0")
15+
)
1116
distance = version.distance or 0
1217

1318
return f"{base}.{distance}" if distance > 0 else base

hwcomponents/find_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def get_models(
155155
globbed = glob.glob(path_or_package, recursive=True)
156156
flattened.extend(globbed)
157157
else:
158-
raise ValueError(f'Invalid type: {type(path_or_package)}')
158+
raise ValueError(f"Invalid type: {type(path_or_package)}")
159159

160160
if _return_wrappers:
161161
models = [EnergyAreaModelWrapper(m, m.__name__) for m in models]
@@ -239,7 +239,9 @@ def get_models(
239239
models.extend(new_models)
240240

241241
if _return_wrappers:
242-
models = [m for m in models if name_must_include.lower() in m.model_name.lower()]
242+
models = [
243+
m for m in models if name_must_include.lower() in m.model_name.lower()
244+
]
243245
return sorted(models, key=lambda x: x.model_name)
244246
models = [m for m in models if name_must_include.lower() in m.__name__.lower()]
245247
return sorted(models, key=lambda x: x.__name__)

hwcomponents/model.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
T = TypeVar("T", bound="EnergyAreaModel")
1010

11+
1112
def actionDynamicEnergy(
1213
func: Callable[[], float] = None, bits_per_action: str = None
1314
) -> Callable[[], float]:
@@ -157,11 +158,11 @@ class EnergyAreaModel(ListLoggable, ABC):
157158

158159
@abstractmethod
159160
def __init__(
160-
self,
161-
leak_power: float | None = None,
162-
area: float | None = None,
163-
subcomponents: list["EnergyAreaModel"] | None = None,
164-
):
161+
self,
162+
leak_power: float | None = None,
163+
area: float | None = None,
164+
subcomponents: list["EnergyAreaModel"] | None = None,
165+
):
165166
if subcomponents is None:
166167
if leak_power is None or area is None:
167168
raise ValueError(
@@ -174,7 +175,9 @@ def __init__(
174175
self.leak_scale: float = 1
175176
self._leak_power: float = leak_power if leak_power is not None else 0
176177
self._area: float = area if area is not None else 0
177-
self.subcomponents: list["EnergyAreaModel"] = [] if subcomponents is None else subcomponents
178+
self.subcomponents: list["EnergyAreaModel"] = (
179+
[] if subcomponents is None else subcomponents
180+
)
178181
self._subcomponents_set = subcomponents is not None
179182
self._energy_used: float = 0
180183

@@ -200,9 +203,7 @@ def area(self) -> Number:
200203
-------
201204
The area in m^2 of the component.
202205
"""
203-
return self._area * self.area_scale + sum(
204-
s.area for s in self.subcomponents
205-
)
206+
return self._area * self.area_scale + sum(s.area for s in self.subcomponents)
206207

207208
@classmethod
208209
def _component_name(cls) -> str:
@@ -251,12 +252,8 @@ def scale(
251252
f"Scaled {key} from {default} to {target}: {attr} multiplied by {scale}"
252253
)
253254
except:
254-
target_float = parse_float(
255-
target, f"{self._component_name()}.{key}"
256-
)
257-
default_float = parse_float(
258-
default, f"{self._component_name()}.{key}"
259-
)
255+
target_float = parse_float(target, f"{self._component_name()}.{key}")
256+
default_float = parse_float(default, f"{self._component_name()}.{key}")
260257
scale = callfunc(target_float, default_float)
261258
setattr(self, attr, prev_val * scale)
262259
self.logger.info(
@@ -292,7 +289,6 @@ def required_arguments(self, action_name: str | None = None) -> List[str]:
292289
action_name : str | None
293290
The name of the action to get the required arguments for.
294291
If None, returns the required arguments for the __init__ method.
295-
296292
Returns
297293
-------
298294
list[str]
@@ -309,31 +305,39 @@ def required_arguments(self, action_name: str | None = None) -> List[str]:
309305
return inspect.signature(action_func).parameters.keys()
310306

311307
@classmethod
312-
def try_init_arbitrary_args(cls: Type[T], **kwargs) -> T:
308+
def try_init_arbitrary_args(
309+
cls: Type[T], _return_estimation_object: bool = False, **kwargs
310+
) -> T:
313311
"""
314312
Tries to initialize the model with the given arguments.
315313
316314
Parameters
317315
----------
318316
**kwargs : dict
319317
The arguments with which to initialize the model.
320-
318+
_return_estimation_object : bool
319+
Whether to return the Estimation object instead of the model.
321320
Returns
322321
-------
323322
The initialized model. If the model cannot be initialized with the given
324323
arguments, an exception is raised.
325324
"""
326325
from hwcomponents._model_wrapper import EnergyAreaQuery, EnergyAreaModelWrapper
326+
327327
wrapper = EnergyAreaModelWrapper(cls, cls.component_name)
328328
cname = cls.component_name
329329
query = EnergyAreaQuery(
330330
component_name=cname if isinstance(cname, str) else cname[0],
331331
component_attributes=kwargs,
332332
)
333-
return wrapper.get_initialized_subclass(query)
334-
335-
336-
def try_call_arbitrary_action(self: T, action_name: str, **kwargs) -> T:
333+
value = wrapper.get_initialized_subclass(query)
334+
if _return_estimation_object:
335+
return value
336+
return value.value
337+
338+
def try_call_arbitrary_action(
339+
self: T, action_name: str, _return_estimation_object: bool = False, **kwargs
340+
) -> T:
337341
"""
338342
Tries to call the given action with the given arguments.
339343
@@ -345,11 +349,15 @@ def try_call_arbitrary_action(self: T, action_name: str, **kwargs) -> T:
345349
The arguments with which to call the action.
346350
"""
347351
from hwcomponents._model_wrapper import EnergyAreaQuery, EnergyAreaModelWrapper
352+
348353
wrapper = EnergyAreaModelWrapper(type(self), self.component_name)
349354
query = EnergyAreaQuery(
350355
component_name=self.component_name,
351356
component_attributes={},
352357
action_name=action_name,
353358
action_arguments=kwargs,
354359
)
355-
return wrapper.estimate_energy(query, initialized_obj=self)
360+
value = wrapper.estimate_energy(query, initialized_obj=self)
361+
if _return_estimation_object:
362+
return value
363+
return value.value

hwcomponents/select_models.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _get_best_estimate(
127127
query: EnergyAreaQuery,
128128
target: str,
129129
models: List[EnergyAreaModelWrapper] | List[EnergyAreaModel] = None,
130-
return_estimation_object: bool = False,
130+
_return_estimation_object: bool = False,
131131
_relaxed_component_name_selection: bool = False,
132132
) -> FloatEstimation | EnergyAreaModel:
133133
if models is None:
@@ -263,7 +263,7 @@ def _get_supported_models(relaxed_component_name_selection: bool):
263263

264264
clear_logs()
265265

266-
if return_estimation_object and estimation is not None:
266+
if _return_estimation_object and estimation is not None:
267267
return estimation
268268

269269
if estimation is not None and estimation.success:
@@ -285,7 +285,7 @@ def get_energy(
285285
action_name: str,
286286
action_arguments: Dict[str, Any],
287287
models: List[EnergyAreaModelWrapper] = None,
288-
return_estimation_object: bool = False,
288+
_return_estimation_object: bool = False,
289289
_relaxed_component_name_selection: bool = False,
290290
) -> float | Estimation:
291291
"""
@@ -301,7 +301,7 @@ def get_energy(
301301
action_name: The name of the action.
302302
action_arguments: The arguments of the action.
303303
models: The models to use.
304-
return_estimation_object: Whether to return the estimation object instead of
304+
_return_estimation_object: Whether to return the estimation object instead of
305305
the energy value.
306306
_relaxed_component_name_selection: Whether to relax the component name
307307
selection. Relaxed selection ignores underscores in the component name.
@@ -317,7 +317,7 @@ def get_energy(
317317
query,
318318
"energy",
319319
models,
320-
return_estimation_object,
320+
_return_estimation_object,
321321
_relaxed_component_name_selection,
322322
)
323323

@@ -326,7 +326,7 @@ def get_area(
326326
component_name: str,
327327
component_attributes: Dict[str, Any],
328328
models: List[EnergyAreaModelWrapper] = None,
329-
return_estimation_object: bool = False,
329+
_return_estimation_object: bool = False,
330330
_relaxed_component_name_selection: bool = False,
331331
) -> float | Estimation:
332332
"""
@@ -339,7 +339,7 @@ def get_area(
339339
component_name: The name of the component.
340340
component_attributes: The attributes of the component.
341341
models: The models to use.
342-
return_estimation_object: Whether to return the estimation object instead of
342+
_return_estimation_object: Whether to return the estimation object instead of
343343
the area value.
344344
_relaxed_component_name_selection: Whether to relax the component name
345345
selection. Relaxed selection ignores underscores in the component name.
@@ -353,7 +353,7 @@ def get_area(
353353
query,
354354
"area",
355355
models,
356-
return_estimation_object,
356+
_return_estimation_object,
357357
_relaxed_component_name_selection,
358358
)
359359

@@ -362,7 +362,7 @@ def get_leak_power(
362362
component_name: str,
363363
component_attributes: Dict[str, Any],
364364
models: List[EnergyAreaModelWrapper] = None,
365-
return_estimation_object: bool = False,
365+
_return_estimation_object: bool = False,
366366
_relaxed_component_name_selection: bool = False,
367367
) -> float | Estimation:
368368
"""
@@ -387,7 +387,7 @@ def get_leak_power(
387387
query,
388388
"leak_power",
389389
models,
390-
return_estimation_object,
390+
_return_estimation_object,
391391
_relaxed_component_name_selection,
392392
)
393393

@@ -397,7 +397,7 @@ def get_model(
397397
component_attributes: Dict[str, Any],
398398
required_actions: List[str] = (),
399399
models: List[EnergyAreaModelWrapper] = None,
400-
return_estimation_object: bool = False,
400+
_return_estimation_object: bool = False,
401401
_relaxed_component_name_selection: bool = False,
402402
) -> EnergyAreaModelWrapper:
403403
"""
@@ -411,7 +411,7 @@ def get_model(
411411
component_attributes: The attributes of the component.
412412
required_actions: The actions that are required for the component.
413413
models: The models to use.
414-
return_estimation_object: Whether to return the estimation object instead of
414+
_return_estimation_object: Whether to return the estimation object instead of
415415
the model wrapper.
416416
_relaxed_component_name_selection: Whether to relax the component name
417417
selection. Relaxed selection ignores underscores in the component name.
@@ -423,4 +423,10 @@ def get_model(
423423
query = EnergyAreaQuery(
424424
component_name.lower(), component_attributes, None, None, required_actions
425425
)
426-
return _get_best_estimate(query, "model", models, return_estimation_object, _relaxed_component_name_selection)
426+
return _get_best_estimate(
427+
query,
428+
"model",
429+
models,
430+
_return_estimation_object,
431+
_relaxed_component_name_selection,
432+
)

0 commit comments

Comments
 (0)