-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathloop.py
More file actions
960 lines (800 loc) · 35 KB
/
loop.py
File metadata and controls
960 lines (800 loc) · 35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
import gzip
import importlib.metadata
import json
import logging
import os
import pickle
import re
import sys
import time
import traceback
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import asdict, dataclass, field, is_dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
import gymnasium as gym
import numpy as np
from browsergym.core.chat import Chat
from browsergym.experiments.agent import Agent
from browsergym.experiments.utils import count_tokens
from dataclasses_json import DataClassJsonMixin
from PIL import Image
from tqdm import tqdm
try:
from agentlab.agents.tapeagent import TapeAgent, save_tape
except ImportError:
TapeAgent = None
logger = logging.getLogger(__name__)
SEED_MAX = 2 ^ 32 # arbitrary max value (exclusive), seems large enough
@dataclass
class EnvArgs(DataClassJsonMixin):
task_name: str
task_seed: Optional[int] = None
max_steps: Optional[int] = None
headless: bool = True
record_video: bool = False
wait_for_user_message: bool = False
viewport: Optional[dict] = None # use default value from BrowserGym
slow_mo: Optional[int] = None # use default value from BrowserGym
storage_state: Optional[str | Path | dict] = None
task_kwargs: Optional[dict] = None # use default value from BrowserGym
pre_observation_delay: float = None # seconds, wait for JS events to be fired
def make_env(
self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=True
):
"""
Instantiates the BrowserGym environment corresponding to the arguments (with some tweaks).
Args:
action_mapping: overrides the action mapping of the environment.
exp_dir: will set some environment parameters (e.g., record_video_dir) with respect to the directory where the experiment is running.
exp_task_kwargs: use with caution! Will override task parameters to experiment-specific values. Useful to set different server configs for different experiments, or output file paths within the experiment's folder (e.g., assistantbench).
use_raw_page_output: if True, the environment will also return raw page output in the observation.
Returns:
env: the gym environment.
"""
extra_kwargs = {}
if self.record_video:
extra_kwargs["record_video_dir"] = exp_dir
if self.viewport:
extra_kwargs["viewport"] = self.viewport
if self.slow_mo is not None:
extra_kwargs["slow_mo"] = self.slow_mo
if self.pre_observation_delay is not None:
extra_kwargs["pre_observation_delay"] = self.pre_observation_delay
if self.storage_state:
extra_kwargs["pw_context_kwargs"] = {"storage_state": self.storage_state}
if self.task_kwargs is not None:
extra_kwargs["task_kwargs"] = self.task_kwargs
if exp_task_kwargs:
extra_kwargs["task_kwargs"] = extra_kwargs.get("task_kwargs", {}) | exp_task_kwargs
# assistantbench hack, write the task output (agent prediction) to a file in the experiment's directory
# TODO: find a better way to deal with this
if self.task_name.startswith("assistantbench.test"):
extra_kwargs["task_kwargs"] = extra_kwargs.get("task_kwargs", {}) | {
"output_file": exp_dir / "assistantbench-prediction.json"
}
return gym.make(
_get_env_name(self.task_name),
disable_env_checker=True,
max_episode_steps=self.max_steps,
headless=self.headless,
wait_for_user_message=self.wait_for_user_message,
action_mapping=action_mapping, # action mapping is provided by the agent
use_raw_page_output=use_raw_page_output,
**extra_kwargs,
)
@dataclass
class AbstractAgentArgs(ABC):
"""A template class that defines the required signature of an agent's arguments."""
agent_name: str = None # type: ignore
def __post_init__(self):
if self.agent_name is None:
self.agent_name = self.__class__.__name__
def prepare(self):
"""Prepare the agent's LLM models before running the experiment."""
pass
def close(self):
"""Close the agent's LLM models after running the experiment."""
pass
@abstractmethod
def make_agent(self) -> Agent:
"""Comply the experiments.loop API for instantiating the agent."""
def save_package_versions(exp_dir: Path):
"""Save the versions of the installed packages in the experiment directory."""
python_dists = "\n".join(
sorted(
[
f"{dist.metadata['Name']}=={dist.metadata['Version']}"
for dist in importlib.metadata.distributions()
]
)
)
(exp_dir / "package_versions.txt").write_text(python_dists)
@dataclass
class StepTimestamps:
env_start: float = 0
action_exec_start: float = 0 # to extract begining of visual action from video
action_exec_stop: float = 0 # to extract end of visual action from video
action_exect_after_timeout: float = 0
env_stop: float = 0
agent_start: float = 0
agent_stop: float = 0
wait_for_page_loading_start: float = 0
wait_for_page_loading_stop: float = 0
validation_start: float = 0
validation_stop: float = 0
get_observation_start: float = 0
get_observation_stop: float = 0
@dataclass
class StepInfo:
"""Collects information about step that will be saved and reloaded.
Helper functions only modify the dataclass attributes and helps keeping the
information organized.
Attributes:
-----------
step: int
The step number of the episode.
obs: dict
The observation of the environment.
reward: float
The reward of the step.
raw_reward: float
The raw reward of the step.
terminated: bool
Whether the episode is terminated i.e. reached a terminal state.
truncated: bool
Whether the episode is truncated i.e. reached a maximum number of steps.
action: str
The action taken by the agent.
agent_info: dict
Additional information from the agent.
stats: dict
Extra statistics about the step.
profiling: StepTimestamps
Timestamps of the different events during the episode.
"""
step: int = None
obs: dict = None
reward: float = 0
raw_reward: float = 0
terminated: bool = None
truncated: bool = None
action: str = None
agent_info: dict = field(default_factory=dict)
stats: dict = None
profiling: StepTimestamps = field(default_factory=StepTimestamps)
task_info: dict = None
def from_step(self, env: gym.Env, action: str, obs_preprocessor: callable):
t = self.profiling
t.env_start = time.time()
self.obs, self.reward, self.terminated, self.truncated, env_info = env.step(action)
t.env_stop = time.time()
self.task_info = env_info.get("task_info", None)
self.raw_reward = env_info.get("RAW_REWARD_GLOBAL", None)
t.action_exec_start = env_info["action_exec_start"] # start
t.action_exect_after_timeout = env_info["action_exec_stop"]
t.action_exec_stop = env_info["action_exec_stop"] - env_info["action_exec_timeout"]
t.wait_for_page_loading_start = env_info.get("wait_for_page_loading_start", None)
t.wait_for_page_loading_stop = env_info.get("wait_for_page_loading_stop", None)
t.validation_start = env_info.get("validation_start", None)
t.validation_stop = env_info.get("validation_stop", None)
t.get_observation_start = env_info.get("get_observation_start", None)
t.get_observation_stop = env_info.get("get_observation_stop", None)
if obs_preprocessor:
self.obs = obs_preprocessor(self.obs)
def from_action(self, agent: Agent):
self.profiling.agent_start = time.time()
self.action, self.agent_info = agent.get_action(self.obs.copy())
self.profiling.agent_stop = time.time()
self.make_stats()
return self.action
def from_reset(self, env: gym.Env, seed: int, obs_preprocessor: callable):
t = self.profiling
t.env_start = time.time()
self.obs, env_info = env.reset(seed=seed)
self.reward, self.terminated, self.truncated = 0, False, False
t.env_stop = time.time()
t.action_exec_start = env_info.get("recording_start_time", t.env_start)
t.action_exect_after_timeout = t.env_stop
t.action_exec_stop = t.env_stop
if obs_preprocessor:
self.obs = obs_preprocessor(self.obs)
@property
def is_done(self):
return self.terminated or self.truncated
def make_stats(self):
if isinstance(self.obs, dict):
stats = {
f"n_token_{key}": count_tokens(val)
for key, val in self.obs.items()
if isinstance(val, str)
}
else:
stats = {}
stats.update(self.agent_info.pop("stats", {}))
t = self.profiling
stats["step_elapsed"] = t.env_stop - t.env_start
stats["agent_elapsed"] = t.agent_stop - t.agent_start
self.stats = stats
def save_step_info(self, exp_dir, save_json=False, save_screenshot=True, save_som=False):
# special treatment for some of the observation fields
if isinstance(self.obs, dict):
# save screenshots to separate files
screenshot = self.obs.pop("screenshot", None)
screenshot_som = self.obs.pop("screenshot_som", None)
if save_screenshot and screenshot is not None:
img = Image.fromarray(screenshot)
img.save(exp_dir / f"screenshot_step_{self.step}.png")
if save_som and screenshot_som is not None:
img = Image.fromarray(screenshot_som)
img.save(exp_dir / f"screenshot_som_step_{self.step}.png")
# save goal object (which might contain images) to a separate file to save space
if self.obs.get("goal_object", False):
# save the goal object only once (goal should never change once setup)
goal_object_file = Path(exp_dir) / "goal_object.pkl.gz"
if not goal_object_file.exists():
with gzip.open(goal_object_file, "wb") as f:
pickle.dump(self.obs["goal_object"], f)
# set goal_object to a special placeholder value, which indicates it should be loaded from a separate file
self.obs["goal_object"] = None
with gzip.open(exp_dir / f"step_{self.step}.pkl.gz", "wb") as f:
pickle.dump(self, f)
if save_json:
with open(exp_dir / "steps_info.json", "w") as f:
json.dump(self, f, indent=4, cls=DataclassJSONEncoder)
if isinstance(self.obs, dict):
# add the screenshots back to the obs
# why do we need this?
if screenshot is not None:
self.obs["screenshot"] = screenshot
if screenshot_som is not None:
self.obs["screenshot_som"] = screenshot_som
@dataclass
class ExpArgs:
"""Arguments to run an experiment, i.e. run agent in an environment until done.
This dataclass is used to store experiments arguments. It contains
agent_args and env_args which follows the same principle. It contains helper
functions to prepare and run experiments.
Attributes:
-----------
agent_args: AbstractAgentArgs
The arguments to instantiate the agent.
env_args: EnvArgs
The arguments to instantiate the environment.
exp_dir: str
The directory where the experiment will be saved.
exp_name: str
The name of the experiment. If None, it will be generated from the
agent and environment names.
enable_debug: bool
If python is running in debug mode and `enable_debug` is True, errors
will be raised instead of only logged
error_msg: str
Error that occured while running the experiment (if any).
stack_trace: str
Stack trace of the error (if any).
order: int (internal)
The order of the experiment in the batch. It is used to keep track of
the original order of the experiments in case they are shuffled.
"""
agent_args: AbstractAgentArgs
env_args: EnvArgs
exp_dir: str = None
exp_name: str = None
enable_debug: bool = True
err_msg: str = None
stack_trace: str = None
order: int = None # use to keep the original order the experiments were meant to be launched.
logging_level: int = logging.INFO
logging_level_stdout: int = logging.INFO
exp_id: str = None
depends_on: tuple[str] = ()
save_screenshot: bool = True
save_som: bool = False
def make_id(self):
"""Create a unique id for the experiment."""
if self.exp_id is None:
self.exp_id = str(uuid.uuid4())
def prepare(self, exp_root):
"""Prepare the experiment directory and save the experiment arguments.
This enables inspecting experiments that are not run yet.
Args:
exp_root: str
The root directory where the experiment will be saved.
"""
if self.env_args.task_seed is None:
self.env_args.task_seed = np.random.randint(0, SEED_MAX)
if self.exp_name is None:
task_name = self.env_args.task_name
self.exp_name = f"{self.agent_args.agent_name}_on_{task_name}_{self.env_args.task_seed}"
# if exp_dir exists, it means it's a re-run, move the old one
if self.exp_dir is not None:
_move_old_exp(self.exp_dir)
self.make_id()
self.exp_date = datetime.now()
self._make_dir(exp_root)
self.exp_dir.mkdir(parents=True, exist_ok=True)
with open(self.exp_dir / "exp_args.pkl", "wb") as f:
pickle.dump(self, f)
def _make_dir(self, exp_root):
"""Create a unique directory for the experiment."""
date_str = self.exp_date.strftime("%Y-%m-%d_%H-%M-%S")
exp_str = re.sub(
r"[\/:*?<>|]", "_", self.exp_name
) # sanitize exp_name to be used as a file name (substitute forbidden characters)
for i in range(1000):
if i >= 999: # make sure we don't loop forever
raise ValueError("Could not find a unique name for the experiment directory.")
tag = f"_{i}" if i > 0 else ""
self.exp_dir = Path(exp_root) / f"{date_str}_{exp_str}{tag}"
if not self.exp_dir.exists():
break
# TODO distinguish between agent error and environment or system error. e.g.
# the parsing error of an action should not be re-run.
def run(self):
"""Run the experiment and save the results"""
# start writing logs to run logfile
self._set_logger()
# log python environment info
save_package_versions(Path(self.exp_dir))
episode_info = []
agent = None
env, step_info, err_msg, stack_trace = None, None, None, None
try:
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
agent = self.agent_args.make_agent()
if hasattr(agent, "set_task_name"):
agent.set_task_name(self.env_args.task_name)
logger.debug("Agent created.")
env = self.env_args.make_env(
action_mapping=agent.action_set.to_python_code,
exp_dir=self.exp_dir,
use_raw_page_output=getattr(self.agent_args, "use_raw_page_output", False),
)
logger.debug("Environment created.")
step_info = StepInfo(step=0)
episode_info = [step_info]
step_info.from_reset(
env, seed=self.env_args.task_seed or 0, obs_preprocessor=agent.obs_preprocessor
)
logger.debug("Environment reset.")
while not step_info.is_done: # set a limit
logger.debug(f"Starting step {step_info.step}.")
action = step_info.from_action(agent)
logger.debug(f"Agent chose action:\n {action}")
if action is None:
# will end the episode after saving the step info.
step_info.truncated = True
step_info.save_step_info(
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
)
logger.debug("Step info saved.")
if hasattr(env.unwrapped, "chat") and isinstance(env.unwrapped.chat, Chat):
_send_chat_info(env.unwrapped.chat, action, step_info.agent_info)
logger.debug("Chat info sent.")
if action is None:
logger.debug("Agent returned None action. Ending episode.")
break
step_info = StepInfo(step=step_info.step + 1)
episode_info.append(step_info)
logger.debug("Sending action to environment.")
step_info.from_step(env, action, obs_preprocessor=agent.obs_preprocessor)
logger.debug("Environment stepped.")
if step_info.is_done:
logger.debug(
f"Episode done: terminated: {step_info.terminated}, truncated: {step_info.truncated}."
)
except Exception as e:
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
stack_trace = traceback.format_exc()
self.err_msg = err_msg
self.stack_trace = stack_trace
logger.warning(err_msg + "\n" + stack_trace)
if _is_debugging() and self.enable_debug:
logger.warning("Debug mode is enabled. Raising the error.")
raise
finally:
try:
if step_info is not None:
step_info.save_step_info(
self.exp_dir, save_screenshot=self.save_screenshot, save_som=self.save_som
)
except Exception as e:
logger.error(f"Error while saving step info in the finally block: {e}")
try:
if (
not err_msg
and len(episode_info) > 0
and not (episode_info[-1].terminated or episode_info[-1].truncated)
):
e = KeyboardInterrupt("Early termination??")
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
logger.info("Saving experiment info.")
self.save_summary_info(episode_info, Path(self.exp_dir), err_msg, stack_trace)
if TapeAgent is not None and isinstance(agent, TapeAgent):
task = getattr(env, "task", {})
save_tape(self.exp_dir, episode_info, task, agent.final_tape)
except Exception as e:
logger.exception(f"Error while saving experiment info: {e}")
try:
if env is not None:
env.close()
except Exception as e:
logger.exception(f"Error while closing the environment: {e}")
try:
self._unset_logger() # stop writing logs to run logfile
except Exception as e:
logger.exception(f"Error while unsetting the logger: {e}")
def _set_logger(self):
# output logging traces to a log file
file_handler = logging.FileHandler(self.exp_dir / "experiment.log")
file_handler.setLevel(self.logging_level) # same level as console outputs
formatter = logging.Formatter(
"%(asctime)s - %(process)d - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
# output handler
stream_handler = logging.StreamHandler()
stream_handler.setLevel(self.logging_level_stdout)
stream_handler.setFormatter(formatter)
# setup root logger
root_logger = logging.getLogger()
# remove previous stream handlers
for handler in root_logger.handlers:
if isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
root_logger.setLevel(self.logging_level)
root_logger.addHandler(file_handler)
root_logger.addHandler(stream_handler)
# setup openai logger (don't go below INFO verbosity)
openai_logger = logging.getLogger("openai._base_client")
openai_logger.setLevel(max(logging.INFO, self.logging_level))
self.logging_file_handler = file_handler
def _unset_logger(self):
root_logger = logging.getLogger()
root_logger.removeHandler(self.logging_file_handler)
def save_summary_info(
self,
episode_info: list[StepInfo],
exp_dir: Path,
err_msg: str | None,
stack_trace: str | None,
):
# bring err from agent_info to the top level
if err_msg is None:
err_msg, stack_trace = _extract_err_msg(episode_info)
else:
# useful until we get a proper place in agent_xray to view error
# messages.
if len(episode_info) == 0:
episode_info.append(StepInfo())
episode_info[-1].agent_info["err_msg"] = err_msg
episode_info[-1].agent_info["stack_trace"] = stack_trace
summary_info = dict(
n_steps=len(episode_info) - 1,
cum_reward=sum([step.reward for step in episode_info]),
cum_raw_reward=sum([step.raw_reward for step in episode_info if step.raw_reward]),
err_msg=err_msg,
stack_trace=stack_trace,
)
for key, val in _aggregate_episode_stats(episode_info).items():
summary_info[f"stats.{key}"] = val
if len(episode_info) > 0:
summary_info["terminated"] = episode_info[-1].terminated
summary_info["truncated"] = episode_info[-1].truncated
with open(exp_dir / "summary_info.json", "w") as f:
json.dump(summary_info, f, indent=4)
def _extract_err_msg(episode_info: list[StepInfo]):
"""Extract the last error message from the episode info."""
errors = [(None, None)]
for step_info in episode_info:
if step_info.agent_info is None:
continue
err_msg = step_info.agent_info.get("err_msg", None)
if err_msg is not None:
errors.append((err_msg, step_info.agent_info.get("stack_trace", None)))
return errors[-1]
def _aggregate_episode_stats(episode_info: list[StepInfo]):
"""Aggregate StepInfo.stats across episodes.
It will compute the sum and max of each value in the stats dict.
These two summaries should cover many use cases. If more are needed, the
user can compute other stats by reloading individual StepInfo.
Args:
episode_info: list[StepInfo]
The list of StepInfo objects to aggregate.
Returns:
dict
A dictionary containing the aggregated stats.
"""
stats = defaultdict(list)
for step_info in episode_info:
if step_info.stats is not None:
for key, val in step_info.stats.items():
if val is None:
val = np.nan
stats[key].append(val)
aggregated_stats = {"cum_steps": len(episode_info)} # to be able to compute the mean
for key, val_list in stats.items():
aggregated_stats[f"cum_{key}"] = np.nansum(val_list)
aggregated_stats[f"max_{key}"] = np.nanmax(val_list)
for key, val in aggregated_stats.items():
if isinstance(val, np.generic):
aggregated_stats[key] = val.item()
if np.isnan(val):
aggregated_stats[key] = None
return aggregated_stats
def _is_debugging():
"""Tells you if your code is currently running in debug mode."""
return sys.gettrace() is not None
class ExpResult:
"""Helper class to load and visualize the results of an experiment.
attributes are loaded lazily.
Attributes (lazily loaded):
exp_args: ExpArgs, the arguments of the experiment.
steps_info: list[StepInfo], the information of each steps so far
summary_info: dict, the summary of the experiment.
screenshots: list[Image], the screenshots of each step.
screenshots_som: list[Image], the screenshots of each step with set of
marks inprinted.
flat_exp_args: dict, the flattened version of exp_args.
chat_video_path: Path, the path to the chat video. (if record_video=True)
task_video_path: Path, the path to the task video. (if record_video=True)
combined_video_path: Path, the path to the combined video. (if video was
combined)
"""
def __init__(self, exp_dir) -> None:
self.exp_dir = Path(exp_dir)
self._exp_args = None
self._steps_info = {}
self._summary_info = None
self._screenshots = {}
self._flat_exp_args = None
self._logs = None
@property
def exp_args(self) -> ExpArgs:
if self._exp_args is None:
with open(self.exp_dir / "exp_args.pkl", "rb") as f:
self._exp_args = pickle.load(f)
# in case experiments were moved
self._exp_args.exp_dir = self.exp_dir
return self._exp_args
def get_step_info(self, step: int) -> StepInfo:
"""Load the step info from the file and return it."""
if self._steps_info.get(step, None) is None:
with gzip.open(self.exp_dir / f"step_{step}.pkl.gz", "rb") as f:
self._steps_info[step] = pickle.load(f)
if self._steps_info[step].obs:
if "screenshot" not in self._steps_info[step].obs:
try:
self._steps_info[step].obs["screenshot"] = np.array(
self.get_screenshot(step), dtype=np.uint8
)
except FileNotFoundError:
pass
if "screenshot_som" not in self._steps_info[step].obs:
try:
self._steps_info[step].obs["screenshot_som"] = np.array(
self.get_screenshot(step, som=True), dtype=np.uint8
)
except FileNotFoundError:
pass
# if goal_object is set to None, it indicates it has been saved into a separate file
if (
"goal_object" in self._steps_info[step].obs
and self._steps_info[step].obs["goal_object"] is None
):
with gzip.open(self.exp_dir / "goal_object.pkl.gz", "rb") as f:
goal_object = pickle.load(f)
self._steps_info[step].obs["goal_object"] = goal_object
return self._steps_info[step]
@property
def steps_info(self) -> list[StepInfo]:
step_files = list(self.exp_dir.glob("step_*.pkl.gz"))
for file in step_files:
step = int(file.name.split("_")[-1].split(".")[0])
self.get_step_info(step)
return [self._steps_info[i] for i in range(len(self._steps_info))]
@property
def summary_info(self) -> dict:
if self._summary_info is None:
with open(self.exp_dir / "summary_info.json", "r") as f:
# if length is zero raise file not found error
if os.fstat(f.fileno()).st_size == 0:
raise FileNotFoundError("summary_info.json is empty.")
self._summary_info = json.load(f)
return self._summary_info
def get_screenshot(self, step: int, som=False) -> Image:
key = (step, som)
if self._screenshots.get(key, None) is None:
file_path = self.get_screenshot_path(step, som=som)
self._screenshots[key] = Image.open(file_path).convert("RGB")
return self._screenshots[key]
def get_screenshot_path(self, step: int, som=False) -> Path:
"""Return the path to the screenshot file."""
file_name = f"screenshot_{'som_' if som else ''}step_{step}"
for ext in [".png", ".jpg"]:
file_path = self.exp_dir / (file_name + ext)
if file_path.exists():
return file_path
raise FileNotFoundError(
f"No screenshot found for step {step} (som={som}) in {self.exp_dir}"
)
def get_screenshots(self, som=False):
files = list(self.exp_dir.glob("screenshot_step_*"))
max_step = 0
for file in files:
step = int(file.name.split("_")[-1].split(".")[0])
self.get_screenshot(step, som=som)
max_step = max(max_step, step)
return [self._screenshots.get((i, som), None) for i in range(max_step + 1)]
@property
def screenshots(self):
return self.get_screenshots(som=False)
@property
def screenshots_som(self):
return self.get_screenshots(som=True)
@property
def flat_exp_args(self) -> dict:
"""Return a dict with exp_args flattened."""
if self._flat_exp_args is None:
exp_args = asdict(self.exp_args)
# this will flatten nested dicts
self._flat_exp_args = _flatten_dict(exp_args)
return self._flat_exp_args
def get_exp_record(self) -> dict:
"""Return a dict with exp_args flattened and summary_info."""
record = {"exp_dir": self.exp_dir}
try:
record.update(self.flat_exp_args)
except FileNotFoundError:
pass
try:
record.update(self.summary_info)
except FileNotFoundError:
pass
return record
@property
def chat_video_path(self) -> Path:
try:
return next(self.exp_dir.glob("chat_video/*.webm"))
except StopIteration:
raise FileNotFoundError(f"No chat_video found in {self.exp_dir}")
@property
def task_video_path(self) -> Path:
try:
return next(self.exp_dir.glob("task_video/*.webm"))
except StopIteration:
raise FileNotFoundError(f"No task_video found in {self.exp_dir}")
@property
def combined_video_path(self) -> Path:
return self.exp_dir / "combined_video.mp4"
@property
def logs(self):
if self._logs is None:
self._logs = (self.exp_dir / "experiment.log").read_text()
return self._logs
@property
def status(self):
"""Possible values:
* "done": completed with no error
* "error": completed with error
* "incomplete": not completed yet (may be pending or just stalled)
Returns:
str: the status of the experiment. One of "done", "error", "incomplete".
"""
try:
summary_info = self.summary_info
except FileNotFoundError:
return "incomplete"
if summary_info.get("err_msg", None) is not None:
return "error"
if summary_info.get("terminated", False) or summary_info.get("truncated", False):
return "done"
return "incomplete"
EXP_RESULT_CACHE = {}
def get_exp_result(exp_dir) -> ExpResult:
"""Keep a cache of pre-loaded exp_results for faster loading"""
exp_dir = str(exp_dir) # make sure it's not a Path
exp_result = EXP_RESULT_CACHE.get(exp_dir, None)
if exp_result is None:
exp_result = ExpResult(exp_dir)
EXP_RESULT_CACHE[exp_dir] = exp_result
return exp_result
def yield_all_exp_results(
savedir_base: str | Path, progress_fn=tqdm, load_hidden=False, use_cache=True
):
"""Recursively find all experiments from savedir_base folder.
This will ignore all experiments that start with "_" or ".". use
`load_hidden=True` to load them anyway.
Args:
savedir_base: str or Path
The base directory where the experiments are saved.
progress_fn: function
A function to show progress. Defaults to tqdm.
load_hidden: bool
If True, load hidden experiments (those starting with "_" or ".").
use_cache: bool
If True, use the cache of pre-loaded exp_results.
Yields:
ExpResult
An instance of ExpResult for each experiment found.
"""
if not isinstance(savedir_base, list):
savedir_base = [savedir_base]
exp_args_paths = []
for exp_dir in savedir_base:
exp_args_paths.extend(list(Path(exp_dir).glob("**/exp_args.pkl")))
if progress_fn is not None:
exp_args_paths = progress_fn(exp_args_paths, desc="Searching experiments directories.")
for exp_args_path in exp_args_paths:
exp_dir = exp_args_path.parent
if not load_hidden:
if exp_dir.name.startswith("_") or exp_dir.name.startswith("."):
continue
if use_cache:
yield get_exp_result(exp_dir)
else:
yield ExpResult(exp_dir)
class DataclassJSONEncoder(json.JSONEncoder):
def default(self, obj):
if is_dataclass(obj):
return asdict(obj)
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return super().default(obj)
def _move_old_exp(exp_dir):
"""Move the old experiment directory to a new name."""
exp_dir = Path(exp_dir)
if exp_dir.exists():
exp_dir.rename(exp_dir.with_name("_" + exp_dir.name))
def _get_env_name(task_name: str):
"""Register tasks if needed (lazy import) and return environment name."""
# lazy import
if task_name.startswith("miniwob"):
import browsergym.miniwob
elif task_name.startswith("workarena"):
import browsergym.workarena
elif task_name.startswith("webarena"):
import browsergym.webarena
import browsergym.webarenalite
try:
import browsergym.webarena_verified
except ImportError:
logger.warning("browsergym.webarena_verified not found. Skipping import.")
elif task_name.startswith("visualwebarena"):
import browsergym.visualwebarena
elif task_name.startswith("assistantbench"):
import browsergym.assistantbench
elif task_name.startswith("weblinx"):
import weblinx_browsergym
return f"browsergym/{task_name}"
def _send_chat_info(chat: Chat, action: str, agent_info: dict):
"""Send the think and action info to the chat."""
msg = ""
if "think" in agent_info:
msg += f"""\
{agent_info["think"]}
"""
msg += f"""\
action:
{action}
"""
logger.info(msg)
chat.add_message(role="info", msg=msg)
def _flatten_dict(d, parent_key="", sep="."):
"""Recursively flatten a nested dictionary."""
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(_flatten_dict(v, new_key, sep).items())
else:
items.append((new_key, v))
return dict(items)