Skip to content

Commit 0162879

Browse files
Update run_grouped_ablation.py
1 parent 54531c8 commit 0162879

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

ablation_studies/run_grouped_ablation.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ class Colors:
2424
UNDERLINE = '\033[4m'
2525

2626
def print_header(variant, env, num_seeds):
27-
print(f"\n{Colors.OKCYAN}---------------------------------------------------------------------------------{Colors.ENDC}")
28-
print(f"{Colors.OKCYAN}|{Colors.ENDC} {Colors.BOLD}Ablation Group:{Colors.ENDC} {variant:<15} | {env:<20} {Colors.OKCYAN}|{Colors.ENDC}")
29-
print(f"{Colors.OKCYAN}|{Colors.ENDC} {Colors.BOLD}Seeds:{Colors.ENDC} 0 to {num_seeds - 1:<3} {Colors.OKCYAN}|{Colors.ENDC}")
30-
print(f"{Colors.OKCYAN}---------------------------------------------------------------------------------{Colors.ENDC}\n")
27+
print(f"\n{Colors.OKCYAN}---------------------------------------------------------------{Colors.ENDC}")
28+
print(f"{Colors.OKCYAN}|{Colors.ENDC} {Colors.BOLD}Ablation Group:{Colors.ENDC} {variant:<18} | {env:<23}|")
29+
seeds_str = f"0 to {num_seeds - 1}"
30+
print(f"{Colors.OKCYAN}|{Colors.ENDC} {Colors.BOLD}Seeds:{Colors.ENDC} {seeds_str:<44}{Colors.OKCYAN}|{Colors.ENDC}")
31+
print(f"{Colors.OKCYAN}---------------------------------------------------------------{Colors.ENDC}\n")
3132

3233
def run_single_seed(variant, env, seed, contract):
3334
"""
@@ -67,14 +68,14 @@ def run_single_seed(variant, env, seed, contract):
6768
# Construct error message instead of printing directly
6869
error_msg = []
6970
error_msg.append(f" [{Colors.BOLD}SEED {seed}{Colors.ENDC}] {Colors.FAIL}x FAILED{Colors.ENDC}")
70-
error_msg.append(f"{Colors.FAIL} | Error Log ---------------------------------------------------{Colors.ENDC}")
71+
error_msg.append(f"{Colors.FAIL} | Error Log ---------------------------------------------{Colors.ENDC}")
7172

7273
# Filter stderr to remove tqdm noise (lines containing %|)
7374
err_lines = [line for line in stderr_content.splitlines() if "%|" not in line and "it/s" not in line]
7475
# Print last 20 lines of filtered error
7576
for line in err_lines[-20:]:
7677
error_msg.append(f"{Colors.FAIL} | {line}{Colors.ENDC}")
77-
error_msg.append(f"{Colors.FAIL} ---------------------------------------------------------------{Colors.ENDC}")
78+
error_msg.append(f"{Colors.FAIL} -------------------------------------------------------{Colors.ENDC}")
7879

7980
return False, "\n".join(error_msg)
8081

@@ -189,7 +190,10 @@ def main():
189190
spikes_str = f"{val_spikes:.2f}" if val_spikes is not None else "N/A"
190191
if val_spikes is not None: spikes_list.append(val_spikes)
191192

192-
print(f" [{Colors.BOLD}SEED {seed}{Colors.ENDC}] {Colors.OKGREEN}+ Finished{Colors.ENDC} Return: {Colors.BOLD}{val_return:.2f}{Colors.ENDC} Spikes/Inf: {Colors.OKCYAN}{spikes_str}{Colors.ENDC}")
193+
msg = f" [{Colors.BOLD}SEED {seed}{Colors.ENDC}] {Colors.OKGREEN}+ Finished{Colors.ENDC} Return: {Colors.BOLD}{val_return:.2f}{Colors.ENDC}"
194+
if val_spikes is not None:
195+
msg += f" Spikes: {Colors.OKCYAN}{spikes_str}{Colors.ENDC}"
196+
print(msg)
193197
else:
194198
print(f" [{Colors.BOLD}SEED {seed}{Colors.ENDC}] {Colors.WARNING}? Finished{Colors.ENDC} Return: {Colors.WARNING}Not Found{Colors.ENDC}")
195199
# Print the captured output for debugging
@@ -206,7 +210,7 @@ def main():
206210
except Exception as exc:
207211
print(f" [{Colors.BOLD}SEED {seed}{Colors.ENDC}] {Colors.FAIL}Generated an exception: {exc}{Colors.ENDC}")
208212

209-
print(f"\n{Colors.OKCYAN}---------------------------------------------------------------------------------{Colors.ENDC}")
213+
print(f"\n{Colors.OKCYAN}---------------------------------------------------------------{Colors.ENDC}")
210214
if returns:
211215
mean_ret = np.mean(returns)
212216
std_ret = np.std(returns)
@@ -218,11 +222,12 @@ def main():
218222

219223
print(f" {Colors.BOLD}FINAL RESULT:{Colors.ENDC}")
220224
print(f" Mean Return: {Colors.OKGREEN}{mean_str}{Colors.ENDC} +/- {Colors.OKGREEN}{std_str}{Colors.ENDC}")
221-
print(f" Mean Spikes: {Colors.OKCYAN}{spikes_final_str}{Colors.ENDC}")
225+
if spikes_list:
226+
print(f" Mean Spikes: {Colors.OKCYAN}{spikes_final_str}{Colors.ENDC}")
222227
print(f" Success Rate: {len(returns)}/{args.num_seeds}")
223228
else:
224229
print(f" {Colors.FAIL}NO SUCCESSFUL RUNS{Colors.ENDC}")
225-
print(f"{Colors.OKCYAN}---------------------------------------------------------------------------------{Colors.ENDC}\n")
230+
print(f"{Colors.OKCYAN}---------------------------------------------------------------{Colors.ENDC}\n")
226231

227232
if __name__ == "__main__":
228233
main()

0 commit comments

Comments
 (0)