11import os
2- import time
32from typing import *
4-
5- import matplotlib .pyplot as plt
63import numpy as np
7- import torch
8- import torch .nn as nn
9- import torch .optim as optim
10- import torch .multiprocessing as mp
11-
12- from stable_baselines3 import PPO
13- from stable_baselines3 .common .env_checker import check_env
14- from stable_baselines3 .common .vec_env import SubprocVecEnv , DummyVecEnv
15-
164import gymnasium as gym
175
18- from onnx_utils import export_onnx , test_onnx
196from config import *
20- from CNN1DExtractor import CNN1DExtractor
21- from TemporalResNetExtractor import TemporalResNetExtractor
22- from CNN1DResNetExtractor import CNN1DResNetExtractor
23-
24- if B_DEBUG : from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
257
268
279def log (s : str ):
@@ -108,112 +90,4 @@ def step(self, action):
10890 # check if the context is correct
10991 # if self.simulation_rank == 0:
11092 # print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
111- return obs , reward , done , truncated , info
112-
113-
114- if __name__ == "__main__" :
115- if not os .path .exists ("/tmp/autotech/" ):
116- os .mkdir ("/tmp/autotech/" )
117-
118- os .system ('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi' )
119- if B_DEBUG :
120- print ("Webots started" , file = open ("/tmp/autotech/logs" , "w" ))
121-
122- def make_env (rank : int ):
123- log (f"CAREFUL !!! created an SERVER env with { rank = } " )
124- return WebotsSimulationGymEnvironment (rank )
125-
126- envs = SubprocVecEnv ([lambda rank = rank : make_env (rank ) for rank in range (n_simulations )])
127-
128- ExtractorClass = TemporalResNetExtractor
129-
130- policy_kwargs = dict (
131- features_extractor_class = ExtractorClass ,
132- features_extractor_kwargs = dict (
133- context_size = context_size ,
134- lidar_horizontal_resolution = lidar_horizontal_resolution ,
135- camera_horizontal_resolution = camera_horizontal_resolution ,
136- device = device
137- ),
138- activation_fn = nn .ReLU ,
139- net_arch = [512 , 512 , 512 ],
140- )
141-
142-
143- ppo_args = dict (
144- n_steps = 4096 ,
145- n_epochs = 10 ,
146- batch_size = 256 ,
147- learning_rate = 3e-4 ,
148- gamma = 0.99 ,
149- verbose = 1 ,
150- normalize_advantage = True ,
151- device = device
152- )
153-
154-
155- save_path = __file__ .rsplit ("/" , 1 )[0 ] + "/checkpoints/" + ExtractorClass .__name__ + "/"
156- if not os .path .exists (save_path ):
157- os .mkdir (save_path )
158-
159- print (save_path )
160- print (os .listdir (save_path ))
161-
162- valid_files = [x for x in os .listdir (save_path ) if x .rstrip (".zip" ).isnumeric ()]
163-
164- if valid_files :
165- model_name = max (
166- valid_files ,
167- key = lambda x : int (x .rstrip (".zip" ))
168- )
169- print (f"Loading model { save_path + model_name } " )
170- model = PPO .load (
171- save_path + model_name ,
172- envs ,
173- ** ppo_args ,
174- policy_kwargs = policy_kwargs
175- )
176- i = int (model_name .rstrip (".zip" )) + 1
177- print (f"----- Model found, loading { model_name } -----" )
178-
179- else :
180- model = PPO (
181- "MlpPolicy" ,
182- envs ,
183- ** ppo_args ,
184- policy_kwargs = policy_kwargs
185- )
186-
187- i = 0
188- print ("----- Model not found, creating a new one -----" )
189-
190- print ("MODEL HAS HYPER PARAMETERS:" )
191- print (f"{ model .learning_rate = } " )
192- print (f"{ model .gamma = } " )
193- print (f"{ model .verbose = } " )
194- print (f"{ model .n_steps = } " )
195- print (f"{ model .n_epochs = } " )
196- print (f"{ model .batch_size = } " )
197- print (f"{ model .device = } " )
198-
199- log (f"SERVER : finished executing" )
200-
201- # obs = envs.reset()
202- # while True:
203- # action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
204- # obs, reward, done, info = envs.step(action)
205- # envs.render() # Optional: visualize the environment
206-
207-
208- while True :
209- export_onnx (model )
210- test_onnx (model )
211-
212- if B_DEBUG :
213- model .learn (total_timesteps = 500_000 , callback = DynamicActionPlotDistributionCallback ())
214- else :
215- model .learn (total_timesteps = 500_000 )
216-
217- model .save (save_path + str (i ))
218-
219- i += 1
93+ return obs , reward , done , truncated , info
0 commit comments