Skip to content

Commit 05ce411

Browse files
committed
fix : fix some issues after the merge between nikopol and arthur
1 parent 0a33030 commit 05ce411

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

scripts/launch_train_multiprocessing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def log(s: str):
3333
if B_DEBUG:
3434
print("Webots started", file=open("/tmp/autotech/logs", "w"))
3535

36-
def make_env(rank: int):
37-
log(f"CAREFUL !!! created an SERVER env with {rank=}")
38-
return WebotsSimulationGymEnvironment(rank)
36+
def make_env(simulation_rank: int, vehicle_rank: int):
37+
log(f"CAREFUL !!! created an SERVER env with {simulation_rank}_{vehicle_rank}")
38+
return WebotsSimulationGymEnvironment(simulation_rank, vehicle_rank)
3939

40-
envs = SubprocVecEnv([lambda rank=rank : make_env(rank) for rank in range(n_simulations)])
40+
envs = SubprocVecEnv([lambda simulation_rank=simulation_rank, vehicle_rank=vehicle_rank : make_env(simulation_rank, vehicle_rank) for vehicle_rank in range(n_vehicles) for simulation_rank in range(n_simulations)])
4141

4242
ExtractorClass = TemporalResNetExtractor
4343

@@ -124,7 +124,7 @@ def make_env(rank: int):
124124
test_onnx(model)
125125

126126
if B_DEBUG:
127-
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
127+
model.learn(total_timesteps=500_000, callback=DynamicActionPlotDistributionCallback())
128128
else:
129129
model.learn(total_timesteps=500_000)
130130

src/Simulateur/WebotsSimulationGymEnvironment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
def log(s: str):
1010
if B_DEBUG:
1111
print(s, file=open("/tmp/autotech/logs", "a"))
12-
1312
class WebotsSimulationGymEnvironment(gym.Env):
1413
"""
1514
One environment for each vehicle
@@ -92,3 +91,5 @@ def step(self, action):
9291
# if self.simulation_rank == 0:
9392
# print(f"{(obs[0] == 0).mean():.3f} {(obs[1] == 0).mean():.3f}")
9493
return obs, reward, done, truncated, info
94+
95+

0 commit comments

Comments
 (0)