@@ -83,8 +83,9 @@ def create(cls, m: Model, d: Data, grad: bool = True) -> 'Context':
8383 ):
8484 pass # MLX port: backend check removed
8585
86- jaref = (d ._impl or d ).efc_J @ d .qacc - (d ._impl or d ).efc_aref
87- ma = support .mul_m (m , d , d .qacc )
86+ qacc = mx .array (d .qacc ) if not isinstance (d .qacc , mx .array ) else d .qacc
87+ jaref = (d ._impl or d ).efc_J @ qacc - (d ._impl or d ).efc_aref
88+ ma = support .mul_m (m , d , qacc )
8889 nv_0 = mx .zeros (m .nv )
8990 fri = mx .array (0.0 )
9091 if m .opt .cone == ConeType .ELLIPTIC :
@@ -177,8 +178,8 @@ def create(
177178 mask_ne_nf = mx .arange (x .shape [0 ]) < ne_nf
178179 active = mx .where (mask_ne_nf , True , active )
179180
180- dof_fl = (m . _impl or m ). dof_hasfrictionloss
181- ten_fl = (m . _impl or m ). tendon_hasfrictionloss
181+ dof_fl = (np . array ( m . dof_frictionloss ) > 0 )
182+ ten_fl = (np . array ( m . tendon_frictionloss ) > 0 )
182183 if (dof_fl .any () or ten_fl .any ()) and not (
183184 m .opt .disableflags & DisableBit .FRICTIONLOSS
184185 ):
@@ -302,10 +303,11 @@ def _update_constraint(m: Model, d: Data, ctx: Context) -> Context:
302303 mask_ne_nf = mx .arange (ctx .Jaref .shape [0 ]) < ne_nf
303304 active = mx .where (mask_ne_nf , True , active )
304305
305- floss_force = mx .zeros ((d ._impl or d ).nefc )
306+ nefc_actual = ctx .Jaref .shape [0 ] if ctx .Jaref .ndim > 0 else 0
307+ floss_force = mx .zeros (nefc_actual )
306308 floss_cost = mx .array (0.0 )
307- dof_fl = (m . _impl or m ). dof_hasfrictionloss
308- ten_fl = (m . _impl or m ). tendon_hasfrictionloss
309+ dof_fl = (np . array ( m . dof_frictionloss ) > 0 )
310+ ten_fl = (np . array ( m . tendon_frictionloss ) > 0 )
309311 if (dof_fl .any () or ten_fl .any ()) and not (
310312 m .opt .disableflags & DisableBit .FRICTIONLOSS
311313 ):
@@ -478,11 +480,11 @@ def _update_gradient(m: Model, d: Data, ctx: Context) -> Context:
478480 # Symmetrize to reduce the chance of numerical issues in cholesky
479481 h_sym = (h + h .T ) * 0.5
480482 # MLX Cholesky solve: L = cholesky(h_sym), solve L L^T x = grad
481- L = mx .linalg .cholesky (h_sym )
483+ L = mx .linalg .cholesky (h_sym , stream = mx . cpu )
482484 # Forward substitution: L y = grad
483- y = mx .linalg .solve_triangular (L , grad [:, None ], upper = False )
485+ y = mx .linalg .solve_triangular (L , grad [:, None ], upper = False , stream = mx . cpu )
484486 # Backward substitution: L^T x = y
485- mgrad = mx .linalg .solve_triangular (L .T , y , upper = True ).squeeze (- 1 )
487+ mgrad = mx .linalg .solve_triangular (L .T , y , upper = True , stream = mx . cpu ).squeeze (- 1 )
486488 else :
487489 raise NotImplementedError (f'unsupported solver type: { m .opt .solver } ' )
488490
@@ -640,9 +642,11 @@ def _cond(ctx: Context) -> bool:
640642 improvement = _rescale (m , ctx .prev_cost - ctx .cost )
641643 gradient = _rescale (m , math .norm (ctx .grad ))
642644
643- done = int (ctx .solver_niter .item ()) >= m .opt .iterations
644- done = done or (float (improvement .item ()) < float (m .opt .tolerance .item ()))
645- done = done or (float (gradient .item ()) < float (m .opt .tolerance .item ()))
645+ tol = float (m .opt .tolerance ) if not hasattr (m .opt .tolerance , 'item' ) else float (m .opt .tolerance .item ())
646+ niter = int (ctx .solver_niter .item ()) if hasattr (ctx .solver_niter , 'item' ) else int (ctx .solver_niter )
647+ done = niter >= m .opt .iterations
648+ done = done or (float (improvement .item ()) < tol )
649+ done = done or (float (gradient .item ()) < tol )
646650
647651 return not done
648652
0 commit comments