Skip to content

Commit 73de44b

Browse files
njzjzpre-commit-ci[bot]njzjz-bot
authored
feat(tf2): add eager TensorFlow array backend (deepmodeling#5598)
## Summary - vendor ndtensorflow under deepmd/_vendors and add a TensorFlow eager Array API backend under deepmd/tf2 - implement tf2 model wrappers, tf.function SavedModel export/deep-eval glue, and backend registration for .savedmodel - replace jax/jax2tf TensorFlow helper implementations with tf2 compatibility exports - add focused tf2 consistency tests for eager and tf.function model paths ## Tests - ruff check . - ruff format . - pytest source/tests/consistent/test_tf2_backend.py -q - manual SavedModel export/load smoke test Note: local pre-commit hook attempted to fetch astral-sh/ruff-pre-commit and timed out on GitHub port 443; the commit was created with --no-verify after running the checks above manually. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a TensorFlow 2 eager backend with SavedModel inference plus TF2 model/descriptor/fitting wrappers (energy/dipole/polar/DOS/property). * Introduced TF2 neighbor-list construction, region/PBC geometry utilities, and output/ghost-atom aggregation. * Expanded the TensorFlow-backed Array API with namespace, FFT, and linear algebra utilities. * **Bug Fixes** * Improved the non-eager TensorFlow compatibility error message. * **Refactor** * Converted multiple JAX2TF modules into TF2 compatibility shims that re-export TF2 implementations. * **Tests** * Added TF2 consistency coverage, including subprocess-based checks. * **Chores** * Added license headers/docs and updated lint configuration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: njzjz-bot (driven by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5))[bot] <48687836+njzjz-bot@users.noreply.github.com>
1 parent f143171 commit 73de44b

118 files changed

Lines changed: 8283 additions & 1021 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/test_python.yml

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,36 @@ jobs:
6363
DP_CI_IMPORT_PADDLE_BEFORE_TF: 1
6464
FLAGS_use_stride_compute_kernel: 0
6565
- name: Test TF2 eager mode
66-
run: pytest --cov=deepmd --cov-append source/tests/consistent/io/test_io.py source/jax2tf_tests
66+
run: |
67+
run_pytest_allow_no_tests() {
68+
set +e
69+
pytest "$@"
70+
local status=$?
71+
set -e
72+
if [ "$status" -eq 5 ]; then
73+
# pytest-split may leave an individual shard with no selected
74+
# tests after path/-k filtering. Other shards still cover the
75+
# selected tests, so do not fail the whole matrix for exit 5.
76+
return 0
77+
fi
78+
return "$status"
79+
}
80+
81+
run_pytest_allow_no_tests --cov=deepmd --cov-append \
82+
source/tests/consistent/io/test_io.py \
83+
source/jax2tf_tests \
84+
--splits 12 \
85+
--group ${{ matrix.group }}
86+
run_pytest_allow_no_tests --cov=deepmd --cov-append \
87+
source/tests/consistent \
88+
-k tf2 \
89+
--splits 12 \
90+
--group ${{ matrix.group }}
6791
env:
6892
NUM_WORKERS: 0
6993
DP_TEST_TF2_ONLY: 1
7094
DP_DTYPE_PROMOTION_STRICT: 1
71-
if: matrix.group == 1
95+
DP_CI_IMPORT_PADDLE_BEFORE_TF: 1
7296
- run: mv .test_durations .test_durations_${{ matrix.group }}
7397
- name: Upload partial durations
7498
uses: actions/upload-artifact@v7

deepmd/_vendors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Vendored third-party modules used by DeePMD-kit."""
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from __future__ import (
3+
annotations,
4+
)
5+
6+
from typing import (
7+
Final,
8+
)
9+
10+
from . import (
11+
fft,
12+
linalg,
13+
)
14+
from ._array import (
15+
Array,
16+
)
17+
from ._info import (
18+
__array_namespace_info__,
19+
)
20+
from ._namespace import *
21+
from ._namespace import __all__ as _namespace_all
22+
23+
__array_api_version__: Final = "2025.12"
24+
25+
__all__ = sorted(
26+
set(_namespace_all)
27+
| {
28+
"Array",
29+
"__array_api_version__",
30+
"__array_namespace_info__",
31+
"fft",
32+
"linalg",
33+
}
34+
)
35+
36+
37+
def __dir__() -> list[str]:
38+
return __all__

0 commit comments

Comments
 (0)