Skip to content

Commit a5a6696

Browse files
committed
Replace HIM dual-encoder + terrain table with DreamWaQ-style VAE
The original HIMEstimator used two encoders (primary + target) with a prototype embedding table and Sinkhorn-based contrastive (swap) loss to learn terrain-aware latent representations. Following DreamWaQ (Nahrendra et al., 2023), this replaces that architecture with a simple Variational Autoencoder: - Encoder now outputs vel(3) + mu(16) + logvar(16) instead of vel(3) + z(16) - Reparameterization trick (z = mu + eps*sigma) is used during training; inference uses mu directly (deterministic, no noise) - Loss = MSE velocity estimation + beta * KL divergence to N(0,1) - Removes: target encoder, prototype embedding table, Sinkhorn algorithm - Old config keys (tar_hidden_dims, num_prototype, temperature) are kept as no-op params so existing config files remain compatible - him_ppo.py and him_on_policy_runner.py updated: swap_loss -> kl_loss
1 parent ef289ac commit a5a6696

3 files changed

Lines changed: 48 additions & 90 deletions

File tree

rsl_rl/rsl_rl/algorithms/him_ppo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def update(self):
120120
mean_value_loss = 0
121121
mean_surrogate_loss = 0
122122
mean_estimation_loss = 0
123-
mean_swap_loss = 0
123+
mean_kl_loss = 0
124124

125125
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
126126

@@ -150,7 +150,7 @@ def update(self):
150150
param_group['lr'] = self.learning_rate
151151

152152
#Estimator Update
153-
estimation_loss, swap_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate)
153+
estimation_loss, kl_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate)
154154

155155
# Surrogate loss
156156
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
@@ -180,13 +180,13 @@ def update(self):
180180
mean_value_loss += value_loss.item()
181181
mean_surrogate_loss += surrogate_loss.item()
182182
mean_estimation_loss += estimation_loss
183-
mean_swap_loss += swap_loss
183+
mean_kl_loss += kl_loss
184184

185185
num_updates = self.num_learning_epochs * self.num_mini_batches
186186
mean_value_loss /= num_updates
187187
mean_surrogate_loss /= num_updates
188188
mean_estimation_loss /= num_updates
189-
mean_swap_loss /= num_updates
189+
mean_kl_loss /= num_updates
190190
self.storage.clear()
191191

192-
return mean_value_loss, mean_surrogate_loss, estimation_loss, swap_loss
192+
return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss
Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,136 +1,94 @@
1-
import copy
2-
import math
31
import torch
42
import torch.nn as nn
53
import torch.optim as optim
64
import torch.nn.functional as F
7-
import torch.distributions as torchd
8-
from torch.distributions import Normal, Categorical
95

106

117
class HIMEstimator(nn.Module):
128
def __init__(self,
139
temporal_steps,
1410
num_one_step_obs,
1511
enc_hidden_dims=[128, 64, 16],
16-
tar_hidden_dims=[128, 64],
12+
tar_hidden_dims=[128, 64], # kept for config compatibility, unused
1713
activation='elu',
1814
learning_rate=1e-3,
1915
max_grad_norm=10.0,
20-
num_prototype=32,
21-
temperature=3.0,
16+
num_prototype=32, # kept for config compatibility, unused
17+
temperature=3.0, # kept for config compatibility, unused
18+
kl_weight=1.0,
2219
**kwargs):
2320
if kwargs:
24-
print("Estimator_CL.__init__ got unexpected arguments, which will be ignored: " + str(
25-
[key for key in kwargs.keys()]))
21+
print("HIMEstimator.__init__ got unexpected arguments, which will be ignored: " +
22+
str([key for key in kwargs.keys()]))
2623
super(HIMEstimator, self).__init__()
27-
activation = get_activation(activation)
24+
activation_fn = get_activation(activation)
2825

2926
self.temporal_steps = temporal_steps
3027
self.num_one_step_obs = num_one_step_obs
3128
self.num_latent = enc_hidden_dims[-1]
3229
self.max_grad_norm = max_grad_norm
33-
self.temperature = temperature
30+
self.kl_weight = kl_weight
3431

35-
# Encoder
32+
# Encoder: outputs vel(3) + mu(num_latent) + logvar(num_latent)
3633
enc_input_dim = self.temporal_steps * self.num_one_step_obs
3734
enc_layers = []
3835
for l in range(len(enc_hidden_dims) - 1):
39-
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation]
36+
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation_fn]
4037
enc_input_dim = enc_hidden_dims[l]
41-
enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[-1] + 3)]
38+
enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)]
4239
self.encoder = nn.Sequential(*enc_layers)
4340

44-
# Target
45-
tar_input_dim = self.num_one_step_obs
46-
tar_layers = []
47-
for l in range(len(tar_hidden_dims)):
48-
tar_layers += [nn.Linear(tar_input_dim, tar_hidden_dims[l]), activation]
49-
tar_input_dim = tar_hidden_dims[l]
50-
tar_layers += [nn.Linear(tar_input_dim, enc_hidden_dims[-1])]
51-
self.target = nn.Sequential(*tar_layers)
52-
53-
# Prototype
54-
self.proto = nn.Embedding(num_prototype, enc_hidden_dims[-1])
55-
56-
# Optimizer
5741
self.learning_rate = learning_rate
5842
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
5943

44+
def reparameterize(self, mu, logvar):
45+
std = torch.exp(0.5 * logvar)
46+
eps = torch.randn_like(std)
47+
return mu + eps * std
48+
6049
def get_latent(self, obs_history):
61-
vel, z = self.encode(obs_history)
62-
return vel.detach(), z.detach()
50+
"""Inference: use mu directly (no sampling noise)."""
51+
out = self.encoder(obs_history.detach())
52+
vel = out[..., :3]
53+
mu = out[..., 3:3 + self.num_latent]
54+
return vel.detach(), mu.detach()
6355

6456
def forward(self, obs_history):
65-
parts = self.encoder(obs_history.detach())
66-
vel, z = parts[..., :3], parts[..., 3:]
67-
z = F.normalize(z, dim=-1, p=2)
68-
return vel.detach(), z.detach()
57+
return self.get_latent(obs_history)
6958

7059
def encode(self, obs_history):
71-
parts = self.encoder(obs_history.detach())
72-
vel, z = parts[..., :3], parts[..., 3:]
73-
z = F.normalize(z, dim=-1, p=2)
74-
return vel, z
60+
"""Training: sample z via reparameterization."""
61+
out = self.encoder(obs_history.detach())
62+
vel = out[..., :3]
63+
mu = out[..., 3:3 + self.num_latent]
64+
logvar = out[..., 3 + self.num_latent:]
65+
z = self.reparameterize(mu, logvar)
66+
return vel, mu, logvar, z
7567

7668
def update(self, obs_history, next_critic_obs, lr=None):
7769
if lr is not None:
7870
self.learning_rate = lr
7971
for param_group in self.optimizer.param_groups:
8072
param_group['lr'] = self.learning_rate
81-
82-
vel = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs+3].detach()
83-
next_obs = next_critic_obs.detach()[:, 3:self.num_one_step_obs+3]
84-
85-
z_s = self.encoder(obs_history)
86-
z_t = self.target(next_obs)
87-
pred_vel, z_s = z_s[..., :3], z_s[..., 3:]
88-
89-
z_s = F.normalize(z_s, dim=-1, p=2)
90-
z_t = F.normalize(z_t, dim=-1, p=2)
9173

92-
with torch.no_grad():
93-
w = self.proto.weight.data.clone()
94-
w = F.normalize(w, dim=-1, p=2)
95-
self.proto.weight.copy_(w)
74+
# Ground-truth velocity from privileged obs
75+
vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach()
9676

97-
score_s = z_s @ self.proto.weight.T
98-
score_t = z_t @ self.proto.weight.T
77+
pred_vel, mu, logvar, _ = self.encode(obs_history)
9978

100-
with torch.no_grad():
101-
q_s = sinkhorn(score_s)
102-
q_t = sinkhorn(score_t)
79+
estimation_loss = F.mse_loss(pred_vel, vel_gt)
10380

104-
log_p_s = F.log_softmax(score_s / self.temperature, dim=-1)
105-
log_p_t = F.log_softmax(score_t / self.temperature, dim=-1)
81+
# KL divergence: D_KL( N(mu, sigma) || N(0,1) )
82+
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
10683

107-
swap_loss = -0.5 * (q_s * log_p_t + q_t * log_p_s).mean()
108-
estimation_loss = F.mse_loss(pred_vel, vel)
109-
losses = estimation_loss + swap_loss
84+
loss = estimation_loss + self.kl_weight * kl_loss
11085

11186
self.optimizer.zero_grad()
112-
losses.backward()
87+
loss.backward()
11388
nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
11489
self.optimizer.step()
11590

116-
return estimation_loss.item(), swap_loss.item()
117-
118-
119-
@torch.no_grad()
120-
def sinkhorn(out, eps=0.05, iters=3):
121-
Q = torch.exp(out / eps).T
122-
K, B = Q.shape[0], Q.shape[1]
123-
Q /= Q.sum()
124-
125-
for it in range(iters):
126-
# normalize each row: total weight per prototype must be 1/K
127-
Q /= torch.sum(Q, dim=1, keepdim=True)
128-
Q /= K
129-
130-
# normalize each column: total weight per sample must be 1/B
131-
Q /= torch.sum(Q, dim=0, keepdim=True)
132-
Q /= B
133-
return (Q * B).T
91+
return estimation_loss.item(), kl_loss.item()
13492

13593

13694
def get_activation(act_name):
@@ -152,4 +110,4 @@ def get_activation(act_name):
152110
return nn.Sigmoid()
153111
else:
154112
print("invalid activation function!")
155-
return None
113+
return None

rsl_rl/rsl_rl/runners/him_on_policy_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False):
139139
start = stop
140140
self.alg.compute_returns(critic_obs)
141141

142-
mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_swap_loss = self.alg.update()
142+
mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss = self.alg.update()
143143
stop = time.time()
144144
learn_time = stop - start
145145
if self.log_dir is not None:
@@ -176,7 +176,7 @@ def log(self, locs, width=80, pad=35):
176176
self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
177177
self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
178178
self.writer.add_scalar('Loss/Estimation Loss', locs['mean_estimation_loss'], locs['it'])
179-
self.writer.add_scalar('Loss/Swap Loss', locs['mean_swap_loss'], locs['it'])
179+
self.writer.add_scalar('Loss/KL Loss', locs['mean_kl_loss'], locs['it'])
180180
self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
181181
self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
182182
self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
@@ -198,7 +198,7 @@ def log(self, locs, width=80, pad=35):
198198
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
199199
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
200200
f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n"""
201-
f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n"""
201+
f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n"""
202202
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
203203
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
204204
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
@@ -212,7 +212,7 @@ def log(self, locs, width=80, pad=35):
212212
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
213213
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
214214
f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n"""
215-
f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n"""
215+
f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n"""
216216
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
217217
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
218218
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")

0 commit comments

Comments
 (0)