77
88from internnav .configs .model .base_encoders import ModelCfg
99from internnav .configs .trainer .exp import ExpCfg
10+ from internnav .model .encoder .navdp_backbone import LearnablePositionalEncoding
1011from internnav .model .encoder .navdp_backbone import (
11- LearnablePositionalEncoding ,
12- NavDP_ImageGoal_Backbone ,
13- NavDP_PixelGoal_Backbone ,
14- NavDP_RGBD_Backbone ,
15- SinusoidalPosEmb ,
12+ NavDP_ImageGoal_Backbone as ImageGoal_Backbone ,
1613)
14+ from internnav .model .encoder .navdp_backbone import (
15+ NavDP_PixelGoal_Backbone as PixelGoal_Backbone ,
16+ )
17+ from internnav .model .encoder .navdp_backbone import NavDP_RGBD_Backbone as RGBD_Backbone
18+ from internnav .model .encoder .navdp_backbone import SinusoidalPosEmb
1719
1820
1921class NavDPModelConfig (PretrainedConfig ):
@@ -83,13 +85,13 @@ def __init__(self, config: NavDPModelConfig):
8385 self .token_dim = self .config .model_cfg ['il' ]['token_dim' ]
8486 self .scratch = self .config .model_cfg ['il' ]['scratch' ]
8587 self .finetune = self .config .model_cfg ['il' ]['finetune' ]
86- self .rgbd_encoder = NavDP_RGBD_Backbone (
88+ self .rgbd_encoder = RGBD_Backbone (
8789 self .image_size , self .token_dim , memory_size = self .memory_size , finetune = self .finetune , device = self ._device
8890 )
89- self .pixel_encoder = NavDP_PixelGoal_Backbone (
91+ self .pixel_encoder = PixelGoal_Backbone (
9092 self .image_size , self .token_dim , pixel_channel = self .pixel_channel , device = self ._device
9193 )
92- self .image_encoder = NavDP_ImageGoal_Backbone (self .image_size , self .token_dim , device = self ._device )
94+ self .image_encoder = ImageGoal_Backbone (self .image_size , self .token_dim , device = self ._device )
9395 self .point_encoder = nn .Linear (3 , self .token_dim )
9496
9597 if not self .finetune :
@@ -185,23 +187,6 @@ def predict_critic(self, predict_trajectory, rgbd_embed):
185187 return critic_output
186188
187189 def forward (self , goal_point , goal_image , goal_pixel , input_images , input_depths , output_actions , augment_actions ):
188- # """get device safely"""
189- # # get device safely
190- # try:
191- # # try to get device through model parameters
192- # device = next(self.parameters()).device
193- # except StopIteration:
194- # # model has no parameters, use the default device
195- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196- # # move all inputs to model device
197- # goal_point = goal_point.to(device)
198- # goal_image = goal_image.to(device)
199- # input_images = input_images.to(device)
200- # input_depths = input_depths.to(device)
201- # output_actions = output_actions.to(device)
202- # augment_actions = augment_actions.to(device)
203- # device = self._device
204- # print(f"self.parameters() is:{self.parameters()}")
205190 device = next (self .parameters ()).device
206191
207192 assert input_images .shape [1 ] == self .memory_size
@@ -330,7 +315,6 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep
330315 naction = self .noise_scheduler .step (model_output = noise_pred , timestep = k , sample = naction ).prev_sample
331316
332317 critic_values = self .predict_critic (naction , rgbd_embed )
333- # all_trajectory = torch.cumsum(naction / 4.0, dim=1)
334318
335319 negative_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(critic_values ).argsort ()[0 :8 ]]
336320 positive_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(- critic_values ).argsort ()[0 :8 ]]
@@ -349,12 +333,7 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num
349333 naction = self .noise_scheduler .step (model_output = noise_pred , timestep = k , sample = naction ).prev_sample
350334
351335 critic_values = self .predict_critic (naction , rgbd_embed )
352- # all_trajectory = torch.cumsum(naction / 4.0, dim=1)
353336
354337 negative_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(critic_values ).argsort ()[0 :8 ]]
355338 positive_trajectory = torch .cumsum (naction / 4.0 , dim = 1 )[(- critic_values ).argsort ()[0 :8 ]]
356339 return negative_trajectory , positive_trajectory
357-
358-
359- # if __name__ == "__main__":
360- # policy = NavDPNet(config=)
0 commit comments