|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
2 | 2 | import os |
| 3 | +import re |
3 | 4 | import site |
4 | 5 | from functools import ( |
5 | 6 | lru_cache, |
@@ -56,6 +57,10 @@ def find_tensorflow() -> tuple[Optional[str], list[str]]: |
56 | 57 | ) is not None: |
57 | 58 | site_packages = Path(os.environ.get("TENSORFLOW_ROOT")).parent.absolute() |
58 | 59 | tf_spec = FileFinder(str(site_packages)).find_spec("tensorflow") |
| 60 | + if tf_spec is None: |
| 61 | + raise RuntimeError( |
| 62 | + f"cannot find TensorFlow under TENSORFLOW_ROOT {os.environ.get('TENSORFLOW_ROOT')}" |
| 63 | + ) |
59 | 64 |
|
60 | 65 | # get tensorflow spec |
61 | 66 | # note: isolated build will not work for backend |
@@ -153,7 +158,8 @@ def get_tf_requirement(tf_version: str = "") -> dict: |
153 | 158 | "tensorflow-cpu; platform_machine!='aarch64' and (platform_machine!='arm64' or platform_system != 'Darwin')", |
154 | 159 | "tensorflow; platform_machine=='aarch64' or (platform_machine=='arm64' and platform_system == 'Darwin')", |
155 | 160 | # https://github.com/tensorflow/tensorflow/issues/61830 |
156 | | - "tensorflow-cpu!=2.15.*; platform_system=='Windows'", |
| 161 | + # Since TF 2.20, not all symbols are exported to the public API. |
| 162 | + "tensorflow-cpu!=2.15.*,<2.20; platform_system=='Windows'", |
157 | 163 | # https://github.com/h5py/h5py/issues/2408 |
158 | 164 | "h5py>=3.6.0,!=3.11.0; platform_system=='Linux' and platform_machine=='aarch64'", |
159 | 165 | *extra_requires, |
@@ -228,6 +234,22 @@ def get_tf_version(tf_path: Optional[Union[str, Path]]) -> str: |
228 | 234 | patch = line.split()[-1] |
229 | 235 | elif line.startswith("#define TF_VERSION_SUFFIX"): |
230 | 236 | suffix = line.split()[-1].strip('"') |
| 237 | + if None in (major, minor, patch): |
| 238 | + # since TF 2.20.0, version information is no more contained in version.h |
| 239 | + # try to read version from tools/pip_package/setup.py |
| 240 | + # _VERSION = '2.20.0' |
| 241 | + setup_file = Path(tf_path) / "tools" / "pip_package" / "setup.py" |
| 242 | + if setup_file.exists(): |
| 243 | + with open(setup_file) as f: |
| 244 | + for line in f: |
| 245 | + # parse with regex |
| 246 | + match = re.search( |
| 247 | + r"_VERSION[ \t]*=[ \t]*'(\d+)\.(\d+)\.(\d+)([a-zA-Z0-9]*)?'", |
| 248 | + line, |
| 249 | + ) |
| 250 | + if match: |
| 251 | + major, minor, patch, suffix = match.groups() |
| 252 | + break |
231 | 253 | if None in (major, minor, patch): |
232 | 254 | raise RuntimeError("Failed to read TF version") |
233 | 255 | return ".".join((major, minor, patch)) + suffix |
0 commit comments