Skip to content

Commit 71937d4

Browse files
committed
Updating for jax 0.7.2. Using np.bool fixes errors caused by regular bool.
1 parent 061751f commit 71937d4

2 files changed

Lines changed: 2 additions & 2 deletions

File tree

src/tfc/mtfc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def H_xla(ctx, *x, d: uint = 0, full: bool = False):
584584
stablehlo.ConcatenateOp(x, 0).result,
585585
mlir.ir_constant(np.int32(d)),
586586
mlir.ir_constant(np.int32(self.dim)),
587-
mlir.ir_constant(bool(full)),
587+
mlir.ir_constant(np.bool(full)),
588588
mlir.ir_constant(np.int32(dim0)),
589589
mlir.ir_constant(np.int32(dim1)),
590590
],

src/tfc/utfc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def H_xla(ctx, x, d: uint = 0, full: bool = False):
339339
mlir.ir_constant(np.int32(self.basisClass.identifier)),
340340
x,
341341
mlir.ir_constant(np.int32(d)),
342-
mlir.ir_constant(bool(full)),
342+
mlir.ir_constant(np.bool(full)),
343343
mlir.ir_constant(np.int32(dim0)),
344344
mlir.ir_constant(np.int32(dim1)),
345345
],

0 commit comments

Comments
 (0)