Skip to content

Commit 9901c7e

Browse files
Create evaluation.py
1 parent 63a015d commit 9901c7e

1 file changed

Lines changed: 85 additions & 0 deletions

File tree

src/utils/evaluation.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import gym
2+
import numpy as np
3+
import torch
4+
import logging
5+
6+
logger = logging.getLogger(__name__)
7+
8+
class RealEnv:
9+
def __init__(self, env_name):
10+
self.env = gym.make(env_name)
11+
self.observation_space = self.env.observation_space
12+
self.action_space = self.env.action_space
13+
self._max_episode_steps = self.env._max_episode_steps
14+
15+
def reset(self, **kwargs):
16+
return self.env.reset(**kwargs)
17+
18+
def step(self, action):
19+
return self.env.step(action)
20+
21+
def close(self):
22+
self.env.close()
23+
24+
class ReplayEnv:
25+
def __init__(self, dataset_path, env_name):
26+
self.dataset = np.load(dataset_path)
27+
self.states = self.dataset['states']
28+
self.actions = self.dataset['actions']
29+
self.rewards = self.dataset['rewards']
30+
self.dones = self.dataset['dones']
31+
self.masks = self.dataset['mask']
32+
33+
# Infer spaces from data
34+
state_dim = self.states.shape[2]
35+
act_dim = self.actions.shape[2]
36+
37+
self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(state_dim,), dtype=np.float32)
38+
# Assuming continuous for now, can infer from metadata if needed
39+
self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(act_dim,), dtype=np.float32)
40+
self._max_episode_steps = self.states.shape[1]
41+
42+
self.current_idx = 0
43+
self.t = 0
44+
self.env_name = env_name
45+
46+
def reset(self, **kwargs):
47+
# Pick a random clip
48+
self.current_idx = np.random.randint(0, len(self.states))
49+
self.t = 0
50+
51+
obs = self.states[self.current_idx, 0]
52+
return obs, {}
53+
54+
def step(self, action):
55+
# We ignore the action for state transition (replay)
56+
# But we could log it if we wanted to compute BC loss
57+
58+
if self.t >= self._max_episode_steps - 1:
59+
# End of clip
60+
done = True
61+
obs = self.states[self.current_idx, self.t] # Return last state again? Or zeros?
62+
reward = 0
63+
else:
64+
obs = self.states[self.current_idx, self.t + 1]
65+
reward = self.rewards[self.current_idx, self.t]
66+
done = bool(self.dones[self.current_idx, self.t])
67+
self.t += 1
68+
69+
return obs, float(reward), done, False, {"replay_mode": True}
70+
71+
def close(self):
72+
pass
73+
74+
def create_env(env_name, simulator_available=False, dataset_path=None):
75+
if simulator_available:
76+
try:
77+
return RealEnv(env_name)
78+
except Exception as e:
79+
logger.warning(f"Failed to create real environment {env_name}: {e}. Falling back to ReplayEnv.")
80+
81+
if dataset_path is None:
82+
raise ValueError("dataset_path must be provided for ReplayEnv (when simulator is unavailable).")
83+
84+
logger.info(f"Creating ReplayEnv from {dataset_path}")
85+
return ReplayEnv(dataset_path, env_name)

0 commit comments

Comments
 (0)