Skip to content

Commit 7817aff

Browse files
authored
Topological level and shortest dependency path of dependency tree (directed acyclic graph) (#996)
1 parent 328dc07 commit 7817aff

1 file changed

Lines changed: 67 additions & 21 deletions

File tree

strax/context.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
RUN_DEFAULTS_KEY = "strax_defaults"
2626
TEMP_DATA_TYPE_PREFIX = "_temp_"
27+
NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin)
2728

2829
# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env)
2930
tqdm = strax.utils.tqdm
@@ -986,18 +987,45 @@ def __add_lineage_to_plugin(
986987
# Set chunk_number in the lineage
987988
if chunk_number is not None:
988989
for d_depends in plugin.depends_on:
989-
if d_depends in chunk_number:
990-
if len(plugin.depends_on) > 1:
991-
raise ValueError(
992-
"Can not assign chunk_number for multi-dependencies plugins "
993-
"because it is not clear which input should be assigned."
994-
)
995-
not_allowed_plugins = (strax.LoopPlugin, strax.OverlapWindowPlugin)
996-
if issubclass(plugin.__class__, not_allowed_plugins):
990+
dependencies = self.get_dependencies(d_depends) | {d_depends}
991+
for d in chunk_number.keys():
992+
if d not in dependencies:
993+
continue
994+
if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS):
997995
raise ValueError(
998-
f"Can not assign chunk_number for {plugin.__class__} "
999-
f"because it is subclass of one of {not_allowed_plugins}!"
996+
f"Can not load per-chunk storage from {d} for {plugin.__class__} "
997+
f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!"
1000998
)
999+
if d_depends in chunk_number:
1000+
if len(plugin.depends_on) > 1:
1001+
for d in plugin.depends_on:
1002+
dependencies = self.get_dependencies(d) | {d}
1003+
msg = (
1004+
f"Can not assign chunk_number for {plugin.__class__} "
1005+
"because it has multiple dependencies and one of the "
1006+
f"dependencies {d} does not (eventually) depend on {d_depends}."
1007+
)
1008+
mask = d_depends in dependencies
1009+
if not mask:
1010+
raise ValueError(msg)
1011+
# Make sure other dependencies depend on the same per-chunk data_type
1012+
for shortest in [False, True]:
1013+
levels = {
1014+
_d: self.tree_levels[shortest][_d]["level"]
1015+
for _d in dependencies
1016+
}
1017+
mask &= (
1018+
len(
1019+
[
1020+
k
1021+
for k, v in levels.items()
1022+
if v == levels.get(d_depends, -1)
1023+
]
1024+
)
1025+
== 1
1026+
)
1027+
if not mask:
1028+
raise ValueError(msg)
10011029
configs.setdefault("chunk_number", {})
10021030
if d_depends in configs["chunk_number"]:
10031031
raise ValueError(
@@ -2848,32 +2876,50 @@ def tree_levels(self):
28482876
if self._fixed_level_cache is not None and context_hash in self._fixed_level_cache:
28492877
return self._fixed_level_cache[context_hash]
28502878

2851-
def _get_levels(data_type=None, results=None):
2879+
def _get_levels(data_type=None, results=None, shortest=False):
28522880
"""Get the level data_type in the context."""
28532881
if results is None:
28542882
results = dict()
28552883
for k in [data_type] if data_type else self._plugin_class_registry.keys():
2884+
if k in results:
2885+
continue
28562886
results[k] = dict()
28572887
_v = self._plugin_class_registry[k]()
28582888
if _v.depends_on:
2859-
results[k]["level"] = (
2860-
max(_get_levels(d, results)[d]["level"] for d in _v.depends_on) + 1
2861-
)
2889+
if shortest:
2890+
results[k]["level"] = (
2891+
min(
2892+
_get_levels(d, results, shortest=shortest)[d]["level"]
2893+
for d in _v.depends_on
2894+
)
2895+
+ 1
2896+
)
2897+
else:
2898+
results[k]["level"] = (
2899+
max(
2900+
_get_levels(d, results, shortest=shortest)[d]["level"]
2901+
for d in _v.depends_on
2902+
)
2903+
+ 1
2904+
)
28622905
else:
28632906
results[k]["level"] = 0
28642907
results[k]["class"] = self._plugin_class_registry[k].__name__
28652908
results[k]["index"] = _v.provides.index(k)
28662909
return results
28672910

28682911
# Sort the results by level, class, and index in provides
2869-
_results = sorted(
2870-
_get_levels().items(), key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"])
2871-
)
2912+
results = dict()
2913+
for shortest in [False, True]:
2914+
results[shortest] = sorted(
2915+
_get_levels(shortest=shortest).items(),
2916+
key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"]),
2917+
)
28722918

2873-
# Assign order to the results
2874-
for order, (key, value) in enumerate(_results):
2875-
value["order"] = order
2876-
results = dict(_results)
2919+
# Assign order to the results
2920+
for order, (key, value) in enumerate(results[shortest]):
2921+
value["order"] = order
2922+
results[shortest] = dict(results[shortest])
28772923

28782924
if self._fixed_level_cache is None:
28792925
self._fixed_level_cache = {context_hash: results}

0 commit comments

Comments
 (0)