@@ -237,6 +237,14 @@ def __new__(cls, *args, **kwargs):
237237 cond = d .relation (cond , GuardFactor (d ))
238238 conditionals [d ] = cond
239239
240+ # Replace the ConditionalDimensions in `expr`
241+ for d , cond in conditionals .items ():
242+ # Replace dimension with index
243+ index = d .index
244+ index = index - relational_min (cond , d .parent )
245+ shift = relational_shift (cond , d .parent )
246+ expr = uxreplace (expr , {d : IntDiv (index , d .symbolic_factor ) + shift })
247+
240248 # Merge conditionals when possible. E.g if we have an implicit_dim
241249 # and there is a dimension with the same parent, we ca merged
242250 # its condition
@@ -247,19 +255,13 @@ def __new__(cls, *args, **kwargs):
247255 if cd .parent == d .parent and cd != d :
248256 cond = conditionals .pop (d )
249257 mode = cd .relation and d .relation
250- conditionals [cd ] = mode (cond , conditionals [cd ])
258+ if issubclass (mode , sympy .Or ):
259+ conditionals [d ] = cond
260+ conditionals .pop (cd )
261+ else :
262+ conditionals [cd ] = mode (cond , conditionals [cd ])
251263 break
252264
253- conditionals = frozendict (conditionals )
254-
255- # Replace the ConditionalDimensions in `expr`
256- for d , cond in conditionals .items ():
257- # Replace dimension with index
258- index = d .index
259- index = index - relational_min (cond , d .parent )
260- shift = relational_shift (cond , d .parent )
261- expr = uxreplace (expr , {d : IntDiv (index , d .symbolic_factor ) + shift })
262-
263265 # Lower all Differentiable operations into SymPy operations
264266 rhs = diff2sympy (expr .rhs )
265267
0 commit comments