|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import sys |
| 4 | + |
| 5 | +from typing import * |
| 6 | + |
| 7 | +import torch.nn as nn |
| 8 | + |
| 9 | +from stable_baselines3 import PPO |
| 10 | +from stable_baselines3.common.vec_env import SubprocVecEnv |
| 11 | + |
| 12 | +simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur' |
| 13 | +if simu_path not in sys.path: |
| 14 | + sys.path.insert(0, simu_path) |
| 15 | + |
| 16 | +from Simulateur.config import LOG_LEVEL |
| 17 | +from config import * |
| 18 | +from TemporalResNetExtractor import TemporalResNetExtractor |
| 19 | +from CNN1DResNetExtractor import CNN1DResNetExtractor |
| 20 | +from onnx_utils import * |
| 21 | + |
| 22 | +from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment |
| 23 | +if LOG_LEVEL == logging.DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback |
| 24 | + |
| 25 | + |
| 26 | +if __name__ == "__main__": |
| 27 | + |
| 28 | + if not os.path.exists("/tmp/autotech/"): |
| 29 | + os.mkdir("/tmp/autotech/") |
| 30 | + |
| 31 | + os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi') |
| 32 | + |
| 33 | + |
| 34 | + def make_env(simulation_rank: int, vehicle_rank: int): |
| 35 | + if LOG_LEVEL == logging.DEBUG: |
| 36 | + print("CAREFUL !!! created an SERVER env with {simulation_rank}_{vehicle_rank}") |
| 37 | + return WebotsSimulationGymEnvironment(simulation_rank, vehicle_rank) |
| 38 | + |
| 39 | + envs = SubprocVecEnv([lambda simulation_rank=simulation_rank, vehicle_rank=vehicle_rank : make_env(simulation_rank, vehicle_rank) for vehicle_rank in range(n_vehicles) for simulation_rank in range(n_simulations)]) |
| 40 | + |
| 41 | + ExtractorClass = CNN1DResNetExtractor |
| 42 | + |
| 43 | + policy_kwargs = dict( |
| 44 | + features_extractor_class=ExtractorClass, |
| 45 | + features_extractor_kwargs=dict( |
| 46 | + context_size=context_size, |
| 47 | + lidar_horizontal_resolution=lidar_horizontal_resolution, |
| 48 | + camera_horizontal_resolution=camera_horizontal_resolution, |
| 49 | + device=device |
| 50 | + ), |
| 51 | + activation_fn=nn.ReLU, |
| 52 | + net_arch=[512, 512, 512], |
| 53 | + ) |
| 54 | + |
| 55 | + |
| 56 | + ppo_args = dict( |
| 57 | + n_steps=4096, |
| 58 | + n_epochs=10, |
| 59 | + batch_size=256, |
| 60 | + learning_rate=3e-4, |
| 61 | + gamma=0.99, |
| 62 | + verbose=1, |
| 63 | + normalize_advantage=True, |
| 64 | + device=device |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | + save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/" |
| 69 | + os.makedirs(save_path, exist_ok=True) |
| 70 | + |
| 71 | + |
| 72 | + print(save_path) |
| 73 | + print(os.listdir(save_path)) |
| 74 | + |
| 75 | + valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()] |
| 76 | + |
| 77 | + if valid_files: |
| 78 | + model_name = max( |
| 79 | + valid_files, |
| 80 | + key=lambda x : int(x.rstrip(".zip")) |
| 81 | + ) |
| 82 | + print(f"Loading model {save_path + model_name}") |
| 83 | + model = PPO.load( |
| 84 | + save_path + model_name, |
| 85 | + envs, |
| 86 | + **ppo_args, |
| 87 | + policy_kwargs=policy_kwargs |
| 88 | + ) |
| 89 | + i = int(model_name.rstrip(".zip")) + 1 |
| 90 | + print(f"----- Model found, loading {model_name} -----") |
| 91 | + |
| 92 | + else: |
| 93 | + model = PPO( |
| 94 | + "MlpPolicy", |
| 95 | + envs, |
| 96 | + **ppo_args, |
| 97 | + policy_kwargs=policy_kwargs |
| 98 | + ) |
| 99 | + |
| 100 | + i = 0 |
| 101 | + print("----- Model not found, creating a new one -----") |
| 102 | + |
| 103 | + print("MODEL HAS HYPER PARAMETERS:") |
| 104 | + print(f"{model.learning_rate=}") |
| 105 | + print(f"{model.gamma=}") |
| 106 | + print(f"{model.verbose=}") |
| 107 | + print(f"{model.n_steps=}") |
| 108 | + print(f"{model.n_epochs=}") |
| 109 | + print(f"{model.batch_size=}") |
| 110 | + print(f"{model.device=}") |
| 111 | + |
| 112 | + print("SERVER : finished executing") |
| 113 | + |
| 114 | + # obs = envs.reset() |
| 115 | + # while True: |
| 116 | + # action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation |
| 117 | + # obs, reward, done, info = envs.step(action) |
| 118 | + # envs.render() # Optional: visualize the environment |
| 119 | + |
| 120 | + |
| 121 | + while True: |
| 122 | + export_onnx(model) |
| 123 | + test_onnx(model) |
| 124 | + |
| 125 | + if LOG_LEVEL <= logging.DEBUG: |
| 126 | + model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback()) |
| 127 | + else: |
| 128 | + model.learn(total_timesteps=500_000) |
| 129 | + |
| 130 | + print("iteration over") |
| 131 | + |
| 132 | + model.save(save_path + str(i)) |
| 133 | + |
| 134 | + i += 1 |
0 commit comments