Skip to content

Commit 3c1b6cc

Browse files
authored
Merge pull request #33 from leakec/jax_7_1
Updating for JAX 0.7.1
2 parents 7c50a62 + e5efe8d commit 3c1b6cc

5 files changed

Lines changed: 6 additions & 12 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ The authors of this repsitory and the associated theory have gone to lengths to
6666
author = {Carl Leake and Hunter Johnston},
6767
title = {{TFC: A Functional Interpolation Framework}},
6868
url = {https://github.com/leakec/tfc},
69-
version = {1.3.0},
69+
version = {1.3.1},
7070
year = {2025},
7171
}
7272
@article{TFC,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ build-backend = "setuptools.build_meta"
88

99
[project]
1010
name = "tfc"
11-
version = "1.3.0"
12-
requires-python = ">=3.11,<3.14"
11+
version = "1.3.1"
12+
requires-python = ">=3.11"
1313
readme = "README.md"
1414
dynamic = ["dependencies", "classifiers", "authors", "license", "description"]
1515

src/tfc/mtfc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,7 @@ def H_jvp(arg_vals, arg_tans, d: tuple[int, ...] = d0, full: bool = False):
623623
out_tans = np.zeros((dim0, dim1))
624624
for k in range(n):
625625
if not (type(arg_tans[k]) is ad.Zero):
626-
if type(arg_tans[k]) is batching.BatchTracer:
627-
flag = onp.any(arg_tans[k].val != 0)
628-
else:
629-
flag = onp.any(arg_tans[k] != 0)
626+
flag = onp.any(arg_tans[k] != 0)
630627
if flag:
631628
dark = tuple(d[j] + 1 if k == j else d[j] for j in range(len(d)))
632629
if flat:

src/tfc/utfc.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,7 @@ def H_jvp(arg_vals, arg_tans, d: uint = 0, full: bool = False):
362362
x = arg_vals[0]
363363
dx = arg_tans[0]
364364
if not (dx is ad.Zero):
365-
if type(dx) is batching.BatchTracer:
366-
flag = onp.any(dx.val != 0)
367-
else:
368-
flag = onp.any(dx != 0)
365+
flag = onp.any(dx != 0)
369366
if flag:
370367
if len(dx.shape) == 1:
371368
out_tans = Hjax(x, d=d + 1, full=full) * onp.expand_dims(dx, 1)

src/tfc/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.3.0"
1+
__version__ = "1.3.1"
22

33

44
def _version_as_tuple(version_str: str) -> tuple[int, ...]:

0 commit comments

Comments
 (0)