Skip to content

Commit 1545c20

Browse files
committed
[FIX] Support NavDP finetuning
1 parent b73c85e commit 1545c20

3 files changed

Lines changed: 14 additions & 32 deletions

File tree

internnav/dataset/navdp_dataset_lerobot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,10 @@ def __getitem__(self, index):
511511
camera_intrinsic,
512512
trajectory_base_extrinsic,
513513
)
514+
# pixel channel == 7 represents the navdp works pixel navigation under asynchronous pace,
515+
# pixel_mask (1), the history image with the assigned pixel goal (3), current image (3)
516+
# if pixel_channel == 4, pixel goal is assigned at current frame, therefore,
517+
# only pixel_mask (1) and current image (3) are needed
514518
if self.pixel_channel == 7:
515519
pixel_goal = np.concatenate((pixel_goal, memory_images[-1]), axis=-1)
516520

internnav/model/basemodel/navdp/navdp_policy.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
from internnav.configs.model.base_encoders import ModelCfg
99
from internnav.configs.trainer.exp import ExpCfg
10+
from internnav.model.encoder.navdp_backbone import LearnablePositionalEncoding
1011
from internnav.model.encoder.navdp_backbone import (
11-
LearnablePositionalEncoding,
12-
NavDP_ImageGoal_Backbone,
13-
NavDP_PixelGoal_Backbone,
14-
NavDP_RGBD_Backbone,
15-
SinusoidalPosEmb,
12+
NavDP_ImageGoal_Backbone as ImageGoal_Backbone,
1613
)
14+
from internnav.model.encoder.navdp_backbone import (
15+
NavDP_PixelGoal_Backbone as PixelGoal_Backbone,
16+
)
17+
from internnav.model.encoder.navdp_backbone import NavDP_RGBD_Backbone as RGBD_Backbone
18+
from internnav.model.encoder.navdp_backbone import SinusoidalPosEmb
1719

1820

1921
class NavDPModelConfig(PretrainedConfig):
@@ -83,13 +85,13 @@ def __init__(self, config: NavDPModelConfig):
8385
self.token_dim = self.config.model_cfg['il']['token_dim']
8486
self.scratch = self.config.model_cfg['il']['scratch']
8587
self.finetune = self.config.model_cfg['il']['finetune']
86-
self.rgbd_encoder = NavDP_RGBD_Backbone(
88+
self.rgbd_encoder = RGBD_Backbone(
8789
self.image_size, self.token_dim, memory_size=self.memory_size, finetune=self.finetune, device=self._device
8890
)
89-
self.pixel_encoder = NavDP_PixelGoal_Backbone(
91+
self.pixel_encoder = PixelGoal_Backbone(
9092
self.image_size, self.token_dim, pixel_channel=self.pixel_channel, device=self._device
9193
)
92-
self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device)
94+
self.image_encoder = ImageGoal_Backbone(self.image_size, self.token_dim, device=self._device)
9395
self.point_encoder = nn.Linear(3, self.token_dim)
9496

9597
if not self.finetune:
@@ -185,23 +187,6 @@ def predict_critic(self, predict_trajectory, rgbd_embed):
185187
return critic_output
186188

187189
def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths, output_actions, augment_actions):
188-
# """get device safely"""
189-
# # get device safely
190-
# try:
191-
# # try to get device through model parameters
192-
# device = next(self.parameters()).device
193-
# except StopIteration:
194-
# # model has no parameters, use the default device
195-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196-
# # move all inputs to model device
197-
# goal_point = goal_point.to(device)
198-
# goal_image = goal_image.to(device)
199-
# input_images = input_images.to(device)
200-
# input_depths = input_depths.to(device)
201-
# output_actions = output_actions.to(device)
202-
# augment_actions = augment_actions.to(device)
203-
# device = self._device
204-
# print(f"self.parameters() is:{self.parameters()}")
205190
device = next(self.parameters()).device
206191

207192
assert input_images.shape[1] == self.memory_size
@@ -330,7 +315,6 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep
330315
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
331316

332317
critic_values = self.predict_critic(naction, rgbd_embed)
333-
# all_trajectory = torch.cumsum(naction / 4.0, dim=1)
334318

335319
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
336320
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]
@@ -349,12 +333,7 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num
349333
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
350334

351335
critic_values = self.predict_critic(naction, rgbd_embed)
352-
# all_trajectory = torch.cumsum(naction / 4.0, dim=1)
353336

354337
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
355338
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]
356339
return negative_trajectory, positive_trajectory
357-
358-
359-
# if __name__ == "__main__":
360-
# policy = NavDPNet(config=)

scripts/train/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def main(config, model_class, model_config_class):
103103
# Initialize distributed training environment
104104
if world_size > 1:
105105
try:
106-
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
107106
dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
108107
print("Distributed initialization SUCCESS")
109108
except Exception as e:

0 commit comments

Comments
 (0)