Skip to content

Commit 524bee8

Browse files
author
ronuchit
authored
fix learning time (#587)
* fix learning time Note one issue is that all three of these things are cumulative now -- num_transitions, cumulative_query_cost, learning_time. But we're only naming one of them "cumulative". Should we just make them all consistent @tomsilver? * remove "cumulative"
1 parent 3d8cc14 commit 524bee8

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

scripts/analyze_results_directory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
# ("AVG_PLAN_LEN", "avg_plan_length"),
4141
# ("AVG_EXECUTION_FAILURES", "avg_execution_failures"),
4242
# ("NUM_TRANSITIONS", "num_transitions"),
43-
# ("CUM_QUERY_COST", "cumulative_query_cost"),
43+
# ("QUERY_COST", "query_cost"),
4444
]
4545

4646

src/main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,18 @@ def _run_pipeline(env: BaseEnv,
123123
total_num_transitions = sum(
124124
len(traj.actions) for traj in offline_dataset.trajectories)
125125
total_query_cost = 0.0
126-
learning_start = time.time()
127126
if CFG.load_approach:
128127
approach.load(online_learning_cycle=None)
128+
learning_time = 0.0 # ignore loading time
129129
else:
130+
learning_start = time.time()
130131
approach.learn_from_offline_dataset(offline_dataset)
132+
learning_time = time.time() - learning_start
131133
# Run evaluation once before online learning starts.
132134
results = _run_testing(env, approach)
133135
results["num_transitions"] = total_num_transitions
134-
results["cumulative_query_cost"] = total_query_cost
135-
results["learning_time"] = time.time() - learning_start
136+
results["query_cost"] = total_query_cost
137+
results["learning_time"] = learning_time
136138
_save_test_results(results, online_learning_cycle=None)
137139
teacher = Teacher(train_tasks)
138140
# The online learning loop.
@@ -156,17 +158,21 @@ def _run_pipeline(env: BaseEnv,
156158
logging.info(f"Query cost incurred this cycle: {query_cost}")
157159
if CFG.load_approach:
158160
approach.load(online_learning_cycle=i)
161+
learning_time += 0.0 # ignore loading time
159162
else:
163+
learning_start = time.time()
160164
approach.learn_from_interaction_results(interaction_results)
165+
learning_time += time.time() - learning_start
161166
# Evaluate approach after every online learning cycle.
162167
results = _run_testing(env, approach)
163168
results["num_transitions"] = total_num_transitions
164-
results["cumulative_query_cost"] = total_query_cost
165-
results["learning_time"] = time.time() - learning_start
169+
results["query_cost"] = total_query_cost
170+
results["learning_time"] = learning_time
166171
_save_test_results(results, online_learning_cycle=i)
167172
else:
168173
results = _run_testing(env, approach)
169174
results["num_transitions"] = 0
175+
results["query_cost"] = 0.0
170176
results["learning_time"] = 0.0
171177
_save_test_results(results, online_learning_cycle=None)
172178

0 commit comments

Comments
 (0)