Skip to content

Commit b063dd7

Browse files
Merge branch 'main' into Recul
2 parents 14857fa + 02f922d commit b063dd7

25 files changed

Lines changed: 548 additions & 559 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ Debug_Wayfinding
1010
.venv
1111
*.onnx
1212
checkpoints
13+
*.wbproj

scripts/envoie_camera_sur_web.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,19 @@ class StreamingServer(socketserver.ThreadingMixIn, server.HTTPServer):
8181
allow_reuse_address = True
8282
daemon_threads = True
8383

84-
# Create Picamera2 instance and configure it
85-
picam2 = Picamera2()
86-
picam2.configure(picam2.create_video_configuration(main={"size": (640, 480)}))
87-
output = StreamingOutput()
88-
picam2.start_recording(JpegEncoder(), FileOutput(output))
8984

90-
try:
91-
# Set up and start the streaming server
92-
address = ('', 8000)
93-
server = StreamingServer(address, StreamingHandler)
94-
server.serve_forever()
95-
finally:
96-
# Stop recording when the script is interrupted
97-
picam2.stop_recording()
85+
if __name__ = "__name__":
86+
# Create Picamera2 instance and configure it
87+
picam2 = Picamera2()
88+
picam2.configure(picam2.create_video_configuration(main={"size": (640, 480)}))
89+
output = StreamingOutput()
90+
picam2.start_recording(JpegEncoder(), FileOutput(output))
91+
92+
try:
93+
# Set up and start the streaming server
94+
address = ('0.0.0.0', 8000)
95+
server = StreamingServer(address, StreamingHandler)
96+
server.serve_forever()
97+
finally:
98+
# Stop recording when the script is interrupted
99+
picam2.stop_recording()

scripts/lanch_one_simu.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
4+
from typing import *
5+
import numpy as np
6+
import onnxruntime as ort
7+
8+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
9+
if simu_path not in sys.path:
10+
sys.path.insert(0, simu_path)
11+
12+
from onnx_utils import run_onnx_model
13+
from config import *
14+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
15+
from TemporalResNetExtractor import TemporalResNetExtractor
16+
from CNN1DResNetExtractor import CNN1DResNetExtractor
17+
# -------------------------------------------------------------------------
18+
19+
20+
21+
# --- Chemin vers le fichier ONNX ---
22+
23+
ONNX_MODEL_PATH = "model.onnx"
24+
25+
# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) ---
26+
def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession:
27+
if not os.path.exists(onnx_path):
28+
raise FileNotFoundError(f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord.")
29+
30+
# Crée la session d'inférence
31+
return ort.InferenceSession(onnx_path) #On peut modifier le providers afin de mettre une CUDA
32+
33+
34+
if __name__ == "__main__":
35+
if not os.path.exists("/tmp/autotech/"):
36+
os.mkdir("/tmp/autotech/")
37+
38+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
39+
40+
41+
# 2. Initialisation de la session ONNX Runtime
42+
try:
43+
ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH)
44+
input_name = ort_session.get_inputs()[0].name
45+
output_name = ort_session.get_outputs()[0].name
46+
print(f"Modèle ONNX chargé depuis {ONNX_MODEL_PATH}")
47+
print(f"Input Name: {input_name}, Output Name: {output_name}")
48+
except FileNotFoundError as e:
49+
print(f"ERREUR : {e}")
50+
print(
51+
"Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'.")
52+
sys.exit(1)
53+
54+
# 3. Boucle d'inférence (Test)
55+
env = WebotsSimulationGymEnvironment(0,0)
56+
obs = env.reset()
57+
print("Début de la simulation en mode inférence...")
58+
59+
max_steps = 5000
60+
step_count = 0
61+
62+
while True:
63+
64+
action = run_onnx_model(ort_session,obs)
65+
66+
# 4. Exécuter l'action dans l'environnement
67+
obs, reward, done, info = env.step(action)
68+
69+
# Note: L'environnement Webots gère généralement son propre affichage
70+
# env.render() # Décommenter si votre env supporte le rendu externe
71+
72+
# Gestion des fins d'épisodes
73+
if done:
74+
print(f"Épisode(s) terminé(s) après {step_count} étapes.")
75+
obs = env.reset()
76+
77+
78+
79+
# Fermeture propre (très important pour les processus parallèles SubprocVecEnv)
80+
envs.close()
81+
print("Simulation terminée. Environnements fermés.")
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import logging
2+
import os
3+
import sys
4+
5+
from typing import *
6+
7+
import torch.nn as nn
8+
9+
from stable_baselines3 import PPO
10+
from stable_baselines3.common.vec_env import SubprocVecEnv
11+
12+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
13+
if simu_path not in sys.path:
14+
sys.path.insert(0, simu_path)
15+
16+
from Simulateur.config import LOG_LEVEL
17+
from config import *
18+
from TemporalResNetExtractor import TemporalResNetExtractor
19+
from CNN1DResNetExtractor import CNN1DResNetExtractor
20+
from onnx_utils import *
21+
22+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
23+
if LOG_LEVEL == logging.DEBUG: from DynamicActionPlotCallback import DynamicActionPlotDistributionCallback
24+
25+
26+
if __name__ == "__main__":
27+
28+
if not os.path.exists("/tmp/autotech/"):
29+
os.mkdir("/tmp/autotech/")
30+
31+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
32+
33+
34+
def make_env(simulation_rank: int, vehicle_rank: int):
35+
if LOG_LEVEL == logging.DEBUG:
36+
print("CAREFUL !!! created an SERVER env with {simulation_rank}_{vehicle_rank}")
37+
return WebotsSimulationGymEnvironment(simulation_rank, vehicle_rank)
38+
39+
envs = SubprocVecEnv([lambda simulation_rank=simulation_rank, vehicle_rank=vehicle_rank : make_env(simulation_rank, vehicle_rank) for vehicle_rank in range(n_vehicles) for simulation_rank in range(n_simulations)])
40+
41+
ExtractorClass = CNN1DResNetExtractor
42+
43+
policy_kwargs = dict(
44+
features_extractor_class=ExtractorClass,
45+
features_extractor_kwargs=dict(
46+
context_size=context_size,
47+
lidar_horizontal_resolution=lidar_horizontal_resolution,
48+
camera_horizontal_resolution=camera_horizontal_resolution,
49+
device=device
50+
),
51+
activation_fn=nn.ReLU,
52+
net_arch=[512, 512, 512],
53+
)
54+
55+
56+
ppo_args = dict(
57+
n_steps=4096,
58+
n_epochs=10,
59+
batch_size=256,
60+
learning_rate=3e-4,
61+
gamma=0.99,
62+
verbose=1,
63+
normalize_advantage=True,
64+
device=device
65+
)
66+
67+
68+
save_path = __file__.rsplit("/", 1)[0] + "/checkpoints/" + ExtractorClass.__name__ + "/"
69+
os.makedirs(save_path, exist_ok=True)
70+
71+
72+
print(save_path)
73+
print(os.listdir(save_path))
74+
75+
valid_files = [x for x in os.listdir(save_path) if x.rstrip(".zip").isnumeric()]
76+
77+
if valid_files:
78+
model_name = max(
79+
valid_files,
80+
key=lambda x : int(x.rstrip(".zip"))
81+
)
82+
print(f"Loading model {save_path + model_name}")
83+
model = PPO.load(
84+
save_path + model_name,
85+
envs,
86+
**ppo_args,
87+
policy_kwargs=policy_kwargs
88+
)
89+
i = int(model_name.rstrip(".zip")) + 1
90+
print(f"----- Model found, loading {model_name} -----")
91+
92+
else:
93+
model = PPO(
94+
"MlpPolicy",
95+
envs,
96+
**ppo_args,
97+
policy_kwargs=policy_kwargs
98+
)
99+
100+
i = 0
101+
print("----- Model not found, creating a new one -----")
102+
103+
print("MODEL HAS HYPER PARAMETERS:")
104+
print(f"{model.learning_rate=}")
105+
print(f"{model.gamma=}")
106+
print(f"{model.verbose=}")
107+
print(f"{model.n_steps=}")
108+
print(f"{model.n_epochs=}")
109+
print(f"{model.batch_size=}")
110+
print(f"{model.device=}")
111+
112+
print("SERVER : finished executing")
113+
114+
# obs = envs.reset()
115+
# while True:
116+
# action, _states = model.predict(obs, deterministic=True) # Use deterministic=True for evaluation
117+
# obs, reward, done, info = envs.step(action)
118+
# envs.render() # Optional: visualize the environment
119+
120+
121+
while True:
122+
export_onnx(model)
123+
test_onnx(model)
124+
125+
if LOG_LEVEL <= logging.DEBUG:
126+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
127+
else:
128+
model.learn(total_timesteps=500_000)
129+
130+
print("iteration over")
131+
132+
model.save(save_path + str(i))
133+
134+
i += 1

scripts/remote_control_controller.py

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -38,62 +38,66 @@ def set_vitesse_m_ms(vit):
3838
global vitesse_m
3939
vitesse_m = vit
4040

41-
###################################################
42-
# Init pygame + manette
43-
###################################################
44-
pygame.init()
45-
pygame.joystick.init()
46-
47-
if pygame.joystick.get_count() == 0:
48-
print("Aucune manette détectée")
49-
exit(1)
50-
51-
joy = pygame.joystick.Joystick(0)
52-
joy.init()
53-
print("Manette détectée:", joy.get_name())
5441

55-
###################################################
56-
# Boucle principale
57-
###################################################
58-
Thread(target=envoie_donnee, daemon=True).start()
59-
60-
try:
61-
while True:
62-
pygame.event.pump()
63-
64-
# Axes :
65-
# Pour Xbox/PS4 USB :
66-
# L2 = axis 2 (souvent 0..1)
67-
# R2 = axis 5 (souvent 0..1)
68-
# Stick gauche horizontal = axis 0 (-1..1)
69-
70-
axis_lx = joy.get_axis(0) # Gauche droite
71-
axis_l2 = joy.get_axis(2) # Accélération inverse
72-
axis_r2 = joy.get_axis(5) # Accélération
73-
74-
# Direction
75-
direction = map_range(axis_lx, -1, 1, -angle_degre_max, angle_degre_max)
76-
set_direction_degre(round(direction))
77-
78-
# Accélération
79-
accel = (axis_r2 + 1)/2
80-
brake = (axis_l2 + 1)/2
81-
82-
# Certaines manettes vont de -1..1, d'autres 0..1
83-
84-
# Avant
85-
if accel > 0.05:
86-
vit = accel * vitesse_max_m_s_soft * 1000
87-
set_vitesse_m_ms(round(vit))
88-
89-
# Arrière
90-
elif brake > 0.05:
91-
vit = brake * vitesse_min_m_s_soft * 1000
92-
set_vitesse_m_ms(round(vit))
93-
else :
94-
set_vitesse_m_ms(0)
95-
time.sleep(0.01)
96-
97-
except KeyboardInterrupt:
98-
print("Fin du programme.")
99-
pygame.quit()
42+
if __name__ == "__main__":
43+
44+
45+
###################################################
46+
# Init pygame + manette
47+
###################################################
48+
pygame.init()
49+
pygame.joystick.init()
50+
51+
if pygame.joystick.get_count() == 0:
52+
print("Aucune manette détectée")
53+
exit(1)
54+
55+
joy = pygame.joystick.Joystick(0)
56+
joy.init()
57+
print("Manette détectée:", joy.get_name())
58+
59+
###################################################
60+
# Boucle principale
61+
###################################################
62+
Thread(target=envoie_donnee, daemon=True).start()
63+
64+
try:
65+
while True:
66+
pygame.event.pump()
67+
68+
# Axes :
69+
# Pour Xbox/PS4 USB :
70+
# L2 = axis 2 (souvent 0..1)
71+
# R2 = axis 5 (souvent 0..1)
72+
# Stick gauche horizontal = axis 0 (-1..1)
73+
74+
axis_lx = joy.get_axis(0) # Gauche droite
75+
axis_l2 = joy.get_axis(2) # Accélération inverse
76+
axis_r2 = joy.get_axis(5) # Accélération
77+
78+
# Direction
79+
direction = map_range(axis_lx, -1, 1, -angle_degre_max, angle_degre_max)
80+
set_direction_degre(round(direction))
81+
82+
# Accélération
83+
accel = (axis_r2 + 1)/2
84+
brake = (axis_l2 + 1)/2
85+
86+
# Certaines manettes vont de -1..1, d'autres 0..1
87+
88+
# Avant
89+
if accel > 0.05:
90+
vit = accel * vitesse_max_m_s_soft * 1000
91+
set_vitesse_m_ms(round(vit))
92+
93+
# Arrière
94+
elif brake > 0.05:
95+
vit = brake * vitesse_min_m_s_soft * 1000
96+
set_vitesse_m_ms(round(vit))
97+
else :
98+
set_vitesse_m_ms(0)
99+
time.sleep(0.01)
100+
101+
except KeyboardInterrupt:
102+
print("Fin du programme.")
103+
pygame.quit()

src/HL/Autotech_constant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import numpy as np
3+
import logging
34

45
# Car control
56
MAX_IA_SPEED = 2 # maximum speed for ia
@@ -57,5 +58,5 @@
5758
# the higher the temperature the more unprobalbe actions become probable, the lower the temperature the more probable actions become probable.
5859
# In our case Higher temperature means less agressive driving and lower temperature means more aggressive driving.
5960

60-
import logging
61+
6162
LOGGING_LEVEL = logging.DEBUG # can be either NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL

0 commit comments

Comments
 (0)