Skip to content

Commit e73bbff

Browse files
njzjzCopilot
andauthored
fix(tf): fix compatibility with TF 2.20 (#4890)
Fix version finding in pip and CMake; pin TF to <2.20 on Windows; fix TENSORFLOW_ROOT in the CI. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added compatibility with TensorFlow 2.20+ via runtime version detection and generated version macros. - Bug Fixes - Clearer errors when a specified TensorFlow root is invalid. - Improved version-parsing fallback for newer TensorFlow releases. - Tightened Windows CPU wheel constraint to avoid incompatible versions. - Chores - Updated devcontainer scripts and CI workflows to more reliably locate TensorFlow without importing it directly. - Linked TensorFlow during version checks to ensure accurate detection. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent accc331 commit e73bbff

File tree

11 files changed

+56
-13
lines changed

11 files changed

+56
-13
lines changed

.devcontainer/build_cxx.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ NPROC=$(nproc --all)
55
SCRIPT_PATH=$(dirname $(realpath -s $0))
66

77
export CMAKE_PREFIX_PATH=${SCRIPT_PATH}/../libtorch
8-
TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
8+
TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
99

1010
mkdir -p ${SCRIPT_PATH}/../buildcxx/
1111
cd ${SCRIPT_PATH}/../buildcxx/

.devcontainer/gdb_lmp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SCRIPT_PATH=$(dirname $(realpath -s $0))
33

44
export CMAKE_PREFIX_PATH=${SCRIPT_PATH}/../libtorch
5-
TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5+
TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
66

77
env LAMMPS_PLUGIN_PATH=${SCRIPT_PATH}/../dp/lib/deepmd_lmp \
88
LD_LIBRARY_PATH=${SCRIPT_PATH}/../dp/lib:${CMAKE_PREFIX_PATH}/lib:${TENSORFLOW_ROOT} \

.devcontainer/gdb_pytest_lmp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SCRIPT_PATH=$(dirname $(realpath -s $0))/../..
33

44
export CMAKE_PREFIX_PATH=${SCRIPT_PATH}/../libtorch
5-
TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5+
TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
66

77
env LAMMPS_PLUGIN_PATH=${SCRIPT_PATH}/../dp/lib/deepmd_lmp \
88
LD_LIBRARY_PATH=${SCRIPT_PATH}/../dp/lib:${CMAKE_PREFIX_PATH}/lib:${TENSORFLOW_ROOT} \

.devcontainer/lmp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SCRIPT_PATH=$(dirname $(realpath -s $0))
33

44
export CMAKE_PREFIX_PATH=${SCRIPT_PATH}/../libtorch
5-
TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5+
TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
66

77
env LAMMPS_PLUGIN_PATH=${SCRIPT_PATH}/../dp/lib/deepmd_lmp \
88
LD_LIBRARY_PATH=${SCRIPT_PATH}/../dp/lib:${CMAKE_PREFIX_PATH}/lib:${TENSORFLOW_ROOT} \

.devcontainer/pytest_lmp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SCRIPT_PATH=$(dirname $(realpath -s $0))/../..
33

44
export CMAKE_PREFIX_PATH=${SCRIPT_PATH}/../libtorch
5-
TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5+
TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
66

77
env LAMMPS_PLUGIN_PATH=${SCRIPT_PATH}/../dp/lib/deepmd_lmp \
88
LD_LIBRARY_PATH=${SCRIPT_PATH}/../dp/lib:${CMAKE_PREFIX_PATH}/lib:${TENSORFLOW_ROOT} \

.github/workflows/test_cc.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install Python dependencies
2727
run: |
2828
source/install/uv_with_retry.sh pip install --system tensorflow-cpu~=2.18.0 jax==0.5.0
29-
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
29+
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3030
source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich
3131
source/install/uv_with_retry.sh pip install --system 'torch==2.7' --index-url https://download.pytorch.org/whl/cpu
3232
- name: Convert models

.github/workflows/test_cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
- run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.7.0" "jax[cuda12]==0.5.0"
4747
- run: |
4848
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
49-
export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
49+
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5050
pip install "paddlepaddle-gpu==3.0.0" -i https://www.paddlepaddle.org.cn/packages/stable/cu126/
5151
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py --reinstall-package deepmd-kit
5252
env:

.github/workflows/test_python.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
- run: |
2828
source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu~=2.18.0
2929
source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu
30-
export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])')
30+
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
3131
export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])')
3232
source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py "jax==0.5.0;python_version>='3.10'"
3333
source/install/uv_with_retry.sh pip install --system -U setuptools

backend/find_tensorflow.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import os
3+
import re
34
import site
45
from functools import (
56
lru_cache,
@@ -56,6 +57,10 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]:
5657
) is not None:
5758
site_packages = Path(os.environ.get("TENSORFLOW_ROOT")).parent.absolute()
5859
tf_spec = FileFinder(str(site_packages)).find_spec("tensorflow")
60+
if tf_spec is None:
61+
raise RuntimeError(
62+
f"cannot find TensorFlow under TENSORFLOW_ROOT {os.environ.get('TENSORFLOW_ROOT')}"
63+
)
5964

6065
# get tensorflow spec
6166
# note: isolated build will not work for backend
@@ -153,7 +158,8 @@ def get_tf_requirement(tf_version: str = "") -> dict:
153158
"tensorflow-cpu; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')",
154159
"tensorflow; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')",
155160
# https://github.com/tensorflow/tensorflow/issues/61830
156-
"tensorflow-cpu!=2.15.*; platform_system=='Windows'",
161+
# Since TF 2.20, not all symbols are exported to the public API.
162+
"tensorflow-cpu!=2.15.*,<2.20; platform_system=='Windows'",
157163
# https://github.com/h5py/h5py/issues/2408
158164
"h5py>=3.6.0,!=3.11.0; platform_system=='Linux' and platform_machine=='aarch64'",
159165
*extra_requires,
@@ -228,6 +234,22 @@ def get_tf_version(tf_path: Optional[Union[str, Path]]) -> str:
228234
patch = line.split()[-1]
229235
elif line.startswith("#define TF_VERSION_SUFFIX"):
230236
suffix = line.split()[-1].strip('"')
237+
if None in (major, minor, patch):
238+
# since TF 2.20.0, version information is no more contained in version.h
239+
# try to read version from tools/pip_package/setup.py
240+
# _VERSION = '2.20.0'
241+
setup_file = Path(tf_path) / "tools" / "pip_package" / "setup.py"
242+
if setup_file.exists():
243+
with open(setup_file) as f:
244+
for line in f:
245+
# parse with regex
246+
match = re.search(
247+
r"_VERSION[ \t]*=[ \t]*'(\d+)\.(\d+)\.(\d+)([a-zA-Z0-9]*)?'",
248+
line,
249+
)
250+
if match:
251+
major, minor, patch, suffix = match.groups()
252+
break
231253
if None in (major, minor, patch):
232254
raise RuntimeError("Failed to read TF version")
233255
return ".".join((major, minor, patch)) + suffix

source/cmake/Findtensorflow.cmake

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,10 @@ if(NOT DEFINED TENSORFLOW_VERSION)
291291
TENSORFLOW_VERSION_RUN_RESULT_VAR TENSORFLOW_VERSION_COMPILE_RESULT_VAR
292292
${CMAKE_CURRENT_BINARY_DIR}/tf_version
293293
"${CMAKE_CURRENT_LIST_DIR}/tf_version.cpp"
294-
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES:STRING=${TensorFlow_INCLUDE_DIRS}"
295-
RUN_OUTPUT_VARIABLE TENSORFLOW_VERSION
294+
CMAKE_FLAGS
295+
"-DINCLUDE_DIRECTORIES:STRING=${TensorFlow_INCLUDE_DIRS}" LINK_LIBRARIES
296+
${TensorFlowFramework_LIBRARY} ${TensorFlow_LIBRARY}
297+
RUN_OUTPUT_STDOUT_VARIABLE TENSORFLOW_VERSION
296298
COMPILE_OUTPUT_VARIABLE TENSORFLOW_VERSION_COMPILE_OUTPUT_VAR)
297299
if(NOT ${TENSORFLOW_VERSION_COMPILE_RESULT_VAR})
298300
message(
@@ -304,6 +306,23 @@ if(NOT DEFINED TENSORFLOW_VERSION)
304306
endif()
305307
endif()
306308

309+
if(TENSORFLOW_VERSION VERSION_GREATER_EQUAL 2.20)
310+
# since TF 2.20, macros like TF_MAJOR_VERSION, TF_MINOR_VERSION, and
311+
# TF_PATCH_VERSION are not defined We manuanlly define them in our CMake files
312+
# first, split TENSORFLOW_VERSION (e.g. 2.20.0rc0) to 2 20 0 rc0
313+
string(REGEX MATCH "^([0-9]+)\\.([0-9]+)\\.([0-9]+)(.*)$" _match
314+
${TENSORFLOW_VERSION})
315+
if(_match)
316+
set(TF_MAJOR_VERSION ${CMAKE_MATCH_1})
317+
set(TF_MINOR_VERSION ${CMAKE_MATCH_2})
318+
set(TF_PATCH_VERSION ${CMAKE_MATCH_3})
319+
# add defines
320+
add_definitions(-DTF_MAJOR_VERSION=${TF_MAJOR_VERSION})
321+
add_definitions(-DTF_MINOR_VERSION=${TF_MINOR_VERSION})
322+
add_definitions(-DTF_PATCH_VERSION=${TF_PATCH_VERSION})
323+
endif()
324+
endif()
325+
307326
# print message
308327
if(NOT TensorFlow_FIND_QUIETLY)
309328
message(

0 commit comments

Comments
 (0)