Skip to content

Commit 49559c1

Browse files
authored
Better performance lineage chunk_number assignment (#1000)
1 parent cd79adc commit 49559c1

3 files changed

Lines changed: 118 additions & 98 deletions

File tree

strax/context.py

Lines changed: 117 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
RUN_DEFAULTS_KEY = "strax_defaults"
2626
TEMP_DATA_TYPE_PREFIX = "_temp_"
27-
NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)
27+
NOT_PER_CHUNK_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)
2828

2929
# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env)
3030
tqdm = strax.utils.tqdm
@@ -740,17 +740,9 @@ def _context_hash(self):
740740
)
741741
return strax.deterministic_hash(_base_hash_on_config)
742742

743-
def _plugins_are_cached(
744-
self,
745-
targets: ty.Union[ty.Tuple[str], ty.List[str]],
746-
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
747-
) -> bool:
743+
def _plugins_are_cached(self, targets: ty.Union[ty.Tuple[str], ty.List[str]]) -> bool:
748744
"""Check if all the requested targets are in the _fixed_plugin_cache."""
749-
if (
750-
self.context_config["use_per_run_defaults"]
751-
or self._fixed_plugin_cache is None
752-
or chunk_number is not None
753-
):
745+
if self.context_config["use_per_run_defaults"] or self._fixed_plugin_cache is None:
754746
# There is no point in caching if plugins (lineage) can
755747
# change per run or the cache is empty.
756748
return False
@@ -761,12 +753,8 @@ def _plugins_are_cached(
761753
plugin_cache = self._fixed_plugin_cache[context_hash]
762754
return all([t in plugin_cache for t in targets])
763755

764-
def _plugins_to_cache(
765-
self,
766-
plugins: dict,
767-
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
768-
) -> None:
769-
if self.context_config["use_per_run_defaults"] or chunk_number is not None:
756+
def _plugins_to_cache(self, plugins: dict) -> None:
757+
if self.context_config["use_per_run_defaults"]:
770758
# There is no point in caching if plugins (lineage) can change per run
771759
return
772760
context_hash = self._context_hash()
@@ -867,9 +855,15 @@ def __get_plugin(
867855
):
868856
"""Get single plugin either from cache or initialize it."""
869857
# Check if plugin for data_type is already cached
870-
if self._plugins_are_cached((data_type,), chunk_number=chunk_number):
858+
if self._plugins_are_cached((data_type,)):
871859
cached_plugins = self.__get_requested_plugins_from_cache(run_id, (data_type,))
872-
target_plugin = cached_plugins[data_type]
860+
if chunk_number is not None:
861+
target_plugin = cached_plugins[data_type].__copy__(True)
862+
self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number)
863+
target_plugin.run_id = run_id
864+
target_plugin.fix_dtype()
865+
else:
866+
target_plugin = cached_plugins[data_type]
873867
return target_plugin
874868

875869
if data_type not in self._plugin_class_registry:
@@ -884,8 +878,7 @@ def __get_plugin(
884878
self._set_plugin_config(plugin, run_id, tolerant=True)
885879

886880
plugin.deps = {
887-
d_depends: self.__get_plugin(run_id, d_depends, chunk_number=chunk_number)
888-
for d_depends in plugin.depends_on
881+
d_depends: self.__get_plugin(run_id, d_depends) for d_depends in plugin.depends_on
889882
}
890883
if plugin.compute_takes_chunk_i:
891884
for k, v in plugin.deps.items():
@@ -900,7 +893,7 @@ def __get_plugin(
900893
"which is not supported."
901894
)
902895

903-
self.__add_lineage_to_plugin(run_id, plugin, chunk_number=chunk_number)
896+
self.__add_lineage_to_plugin(run_id, plugin)
904897

905898
if not hasattr(plugin, "data_kind") and not plugin.multi_output:
906899
if len(plugin.depends_on):
@@ -915,11 +908,17 @@ def __get_plugin(
915908
plugin.fix_dtype()
916909

917910
# Add plugin to cache
918-
self._plugins_to_cache(
919-
{data_type: plugin for data_type in plugin.provides}, chunk_number=chunk_number
920-
)
911+
self._plugins_to_cache({data_type: plugin for data_type in plugin.provides})
921912

922-
return plugin
913+
if chunk_number is not None:
914+
target_plugin = plugin.__copy__(True)
915+
self.__assign_chunk_number_to_plugin(target_plugin, chunk_number=chunk_number)
916+
target_plugin.run_id = run_id
917+
target_plugin.fix_dtype()
918+
else:
919+
target_plugin = plugin
920+
921+
return target_plugin
923922

924923
@staticmethod
925924
def _check_chunk_number(chunk_number: ty.List[int]):
@@ -937,18 +936,13 @@ def _check_chunk_number(chunk_number: ty.List[int]):
937936
f"but got {chunk_number}"
938937
)
939938

940-
def __add_lineage_to_plugin(
941-
self,
942-
run_id,
943-
plugin,
944-
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
945-
):
939+
def __add_lineage_to_plugin(self, run_id, plugin):
946940
"""Adds lineage to plugin in place.
947941
948942
Also adds parent infromation in case of a child plugin.
949943
950944
"""
951-
last_provide = [d_provides for d_provides in plugin.provides][-1]
945+
last_provide = plugin.provides[-1]
952946

953947
if plugin.child_plugin:
954948
# Plugin is a child of another plugin, hence we have to
@@ -984,68 +978,91 @@ def __add_lineage_to_plugin(
984978
if plugin.takes_config[option].track
985979
}
986980

987-
# Set chunk_number in the lineage
988-
if chunk_number is not None:
989-
for d_depends in plugin.depends_on:
990-
dependencies = self.get_dependencies(d_depends) | {d_depends}
991-
for d in chunk_number.keys():
992-
if d not in dependencies:
993-
continue
994-
if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS):
995-
raise ValueError(
996-
f"Can not load per-chunk storage from {d} for {plugin.__class__} "
997-
f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!"
998-
)
999-
if d_depends in chunk_number:
1000-
if len(plugin.depends_on) > 1:
1001-
for d in plugin.depends_on:
1002-
dependencies = self.get_dependencies(d) | {d}
1003-
msg = (
1004-
f"Can not assign chunk_number for {plugin.__class__} "
1005-
"because it has multiple dependencies and one of the "
1006-
f"dependencies {d} does not (eventually) depend on {d_depends}."
1007-
)
1008-
mask = d_depends in dependencies
1009-
if not mask:
1010-
raise ValueError(msg)
1011-
# Make sure other dependencies depend on the same per-chunk data_type
1012-
for shortest in [False, True]:
1013-
levels = {
1014-
_d: self.tree_levels[shortest][_d]["level"]
1015-
for _d in dependencies
1016-
}
1017-
mask &= (
1018-
len(
1019-
[
1020-
k
1021-
for k, v in levels.items()
1022-
if v == levels.get(d_depends, -1)
1023-
]
1024-
)
1025-
== 1
1026-
)
1027-
if not mask:
1028-
raise ValueError(msg)
1029-
configs.setdefault("chunk_number", {})
1030-
if d_depends in configs["chunk_number"]:
1031-
raise ValueError(
1032-
f"Chunk number for {d_depends} is already set in the lineage"
1033-
)
1034-
self._check_chunk_number(chunk_number[d_depends])
1035-
plugin.chunk_number = chunk_number[d_depends]
1036-
if plugin.compute_takes_chunk_i and plugin.deps[d_depends].rechunk_on_load:
1037-
raise ValueError(
1038-
"Can not assign chunk_number for a plugin that takes chunk_i as input "
1039-
"when dependency's rechunk_on_load is True."
1040-
)
1041-
configs["chunk_number"][d_depends] = chunk_number[d_depends]
1042-
1043981
plugin.lineage = {last_provide: (plugin.__class__.__name__, plugin.version(), configs)}
1044982

1045983
# This is why the lineage of a plugin contains all its dependencies
1046984
for d_depends in plugin.depends_on:
1047985
plugin.lineage.update(plugin.deps[d_depends].lineage)
1048986

987+
def __assign_chunk_number_to_plugin(
988+
self,
989+
plugin,
990+
chunk_number: ty.Optional[ty.Dict[str, ty.List[int]]] = None,
991+
):
992+
"""Assign chunk_number to plugin in place.
993+
994+
:param plugin: Plugin to which we assign chunk_number
995+
:param chunk_number: Dictionary with data_type as key and chunk_number as value. If None, do
996+
nothing.
997+
998+
"""
999+
if chunk_number is None:
1000+
return
1001+
1002+
if len(set(plugin.depends_on) & set(chunk_number)) > 1 and plugin.compute_takes_chunk_i:
1003+
raise ValueError(
1004+
"Can not assign chunk_number for a plugin that takes chunk_i as input "
1005+
"when multiple dependencies are per-chunk."
1006+
)
1007+
1008+
for d in plugin.depends_on:
1009+
if d not in chunk_number:
1010+
continue
1011+
# This attribute assignment is needed by p.iter
1012+
plugin.chunk_number = chunk_number[d]
1013+
if plugin.compute_takes_chunk_i and plugin.deps[d].rechunk_on_load:
1014+
raise ValueError(
1015+
"Can not assign chunk_number for a plugin that takes chunk_i as input "
1016+
"when dependency's rechunk_on_load is True."
1017+
)
1018+
1019+
# Iterate over the lineage of the plugin and check if chunk_number
1020+
# is needed to be set for the dependencies of the plugin.
1021+
for last_provide in plugin.lineage:
1022+
p = self.__get_plugin("0", last_provide)
1023+
if not (set(p.depends_on) & set(chunk_number)):
1024+
continue
1025+
1026+
if issubclass(p.__class__, NOT_PER_CHUNK_ALLOWED_PLUGINS):
1027+
raise ValueError(
1028+
f"Can not load per-chunk storage from {chunk_number} for {p.__class__} "
1029+
f"because it is subclass of one of {NOT_PER_CHUNK_ALLOWED_PLUGINS}!"
1030+
)
1031+
1032+
# Set chunk_number in the lineage
1033+
for d in p.depends_on:
1034+
if d not in chunk_number:
1035+
continue
1036+
self._check_chunk_number(chunk_number[d])
1037+
# Make sure that d is the connector of subplots
1038+
# For details: https://github.com/AxFoundation/strax/pull/996
1039+
if len(p.depends_on) > 1:
1040+
for c in p.depends_on:
1041+
dependencies = self.get_dependencies(c) | {c}
1042+
msg = (
1043+
f"Can not assign chunk_number for {p.__class__} "
1044+
"because it has multiple dependencies and one of the "
1045+
f"dependencies {c} does not (eventually) depend on {d}."
1046+
)
1047+
mask = d in dependencies
1048+
if not mask:
1049+
raise ValueError(msg)
1050+
# Make sure other dependencies depend on the same per-chunk data_type
1051+
for shortest in [False, True]:
1052+
levels = {
1053+
_d: self.tree_levels[shortest][_d]["level"] for _d in dependencies
1054+
}
1055+
mask &= (
1056+
len([k for k, v in levels.items() if v == levels.get(d, -1)]) == 1
1057+
)
1058+
if not mask:
1059+
raise ValueError(msg)
1060+
configs = plugin.lineage[last_provide][2]
1061+
configs.setdefault("chunk_number", {})
1062+
if d in configs["chunk_number"]:
1063+
raise ValueError(f"Chunk number for {d} is already set in the lineage")
1064+
configs["chunk_number"][d] = chunk_number[d]
1065+
10491066
def _per_run_default_allowed_check(self, option_name, option):
10501067
"""Check if an option of a registered plugin is allowed."""
10511068
per_run_default = option.default_by_run != strax.OMITTED
@@ -2079,7 +2096,7 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
20792096
:return: strax.DataKey of the target
20802097
20812098
"""
2082-
if self._plugins_are_cached((target,), chunk_number=chunk_number):
2099+
if self._plugins_are_cached((target,)):
20832100
context_hash = self._context_hash()
20842101
if context_hash in self._fixed_plugin_cache:
20852102
plugins = self._fixed_plugin_cache[self._context_hash()]
@@ -2088,12 +2105,18 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
20882105
self.log.warning(
20892106
f"Context hash changed to {context_hash} for {self._plugin_class_registry}?"
20902107
)
2091-
plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number)
2108+
plugins = self._get_plugins((target,), run_id)
2109+
else:
2110+
plugins = self._get_plugins((target,), run_id)
2111+
2112+
# Prevent modifying the cached plugin
2113+
if chunk_number is not None:
2114+
plugin = plugins[target].__copy__(True)
2115+
self.__assign_chunk_number_to_plugin(plugin, chunk_number=chunk_number)
20922116
else:
2093-
plugins = self._get_plugins((target,), run_id, chunk_number=chunk_number)
2117+
plugin = plugins[target].__copy__(False)
20942118

2095-
lineage = plugins[target].lineage
2096-
return self.get_data_key(run_id, target, lineage, combining=combining)
2119+
return self.get_data_key(run_id, target, plugin.lineage, combining=combining)
20972120

20982121
def get_metadata(self, run_id, target, chunk_number=None, combining=False) -> dict:
20992122
"""Return metadata for target for run_id, or raise DataNotAvailable if data is not yet

strax/plugins/plugin.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,6 @@ def __copy__(self, _deep_copy=False):
196196
plugin_copy.__setattr__(attribute, copy(source_value))
197197
return plugin_copy
198198

199-
def __deepcopy__(self):
200-
return self.__copy__(_deep_copy=True)
201-
202199
def __getattr__(self, name):
203200
"""Allow access to config parameters as attributes this allows backwards compatibility in
204201
cases where a descriptor style config depends on a non descriptor style config."""

tests/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def test_per_chunk_storage():
514514
# Per-chunk storage not allowed for some plugins
515515
p = type("whatever", (strax.OverlapWindowPlugin,), dict(depends_on="records"))
516516
st.register(p)
517-
with pytest.raises(ValueError):
517+
with pytest.raises(NotImplementedError):
518518
st.make(run_id, "whatever", chunk_number={"records": [0]})
519519

520520

0 commit comments

Comments
 (0)