Skip to content

Commit 294e526

Browse files
committed
update
1 parent 7d19c07 commit 294e526

3 files changed

Lines changed: 39 additions & 10 deletions

File tree

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT
2+
from agentlab.experiments.study import Study
3+
import logging
4+
import os
5+
6+
7+
logging.getLogger().setLevel(logging.INFO)
8+
9+
vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]]
10+
benchmark = "miniwob"
11+
os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/"
12+
reproducibility_mode = False
13+
relaunch = False
14+
n_jobs = 1
15+
16+
17+
if __name__ == "__main__":
18+
if reproducibility_mode:
19+
for vl_agent_args in vl_agent_args_list:
20+
vl_agent_args.set_reproducibility_mode()
21+
if relaunch:
22+
study = Study.load_most_recent(contains=None)
23+
study.find_incomplete(include_errors=True)
24+
else:
25+
study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING)
26+
study.run(
27+
n_jobs=n_jobs,
28+
parallel_backend="sequential",
29+
strict_reproducibility=reproducibility_mode,
30+
n_relaunch=3,
31+
)
32+
if reproducibility_mode:
33+
study.append_to_journal(strict_reproducibility=True)

src/agentlab/agents/vl_agent/vl_model/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from abc import ABC, abstractmethod
22
from agentlab.llm.llm_utils import AIMessage, Discussion
3-
from torch.nn import Module
43

54

6-
class VLModel(ABC, Module):
5+
class VLModel(ABC):
76
@abstractmethod
87
def __call__(self, messages: Discussion) -> AIMessage:
98
raise NotImplementedError

src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from agentlab.llm.llm_utils import AIMessage, Discussion
22
from dataclasses import dataclass
3-
from openai import AsyncOpenAI, RateLimitError
3+
from openai import OpenAI, RateLimitError
44
from .base import VLModel, VLModelArgs
5-
import asyncio
65
import backoff
76
import os
87

@@ -15,15 +14,15 @@ def __init__(
1514
max_tokens: int,
1615
reproducibility_config: dict,
1716
):
18-
self.client = AsyncOpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY"))
17+
self.client = OpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY"))
1918
self.model_id = model_id
2019
self.max_tokens = max_tokens
2120
self.reproducibility_config = reproducibility_config
2221

2322
def __call__(self, messages: Discussion) -> AIMessage:
2423
@backoff.on_exception(backoff.expo, RateLimitError)
25-
async def get_response(messages, max_tokens, **kwargs):
26-
completion = await self.client.chat.completions.create(
24+
def get_response(messages, max_tokens, **kwargs):
25+
completion = self.client.chat.completions.create(
2726
model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs
2827
)
2928
try:
@@ -32,9 +31,7 @@ async def get_response(messages, max_tokens, **kwargs):
3231
response = ""
3332
return response
3433

35-
response = asyncio.run(
36-
get_response(messages, self.max_tokens, **self.reproducibility_config)
37-
)
34+
response = get_response(messages, self.max_tokens, **self.reproducibility_config)
3835
return AIMessage([{"type": "text", "text": response}])
3936

4037
def get_stats(self) -> dict:

0 commit comments

Comments
 (0)