@@ -63,14 +63,23 @@ def get_det_action(self, state):
6363
6464
6565class Critic (nn .Module ):
66- def __init__ (self , state_size , action_size , hidden_size = 256 ):
66+ def __init__ (self , state_size , action_size , hidden_size = 256 , is_discrete = False ):
6767 super (Critic , self ).__init__ ()
68- self .fc1 = nn .Linear (state_size + action_size , hidden_size )
68+ self .is_discrete = is_discrete
69+ if self .is_discrete :
70+ self .action_embedding = nn .Embedding (action_size , hidden_size )
71+ self .fc1 = nn .Linear (state_size + hidden_size , hidden_size )
72+ else :
73+ self .fc1 = nn .Linear (state_size + action_size , hidden_size )
6974 self .fc2 = nn .Linear (hidden_size , hidden_size )
7075 self .fc3 = nn .Linear (hidden_size , 1 )
7176
7277 def forward (self , state , action ):
73- x = torch .cat ((state , action ), dim = - 1 )
78+ if self .is_discrete :
79+ action_emb = self .action_embedding (action .long ().squeeze (- 1 ))
80+ x = torch .cat ((state , action_emb ), dim = - 1 )
81+ else :
82+ x = torch .cat ((state , action ), dim = - 1 )
7483 x = F .relu (self .fc1 (x ))
7584 x = F .relu (self .fc2 (x ))
7685 return self .fc3 (x )
@@ -106,12 +115,12 @@ def __init__(self, cfg):
106115 self .is_discrete = cfg .dataset .is_discrete
107116
108117 self .actor = Actor (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size , is_discrete = self .is_discrete ).to (self .device )
109- self .critic1 = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size ).to (self .device )
110- self .critic2 = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size ).to (self .device )
118+ self .critic1 = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size , is_discrete = self . is_discrete ).to (self .device )
119+ self .critic2 = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size , is_discrete = self . is_discrete ).to (self .device )
111120 self .value_net = Value (cfg .dataset .state_dim , cfg .iql .hidden_size ).to (self .device )
112121
113- self .critic1_target = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size ).to (self .device )
114- self .critic2_target = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size ).to (self .device )
122+ self .critic1_target = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size , is_discrete = self . is_discrete ).to (self .device )
123+ self .critic2_target = Critic (cfg .dataset .state_dim , cfg .dataset .act_dim , cfg .iql .hidden_size , is_discrete = self . is_discrete ).to (self .device )
115124 self .critic1_target .load_state_dict (self .critic1 .state_dict ())
116125 self .critic2_target .load_state_dict (self .critic2 .state_dict ())
117126
@@ -135,13 +144,8 @@ def learn(self, batch):
135144
136145 # Value loss
137146 with torch .no_grad ():
138- if self .is_discrete :
139- actions_one_hot = F .one_hot (actions .squeeze ().long (), num_classes = self .actor .logits .out_features ).float ()
140- q1 = self .critic1_target (states , actions_one_hot )
141- q2 = self .critic2_target (states , actions_one_hot )
142- else :
143- q1 = self .critic1_target (states , actions )
144- q2 = self .critic2_target (states , actions )
147+ q1 = self .critic1_target (states , actions )
148+ q2 = self .critic2_target (states , actions )
145149 min_q = torch .min (q1 , q2 )
146150 value = self .value_net (states )
147151 value_loss = loss_fn (min_q - value , self .expectile ).mean ()
@@ -155,7 +159,10 @@ def learn(self, batch):
155159 exp_a = torch .exp ((min_q - v ) * self .temperature )
156160 exp_a = torch .min (exp_a , torch .tensor (100.0 , device = self .device ))
157161 _ , dist = self .actor .evaluate (states )
158- log_probs = dist .log_prob (actions )
162+ if self .is_discrete :
163+ log_probs = dist .log_prob (actions .squeeze (- 1 ).long ())
164+ else :
165+ log_probs = dist .log_prob (actions )
159166 actor_loss = - (exp_a * log_probs ).mean ()
160167 self .actor_optimizer .zero_grad ()
161168 actor_loss .backward ()
@@ -166,13 +173,8 @@ def learn(self, batch):
166173 next_v = self .value_net (next_states )
167174 q_target = rewards + self .gamma * (1 - dones ) * next_v
168175
169- if self .is_discrete :
170- actions_one_hot = F .one_hot (actions .squeeze ().long (), num_classes = self .actor .logits .out_features ).float ()
171- q1 = self .critic1 (states , actions_one_hot )
172- q2 = self .critic2 (states , actions_one_hot )
173- else :
174- q1 = self .critic1 (states , actions )
175- q2 = self .critic2 (states , actions )
176+ q1 = self .critic1 (states , actions )
177+ q2 = self .critic2 (states , actions )
176178
177179 critic1_loss = F .mse_loss (q1 , q_target )
178180 critic2_loss = F .mse_loss (q2 , q_target )
0 commit comments