Skip to content

Commit b42aa03

Browse files
Revert "Split train"
1 parent 1de4aa7 commit b42aa03

11 files changed

Lines changed: 175 additions & 149 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@ Debug_Wayfinding
1010
.venv
1111
*.onnx
1212
checkpoints
13-
*.wbproj

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ uv sync --extra rpi
3333

3434
Navigate to the simulator directory.
3535
```bash
36-
cd scripts
36+
cd src/Simulateur
3737
```
3838

3939
Run the multi-process training script.

scripts/launch_train_multiprocessing.py

Lines changed: 0 additions & 133 deletions
This file was deleted.

src/Simulateur/__init__.py

Whitespace-only changes.

src/Simulateur/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.cuda import is_available
33

44
n_map = 2
5-
n_simulations = 2
5+
n_simulations = 8
66
n_vehicles = 1
77
n_stupid_vehicles = 0
88
n_actions_steering = 16

src/Simulateur/controllers/controllerWorldSupervisor/controllerWorldSupervisor.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import *
33
import numpy as np
44
import gymnasium as gym
5-
import time
65

76
from checkpointmanager import CheckpointManager, checkpoints
87

@@ -220,17 +219,7 @@ def main():
220219
#Prédiction pour séléctionner une action à partir de l"observation
221220
for e in envs:
222221
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : trying to read from fifo")
223-
224-
timeout = 10 # seconds
225-
start_time = time.time()
226-
227-
while time.time() - start_time < timeout:
228-
raw = e.fifo_r.read(np.dtype(np.int64).itemsize * 2)
229-
if len(raw) == np.dtype(np.int64).itemsize * 2:
230-
# We got the full action data
231-
action = np.frombuffer(raw, dtype=np.int64)
232-
break
233-
222+
action = np.frombuffer(e.fifo_r.read(np.dtype(np.int64).itemsize * 2), dtype=np.int64)
234223
log(f"CLIENT{simulation_rank}/{e.vehicle_rank} : received {action=}")
235224

236225
obs, reward, done, truncated, info = e.step(action)

src/Simulateur/WebotsSimulationGymEnvironment.py renamed to src/Simulateur/launch_train_multiprocessing.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,27 @@
11
import os
2+
import time
23
from typing import *
4+
5+
import matplotlib.pyplot as plt
36
import numpy as np
7+
import torch
8+
import torch.nn as nn
9+
import torch.optim as optim
10+
import torch.multiprocessing as mp
11+
12+
from stable_baselines3 import PPO
13+
from stable_baselines3.common.env_checker import check_env
14+
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
15+
416
import gymnasium as gym
517

18+
from onnx_utils import export_onnx, test_onnx
619
from config import *
20+
from CNN1DExtractor import CNN1DExtractor
21+
from TemporalResNetExtractor import TemporalResNetExtractor
22+
from CNN1DResNetExtractor import CNN1DResNetExtractor
23+
24+
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
725

826

927
def log(s: str):
@@ -90,4 +108,112 @@ def step(self, action):
90108
# check if the context is correct
91109
# if self.simulation_rank == 0:
92110
# print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
93-
return obs, reward, done, truncated, info
111+
return obs, reward, done, truncated, info
112+
113+
114+
if __name__ == "__main__":
115+
if not os.path.exists("/tmp/autotech/"):
116+
os.mkdir("/tmp/autotech/")
117+
118+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
119+
if B_DEBUG:
120+
print("Webots started", file=open("/tmp/autotech/logs", "w"))
121+
122+
def make_env(rank: int):
123+
log(f"CAREFUL !!! created an SERVER env with {rank=}")
124+
return WebotsSimulationGymEnvironment(rank)
125+
126+
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
127+
128+
ExtractorClass = TemporalResNetExtractor
129+
130+
policy_kwargs = dict(
131+
features_extractor_class=ExtractorClass,
132+
features_extractor_kwargs=dict(
133+
context_size=context_size,
134+
lidar_horizontal_resolution=lidar_horizontal_resolution,
135+
camera_horizontal_resolution=camera_horizontal_resolution,
136+
device=device
137+
),
138+
activation_fn=nn.ReLU,
139+
net_arch=[512, 512, 512],
140+
)
141+
142+
143+
ppo_args = dict(
144+
n_steps=4096,
145+
n_epochs=10,
146+
batch_size=256,
147+
learning_rate=3e-4,
148+
gamma=0.99,
149+
verbose=1,
150+
normalize_advantage=True,
151+
device=device
152+
)
153+
154+
155+
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
156+
if not os.path.exists(save_path):
157+
os.mkdir(save_path)
158+
159+
print(save_path)
160+
print(os.listdir(save_path))
161+
162+
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
163+
164+
if valid_files:
165+
model_name = max(
166+
valid_files,
167+
key=lambda x : int(x.rstrip(".zip"))
168+
)
169+
print(f"Loading model {save_path + model_name}")
170+
model = PPO.load(
171+
save_path + model_name,
172+
envs,
173+
**ppo_args,
174+
policy_kwargs=policy_kwargs
175+
)
176+
i = int(model_name.rstrip(".zip")) + 1
177+
print(f"----- Model found, loading {model_name} -----")
178+
179+
else:
180+
model = PPO(
181+
"MlpPolicy",
182+
envs,
183+
**ppo_args,
184+
policy_kwargs=policy_kwargs
185+
)
186+
187+
i = 0
188+
print("----- Model not found, creating a new one -----")
189+
190+
print("MODEL HAS HYPER PARAMETERS:")
191+
print(f"{model.learning_rate=}")
192+
print(f"{model.gamma=}")
193+
print(f"{model.verbose=}")
194+
print(f"{model.n_steps=}")
195+
print(f"{model.n_epochs=}")
196+
print(f"{model.batch_size=}")
197+
print(f"{model.device=}")
198+
199+
log(f"SERVER : finished executing")
200+
201+
# obs = envs.reset()
202+
# while True:
203+
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
204+
# obs, reward, done, info = envs.step(action)
205+
# envs.render() # Optional: visualize the environment
206+
207+
208+
while True:
209+
export_onnx(model)
210+
test_onnx(model)
211+
212+
if B_DEBUG:
213+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
214+
else:
215+
model.learn(total_timesteps=500_000)
216+
217+
model.save(save_path + str(i))
218+
219+
i += 1
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Webots Project File version R2023b
2+
perspectives: 000000ff00000000fd0000000200000001000001cf00000278fc0200000001fb0000001400540065007800740045006400690074006f00720100000016000002780000008900ffffff000000030000078000000176fc0100000001fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000007800000006900ffffff000005af0000027800000001000000020000000100000008fc00000000
3+
simulationViewPerspectives: 000000ff0000000100000002000001520000045b0100000002010000000100
4+
sceneTreePerspectives: 000000ff00000001000000030000001f000000c0000000fa0100000002010000000200
5+
maximizedDockId: -1
6+
centralWidgetVisible: 1
7+
orthographicViewHeight: 1
8+
textFiles: 2 "Bobox.proto" "../../CoVAPSy_Intech/Simulateur/worlds/piste.wbt" "../../CoVAPSy_Intech/Simulateur/protos/Vehicle.proto"
9+
consoles: Console:All:All
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Webots Project File version R2025a
2+
perspectives: 000000ff00000000fd00000002000000010000011c00000177fc0200000001fb0000001400540065007800740045006400690074006f00720000000016000001770000004500ffffff00000003000006c000000216fc0100000001fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000006c00000008700ffffff000006c0000001c400000001000000020000000100000008fc00000000
3+
simulationViewPerspectives: 000000ff0000000100000002000001000000017d0100000002010000000100
4+
sceneTreePerspectives: 000000ff00000001000000030000001f0000013e000000fa0100000002010000000200
5+
maximizedDockId: -1
6+
centralWidgetVisible: 1
7+
orthographicViewHeight: 1
8+
textFiles: -1
9+
globalOptionalRendering: LidarRaysPaths::LidarPointClouds
10+
consoles: Console:All:All
11+
renderingDevicePerspectives: TT02_0:RASPI_Camera_V2;1;32;0;0
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Webots Project File version R2025a
2+
perspectives: 000000ff00000000fd0000000200000001000000870000028afc0200000001fb0000001400540065007800740045006400690074006f007201000000000000028a0000004500ffffff00000003000006c000000150fc0100000002fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000006c00000008700fffffffb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c010000000000000a000000000000000000000006370000028a00000001000000020000000100000008fc00000000
3+
simulationViewPerspectives: 000000ff00000001000000020000012b000005a50100000002010000000100
4+
sceneTreePerspectives: 000000ff00000001000000030000001e00000364000000fa0100000002010000000200
5+
minimizedPerspectives: 000000ff00000000fd0000000200000001000000750000017bfc0200000001fb0000001400540065007800740045006400690074006f007201000000160000017b0000003f00ffffff000000030000039b00000039fc0100000002fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c01000000000000039b0000006900fffffffb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c010000000000000a000000000000000000000003240000017b00000001000000020000000100000008fc00000000
6+
maximizedDockId: -1
7+
centralWidgetVisible: 1
8+
orthographicViewHeight: 1
9+
textFiles: -1
10+
globalOptionalRendering: LidarPointClouds::LidarRaysPaths
11+
consoles: Console:All:All
12+
renderingDevicePerspectives: TT02_0:RASPI_Camera_V2;1;32;0;0
13+
renderingDevicePerspectives: sparringpartner_car_0:RASPI_Camera_V2;1;32;0;0

0 commit comments

Comments
 (0)