11import os
2+ import time
23from typing import *
4+
5+ import matplotlib .pyplot as plt
36import numpy as np
7+ import torch
8+ import torch .nn as nn
9+ import torch .optim as optim
10+ import torch .multiprocessing as mp
11+
12+ from stable_baselines3 import PPO
13+ from stable_baselines3 .common .env_checker import check_env
14+ from stable_baselines3 .common .vec_env import SubprocVecEnv , DummyVecEnv
15+
416import gymnasium as gym
517
18+ from onnx_utils import export_onnx , test_onnx
619from config import *
20+ from CNN1DExtractor import CNN1DExtractor
21+ from TemporalResNetExtractor import TemporalResNetExtractor
22+ from CNN1DResNetExtractor import CNN1DResNetExtractor
23+
24+ if B_DEBUG : from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
725
826
927def log (s : str ):
@@ -90,4 +108,112 @@ def step(self, action):
90108 # check if the context is correct
91109 # if self.simulation_rank == 0:
92110 # print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
93- return obs , reward , done , truncated , info
111+ return obs , reward , done , truncated , info
112+
113+
114+ if __name__ == "__main__" :
115+ if not os .path .exists ("/tmp/autotech/" ):
116+ os .mkdir ("/tmp/autotech/" )
117+
118+ os .system ('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi' )
119+ if B_DEBUG :
120+ print ("Webots started" , file = open ("/tmp/autotech/logs" , "w" ))
121+
122+ def make_env (rank : int ):
123+ log (f"CAREFUL !!! created an SERVER env with { rank = } " )
124+ return WebotsSimulationGymEnvironment (rank )
125+
126+ envs = SubprocVecEnv ([lambda rank = rank : make_env (rank ) for rank in range (n_simulations )])
127+
128+ ExtractorClass = TemporalResNetExtractor
129+
130+ policy_kwargs = dict (
131+ features_extractor_class = ExtractorClass ,
132+ features_extractor_kwargs = dict (
133+ context_size = context_size ,
134+ lidar_horizontal_resolution = lidar_horizontal_resolution ,
135+ camera_horizontal_resolution = camera_horizontal_resolution ,
136+ device = device
137+ ),
138+ activation_fn = nn .ReLU ,
139+ net_arch = [512 , 512 , 512 ],
140+ )
141+
142+
143+ ppo_args = dict (
144+ n_steps = 4096 ,
145+ n_epochs = 10 ,
146+ batch_size = 256 ,
147+ learning_rate = 3e-4 ,
148+ gamma = 0.99 ,
149+ verbose = 1 ,
150+ normalize_advantage = True ,
151+ device = device
152+ )
153+
154+
155+ save_path = __file__ .rsplit ("/" , 1 )[0 ] + "/checkpoints/" + ExtractorClass .__name__ + "/"
156+ if not os .path .exists (save_path ):
157+ os .mkdir (save_path )
158+
159+ print (save_path )
160+ print (os .listdir (save_path ))
161+
162+ valid_files = [x for x in os .listdir (save_path ) if x .rstrip (".zip" ).isnumeric ()]
163+
164+ if valid_files :
165+ model_name = max (
166+ valid_files ,
167+ key = lambda x : int (x .rstrip (".zip" ))
168+ )
169+ print (f"Loading model { save_path + model_name } " )
170+ model = PPO .load (
171+ save_path + model_name ,
172+ envs ,
173+ ** ppo_args ,
174+ policy_kwargs = policy_kwargs
175+ )
176+ i = int (model_name .rstrip (".zip" )) + 1
177+ print (f"----- Model found, loading { model_name } -----" )
178+
179+ else :
180+ model = PPO (
181+ "MlpPolicy" ,
182+ envs ,
183+ ** ppo_args ,
184+ policy_kwargs = policy_kwargs
185+ )
186+
187+ i = 0
188+ print ("----- Model not found, creating a new one -----" )
189+
190+ print ("MODEL HAS HYPER PARAMETERS:" )
191+ print (f"{ model .learning_rate = } " )
192+ print (f"{ model .gamma = } " )
193+ print (f"{ model .verbose = } " )
194+ print (f"{ model .n_steps = } " )
195+ print (f"{ model .n_epochs = } " )
196+ print (f"{ model .batch_size = } " )
197+ print (f"{ model .device = } " )
198+
199+ log (f"SERVER : finished executing" )
200+
201+ # obs = envs.reset()
202+ # while True:
203+ # action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
204+ # obs, reward, done, info = envs.step(action)
205+ # envs.render() # Optional: visualize the environment
206+
207+
208+ while True :
209+ export_onnx (model )
210+ test_onnx (model )
211+
212+ if B_DEBUG :
213+ model .learn (total_timesteps = 500_000 , callback = DynamicActionPlotDistributionCallback ())
214+ else :
215+ model .learn (total_timesteps = 500_000 )
216+
217+ model .save (save_path + str (i ))
218+
219+ i += 1
0 commit comments