Skip to content

Commit d41341b

Browse files
authored
example.py uglyness moved to puffer_policy (#1632)
Quick update to move some of the ugliness up into puffer_policy [Asana Task](https://app.asana.com/1/1209016784099267/project/1210348820405981/task/1210873205908826)
1 parent 259e859 commit d41341b

2 files changed

Lines changed: 97 additions & 84 deletions

File tree

Lines changed: 26 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
import einops
42
import pufferlib.models
53
import pufferlib.pytorch
@@ -11,17 +9,6 @@ class Recurrent(pufferlib.models.LSTMWrapper):
119
def __init__(self, env, policy, input_size=512, hidden_size=512):
1210
super().__init__(env, policy, input_size, hidden_size)
1311

14-
def initialize_to_environment(
15-
self,
16-
features: dict[str, dict],
17-
action_names: list[str],
18-
action_max_params: list[int],
19-
device,
20-
):
21-
"""Pass initialization to wrapped policy."""
22-
if hasattr(self.policy, "initialize_to_environment"):
23-
self.policy.initialize_to_environment(features, action_names, action_max_params, device)
24-
2512

2613
class Policy(nn.Module):
2714
def __init__(self, env, cnn_channels=128, hidden_size=512, **kwargs):
@@ -33,34 +20,6 @@ def __init__(self, env, cnn_channels=128, hidden_size=512, **kwargs):
3320
self.out_height = 11
3421
self.num_layers = 22
3522

36-
# Define the standard feature order and their empirically determined normalizations
37-
# This acts like original_feature_mapping in MettaAgent
38-
self.feature_normalizations = {
39-
"type_id": 9.0,
40-
"agent:group": 1.0,
41-
"hp": 1.0,
42-
"agent:frozen": 10.0,
43-
"agent:orientation": 3.0,
44-
"agent:color": 254.0,
45-
"converting": 1.0,
46-
"swappable": 1.0,
47-
"episode_completion_pct": 235.0,
48-
"last_action": 8.0,
49-
"last_action_arg": 9.0,
50-
"last_reward": 250.0,
51-
"agent:glyph": 29.0,
52-
"resource_rewards": 1.0,
53-
# Inventory features (positions 14-21)
54-
"inv:0": 1.0,
55-
"inv:1": 8.0,
56-
"inv:2": 1.0,
57-
"inv:3": 1.0,
58-
"inv:4": 6.0,
59-
"inv:5": 3.0,
60-
"inv:6": 1.0,
61-
"inv:7": 2.0,
62-
}
63-
6423
self.network = nn.Sequential(
6524
pufferlib.pytorch.layer_init(nn.Conv2d(self.num_layers, cnn_channels, 5, stride=3)),
6625
nn.ReLU(),
@@ -76,9 +35,32 @@ def __init__(self, env, cnn_channels=128, hidden_size=512, **kwargs):
7635
nn.ReLU(),
7736
)
7837

79-
# Initialize max_vec with ones - will be properly set during initialize_to_environment
80-
# This ensures the model works even if initialize_to_environment isn't called
81-
max_vec = torch.ones(self.num_layers, dtype=torch.float32)[None, :, None, None]
38+
max_vec = torch.tensor(
39+
[
40+
9.0,
41+
1.0,
42+
1.0,
43+
10.0,
44+
3.0,
45+
254.0,
46+
1.0,
47+
1.0,
48+
235.0,
49+
8.0,
50+
9.0,
51+
250.0,
52+
29.0,
53+
1.0,
54+
1.0,
55+
8.0,
56+
1.0,
57+
1.0,
58+
6.0,
59+
3.0,
60+
1.0,
61+
2.0,
62+
]
63+
)[None, :, None, None] # noqa:E231
8264
self.register_buffer("max_vec", max_vec)
8365

8466
action_nvec = env.single_action_space.nvec
@@ -147,39 +129,3 @@ def decode_actions(self, hidden):
147129
logits = [dec(hidden) for dec in self.actor]
148130
value = self.value(hidden)
149131
return logits, value
150-
151-
def initialize_to_environment(
152-
self,
153-
features: dict[str, dict],
154-
action_names: list[str],
155-
action_max_params: list[int],
156-
device,
157-
):
158-
"""Initialize policy by mapping our feature normalizations to current environment IDs.
159-
160-
This works like MettaAgent's feature remapping: we have a fixed set of known
161-
features with empirically determined normalizations, and we map them to whatever
162-
IDs the current environment uses.
163-
"""
164-
# Create max_vec based on current environment's feature IDs
165-
max_values = [1.0] * self.num_layers # Default normalization
166-
167-
# Map our known features to the environment's feature IDs
168-
for feature_name, feature_props in features.items():
169-
if "id" in feature_props and 0 <= feature_props["id"] < self.num_layers:
170-
feature_id = feature_props["id"]
171-
172-
# Check if this is a feature we know about
173-
if feature_name in self.feature_normalizations:
174-
# Use our empirically determined normalization
175-
max_values[feature_id] = self.feature_normalizations[feature_name]
176-
elif feature_name.startswith("inv:") and "inv:0" in self.feature_normalizations:
177-
# For unknown inventory items, use a default inventory normalization
178-
max_values[feature_id] = 100.0 # DEFAULT_INVENTORY_NORMALIZATION
179-
elif "normalization" in feature_props:
180-
# Use environment's normalization for unknown features
181-
max_values[feature_id] = feature_props["normalization"]
182-
183-
# Update max_vec with the mapped values
184-
new_max_vec = torch.tensor(max_values, dtype=torch.float32, device=device)[None, :, None, None]
185-
self.max_vec.data = new_max_vec

metta/rl/puffer_policy.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,33 @@ class PytorchAgent(nn.Module):
8282
differences like critic→value, hidden→logits, etc.
8383
"""
8484

85+
# Default feature normalizations for legacy policies using max_vec
86+
DEFAULT_FEATURE_NORMALIZATIONS = {
87+
"type_id": 9.0,
88+
"agent:group": 1.0,
89+
"hp": 1.0,
90+
"agent:frozen": 10.0,
91+
"agent:orientation": 3.0,
92+
"agent:color": 254.0,
93+
"converting": 1.0,
94+
"swappable": 1.0,
95+
"episode_completion_pct": 235.0,
96+
"last_action": 8.0,
97+
"last_action_arg": 9.0,
98+
"last_reward": 250.0,
99+
"agent:glyph": 29.0,
100+
"resource_rewards": 1.0,
101+
# Inventory features (positions 14-21)
102+
"inv:0": 1.0,
103+
"inv:1": 8.0,
104+
"inv:2": 1.0,
105+
"inv:3": 1.0,
106+
"inv:4": 6.0,
107+
"inv:5": 3.0,
108+
"inv:6": 1.0,
109+
"inv:7": 2.0,
110+
}
111+
85112
def __init__(self, policy: nn.Module):
86113
super().__init__()
87114
self.policy = policy
@@ -160,19 +187,59 @@ def initialize_to_environment(
160187
device,
161188
is_training: bool = True,
162189
):
163-
"""Initialize to environment - forward to wrapped policy if it has this method."""
190+
"""Initialize to environment - handle max_vec normalization for legacy policies."""
164191
# is_training parameter is deprecated and ignored - mode is auto-detected
165192

166-
# TODO: This hasattr pattern is a transitional state to support both old and new interfaces.
167-
# Once all policies have been migrated to implement initialize_to_environment,
168-
# we should remove these checks and make the interface mandatory.
193+
# Handle max_vec normalization for policies that use it
194+
target_policy = self._get_inner_policy()
195+
if hasattr(target_policy, "max_vec") and hasattr(target_policy, "num_layers"):
196+
self._update_max_vec_normalizations(target_policy, features, device)
197+
198+
# Forward to wrapped policy if it has initialize_to_environment
169199
if hasattr(self.policy, "initialize_to_environment"):
170200
self.policy.initialize_to_environment(features, action_names, action_max_params, device)
171201
elif hasattr(self.policy, "activate_actions"):
172202
# Fallback to old interface if available
173203
self.policy.activate_actions(action_names, action_max_params, device)
174204
self.device = device
175205

206+
def _get_inner_policy(self):
207+
"""Get the inner policy (unwrap if this is a wrapped policy like Recurrent)."""
208+
target_policy = self.policy
209+
if hasattr(self.policy, "policy") and hasattr(self.policy.policy, "max_vec"):
210+
# This is a wrapped policy, get the inner policy
211+
target_policy = self.policy.policy
212+
return target_policy
213+
214+
def _update_max_vec_normalizations(self, policy, features: dict[str, dict], device):
215+
"""Update max_vec based on feature normalizations for legacy policies."""
216+
# Don't update if the policy has its own feature normalization system
217+
if hasattr(policy, "feature_normalizations"):
218+
return
219+
220+
# Create max_vec based on current environment's feature IDs
221+
max_values = [1.0] * policy.num_layers # Default normalization
222+
223+
# Map our known features to the environment's feature IDs
224+
for feature_name, feature_props in features.items():
225+
if "id" in feature_props and 0 <= feature_props["id"] < policy.num_layers:
226+
feature_id = feature_props["id"]
227+
228+
# Check if this is a feature we know about
229+
if feature_name in self.DEFAULT_FEATURE_NORMALIZATIONS:
230+
# Use our empirically determined normalization
231+
max_values[feature_id] = self.DEFAULT_FEATURE_NORMALIZATIONS[feature_name]
232+
elif feature_name.startswith("inv:") and "inv:0" in self.DEFAULT_FEATURE_NORMALIZATIONS:
233+
# For unknown inventory items, use a default inventory normalization
234+
max_values[feature_id] = 100.0 # DEFAULT_INVENTORY_NORMALIZATION
235+
elif "normalization" in feature_props:
236+
# Use environment's normalization for unknown features
237+
max_values[feature_id] = feature_props["normalization"]
238+
239+
# Update max_vec with the mapped values
240+
new_max_vec = torch.tensor(max_values, dtype=torch.float32, device=device)[None, :, None, None]
241+
policy.max_vec.data = new_max_vec
242+
176243
def l2_reg_loss(self) -> torch.Tensor:
177244
"""L2 regularization loss."""
178245
if hasattr(self.policy, "l2_reg_loss"):

0 commit comments

Comments
 (0)