Skip to content

Commit e652603

Browse files
committed
Refactor file path handling to utilize utility functions for database and test data directories
1 parent 2db833d commit e652603

2 files changed

Lines changed: 22 additions & 26 deletions

File tree

autotst/data/base_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import autotst
3737
from ..reaction import Reaction
3838
from .base import QMData, DistanceData, TransitionStates, TransitionStateDepository, TSGroups
39+
from ..utils.paths import database_dir, test_data_dir
3940
import rmgpy
4041
import rmgpy.data.rmg
4142

@@ -61,7 +62,8 @@ def test_get_qmdata(self):
6162
A method that is designed to obtain the QM data for a transitionstate or molecule
6263
Returns a qmdata object
6364
"""
64-
self.qmdata.get_qmdata(os.path.expandvars("$AUTOTST/test/bin/log-files/CC+[O]O_[CH2]C+OO_forward_0.log"))
65+
log_path = test_data_dir() / "CC+[O]O_[CH2]C+OO_forward_0.log"
66+
self.qmdata.get_qmdata(str(log_path))
6567

6668
self.assertEqual(self.qmdata.ground_state_degeneracy, 2)
6769
self.assertAlmostEqual(self.qmdata.molecular_mass[0], 126.1, places=1)
@@ -154,9 +156,7 @@ def setUp(self):
154156
self.ts_depository = TransitionStateDepository(label="test")
155157

156158
self.settings = {
157-
"file_path": os.path.join(
158-
os.path.expandvars("$AUTOTST"), "database", "H_Abstraction", "TS_training", "reactions.py"
159-
),
159+
"file_path": str(database_dir() / "H_Abstraction" / "TS_training" / "reactions.py"),
160160
"local_context": {"DistanceData":DistanceData},
161161
"global_context": {'__builtins__': None}
162162
}
@@ -176,9 +176,7 @@ def setUp(self):
176176
self.ts_groups = TSGroups(label="test")
177177

178178
self.settings = {
179-
"file_path": os.path.join(
180-
os.path.expandvars("$AUTOTST"), "database", "H_Abstraction", "TS_groups.py"
181-
),
179+
"file_path": str(database_dir() / "H_Abstraction" / "TS_groups.py"),
182180
"local_context": {"DistanceData":DistanceData},
183181
"global_context": {'__builtins__': None}
184182
}
@@ -211,4 +209,4 @@ def test_estimate_distances_using_group_additivity(self):
211209
self.assertAlmostEquals(d23, distance_data.distances["d23"], places=1)
212210

213211
if __name__ == "__main__":
214-
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))
212+
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))

autotst/data/update.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from ..species import Species, Conformer
3838
from ..reaction import Reaction, TS
3939
from .base import *
40+
from ..utils.paths import database_dir
4041
import rmgpy
4142
import rmgpy.molecule
4243
import rmgpy.data.base
@@ -395,25 +396,24 @@ def update_databases(reactions, method='', short_desc='', reaction_family='', ov
395396
# logging.warning(
396397
# 'Defaulting to reaction family of {}'.format(reaction_family))
397398

398-
general_path = os.path.join(os.path.expandvars(
399-
'$AUTOTST'), 'database', reaction_family, 'TS_training')
399+
general_path = database_dir() / reaction_family / "TS_training"
400400

401-
dict_path = os.path.join(general_path, 'dictionary.txt')
402-
old_reactions_path = os.path.join(general_path, 'reactions.py')
401+
dict_path = general_path / "dictionary.txt"
402+
old_reactions_path = general_path / "reactions.py"
403403

404404
if overwrite:
405-
new_dict_path = os.path.join(general_path, 'dictionary.txt')
406-
new_reactions_path = os.path.join(general_path, 'reactions.py')
405+
new_dict_path = general_path / "dictionary.txt"
406+
new_reactions_path = general_path / "reactions.py"
407407
else:
408-
new_dict_path = os.path.join(general_path, 'updated_dictionary.txt')
409-
new_reactions_path = os.path.join(general_path, 'updated_reactions.py')
408+
new_dict_path = general_path / "updated_dictionary.txt"
409+
new_reactions_path = general_path / "updated_reactions.py"
410410

411-
known_species = rmgpy.data.base.Database().get_species(dict_path)
411+
known_species = rmgpy.data.base.Database().get_species(str(dict_path))
412412
unknown_species = get_unknown_species(reactions, known_species)
413413

414414
updated_known_species = []
415415
if len(unknown_species) > 0:
416-
old_dict_entries = rote_load_dict(dict_path)
416+
old_dict_entries = rote_load_dict(str(dict_path))
417417

418418
assert len(known_species) == len(old_dict_entries)
419419

@@ -424,15 +424,15 @@ def update_databases(reactions, method='', short_desc='', reaction_family='', ov
424424
len(unknown_species) == len(all_dict_entries)
425425

426426
if check_dictionary_entries(all_dict_entries):
427-
rote_save_dictionary(new_dict_path, all_dict_entries)
427+
rote_save_dictionary(str(new_dict_path), all_dict_entries)
428428

429-
updated_known_species = rmgpy.data.base.Database().get_species(new_dict_path)
429+
updated_known_species = rmgpy.data.base.Database().get_species(str(new_dict_path))
430430
unk_spec = get_unknown_species(reactions, updated_known_species)
431431
assert len(unk_spec) == 0, f'{len(unk_spec)} unknown species found after updating'
432432
else:
433433
updated_known_species = known_species
434434

435-
r_db, old_db, new_db = update_known_reactions(old_reactions_path,
435+
r_db, old_db, new_db = update_known_reactions(str(old_reactions_path),
436436
reactions,
437437
updated_known_species,
438438
method=method,
@@ -443,7 +443,7 @@ def update_databases(reactions, method='', short_desc='', reaction_family='', ov
443443
# if check_reactions_database():
444444
if True:
445445
logging.warning('No duplicate check for reactions database')
446-
r_db.save(new_reactions_path)
446+
r_db.save(str(new_reactions_path))
447447
if len(reactions) < 5:
448448
for reaction in reactions:
449449
logging.info(
@@ -468,8 +468,7 @@ def TS_Database_Update(families, path=None, auto_save=False):
468468

469469
assert isinstance(
470470
families, list), "Families must be a list. If singular family, still keep it in list"
471-
acceptable_families = os.listdir(os.path.join(
472-
os.path.expandvars("$AUTOTST"), "database"))
471+
acceptable_families = os.listdir(database_dir())
473472
for family in families:
474473
assert isinstance(
475474
family, str), "Family names must be provided as strings"
@@ -555,8 +554,7 @@ def __init__(self, family, rmg_database, path=None):
555554
if path is not None:
556555
self.path = path
557556
else:
558-
self.path = os.path.join(os.path.expandvars(
559-
"$AUTOTST"), "database", family)
557+
self.path = database_dir() / family
560558

561559
self.family = family
562560

0 commit comments

Comments
 (0)