Skip to content

Commit 469dbf9

Browse files
author
Markgraf
committed
unfonctionnal: split lauch programm
1 parent 72154cb commit 469dbf9

5 files changed

Lines changed: 15 additions & 130 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ uv sync --extra rpi
3333

3434
Navigate to the simulator directory.
3535
```bash
36-
cd src/Simulateur
36+
cd scripts
3737
```
3838

3939
Run the multi-process training script.
Lines changed: 1 addition & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
11
import os
2-
import time
32
from typing import *
4-
5-
import matplotlib.pyplot as plt
63
import numpy as np
7-
import torch
8-
import torch.nn as nn
9-
import torch.optim as optim
10-
import torch.multiprocessing as mp
11-
12-
from stable_baselines3 import PPO
13-
from stable_baselines3.common.env_checker import check_env
14-
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
15-
164
import gymnasium as gym
175

18-
from onnx_utils import export_onnx, test_onnx
196
from config import *
20-
from CNN1DExtractor import CNN1DExtractor
21-
from TemporalResNetExtractor import TemporalResNetExtractor
22-
from CNN1DResNetExtractor import CNN1DResNetExtractor
23-
24-
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
257

268

279
def log(s: str):
@@ -108,112 +90,4 @@ def step(self, action):
10890
# check if the context is correct
10991
# if self.simulation_rank == 0:
11092
# print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
111-
return obs, reward, done, truncated, info
112-
113-
114-
if __name__ == "__main__":
115-
if not os.path.exists("/tmp/autotech/"):
116-
os.mkdir("/tmp/autotech/")
117-
118-
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
119-
if B_DEBUG:
120-
print("Webots started", file=open("/tmp/autotech/logs", "w"))
121-
122-
def make_env(rank: int):
123-
log(f"CAREFUL !!! created an SERVER env with {rank=}")
124-
return WebotsSimulationGymEnvironment(rank)
125-
126-
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
127-
128-
ExtractorClass = TemporalResNetExtractor
129-
130-
policy_kwargs = dict(
131-
features_extractor_class=ExtractorClass,
132-
features_extractor_kwargs=dict(
133-
context_size=context_size,
134-
lidar_horizontal_resolution=lidar_horizontal_resolution,
135-
camera_horizontal_resolution=camera_horizontal_resolution,
136-
device=device
137-
),
138-
activation_fn=nn.ReLU,
139-
net_arch=[512, 512, 512],
140-
)
141-
142-
143-
ppo_args = dict(
144-
n_steps=4096,
145-
n_epochs=10,
146-
batch_size=256,
147-
learning_rate=3e-4,
148-
gamma=0.99,
149-
verbose=1,
150-
normalize_advantage=True,
151-
device=device
152-
)
153-
154-
155-
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
156-
if not os.path.exists(save_path):
157-
os.mkdir(save_path)
158-
159-
print(save_path)
160-
print(os.listdir(save_path))
161-
162-
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
163-
164-
if valid_files:
165-
model_name = max(
166-
valid_files,
167-
key=lambda x : int(x.rstrip(".zip"))
168-
)
169-
print(f"Loading model {save_path + model_name}")
170-
model = PPO.load(
171-
save_path + model_name,
172-
envs,
173-
**ppo_args,
174-
policy_kwargs=policy_kwargs
175-
)
176-
i = int(model_name.rstrip(".zip")) + 1
177-
print(f"----- Model found, loading {model_name} -----")
178-
179-
else:
180-
model = PPO(
181-
"MlpPolicy",
182-
envs,
183-
**ppo_args,
184-
policy_kwargs=policy_kwargs
185-
)
186-
187-
i = 0
188-
print("----- Model not found, creating a new one -----")
189-
190-
print("MODEL HAS HYPER PARAMETERS:")
191-
print(f"{model.learning_rate=}")
192-
print(f"{model.gamma=}")
193-
print(f"{model.verbose=}")
194-
print(f"{model.n_steps=}")
195-
print(f"{model.n_epochs=}")
196-
print(f"{model.batch_size=}")
197-
print(f"{model.device=}")
198-
199-
log(f"SERVER : finished executing")
200-
201-
# obs = envs.reset()
202-
# while True:
203-
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
204-
# obs, reward, done, info = envs.step(action)
205-
# envs.render() # Optional: visualize the environment
206-
207-
208-
while True:
209-
export_onnx(model)
210-
test_onnx(model)
211-
212-
if B_DEBUG:
213-
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
214-
else:
215-
model.learn(total_timesteps=500_000)
216-
217-
model.save(save_path + str(i))
218-
219-
i += 1
93+
return obs, reward, done, truncated, info

src/Simulateur/__init__.py

Whitespace-only changes.

src/Simulateur/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.cuda import is_available
33

44
n_map = 2
5-
n_simulations = 8
5+
n_simulations = 2
66
n_vehicles = 1
77
n_stupid_vehicles = 0
88
n_actions_steering = 16

src/Simulateur/controllers/controllerWorldSupervisor/controllerWorldSupervisor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import *
33
import numpy as np
44
import gymnasium as gym
5+
import time
56

67
from checkpointmanager import CheckpointManager, checkpoints
78

@@ -219,7 +220,17 @@ def main():
219220
#Prédiction pour séléctionner une action à partir de l"observation
220221
for e in envs:
221222
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : trying to read from fifo")
222-
action = np.frombuffer(e.fifo_r.read(np.dtype(np.int64).itemsize * 2), dtype=np.int64)
223+
224+
timeout = 10 # seconds
225+
start_time = time.time()
226+
227+
while time.time() - start_time < timeout:
228+
raw = e.fifo_r.read(np.dtype(np.int64).itemsize * 2)
229+
if len(raw) == np.dtype(np.int64).itemsize * 2:
230+
# We got the full action data
231+
action = np.frombuffer(raw, dtype=np.int64)
232+
break
233+
223234
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : received {action=}")
224235

225236
obs, reward, done, truncated, info = e.step(action)

0 commit comments

Comments
 (0)