Skip to content

Commit 0a33030

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/HL' into HL
2 parents b3b4951 + d3d878f commit 0a33030

10 files changed

Lines changed: 136 additions & 173 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Debug_Wayfinding
1010
.venv
1111
*.onnx
1212
checkpoints
13+
*.wbproj

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ uv sync --extra rpi
3333

3434
Navigate to the simulator directory.
3535
```bash
36-
cd src/Simulateur
36+
cd scripts
3737
```
3838

3939
Run the multi-process training script.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
import sys
3+
4+
from typing import *
5+
6+
import torch.nn as nn
7+
8+
from stable_baselines3 import PPO
9+
from stable_baselines3.common.vec_env import SubprocVecEnv
10+
11+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
12+
if simu_path not in sys.path:
13+
sys.path.insert(0, simu_path)
14+
15+
from config import *
16+
from TemporalResNetExtractor import TemporalResNetExtractor
17+
from onnx_utils import *
18+
19+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
20+
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
21+
22+
def log(s: str):
23+
if B_DEBUG:
24+
print(s, file=open("/tmp/autotech/logs", "a"))
25+
26+
27+
28+
if __name__ == "__main__":
29+
if not os.path.exists("/tmp/autotech/"):
30+
os.mkdir("/tmp/autotech/")
31+
32+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
33+
if B_DEBUG:
34+
print("Webots started", file=open("/tmp/autotech/logs", "w"))
35+
36+
def make_env(rank: int):
37+
log(f"CAREFUL !!! created an SERVER env with {rank=}")
38+
return WebotsSimulationGymEnvironment(rank)
39+
40+
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
41+
42+
ExtractorClass = TemporalResNetExtractor
43+
44+
policy_kwargs = dict(
45+
features_extractor_class=ExtractorClass,
46+
features_extractor_kwargs=dict(
47+
context_size=context_size,
48+
lidar_horizontal_resolution=lidar_horizontal_resolution,
49+
camera_horizontal_resolution=camera_horizontal_resolution,
50+
device=device
51+
),
52+
activation_fn=nn.ReLU,
53+
net_arch=[512, 512, 512],
54+
)
55+
56+
57+
ppo_args = dict(
58+
n_steps=4096,
59+
n_epochs=10,
60+
batch_size=256,
61+
learning_rate=3e-4,
62+
gamma=0.99,
63+
verbose=1,
64+
normalize_advantage=True,
65+
device=device
66+
)
67+
68+
69+
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
70+
os.makedirs(save_path, exist_ok=True)
71+
72+
73+
print(save_path)
74+
print(os.listdir(save_path))
75+
76+
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
77+
78+
if valid_files:
79+
model_name = max(
80+
valid_files,
81+
key=lambda x : int(x.rstrip(".zip"))
82+
)
83+
print(f"Loading model {save_path + model_name}")
84+
model = PPO.load(
85+
save_path + model_name,
86+
envs,
87+
**ppo_args,
88+
policy_kwargs=policy_kwargs
89+
)
90+
i = int(model_name.rstrip(".zip")) + 1
91+
print(f"----- Model found, loading {model_name} -----")
92+
93+
else:
94+
model = PPO(
95+
"MlpPolicy",
96+
envs,
97+
**ppo_args,
98+
policy_kwargs=policy_kwargs
99+
)
100+
101+
i = 0
102+
print("----- Model not found, creating a new one -----")
103+
104+
print("MODEL HAS HYPER PARAMETERS:")
105+
print(f"{model.learning_rate=}")
106+
print(f"{model.gamma=}")
107+
print(f"{model.verbose=}")
108+
print(f"{model.n_steps=}")
109+
print(f"{model.n_epochs=}")
110+
print(f"{model.batch_size=}")
111+
print(f"{model.device=}")
112+
113+
log(f"SERVER : finished executing")
114+
115+
# obs = envs.reset()
116+
# while True:
117+
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
118+
# obs, reward, done, info = envs.step(action)
119+
# envs.render() # Optional: visualize the environment
120+
121+
122+
while True:
123+
export_onnx(model)
124+
test_onnx(model)
125+
126+
if B_DEBUG:
127+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
128+
else:
129+
model.learn(total_timesteps=500_000)
130+
131+
model.save(save_path + str(i))
132+
133+
i += 1
Lines changed: 0 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
11
import os
2-
import time
32
from typing import *
4-
5-
import matplotlib.pyplot as plt
63
import numpy as np
7-
import torch
8-
import torch.nn as nn
9-
import torch.optim as optim
10-
import torch.multiprocessing as mp
11-
12-
from stable_baselines3 import PPO
13-
from stable_baselines3.common.env_checker import check_env
14-
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
15-
164
import gymnasium as gym
175

18-
from onnx_utils import export_onnx, test_onnx
196
from config import *
20-
from CNN1DExtractor import CNN1DExtractor
21-
from TemporalResNetExtractor import TemporalResNetExtractor
22-
from CNN1DResNetExtractor import CNN1DResNetExtractor
23-
24-
if B_DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
257

268

279
def log(s: str):
@@ -110,111 +92,3 @@ def step(self, action):
11092
# if self.simulation_rank == 0:
11193
# print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
11294
return obs, reward, done, truncated, info
113-
114-
115-
if __name__ == "__main__":
116-
if not os.path.exists("/tmp/autotech/"):
117-
os.mkdir("/tmp/autotech/")
118-
119-
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
120-
if B_DEBUG:
121-
print("Webots started", file=open("/tmp/autotech/logs", "w"))
122-
123-
def make_env(rank: int, rank_v: int):
124-
log(f"CAREFUL !!! created an SERVER env with {rank=}")
125-
return WebotsSimulationGymEnvironment(rank, rank_v)
126-
127-
envs = SubprocVecEnv([lambda rank=rank, rank_v =rank_v : make_env(rank, rank_v) for rank_v in range(n_vehicles) for rank in range(n_simulations)])
128-
129-
ExtractorClass = TemporalResNetExtractor
130-
131-
policy_kwargs = dict(
132-
features_extractor_class=ExtractorClass,
133-
features_extractor_kwargs=dict(
134-
context_size=context_size,
135-
lidar_horizontal_resolution=lidar_horizontal_resolution,
136-
camera_horizontal_resolution=camera_horizontal_resolution,
137-
device=device
138-
),
139-
activation_fn=nn.ReLU,
140-
net_arch=[512, 512, 512],
141-
)
142-
143-
144-
ppo_args = dict(
145-
n_steps=4096,
146-
n_epochs=10,
147-
batch_size=256,
148-
learning_rate=3e-4,
149-
gamma=0.99,
150-
verbose=1,
151-
normalize_advantage=True,
152-
device=device
153-
)
154-
155-
156-
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
157-
if not os.path.exists(save_path):
158-
os.mkdir(save_path)
159-
160-
print(save_path)
161-
print(os.listdir(save_path))
162-
163-
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
164-
165-
if valid_files:
166-
model_name = max(
167-
valid_files,
168-
key=lambda x : int(x.rstrip(".zip"))
169-
)
170-
print(f"Loading model {save_path + model_name}")
171-
model = PPO.load(
172-
save_path + model_name,
173-
envs,
174-
**ppo_args,
175-
policy_kwargs=policy_kwargs
176-
)
177-
i = int(model_name.rstrip(".zip")) + 1
178-
print(f"----- Model found, loading {model_name} -----")
179-
180-
else:
181-
model = PPO(
182-
"MlpPolicy",
183-
envs,
184-
**ppo_args,
185-
policy_kwargs=policy_kwargs
186-
)
187-
188-
i = 0
189-
print("----- Model not found, creating a new one -----")
190-
191-
print("MODEL HAS HYPER PARAMETERS:")
192-
print(f"{model.learning_rate=}")
193-
print(f"{model.gamma=}")
194-
print(f"{model.verbose=}")
195-
print(f"{model.n_steps=}")
196-
print(f"{model.n_epochs=}")
197-
print(f"{model.batch_size=}")
198-
print(f"{model.device=}")
199-
200-
log(f"SERVER : finished executing")
201-
202-
# obs = envs.reset()
203-
# while True:
204-
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
205-
# obs, reward, done, info = envs.step(action)
206-
# envs.render() # Optional: visualize the environment
207-
208-
209-
while True:
210-
export_onnx(model)
211-
test_onnx(model)
212-
213-
if B_DEBUG:
214-
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
215-
else:
216-
model.learn(total_timesteps=500_000)
217-
218-
model.save(save_path + str(i))
219-
220-
i += 1

src/Simulateur/__init__.py

Whitespace-only changes.

src/Simulateur/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch.cuda import is_available
33

44
n_map = 2
5-
n_simulations = 8
5+
n_simulations = 2
66
n_vehicles = 1
77
n_stupid_vehicles = 0
88
n_actions_steering = 16

src/Simulateur/worlds/.piste.wbproj

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/Simulateur/worlds/.piste0.wbproj

Lines changed: 0 additions & 11 deletions
This file was deleted.

src/Simulateur/worlds/.piste1.wbproj

Lines changed: 0 additions & 13 deletions
This file was deleted.

src/Simulateur/worlds/.piste2.wbproj

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)