@@ -709,12 +709,17 @@ def _resolve_conditions(self):
709709 4. We set the `self._resolved` flag to True if all conditional
710710 priors were added in the right order
711711 """
712- self ._unconditional_keys = [
713- key for key in self .keys () if not hasattr (self [key ], "condition_func" )
714- ]
715- conditional_keys_unsorted = [
716- key for key in self .keys () if hasattr (self [key ], "condition_func" )
717- ]
712+ conditional_keys_unsorted = []
713+ self ._unconditional_keys = []
714+ joint_dists = {}
715+ for key in self .keys ():
716+ if not hasattr (self [key ], "condition_func" ):
717+ self ._unconditional_keys .append (key )
718+ else :
719+ conditional_keys_unsorted .append (key )
720+ if isinstance (self [key ], JointPrior ):
721+ joint_dists [self [key ].dist .distname ] = self [key ].dist .names
722+
718723 self ._conditional_keys = []
719724 for _ in range (len (self )):
720725 for key in conditional_keys_unsorted [:]:
@@ -726,6 +731,12 @@ def _resolve_conditions(self):
726731 if len (conditional_keys_unsorted ) != 0 :
727732 self ._resolved = False
728733
734+ # ensure that all joint dist names are resolved
735+ for names in joint_dists .values ():
736+ for name in names :
737+ if name not in self .sorted_keys :
738+ self ._resolved = False
739+
729740 def _check_conditions_resolved (self , key , sampled_keys ):
730741 """Checks if all required variables have already been sampled so we can sample this key"""
731742 conditions_resolved = True
0 commit comments