1- import copy
2- import math
31import torch
42import torch .nn as nn
53import torch .optim as optim
64import torch .nn .functional as F
7- import torch .distributions as torchd
8- from torch .distributions import Normal , Categorical
95
106
117class HIMEstimator (nn .Module ):
128 def __init__ (self ,
139 temporal_steps ,
1410 num_one_step_obs ,
1511 enc_hidden_dims = [128 , 64 , 16 ],
16- tar_hidden_dims = [128 , 64 ],
12+ tar_hidden_dims = [128 , 64 ], # kept for config compatibility, unused
1713 activation = 'elu' ,
1814 learning_rate = 1e-3 ,
1915 max_grad_norm = 10.0 ,
20- num_prototype = 32 ,
21- temperature = 3.0 ,
16+ num_prototype = 32 , # kept for config compatibility, unused
17+ temperature = 3.0 , # kept for config compatibility, unused
18+ kl_weight = 1.0 ,
2219 ** kwargs ):
2320 if kwargs :
24- print ("Estimator_CL .__init__ got unexpected arguments, which will be ignored: " + str (
25- [key for key in kwargs .keys ()]))
21+ print ("HIMEstimator .__init__ got unexpected arguments, which will be ignored: " +
22+ str ( [key for key in kwargs .keys ()]))
2623 super (HIMEstimator , self ).__init__ ()
27- activation = get_activation (activation )
24+ activation_fn = get_activation (activation )
2825
2926 self .temporal_steps = temporal_steps
3027 self .num_one_step_obs = num_one_step_obs
3128 self .num_latent = enc_hidden_dims [- 1 ]
3229 self .max_grad_norm = max_grad_norm
33- self .temperature = temperature
30+ self .kl_weight = kl_weight
3431
35- # Encoder
32+ # Encoder: outputs vel(3) + mu(num_latent) + logvar(num_latent)
3633 enc_input_dim = self .temporal_steps * self .num_one_step_obs
3734 enc_layers = []
3835 for l in range (len (enc_hidden_dims ) - 1 ):
39- enc_layers += [nn .Linear (enc_input_dim , enc_hidden_dims [l ]), activation ]
36+ enc_layers += [nn .Linear (enc_input_dim , enc_hidden_dims [l ]), activation_fn ]
4037 enc_input_dim = enc_hidden_dims [l ]
41- enc_layers += [nn .Linear (enc_input_dim , enc_hidden_dims [ - 1 ] + 3 )]
38+ enc_layers += [nn .Linear (enc_input_dim , 3 + self . num_latent * 2 )]
4239 self .encoder = nn .Sequential (* enc_layers )
4340
44- # Target
45- tar_input_dim = self .num_one_step_obs
46- tar_layers = []
47- for l in range (len (tar_hidden_dims )):
48- tar_layers += [nn .Linear (tar_input_dim , tar_hidden_dims [l ]), activation ]
49- tar_input_dim = tar_hidden_dims [l ]
50- tar_layers += [nn .Linear (tar_input_dim , enc_hidden_dims [- 1 ])]
51- self .target = nn .Sequential (* tar_layers )
52-
53- # Prototype
54- self .proto = nn .Embedding (num_prototype , enc_hidden_dims [- 1 ])
55-
56- # Optimizer
5741 self .learning_rate = learning_rate
5842 self .optimizer = optim .Adam (self .parameters (), lr = self .learning_rate )
5943
44+ def reparameterize (self , mu , logvar ):
45+ std = torch .exp (0.5 * logvar )
46+ eps = torch .randn_like (std )
47+ return mu + eps * std
48+
6049 def get_latent (self , obs_history ):
61- vel , z = self .encode (obs_history )
62- return vel .detach (), z .detach ()
50+ """Inference: use mu directly (no sampling noise)."""
51+ out = self .encoder (obs_history .detach ())
52+ vel = out [..., :3 ]
53+ mu = out [..., 3 :3 + self .num_latent ]
54+ return vel .detach (), mu .detach ()
6355
6456 def forward (self , obs_history ):
65- parts = self .encoder (obs_history .detach ())
66- vel , z = parts [..., :3 ], parts [..., 3 :]
67- z = F .normalize (z , dim = - 1 , p = 2 )
68- return vel .detach (), z .detach ()
57+ return self .get_latent (obs_history )
6958
7059 def encode (self , obs_history ):
71- parts = self .encoder (obs_history .detach ())
72- vel , z = parts [..., :3 ], parts [..., 3 :]
73- z = F .normalize (z , dim = - 1 , p = 2 )
74- return vel , z
60+ """Training: sample z via reparameterization."""
61+ out = self .encoder (obs_history .detach ())
62+ vel = out [..., :3 ]
63+ mu = out [..., 3 :3 + self .num_latent ]
64+ logvar = out [..., 3 + self .num_latent :]
65+ z = self .reparameterize (mu , logvar )
66+ return vel , mu , logvar , z
7567
7668 def update (self , obs_history , next_critic_obs , lr = None ):
7769 if lr is not None :
7870 self .learning_rate = lr
7971 for param_group in self .optimizer .param_groups :
8072 param_group ['lr' ] = self .learning_rate
81-
82- vel = next_critic_obs [:, self .num_one_step_obs :self .num_one_step_obs + 3 ].detach ()
83- next_obs = next_critic_obs .detach ()[:, 3 :self .num_one_step_obs + 3 ]
84-
85- z_s = self .encoder (obs_history )
86- z_t = self .target (next_obs )
87- pred_vel , z_s = z_s [..., :3 ], z_s [..., 3 :]
88-
89- z_s = F .normalize (z_s , dim = - 1 , p = 2 )
90- z_t = F .normalize (z_t , dim = - 1 , p = 2 )
9173
92- with torch .no_grad ():
93- w = self .proto .weight .data .clone ()
94- w = F .normalize (w , dim = - 1 , p = 2 )
95- self .proto .weight .copy_ (w )
74+ # Ground-truth velocity from privileged obs
75+ vel_gt = next_critic_obs [:, self .num_one_step_obs :self .num_one_step_obs + 3 ].detach ()
9676
97- score_s = z_s @ self .proto .weight .T
98- score_t = z_t @ self .proto .weight .T
77+ pred_vel , mu , logvar , _ = self .encode (obs_history )
9978
100- with torch .no_grad ():
101- q_s = sinkhorn (score_s )
102- q_t = sinkhorn (score_t )
79+ estimation_loss = F .mse_loss (pred_vel , vel_gt )
10380
104- log_p_s = F . log_softmax ( score_s / self . temperature , dim = - 1 )
105- log_p_t = F . log_softmax ( score_t / self . temperature , dim = - 1 )
81+ # KL divergence: D_KL( N(mu, sigma) || N(0,1) )
82+ kl_loss = - 0.5 * torch . mean ( 1 + logvar - mu . pow ( 2 ) - logvar . exp () )
10683
107- swap_loss = - 0.5 * (q_s * log_p_t + q_t * log_p_s ).mean ()
108- estimation_loss = F .mse_loss (pred_vel , vel )
109- losses = estimation_loss + swap_loss
84+ loss = estimation_loss + self .kl_weight * kl_loss
11085
11186 self .optimizer .zero_grad ()
112- losses .backward ()
87+ loss .backward ()
11388 nn .utils .clip_grad_norm_ (self .parameters (), self .max_grad_norm )
11489 self .optimizer .step ()
11590
116- return estimation_loss .item (), swap_loss .item ()
117-
118-
119- @torch .no_grad ()
120- def sinkhorn (out , eps = 0.05 , iters = 3 ):
121- Q = torch .exp (out / eps ).T
122- K , B = Q .shape [0 ], Q .shape [1 ]
123- Q /= Q .sum ()
124-
125- for it in range (iters ):
126- # normalize each row: total weight per prototype must be 1/K
127- Q /= torch .sum (Q , dim = 1 , keepdim = True )
128- Q /= K
129-
130- # normalize each column: total weight per sample must be 1/B
131- Q /= torch .sum (Q , dim = 0 , keepdim = True )
132- Q /= B
133- return (Q * B ).T
91+ return estimation_loss .item (), kl_loss .item ()
13492
13593
13694def get_activation (act_name ):
@@ -152,4 +110,4 @@ def get_activation(act_name):
152110 return nn .Sigmoid ()
153111 else :
154112 print ("invalid activation function!" )
155- return None
113+ return None
0 commit comments