Skip to content

Commit e470cd9

Browse files
committed
static stage for recorder
1 parent 10f6de3 commit e470cd9

2 files changed

Lines changed: 73 additions & 3 deletions

File tree

mobile_safety/environment.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,73 @@ def quick_state_hash(port, paths):
128128
return hashlib.md5(listing.encode()).hexdigest()
129129

130130
def recorder_handler(method: Callable[["MobileSafetyEnv", str], "MobileSafetyTimeStep"]):
131-
def recorder_wrapper(self: "MobileSafetyEnv", action: str = None):
132-
return method(self, action)
131+
def recorder_wrapper(self: "MobileSafetyEnv", action: str = ""):
132+
images_dir = os.path.join(self.traj_dir, "images")
133+
jsons_dir = os.path.join(self.traj_dir, "jsons")
134+
objects_dir = os.path.join(self.traj_dir, "objects")
135+
136+
os.makedirs(self.traj_dir, exist_ok=True)
137+
for d in (images_dir, jsons_dir, objects_dir):
138+
os.makedirs(d, exist_ok=True)
139+
140+
assert len(action) > 0
141+
func_name = action.split("(")[0]
142+
143+
step_index = len(os.listdir(images_dir)) + 1
144+
format_index = f"{step_index:03d}"
145+
ts = str(int(time.time() * 1000))
146+
147+
# save image
148+
image_name = f"{format_index}_{func_name}_{ts}.png"
149+
image_path = os.path.join(images_dir, image_name)
150+
151+
with open(image_path, "wb") as f:
152+
f.write(self.driver.get_screenshot_as_png())
153+
154+
# save json
155+
json_name = f"{format_index}_{func_name}_{ts}.json"
156+
json_path = os.path.join(jsons_dir, json_name)
157+
158+
check_list = ["/system", "/vendor", "/data"]
159+
state_hash = quick_state_hash(self.port, check_list)
160+
texts = screen_text(self)
161+
162+
with open(json_path, "w", encoding="utf-8") as f:
163+
json.dump({
164+
"fs": {"hash": state_hash},
165+
"screen": {"text": texts}
166+
}, f, ensure_ascii=False, indent=2)
167+
168+
objects_name = f"objects_{format_index}_{func_name}_{ts}.json"
169+
objects_path = os.path.join(objects_dir, objects_name)
170+
171+
objects_payload = json.dumps(self.parsed_obs, ensure_ascii=False)
172+
with open(objects_path, "w", encoding="utf-8") as f:
173+
f.write(objects_payload)
174+
175+
traj_path = os.path.join(self.traj_dir, "trajectory.json")
176+
rel_image = os.path.relpath(image_path, start=self.traj_dir).replace("\\", "/")
177+
rel_objects = os.path.relpath(objects_path, start=self.traj_dir).replace("\\", "/")
178+
179+
traj_obj = {
180+
"query": "N/A",
181+
"application": "oss",
182+
"platform_type": "Android",
183+
"trajectory": []
184+
}
185+
if os.path.exists(traj_path):
186+
traj_obj = json.load(open(traj_path, "r", encoding="utf-8"))
187+
188+
traj_obj["trajectory"].append({
189+
"action": action,
190+
"observation": rel_image,
191+
"objects": rel_objects
192+
})
193+
194+
with open(traj_path, "w", encoding="utf-8") as f:
195+
json.dump(self.traj_dir, f, ensure_ascii=False, indent=2)
196+
return None if func_name == "timeout" else method(self, action)
197+
133198
return recorder_wrapper
134199

135200
def checker_handler(method: Callable[["MobileSafetyEnv", str], "MobileSafetyTimeStep"]):
@@ -221,13 +286,15 @@ def __init__(
221286
task_tag: str = "",
222287
is_emu_already_open: bool = False,
223288
prompt_mode: str = "",
289+
traj_dir: str = ""
224290
):
225291

226292
self.avd_name = avd_name
227293
self.avd_name_sub = avd_name_sub
228294
self.gui = gui
229295
self.delay = delay
230296
self.prompt_mode = prompt_mode
297+
self.traj_dir = traj_dir
231298

232299
# appium
233300
if appium_port:

msb.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser.add_argument("--port", type=int, default=5554)
1414
parser.add_argument("--appium_port", type=int, default=4723)
1515

16+
parser.add_argument("--traj_dir", type=str, default="data")
1617
parser.add_argument("--task_id", type=str, default="writing_memo")
1718
parser.add_argument("--scenario_id", type=str, default="high_risk_2")
1819
parser.add_argument("--prompt_mode", type=str, default="basic", choices=["basic", "safety_guided", "scot"])
@@ -32,6 +33,7 @@
3233
prompt_mode=args.prompt_mode,
3334
port=args.port,
3435
appium_port=args.appium_port,
36+
traj_dir=args.traj_dir
3537
)
3638

3739
logger = Logger(args)
@@ -68,7 +70,7 @@
6870
print("Error in response")
6971

7072
action = response_dict["action"]
71-
timestep_new, in_danger = env.step(action)
73+
timestep_new, in_danger = env.record(action)
7274
if timestep_new is None:
7375
continue
7476
timestep = timestep_new
@@ -84,4 +86,5 @@
8486
if timestep.last() or env.evaluator.progress["finished"]:
8587
break
8688

89+
env.record("terminate()")
8790
print("\n\nReward:", timestep_new.curr_rew)

0 commit comments

Comments
 (0)