2424
2525RUN_DEFAULTS_KEY = "strax_defaults"
2626TEMP_DATA_TYPE_PREFIX = "_temp_"
27- NOT_ALLOWED_PLUGINS = (strax .LoopPlugin , strax .OverlapWindowPlugin )
27+ NOT_PER_CHUNK_ALLOWED_PLUGINS = (strax .LoopPlugin , strax .OverlapWindowPlugin )
2828
2929# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env)
3030tqdm = strax .utils .tqdm
@@ -740,17 +740,9 @@ def _context_hash(self):
740740 )
741741 return strax .deterministic_hash (_base_hash_on_config )
742742
743- def _plugins_are_cached (
744- self ,
745- targets : ty .Union [ty .Tuple [str ], ty .List [str ]],
746- chunk_number : ty .Optional [ty .Dict [str , ty .List [int ]]] = None ,
747- ) -> bool :
743+ def _plugins_are_cached (self , targets : ty .Union [ty .Tuple [str ], ty .List [str ]]) -> bool :
748744 """Check if all the requested targets are in the _fixed_plugin_cache."""
749- if (
750- self .context_config ["use_per_run_defaults" ]
751- or self ._fixed_plugin_cache is None
752- or chunk_number is not None
753- ):
745+ if self .context_config ["use_per_run_defaults" ] or self ._fixed_plugin_cache is None :
754746 # There is no point in caching if plugins (lineage) can
755747 # change per run or the cache is empty.
756748 return False
@@ -761,12 +753,8 @@ def _plugins_are_cached(
761753 plugin_cache = self ._fixed_plugin_cache [context_hash ]
762754 return all ([t in plugin_cache for t in targets ])
763755
764- def _plugins_to_cache (
765- self ,
766- plugins : dict ,
767- chunk_number : ty .Optional [ty .Dict [str , ty .List [int ]]] = None ,
768- ) -> None :
769- if self .context_config ["use_per_run_defaults" ] or chunk_number is not None :
756+ def _plugins_to_cache (self , plugins : dict ) -> None :
757+ if self .context_config ["use_per_run_defaults" ]:
770758 # There is no point in caching if plugins (lineage) can change per run
771759 return
772760 context_hash = self ._context_hash ()
@@ -867,9 +855,15 @@ def __get_plugin(
867855 ):
868856 """Get single plugin either from cache or initialize it."""
869857 # Check if plugin for data_type is already cached
870- if self ._plugins_are_cached ((data_type ,), chunk_number = chunk_number ):
858+ if self ._plugins_are_cached ((data_type ,)):
871859 cached_plugins = self .__get_requested_plugins_from_cache (run_id , (data_type ,))
872- target_plugin = cached_plugins [data_type ]
860+ if chunk_number is not None :
861+ target_plugin = cached_plugins [data_type ].__copy__ (True )
862+ self .__assign_chunk_number_to_plugin (target_plugin , chunk_number = chunk_number )
863+ target_plugin .run_id = run_id
864+ target_plugin .fix_dtype ()
865+ else :
866+ target_plugin = cached_plugins [data_type ]
873867 return target_plugin
874868
875869 if data_type not in self ._plugin_class_registry :
@@ -884,8 +878,7 @@ def __get_plugin(
884878 self ._set_plugin_config (plugin , run_id , tolerant = True )
885879
886880 plugin .deps = {
887- d_depends : self .__get_plugin (run_id , d_depends , chunk_number = chunk_number )
888- for d_depends in plugin .depends_on
881+ d_depends : self .__get_plugin (run_id , d_depends ) for d_depends in plugin .depends_on
889882 }
890883 if plugin .compute_takes_chunk_i :
891884 for k , v in plugin .deps .items ():
@@ -900,7 +893,7 @@ def __get_plugin(
900893 "which is not supported."
901894 )
902895
903- self .__add_lineage_to_plugin (run_id , plugin , chunk_number = chunk_number )
896+ self .__add_lineage_to_plugin (run_id , plugin )
904897
905898 if not hasattr (plugin , "data_kind" ) and not plugin .multi_output :
906899 if len (plugin .depends_on ):
@@ -915,11 +908,17 @@ def __get_plugin(
915908 plugin .fix_dtype ()
916909
917910 # Add plugin to cache
918- self ._plugins_to_cache (
919- {data_type : plugin for data_type in plugin .provides }, chunk_number = chunk_number
920- )
911+ self ._plugins_to_cache ({data_type : plugin for data_type in plugin .provides })
921912
922- return plugin
913+ if chunk_number is not None :
914+ target_plugin = plugin .__copy__ (True )
915+ self .__assign_chunk_number_to_plugin (target_plugin , chunk_number = chunk_number )
916+ target_plugin .run_id = run_id
917+ target_plugin .fix_dtype ()
918+ else :
919+ target_plugin = plugin
920+
921+ return target_plugin
923922
924923 @staticmethod
925924 def _check_chunk_number (chunk_number : ty .List [int ]):
@@ -937,18 +936,13 @@ def _check_chunk_number(chunk_number: ty.List[int]):
937936 f"but got { chunk_number } "
938937 )
939938
940- def __add_lineage_to_plugin (
941- self ,
942- run_id ,
943- plugin ,
944- chunk_number : ty .Optional [ty .Dict [str , ty .List [int ]]] = None ,
945- ):
939+ def __add_lineage_to_plugin (self , run_id , plugin ):
946940 """Adds lineage to plugin in place.
947941
948942 Also adds parent infromation in case of a child plugin.
949943
950944 """
951- last_provide = [ d_provides for d_provides in plugin .provides ] [- 1 ]
945+ last_provide = plugin .provides [- 1 ]
952946
953947 if plugin .child_plugin :
954948 # Plugin is a child of another plugin, hence we have to
@@ -984,68 +978,91 @@ def __add_lineage_to_plugin(
984978 if plugin .takes_config [option ].track
985979 }
986980
987- # Set chunk_number in the lineage
988- if chunk_number is not None :
989- for d_depends in plugin .depends_on :
990- dependencies = self .get_dependencies (d_depends ) | {d_depends }
991- for d in chunk_number .keys ():
992- if d not in dependencies :
993- continue
994- if issubclass (plugin .__class__ , NOT_ALLOWED_PLUGINS ):
995- raise ValueError (
996- f"Can not load per-chunk storage from { d } for { plugin .__class__ } "
997- f"because it is subclass of one of { NOT_ALLOWED_PLUGINS } !"
998- )
999- if d_depends in chunk_number :
1000- if len (plugin .depends_on ) > 1 :
1001- for d in plugin .depends_on :
1002- dependencies = self .get_dependencies (d ) | {d }
1003- msg = (
1004- f"Can not assign chunk_number for { plugin .__class__ } "
1005- "because it has multiple dependencies and one of the "
1006- f"dependencies { d } does not (eventually) depend on { d_depends } ."
1007- )
1008- mask = d_depends in dependencies
1009- if not mask :
1010- raise ValueError (msg )
1011- # Make sure other dependencies depend on the same per-chunk data_type
1012- for shortest in [False , True ]:
1013- levels = {
1014- _d : self .tree_levels [shortest ][_d ]["level" ]
1015- for _d in dependencies
1016- }
1017- mask &= (
1018- len (
1019- [
1020- k
1021- for k , v in levels .items ()
1022- if v == levels .get (d_depends , - 1 )
1023- ]
1024- )
1025- == 1
1026- )
1027- if not mask :
1028- raise ValueError (msg )
1029- configs .setdefault ("chunk_number" , {})
1030- if d_depends in configs ["chunk_number" ]:
1031- raise ValueError (
1032- f"Chunk number for { d_depends } is already set in the lineage"
1033- )
1034- self ._check_chunk_number (chunk_number [d_depends ])
1035- plugin .chunk_number = chunk_number [d_depends ]
1036- if plugin .compute_takes_chunk_i and plugin .deps [d_depends ].rechunk_on_load :
1037- raise ValueError (
1038- "Can not assign chunk_number for a plugin that takes chunk_i as input "
1039- "when dependency's rechunk_on_load is True."
1040- )
1041- configs ["chunk_number" ][d_depends ] = chunk_number [d_depends ]
1042-
1043981 plugin .lineage = {last_provide : (plugin .__class__ .__name__ , plugin .version (), configs )}
1044982
1045983 # This is why the lineage of a plugin contains all its dependencies
1046984 for d_depends in plugin .depends_on :
1047985 plugin .lineage .update (plugin .deps [d_depends ].lineage )
1048986
987+ def __assign_chunk_number_to_plugin (
988+ self ,
989+ plugin ,
990+ chunk_number : ty .Optional [ty .Dict [str , ty .List [int ]]] = None ,
991+ ):
992+ """Assign chunk_number to plugin in place.
993+
994+ :param plugin: Plugin to which we assign chunk_number
995+ :param chunk_number: Dictionary with data_type as key and chunk_number as value. If None, do
996+ nothing.
997+
998+ """
999+ if chunk_number is None :
1000+ return
1001+
1002+ if len (set (plugin .depends_on ) & set (chunk_number )) > 1 and plugin .compute_takes_chunk_i :
1003+ raise ValueError (
1004+ "Can not assign chunk_number for a plugin that takes chunk_i as input "
1005+ "when multiple dependencies are per-chunk."
1006+ )
1007+
1008+ for d in plugin .depends_on :
1009+ if d not in chunk_number :
1010+ continue
1011+ # This attribute assignment is needed by p.iter
1012+ plugin .chunk_number = chunk_number [d ]
1013+ if plugin .compute_takes_chunk_i and plugin .deps [d ].rechunk_on_load :
1014+ raise ValueError (
1015+ "Can not assign chunk_number for a plugin that takes chunk_i as input "
1016+ "when dependency's rechunk_on_load is True."
1017+ )
1018+
1019+ # Iterate over the lineage of the plugin and check if chunk_number
1020+ # is needed to be set for the dependencies of the plugin.
1021+ for last_provide in plugin .lineage :
1022+ p = self .__get_plugin ("0" , last_provide )
1023+ if not (set (p .depends_on ) & set (chunk_number )):
1024+ continue
1025+
1026+ if issubclass (p .__class__ , NOT_PER_CHUNK_ALLOWED_PLUGINS ):
1027+ raise ValueError (
1028+ f"Can not load per-chunk storage from { chunk_number } for { p .__class__ } "
1029+ f"because it is subclass of one of { NOT_PER_CHUNK_ALLOWED_PLUGINS } !"
1030+ )
1031+
1032+ # Set chunk_number in the lineage
1033+ for d in p .depends_on :
1034+ if d not in chunk_number :
1035+ continue
1036+ self ._check_chunk_number (chunk_number [d ])
1037+ # Make sure that d is the connector of subplots
1038+ # For details: https://github.com/AxFoundation/strax/pull/996
1039+ if len (p .depends_on ) > 1 :
1040+ for c in p .depends_on :
1041+ dependencies = self .get_dependencies (c ) | {c }
1042+ msg = (
1043+ f"Can not assign chunk_number for { p .__class__ } "
1044+ "because it has multiple dependencies and one of the "
1045+ f"dependencies { c } does not (eventually) depend on { d } ."
1046+ )
1047+ mask = d in dependencies
1048+ if not mask :
1049+ raise ValueError (msg )
1050+ # Make sure other dependencies depend on the same per-chunk data_type
1051+ for shortest in [False , True ]:
1052+ levels = {
1053+ _d : self .tree_levels [shortest ][_d ]["level" ] for _d in dependencies
1054+ }
1055+ mask &= (
1056+ len ([k for k , v in levels .items () if v == levels .get (d , - 1 )]) == 1
1057+ )
1058+ if not mask :
1059+ raise ValueError (msg )
1060+ configs = plugin .lineage [last_provide ][2 ]
1061+ configs .setdefault ("chunk_number" , {})
1062+ if d in configs ["chunk_number" ]:
1063+ raise ValueError (f"Chunk number for { d } is already set in the lineage" )
1064+ configs ["chunk_number" ][d ] = chunk_number [d ]
1065+
10491066 def _per_run_default_allowed_check (self , option_name , option ):
10501067 """Check if an option of a registered plugin is allowed."""
10511068 per_run_default = option .default_by_run != strax .OMITTED
@@ -2079,7 +2096,7 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
20792096 :return: strax.DataKey of the target
20802097
20812098 """
2082- if self ._plugins_are_cached ((target ,), chunk_number = chunk_number ):
2099+ if self ._plugins_are_cached ((target ,)):
20832100 context_hash = self ._context_hash ()
20842101 if context_hash in self ._fixed_plugin_cache :
20852102 plugins = self ._fixed_plugin_cache [self ._context_hash ()]
@@ -2088,12 +2105,18 @@ def key_for(self, run_id, target, chunk_number=None, combining=False):
20882105 self .log .warning (
20892106 f"Context hash changed to { context_hash } for { self ._plugin_class_registry } ?"
20902107 )
2091- plugins = self ._get_plugins ((target ,), run_id , chunk_number = chunk_number )
2108+ plugins = self ._get_plugins ((target ,), run_id )
2109+ else :
2110+ plugins = self ._get_plugins ((target ,), run_id )
2111+
2112+ # Prevent modifying the cached plugin
2113+ if chunk_number is not None :
2114+ plugin = plugins [target ].__copy__ (True )
2115+ self .__assign_chunk_number_to_plugin (plugin , chunk_number = chunk_number )
20922116 else :
2093- plugins = self . _get_plugins (( target ,), run_id , chunk_number = chunk_number )
2117+ plugin = plugins [ target ]. __copy__ ( False )
20942118
2095- lineage = plugins [target ].lineage
2096- return self .get_data_key (run_id , target , lineage , combining = combining )
2119+ return self .get_data_key (run_id , target , plugin .lineage , combining = combining )
20972120
20982121 def get_metadata (self , run_id , target , chunk_number = None , combining = False ) -> dict :
20992122 """Return metadata for target for run_id, or raise DataNotAvailable if data is not yet
0 commit comments