From 2f9203d37047b2b0707b15da99932806e0b0e052 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 4 Sep 2025 13:36:10 +0800 Subject: [PATCH] Fix file utilities, improve docs, and add tests --- deepmd/tf/nvnmd/utils/fio.py | 21 +++++++++++++-------- doc/model/train-fitting-tensor.md | 2 +- source/tests/tf/test_fio_utils.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 9 deletions(-) create mode 100644 source/tests/tf/test_fio_utils.py diff --git a/deepmd/tf/nvnmd/utils/fio.py b/deepmd/tf/nvnmd/utils/fio.py index ab2f5baa57..8c376d50d4 100644 --- a/deepmd/tf/nvnmd/utils/fio.py +++ b/deepmd/tf/nvnmd/utils/fio.py @@ -39,7 +39,7 @@ def is_file(self, file_name): def get_file_list(self, path) -> list: if self.is_file(path): return [] - if self.is_path: + if self.is_path(path): listdir = os.listdir(path) file_lst = [] for name in listdir: @@ -53,8 +53,9 @@ def get_file_list(self, path) -> list: class FioDic: - r"""Input and output for dict class data - the file can be .json or .npy file containing a dictionary. + r"""Input and output for dictionary data. + + The file can be a `.json` or `.npy` file containing a dictionary. """ def __init__(self) -> None: @@ -83,7 +84,7 @@ def get(self, jdata, key, default_value): return default_value def update(self, jdata, jdata_o): - r"""Update key-value pair is key in jdata_o.keys(). + r"""Update key-value pairs if the key exists in ``jdata_o``. Parameters ---------- @@ -186,8 +187,10 @@ class FioTxt: def __init__(self) -> None: pass - def load(self, file_name="", default_value=[]): - r"""Load .txt file into string list.""" + def load(self, file_name="", default_value=None): + r"""Load a text file into a list of strings.""" + if default_value is None: + default_value = [] if Fio().exits(file_name): log.info(f"load {file_name}") with open(file_name, encoding="utf-8") as fr: @@ -198,11 +201,13 @@ def load(self, file_name="", default_value=[]): log.info(f"can not find {file_name}") return default_value - def save(self, file_name: str = "", data: list = []) -> None: - r"""Save string list into .txt file.""" + def save(self, file_name: str = "", data: list | str | None = None) -> None: + r"""Save a list of strings into a text file.""" log.info(f"write string to txt file {file_name}") Fio().create_file_path(file_name) + if data is None: + data = [] if isinstance(data, str): data = [data] data = [d + "\n" for d in data] diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index 29c95b2d68..fade6423ee 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -174,7 +174,7 @@ The loss section should be provided like In tensor mode, the identification of the label's type (global or atomic) is derived from the file name. The global label should be named `dipole.npy/raw` or `polarizability.npy/raw`, while the atomic label should be named `atomic_dipole.npy/raw` or `atomic_polarizability.npy/raw`. If wrongly named, DP will report an error ```bash -ValueError: cannot reshape array of size xxx into shape (xx,xx). This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`. +ValueError: cannot reshape array of size xxx into shape (xx,xx). This error may occur when your label mismatches its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`. ``` In this case, please check the file name of the label. diff --git a/source/tests/tf/test_fio_utils.py b/source/tests/tf/test_fio_utils.py new file mode 100644 index 0000000000..54cb9942e3 --- /dev/null +++ b/source/tests/tf/test_fio_utils.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.tf.nvnmd.utils.fio import ( + Fio, + FioTxt, +) + + +def test_get_file_list(tmp_path): + """get_file_list should handle non-existent paths and collect files recursively.""" + # create directory with one file + subdir = tmp_path / "sub" + subdir.mkdir() + file_path = subdir / "file.txt" + file_path.write_text("hello") + + fio = Fio() + + # existing directory returns the file + files = fio.get_file_list(str(tmp_path)) + assert files == [str(file_path)] + + # non-existent directory should return an empty list + missing = tmp_path / "missing" + assert fio.get_file_list(str(missing)) == [] + + +def test_fiotxt_load_default(tmp_path): + """FioTxt.load should return default empty list when file does not exist.""" + fio_txt = FioTxt() + missing = tmp_path / "no.txt" + assert fio_txt.load(str(missing)) == []