Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions deepmd/tf/nvnmd/utils/fio.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def is_file(self, file_name):
def get_file_list(self, path) -> list:
if self.is_file(path):
return []
if self.is_path:
if self.is_path(path):
listdir = os.listdir(path)
file_lst = []
for name in listdir:
Expand All @@ -53,8 +53,9 @@ def get_file_list(self, path) -> list:


class FioDic:
r"""Input and output for dict class data
the file can be .json or .npy file containing a dictionary.
r"""Input and output for dictionary data.

The file can be a `.json` or `.npy` file containing a dictionary.
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -83,7 +84,7 @@ def get(self, jdata, key, default_value):
return default_value

def update(self, jdata, jdata_o):
r"""Update key-value pair is key in jdata_o.keys().
r"""Update key-value pairs if the key exists in ``jdata_o``.

Parameters
----------
Expand Down Expand Up @@ -186,8 +187,10 @@ class FioTxt:
def __init__(self) -> None:
pass

def load(self, file_name="", default_value=[]):
r"""Load .txt file into string list."""
def load(self, file_name="", default_value=None):
r"""Load a text file into a list of strings."""
if default_value is None:
default_value = []
if Fio().exits(file_name):
log.info(f"load {file_name}")
with open(file_name, encoding="utf-8") as fr:
Expand All @@ -198,11 +201,13 @@ def load(self, file_name="", default_value=[]):
log.info(f"can not find {file_name}")
return default_value

def save(self, file_name: str = "", data: list = []) -> None:
r"""Save string list into .txt file."""
def save(self, file_name: str = "", data: list | str | None = None) -> None:
r"""Save a list of strings into a text file."""
log.info(f"write string to txt file {file_name}")
Fio().create_file_path(file_name)

if data is None:
data = []
if isinstance(data, str):
data = [data]
data = [d + "\n" for d in data]
Expand Down
2 changes: 1 addition & 1 deletion doc/model/train-fitting-tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ The loss section should be provided like
In tensor mode, the identification of the label's type (global or atomic) is derived from the file name. The global label should be named `dipole.npy/raw` or `polarizability.npy/raw`, while the atomic label should be named `atomic_dipole.npy/raw` or `atomic_polarizability.npy/raw`. If wrongly named, DP will report an error

```bash
ValueError: cannot reshape array of size xxx into shape (xx,xx). This error may occur when your label mismatch it's name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`.
ValueError: cannot reshape array of size xxx into shape (xx,xx). This error may occur when your label mismatches its name, i.e. you might store global tensor in `atomic_tensor.npy` or atomic tensor in `tensor.npy`.
```

In this case, please check the file name of the label.
Expand Down
31 changes: 31 additions & 0 deletions source/tests/tf/test_fio_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.tf.nvnmd.utils.fio import (
Fio,
FioTxt,
)


def test_get_file_list(tmp_path):
"""get_file_list should handle non-existent paths and collect files recursively."""
# create directory with one file
subdir = tmp_path / "sub"
subdir.mkdir()
file_path = subdir / "file.txt"
file_path.write_text("hello")

fio = Fio()

# existing directory returns the file
files = fio.get_file_list(str(tmp_path))
assert files == [str(file_path)]

# non-existent directory should return an empty list
missing = tmp_path / "missing"
assert fio.get_file_list(str(missing)) == []


def test_fiotxt_load_default(tmp_path):
"""FioTxt.load should return default empty list when file does not exist."""
fio_txt = FioTxt()
missing = tmp_path / "no.txt"
assert fio_txt.load(str(missing)) == []
Loading