Skip to content

Commit 4bbe5d3

Browse files
committed
change default output path for aim run export
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
1 parent a28aa7b commit 4bbe5d3

2 files changed

Lines changed: 16 additions & 6 deletions

File tree

tuning/tracker/aimstack_tracker.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CustomAimCallback(AimCallback):
1111

1212
# A path to export run hash generated by Aim
1313
# This is used to link back to the expriments from outside aimstack
14-
aim_run_hash_export_path = None
14+
run_hash_export_path = None
1515

1616
def on_init_end(self, args, state, control, **kwargs):
1717

@@ -20,9 +20,18 @@ def on_init_end(self, args, state, control, **kwargs):
2020

2121
self.setup() # initializes the run_hash
2222

23-
# store the run hash
24-
if self.aim_run_hash_export_path:
25-
with open(self.aim_run_hash_export_path, 'w') as f:
23+
# Store the run hash
24+
# Change default run hash path to output directory
25+
if self.run_hash_export_path is None:
26+
if args and args.output_dir:
27+
# args.output_dir/.aim_run_hash
28+
self.run_hash_export_path = os.path.join(
29+
args.output_dir,
30+
'.aim_run_hash'
31+
)
32+
33+
if self.run_hash_export_path:
34+
with open(self.run_hash_export_path, 'w') as f:
2635
f.write('{\"run_hash\":\"'+str(self._run.hash)+'\"}\n')
2736

2837
def on_train_begin(self, args, state, control, model=None, **kwargs):
@@ -60,7 +69,7 @@ def get_hf_callback(self):
6069
else:
6170
aim_callback = CustomAimCallback(experiment=exp)
6271

63-
aim_callback.aim_run_hash_export_path = hash_export_path
72+
aim_callback.run_hash_export_path = hash_export_path
6473
self.hf_callback = aim_callback
6574
return self.hf_callback
6675

tuning/tracker/tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ def __init__(self, name=None, tracker_config=None) -> None:
99
else:
1010
self._name = name
1111

12-
def get_hf_callback():
12+
# we use args here to denote any argument.
13+
def get_hf_callback(self):
1314
return None
1415

1516
def track(self, metric, name, stage):

0 commit comments

Comments
 (0)