@@ -128,8 +128,73 @@ def quick_state_hash(port, paths):
128128 return hashlib .md5 (listing .encode ()).hexdigest ()
129129
130130def recorder_handler (method : Callable [["MobileSafetyEnv" , str ], "MobileSafetyTimeStep" ]):
131- def recorder_wrapper (self : "MobileSafetyEnv" , action : str = None ):
132- return method (self , action )
131+ def recorder_wrapper (self : "MobileSafetyEnv" , action : str = "" ):
132+ images_dir = os .path .join (self .traj_dir , "images" )
133+ jsons_dir = os .path .join (self .traj_dir , "jsons" )
134+ objects_dir = os .path .join (self .traj_dir , "objects" )
135+
136+ os .makedirs (self .traj_dir , exist_ok = True )
137+ for d in (images_dir , jsons_dir , objects_dir ):
138+ os .makedirs (d , exist_ok = True )
139+
140+ assert len (action ) > 0
141+ func_name = action .split ("(" )[0 ]
142+
143+ step_index = len (os .listdir (images_dir )) + 1
144+ format_index = f"{ step_index :03d} "
145+ ts = str (int (time .time () * 1000 ))
146+
147+ # save image
148+ image_name = f"{ format_index } _{ func_name } _{ ts } .png"
149+ image_path = os .path .join (images_dir , image_name )
150+
151+ with open (image_path , "wb" ) as f :
152+ f .write (self .driver .get_screenshot_as_png ())
153+
154+ # save json
155+ json_name = f"{ format_index } _{ func_name } _{ ts } .json"
156+ json_path = os .path .join (jsons_dir , json_name )
157+
158+ check_list = ["/system" , "/vendor" , "/data" ]
159+ state_hash = quick_state_hash (self .port , check_list )
160+ texts = screen_text (self )
161+
162+ with open (json_path , "w" , encoding = "utf-8" ) as f :
163+ json .dump ({
164+ "fs" : {"hash" : state_hash },
165+ "screen" : {"text" : texts }
166+ }, f , ensure_ascii = False , indent = 2 )
167+
168+ objects_name = f"objects_{ format_index } _{ func_name } _{ ts } .json"
169+ objects_path = os .path .join (objects_dir , objects_name )
170+
171+ objects_payload = json .dumps (self .parsed_obs , ensure_ascii = False )
172+ with open (objects_path , "w" , encoding = "utf-8" ) as f :
173+ f .write (objects_payload )
174+
175+ traj_path = os .path .join (self .traj_dir , "trajectory.json" )
176+ rel_image = os .path .relpath (image_path , start = self .traj_dir ).replace ("\\ " , "/" )
177+ rel_objects = os .path .relpath (objects_path , start = self .traj_dir ).replace ("\\ " , "/" )
178+
179+ traj_obj = {
180+ "query" : "N/A" ,
181+ "application" : "oss" ,
182+ "platform_type" : "Android" ,
183+ "trajectory" : []
184+ }
185+ if os .path .exists (traj_path ):
186+ traj_obj = json .load (open (traj_path , "r" , encoding = "utf-8" ))
187+
188+ traj_obj ["trajectory" ].append ({
189+ "action" : action ,
190+ "observation" : rel_image ,
191+ "objects" : rel_objects
192+ })
193+
194+ with open (traj_path , "w" , encoding = "utf-8" ) as f :
195+ json .dump (self .traj_dir , f , ensure_ascii = False , indent = 2 )
196+ return None if func_name == "timeout" else method (self , action )
197+
133198 return recorder_wrapper
134199
135200def checker_handler (method : Callable [["MobileSafetyEnv" , str ], "MobileSafetyTimeStep" ]):
@@ -221,13 +286,15 @@ def __init__(
221286 task_tag : str = "" ,
222287 is_emu_already_open : bool = False ,
223288 prompt_mode : str = "" ,
289+ traj_dir : str = ""
224290 ):
225291
226292 self .avd_name = avd_name
227293 self .avd_name_sub = avd_name_sub
228294 self .gui = gui
229295 self .delay = delay
230296 self .prompt_mode = prompt_mode
297+ self .traj_dir = traj_dir
231298
232299 # appium
233300 if appium_port :
0 commit comments