|
24 | 24 |
|
25 | 25 | RUN_DEFAULTS_KEY = "strax_defaults" |
26 | 26 | TEMP_DATA_TYPE_PREFIX = "_temp_" |
| 27 | +NOT_ALLOWED_PLUGINS = (strax.LoopPlugin, strax.OverlapWindowPlugin) |
27 | 28 |
|
28 | 29 | # use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env) |
29 | 30 | tqdm = strax.utils.tqdm |
@@ -986,18 +987,45 @@ def __add_lineage_to_plugin( |
986 | 987 | # Set chunk_number in the lineage |
987 | 988 | if chunk_number is not None: |
988 | 989 | for d_depends in plugin.depends_on: |
989 | | - if d_depends in chunk_number: |
990 | | - if len(plugin.depends_on) > 1: |
991 | | - raise ValueError( |
992 | | - "Can not assign chunk_number for multi-dependencies plugins " |
993 | | - "because it is not clear which input should be assigned." |
994 | | - ) |
995 | | - not_allowed_plugins = (strax.LoopPlugin, strax.OverlapWindowPlugin) |
996 | | - if issubclass(plugin.__class__, not_allowed_plugins): |
| 990 | + dependencies = self.get_dependencies(d_depends) | {d_depends} |
| 991 | + for d in chunk_number.keys(): |
| 992 | + if d not in dependencies: |
| 993 | + continue |
| 994 | + if issubclass(plugin.__class__, NOT_ALLOWED_PLUGINS): |
997 | 995 | raise ValueError( |
998 | | - f"Can not assign chunk_number for {plugin.__class__} " |
999 | | - f"because it is subclass of one of {not_allowed_plugins}!" |
| 996 | + f"Can not load per-chunk storage from {d} for {plugin.__class__} " |
| 997 | + f"because it is subclass of one of {NOT_ALLOWED_PLUGINS}!" |
1000 | 998 | ) |
| 999 | + if d_depends in chunk_number: |
| 1000 | + if len(plugin.depends_on) > 1: |
| 1001 | + for d in plugin.depends_on: |
| 1002 | + dependencies = self.get_dependencies(d) | {d} |
| 1003 | + msg = ( |
| 1004 | + f"Can not assign chunk_number for {plugin.__class__} " |
| 1005 | + "because it has multiple dependencies and one of the " |
| 1006 | + f"dependencies {d} does not (eventually) depend on {d_depends}." |
| 1007 | + ) |
| 1008 | + mask = d_depends in dependencies |
| 1009 | + if not mask: |
| 1010 | + raise ValueError(msg) |
| 1011 | + # Make sure other dependencies depend on the same per-chunk data_type |
| 1012 | + for shortest in [False, True]: |
| 1013 | + levels = { |
| 1014 | + _d: self.tree_levels[shortest][_d]["level"] |
| 1015 | + for _d in dependencies |
| 1016 | + } |
| 1017 | + mask &= ( |
| 1018 | + len( |
| 1019 | + [ |
| 1020 | + k |
| 1021 | + for k, v in levels.items() |
| 1022 | + if v == levels.get(d_depends, -1) |
| 1023 | + ] |
| 1024 | + ) |
| 1025 | + == 1 |
| 1026 | + ) |
| 1027 | + if not mask: |
| 1028 | + raise ValueError(msg) |
1001 | 1029 | configs.setdefault("chunk_number", {}) |
1002 | 1030 | if d_depends in configs["chunk_number"]: |
1003 | 1031 | raise ValueError( |
@@ -2848,32 +2876,50 @@ def tree_levels(self): |
2848 | 2876 | if self._fixed_level_cache is not None and context_hash in self._fixed_level_cache: |
2849 | 2877 | return self._fixed_level_cache[context_hash] |
2850 | 2878 |
|
2851 | | - def _get_levels(data_type=None, results=None): |
| 2879 | + def _get_levels(data_type=None, results=None, shortest=False): |
2852 | 2880 | """Get the level data_type in the context.""" |
2853 | 2881 | if results is None: |
2854 | 2882 | results = dict() |
2855 | 2883 | for k in [data_type] if data_type else self._plugin_class_registry.keys(): |
| 2884 | + if k in results: |
| 2885 | + continue |
2856 | 2886 | results[k] = dict() |
2857 | 2887 | _v = self._plugin_class_registry[k]() |
2858 | 2888 | if _v.depends_on: |
2859 | | - results[k]["level"] = ( |
2860 | | - max(_get_levels(d, results)[d]["level"] for d in _v.depends_on) + 1 |
2861 | | - ) |
| 2889 | + if shortest: |
| 2890 | + results[k]["level"] = ( |
| 2891 | + min( |
| 2892 | + _get_levels(d, results, shortest=shortest)[d]["level"] |
| 2893 | + for d in _v.depends_on |
| 2894 | + ) |
| 2895 | + + 1 |
| 2896 | + ) |
| 2897 | + else: |
| 2898 | + results[k]["level"] = ( |
| 2899 | + max( |
| 2900 | + _get_levels(d, results, shortest=shortest)[d]["level"] |
| 2901 | + for d in _v.depends_on |
| 2902 | + ) |
| 2903 | + + 1 |
| 2904 | + ) |
2862 | 2905 | else: |
2863 | 2906 | results[k]["level"] = 0 |
2864 | 2907 | results[k]["class"] = self._plugin_class_registry[k].__name__ |
2865 | 2908 | results[k]["index"] = _v.provides.index(k) |
2866 | 2909 | return results |
2867 | 2910 |
|
2868 | 2911 | # Sort the results by level, class, and index in provides |
2869 | | - _results = sorted( |
2870 | | - _get_levels().items(), key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"]) |
2871 | | - ) |
| 2912 | + results = dict() |
| 2913 | + for shortest in [False, True]: |
| 2914 | + results[shortest] = sorted( |
| 2915 | + _get_levels(shortest=shortest).items(), |
| 2916 | + key=lambda x: (x[1]["level"], x[1]["class"], x[1]["index"]), |
| 2917 | + ) |
2872 | 2918 |
|
2873 | | - # Assign order to the results |
2874 | | - for order, (key, value) in enumerate(_results): |
2875 | | - value["order"] = order |
2876 | | - results = dict(_results) |
| 2919 | + # Assign order to the results |
| 2920 | + for order, (key, value) in enumerate(results[shortest]): |
| 2921 | + value["order"] = order |
| 2922 | + results[shortest] = dict(results[shortest]) |
2877 | 2923 |
|
2878 | 2924 | if self._fixed_level_cache is None: |
2879 | 2925 | self._fixed_level_cache = {context_hash: results} |
|
0 commit comments