Skip to content

Commit 30bc57b

Browse files
committed
update
1 parent a7b9999 commit 30bc57b

6 files changed

Lines changed: 75 additions & 97 deletions

File tree

src/agentlab/agents/vl_agent/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
"ui_agent": UIAgentArgs(
4242
main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"],
4343
auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"],
44-
action_set_args=HighLevelActionSetArgs(subsets=["coord"]),
4544
ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"],
45+
action_set_args=HighLevelActionSetArgs(subsets=["coord"]),
4646
max_retry=4,
4747
)
4848
}

src/agentlab/agents/vl_agent/vl_agent/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from abc import ABC, abstractmethod
2+
from browsergym.core.action.highlevel import HighLevelActionSet
23
from browsergym.experiments.benchmark import Benchmark
34
from dataclasses import dataclass
45

56

67
class VLAgent(ABC):
8+
@property
9+
@abstractmethod
10+
def action_set(self) -> HighLevelActionSet:
11+
raise NotImplementedError
12+
713
@abstractmethod
814
def get_action(self, obs: dict) -> tuple[str, dict]:
915
raise NotImplementedError

src/agentlab/agents/vl_agent/vl_agent/ui_agent.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from agentlab.llm.llm_utils import ParseError, retry
22
from agentlab.llm.tracking import cost_tracker_decorator
3+
from browsergym.core.action.highlevel import HighLevelActionSet
34
from browsergym.experiments.agent import AgentInfo
45
from browsergym.experiments.benchmark import Benchmark
56
from browsergym.experiments.benchmark.base import HighLevelActionSetArgs
67
from browsergym.utils.obs import overlay_som
78
from copy import copy, deepcopy
89
from dataclasses import asdict, dataclass
10+
from functools import cache
911
from typing import Optional
1012
from .base import VLAgent, VLAgentArgs
1113
from ..vl_model.base import VLModelArgs
@@ -17,21 +19,26 @@ def __init__(
1719
self,
1820
main_vl_model_args: VLModelArgs,
1921
auxiliary_vl_model_args: Optional[VLModelArgs],
20-
action_set_args: HighLevelActionSetArgs,
2122
ui_prompt_args: UIPromptArgs,
23+
action_set_args: HighLevelActionSetArgs,
2224
max_retry: int,
2325
):
2426
self.main_vl_model = main_vl_model_args.make_model()
2527
if auxiliary_vl_model_args is None:
2628
self.auxiliary_vl_model = None
2729
else:
2830
self.auxiliary_vl_model = auxiliary_vl_model_args.make_model()
29-
self.action_set = action_set_args.make_action_set()
3031
self.ui_prompt_args = ui_prompt_args
32+
self.action_set_args = action_set_args
3133
self.max_retry = max_retry
3234
self.thoughts = []
3335
self.actions = []
3436

37+
@property
38+
@cache
39+
def action_set(self) -> HighLevelActionSet:
40+
return self.action_set_args.make_action_set()
41+
3542
@cost_tracker_decorator
3643
def get_action(self, obs: dict) -> tuple[str, dict]:
3744
ui_prompt = self.ui_prompt_args.make_prompt(
@@ -94,25 +101,27 @@ def obs_preprocessor(self, obs: dict) -> dict:
94101
class UIAgentArgs(VLAgentArgs):
95102
main_vl_model_args: VLModelArgs
96103
auxiliary_vl_model_args: Optional[VLModelArgs]
97-
action_set_args: HighLevelActionSetArgs
98104
ui_prompt_args: UIPromptArgs
105+
action_set_args: HighLevelActionSetArgs
99106
max_retry: int
100107

101108
@property
109+
@cache
102110
def agent_name(self) -> str:
103111
if self.auxiliary_vl_model_args is None:
104112
return f"UIAgent-{self.main_vl_model_args.model_name}"
105113
else:
106114
return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}"
107115

108116
def make_agent(self) -> UIAgent:
109-
return UIAgent(
117+
self.ui_agent = UIAgent(
110118
main_vl_model_args=self.main_vl_model_args,
111119
auxiliary_vl_model_args=self.auxiliary_vl_model_args,
112-
action_set_args=self.action_set_args,
113120
ui_prompt_args=self.ui_prompt_args,
121+
action_set_args=self.action_set_args,
114122
max_retry=self.max_retry,
115123
)
124+
return self.ui_agent
116125

117126
def prepare(self):
118127
self.main_vl_model_args.prepare()

src/agentlab/agents/vl_agent/vl_model/llama_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from accelerate.utils.modeling import load_checkpoint_in_model
33
from agentlab.llm.llm_utils import AIMessage, Discussion
44
from dataclasses import dataclass
5+
from functools import cache
56
from transformers import AutoProcessor, MllamaForConditionalGeneration
67
from typing import Optional
78
from .base import VLModel, VLModelArgs
@@ -84,6 +85,7 @@ class LlamaModelArgs(VLModelArgs):
8485
device: Optional[str]
8586

8687
@property
88+
@cache
8789
def model_name(self) -> str:
8890
return self.model_path.split("/")[-1].replace("-", "_").replace(".", "")
8991

@@ -113,13 +115,14 @@ def make_model(self) -> LlamaModel:
113115
else:
114116
llama_model.model = llama_model.model.to(self.device)
115117
llama_model.model.eval()
116-
return llama_model
118+
self.llama_model = llama_model
119+
return self.llama_model
117120

118121
def prepare(self):
119122
pass
120123

121124
def close(self):
122-
pass
125+
del self.llama_model.model
123126

124127
def set_reproducibility_mode(self):
125128
self.reproducibility_config = {"do_sample": False}

src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from agentlab.llm.llm_utils import AIMessage, Discussion
22
from dataclasses import dataclass
3+
from functools import cache
34
from openai import OpenAI, RateLimitError
45
from .base import VLModel, VLModelArgs
56
import backoff
@@ -46,16 +47,18 @@ class OpenRouterAPIModelArgs(VLModelArgs):
4647
reproducibility_config: dict
4748

4849
@property
50+
@cache
4951
def model_name(self) -> str:
5052
return self.model_id.split("/")[-1].replace("-", "_").replace(".", "")
5153

5254
def make_model(self) -> OpenRouterAPIModel:
53-
return OpenRouterAPIModel(
55+
self.openrouter_api_model = OpenRouterAPIModel(
5456
base_url=self.base_url,
5557
model_id=self.model_id,
5658
max_tokens=self.max_tokens,
5759
reproducibility_config=self.reproducibility_config,
5860
)
61+
return self.openrouter_api_model
5962

6063
def prepare(self):
6164
pass

0 commit comments

Comments
 (0)