Skip to content

Commit e5e6fdb

Browse files
committed
Mathias's issues give to Nachid test 1
1 parent ad403a8 commit e5e6fdb

2 files changed

Lines changed: 94 additions & 8 deletions

File tree

scripts/lanch_one_simu.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import sys
3+
4+
from typing import *
5+
import numpy as np
6+
import onnxruntime as ort
7+
8+
simu_path = __file__.rsplit('/', 2)[0] + '/src/Simulateur'
9+
if simu_path not in sys.path:
10+
sys.path.insert(0, simu_path)
11+
12+
from onnx_utils import run_onnx_model
13+
from config import *
14+
from WebotsSimulationGymEnvironment import WebotsSimulationGymEnvironment
15+
from TemporalResNetExtractor import TemporalResNetExtractor
16+
from CNN1DResNetExtractor import CNN1DResNetExtractor
17+
# -------------------------------------------------------------------------
18+
19+
def log(s: str):
20+
# Conservez votre fonction de log
21+
if B_DEBUG:
22+
print(s, file=open("/tmp/autotech/logs", "a"))
23+
24+
# --- Chemin vers le fichier ONNX ---
25+
26+
ONNX_MODEL_PATH = "model.onnx"
27+
28+
# --- Initialisation du moteur d'inférence ONNX Runtime (ORT) ---
29+
def init_onnx_runtime_session(onnx_path: str) -> ort.InferenceSession:
30+
if not os.path.exists(onnx_path):
31+
raise FileNotFoundError(f"Le fichier ONNX est introuvable à : {onnx_path}. Veuillez l'exporter d'abord.")
32+
33+
# Crée la session d'inférence
34+
return ort.InferenceSession(onnx_path) #On peut modifier le providers afin de mettre une CUDA
35+
36+
37+
if __name__ == "__main__":
38+
if not os.path.exists("/tmp/autotech/"):
39+
os.mkdir("/tmp/autotech/")
40+
41+
os.system('if [ -n "$(ls /tmp/autotech)" ]; then rm /tmp/autotech/*; fi')
42+
if B_DEBUG:
43+
print("Webots started", file=open("/tmp/autotech/logs", "w"))
44+
45+
46+
# 2. Initialisation de la session ONNX Runtime
47+
try:
48+
ort_session = init_onnx_runtime_session(ONNX_MODEL_PATH)
49+
input_name = ort_session.get_inputs()[0].name
50+
output_name = ort_session.get_outputs()[0].name
51+
print(f"Modèle ONNX chargé depuis {ONNX_MODEL_PATH}")
52+
print(f"Input Name: {input_name}, Output Name: {output_name}")
53+
except FileNotFoundError as e:
54+
print(f"ERREUR : {e}")
55+
print(
56+
"Veuillez vous assurer que vous avez exécuté une fois le script d'entraînement pour exporter 'model.onnx'.")
57+
sys.exit(1)
58+
59+
# 3. Boucle d'inférence (Test)
60+
env = WebotsSimulationGymEnvironment(0,0)
61+
obs = env.reset()
62+
print("Début de la simulation en mode inférence...")
63+
64+
max_steps = 5000
65+
step_count = 0
66+
67+
while True:
68+
69+
action = run_onnx_model(ort_session,obs)
70+
71+
# 4. Exécuter l'action dans l'environnement
72+
obs, reward, done, info = env.step(action)
73+
74+
# Note: L'environnement Webots gère généralement son propre affichage
75+
# env.render() # Décommenter si votre env supporte le rendu externe
76+
77+
# Gestion des fins d'épisodes
78+
if done:
79+
print(f"Épisode(s) terminé(s) après {step_count} étapes.")
80+
obs = env.reset()
81+
82+
83+
84+
# Fermeture propre (très important pour les processus parallèles SubprocVecEnv)
85+
envs.close()
86+
print("Simulation terminée. Environnements fermés.")

src/Simulateur/onnx_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
import torch
55
from config import *
6-
6+
import numpy as np
77
from CNN1DExtractor import CNN1DExtractor
88
from TemporalResNetExtractor import TemporalResNetExtractor
99

@@ -38,32 +38,32 @@ def export_onnx(model):
3838
model.policy.to(device)
3939
model.policy.train()
4040

41+
def run_onnx_model(session : ort.InferenceSession,x : np.ndarray):
42+
43+
return session.run(None, {"input": x})[0]
4144

4245
def test_onnx(model):
4346
device = model.policy.device
4447
model.policy.eval()
4548
true_model = get_true_model(model)
4649

50+
loss_fn = nn.MSELoss()
51+
x = torch.randn(1000, 2, context_size, lidar_horizontal_resolution)
52+
4753
try:
4854
session = ort.InferenceSession("model.onnx")
4955
except Exception as e:
5056
print(f"Error loading ONNX model: {e}")
5157
return
5258

53-
def model_onnx(x):
54-
return session.run(None, {"input": x.cpu().numpy()})[0]
55-
56-
loss_fn = nn.MSELoss()
57-
x = torch.randn(1000, 2, context_size, lidar_horizontal_resolution)
58-
5959
with torch.no_grad():
6060
y_true_test = true_model(x)
6161

6262
true_model.train()
6363
y_true_train = true_model(x)
6464
true_model.eval()
6565

66-
y_onnx = model_onnx(x)
66+
y_onnx = run_onnx_model(session,x.cpu().numpy())
6767

6868
loss_test = loss_fn(y_true_test, torch.tensor(y_onnx))
6969
loss_train = loss_fn(y_true_train, torch.tensor(y_onnx))

0 commit comments

Comments
 (0)