Skip to content

Commit b73c85e

Browse files
committed
[FIX] update navdp training parameters
1 parent eb1d0e3 commit b73c85e

3 files changed

Lines changed: 13 additions & 40 deletions

File tree

internnav/model/basemodel/navdp/navdp_policy.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77

88
from internnav.configs.model.base_encoders import ModelCfg
99
from internnav.configs.trainer.exp import ExpCfg
10-
from internnav.model.encoder.navdp_backbone import *
10+
from internnav.model.encoder.navdp_backbone import (
11+
LearnablePositionalEncoding,
12+
NavDP_ImageGoal_Backbone,
13+
NavDP_PixelGoal_Backbone,
14+
NavDP_RGBD_Backbone,
15+
SinusoidalPosEmb,
16+
)
1117

1218

1319
class NavDPModelConfig(PretrainedConfig):
@@ -324,7 +330,7 @@ def predict_pointgoal_batch_action_vel(self, goal_point, input_images, input_dep
324330
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
325331

326332
critic_values = self.predict_critic(naction, rgbd_embed)
327-
all_trajectory = torch.cumsum(naction / 4.0, dim=1)
333+
# all_trajectory = torch.cumsum(naction / 4.0, dim=1)
328334

329335
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
330336
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]
@@ -343,12 +349,12 @@ def predict_nogoal_batch_action_vel(self, input_images, input_depths, sample_num
343349
naction = self.noise_scheduler.step(model_output=noise_pred, timestep=k, sample=naction).prev_sample
344350

345351
critic_values = self.predict_critic(naction, rgbd_embed)
346-
all_trajectory = torch.cumsum(naction / 4.0, dim=1)
352+
# all_trajectory = torch.cumsum(naction / 4.0, dim=1)
347353

348354
negative_trajectory = torch.cumsum(naction / 4.0, dim=1)[(critic_values).argsort()[0:8]]
349355
positive_trajectory = torch.cumsum(naction / 4.0, dim=1)[(-critic_values).argsort()[0:8]]
350356
return negative_trajectory, positive_trajectory
351357

352358

353359
# if __name__ == "__main__":
354-
# policy = NavDPNet(config=)
360+
# policy = NavDPNet(config=)

scripts/train/configs/navdp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
),
3030
il=IlCfg(
3131
epochs=1000,
32-
batch_size=16,
32+
batch_size=32,
3333
lr=1e-4,
3434
num_workers=8,
3535
weight_decay=1e-4, # TODO
@@ -57,6 +57,7 @@
5757
prior_sample=False,
5858
memory_size=8,
5959
predict_size=24,
60+
pixel_channel=4,
6061
temporal_depth=16,
6162
heads=8,
6263
token_dim=384,

scripts/train/train.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,16 @@ def on_save(self, args, state, control, **kwargs):
6060

6161

6262
def _make_dir(config):
63-
config.tensorboard_dir = config.tensorboard_dir % config.name
6463
config.tensorboard_dir = config.tensorboard_dir % config.name
6564
config.checkpoint_folder = config.checkpoint_folder % config.name
6665
config.log_dir = config.log_dir % config.name
6766
config.output_dir = config.output_dir % config.name
6867
if not os.path.exists(config.tensorboard_dir):
6968
os.makedirs(config.tensorboard_dir, exist_ok=True)
70-
os.makedirs(config.tensorboard_dir, exist_ok=True)
7169
if not os.path.exists(config.checkpoint_folder):
7270
os.makedirs(config.checkpoint_folder, exist_ok=True)
73-
os.makedirs(config.checkpoint_folder, exist_ok=True)
7471
if not os.path.exists(config.log_dir):
7572
os.makedirs(config.log_dir, exist_ok=True)
76-
os.makedirs(config.log_dir, exist_ok=True)
7773

7874

7975
def main(config, model_class, model_config_class):
@@ -98,14 +94,12 @@ def main(config, model_class, model_config_class):
9894
world_size = int(os.getenv('WORLD_SIZE', '1'))
9995
rank = int(os.getenv('RANK', '0'))
10096

101-
10297
# Set CUDA device for each process
10398
device_id = local_rank
10499
torch.cuda.set_device(device_id)
105100
device = torch.device(f'cuda:{device_id}')
106101
print(f"World size: {world_size}, Local rank: {local_rank}, Global rank: {rank}")
107102

108-
109103
# Initialize distributed training environment
110104
if world_size > 1:
111105
try:
@@ -116,7 +110,6 @@ def main(config, model_class, model_config_class):
116110
print(f"Distributed initialization FAILED: {str(e)}")
117111
world_size = 1
118112

119-
print("=" * 50)
120113
print("=" * 50)
121114
print("After distributed init:")
122115
print(f"LOCAL_RANK: {local_rank}")
@@ -146,13 +139,10 @@ def main(config, model_class, model_config_class):
146139
print(f"Buffer {name} is on wrong device {buffer.device}, should be moved to {device}")
147140
buffer.data = buffer.data.to(device)
148141

149-
150142
# If distributed training, wrap the model with DDP
151143
if world_size > 1:
152144
model = torch.nn.parallel.DistributedDataParallel(
153-
model, device_ids=[local_rank],
154-
output_device=local_rank,
155-
find_unused_parameters=True
145+
model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True
156146
)
157147
# ------------ load logger ------------
158148
train_logger_filename = os.path.join(config.log_dir, 'train.log')
@@ -162,15 +152,10 @@ def main(config, model_class, model_config_class):
162152
level=logging.INFO,
163153
format_str='%(asctime)-15s %(message)s',
164154
filename=train_logger_filename,
165-
name='train',
166-
level=logging.INFO,
167-
format_str='%(asctime)-15s %(message)s',
168-
filename=train_logger_filename,
169155
)
170156
else:
171157
# Other processes use console logging
172158
train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s')
173-
train_logger = MyLogger(name='train', level=logging.INFO, format_str='%(asctime)-15s %(message)s')
174159
transformers_logger = logging.getLogger("transformers")
175160
if transformers_logger.hasHandlers():
176161
transformers_logger.handlers = []
@@ -180,18 +165,6 @@ def main(config, model_class, model_config_class):
180165

181166
# ------------ load dataset ------------
182167
if config.model_name == "navdp":
183-
train_dataset_data = NavDP_Base_Datset(
184-
config.il.root_dir,
185-
config.il.dataset_navdp,
186-
config.il.memory_size,
187-
config.il.predict_size,
188-
config.il.batch_size,
189-
config.il.image_size,
190-
config.il.scene_scale,
191-
preload=config.il.preload,
192-
random_digit=config.il.random_digit,
193-
prior_sample=config.il.prior_sample,
194-
)
195168
train_dataset_data = NavDP_Base_Datset(
196169
config.il.root_dir,
197170
config.il.dataset_navdp,
@@ -239,7 +212,6 @@ def main(config, model_class, model_config_class):
239212
config.il.lerobot_features_dir,
240213
dataset_data=train_dataset_data,
241214
batch_size=config.il.batch_size,
242-
batch_size=config.il.batch_size,
243215
)
244216
collate_fn = rdp_collate_fn(global_batch_size=global_batch_size)
245217
elif config.model_name == 'navdp':
@@ -255,7 +227,6 @@ def main(config, model_class, model_config_class):
255227
deepspeed='',
256228
gradient_checkpointing=False,
257229
bf16=False, # fp16=False,
258-
bf16=False, # fp16=False,
259230
tf32=False,
260231
per_device_train_batch_size=config.il.batch_size,
261232
gradient_accumulation_steps=1,
@@ -267,7 +238,6 @@ def main(config, model_class, model_config_class):
267238
logging_steps=10.0,
268239
num_train_epochs=config.il.epochs,
269240
save_strategy='epoch', # no
270-
save_strategy='epoch', # no
271241
save_steps=config.il.save_interval_epochs,
272242
save_total_limit=8,
273243
report_to=config.il.report_to,
@@ -279,7 +249,6 @@ def main(config, model_class, model_config_class):
279249
dataloader_drop_last=True,
280250
disable_tqdm=True,
281251
log_level="info",
282-
log_level="info",
283252
)
284253

285254
# Create the trainer
@@ -299,17 +268,14 @@ def main(config, model_class, model_config_class):
299268
except Exception as e:
300269
import traceback
301270

302-
303271
print(f"Unhandled exception: {str(e)}")
304272
print("Stack trace:")
305273
traceback.print_exc()
306274

307-
308275
# If distributed environment, ensure all processes exit
309276
if dist.is_initialized():
310277
dist.destroy_process_group()
311278

312-
313279
raise
314280

315281

0 commit comments

Comments
 (0)