-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
163 lines (141 loc) · 6.07 KB
/
train.py
File metadata and controls
163 lines (141 loc) · 6.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from pathlib import Path
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from torch.utils.data import DataLoader
import torch
import tinygrad
from tinygrad import Tensor, nn, TinyJit
from omegaconf import ListConfig, OmegaConf
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
from act import *
from utils import clip_grad_norm_
import argparse
# Start of training code
parser=argparse.ArgumentParser(description="Argument Parser for ACT training on simulated environments")
parser.add_argument("--env_name", type=str, choices=['aloha_sim_transfer_cube_human', 'aloha_sim_insertion_human'], default='aloha_sim_insertion_human')
parser.add_argument("--model_starting_point", type=str)
parser.add_argument("--model_start_step_count", type=int)
args=parser.parse_args()
env_name = args.env_name
# Create a directory to store the training checkpoint.
output_directory = Path(f"outputs/train/{env_name}")
output_directory.mkdir(parents=True, exist_ok=True)
# Number of offline training steps (we'll only do offline training for this example.)
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
training_steps = 100000
log_freq = 10
# Set up the dataset.
delta_timestamps = {
"action": [i / 50.0 for i in range(100)],
}
dataset = LeRobotDataset(f'lerobot/{env_name}', delta_timestamps=delta_timestamps)
print(dataset.meta.stats)
cfg = ACTConfig()
policy = ACTPolicy(cfg, dataset_stats=dataset.meta.stats)
policy.reset()
step = 0
if args.model_starting_point:
if (Path(args.model_starting_point).is_file()):
state_dict = safe_load(args.model_starting_point)
load_state_dict(policy, state_dict)
if args.model_start_step_count:
step = args.model_start_step_count
policy.model.training = True
if cfg.train_backbone_separately:
params_not_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and not n.startswith("model.backbone")]
params_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and n.startswith("model.backbone")]
else:
params_not_backbone = nn.state.get_parameters(policy)
Tensor.manual_seed(1000)
if hasattr(cfg, 'override_dataset_stats'):
for key, stats_dict in cfg.override_dataset_stats.items():
for stats_type, listconfig in stats_dict.items():
# example of stats_type: min, max, mean, std
print(f'listconfig: {listconfig}')
dataset.meta.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32)
if cfg.train_backbone_separately == True:
opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4)
opt_backbone = nn.optim.AdamW(params_backbone, lr=1e-5, weight_decay=1e-4)
else:
opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4)
@TinyJit
@Tensor.train()
def train_step(
observation_state: Tensor | None = None,
observation_images: Tensor | None = None,
#observation_environment_state: Tensor | None = None,
action: Tensor | None = None,
action_is_pad: Tensor | None = None
) -> dict[str, float]:
Tensor.training = True
output_dict = policy(observation_state, observation_images, None, action, action_is_pad)
loss = output_dict["loss"]
opt.zero_grad()
if cfg.train_backbone_separately:
opt_backbone.zero_grad()
loss.backward()
if cfg.train_backbone_separately:
grad_norm_not_backbone = clip_grad_norm_(params_not_backbone, 10.0)
grad_norm_backbone = clip_grad_norm_(params_backbone, 10.0)
else:
grad_norm_not_backbone = clip_grad_norm_(params_not_backbone, 10.0)
# Filter out parameters without gradients before stepping
opt.params = [p for p in opt.params if p.grad is not None]
opt.step()
if cfg.train_backbone_separately:
opt_backbone.params = [p for p in opt_backbone.params if p.grad is not None]
opt_backbone.step()
return (
loss.realize(),
grad_norm_backbone.realize() if cfg.train_backbone_separately else grad_norm_not_backbone,
grad_norm_not_backbone.realize()
)
print(f'Starting training loop')
# Create dataloader for offline training.
dataloader = DataLoader(
dataset,
num_workers=0,
batch_size=8,
shuffle=True,
pin_memory=False,
drop_last=True,
)
done = False
with Tensor.train():
while not done:
for batch in dataloader:
batch_converted = {}
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch_converted[k] = Tensor(v.detach().cpu().numpy(), requires_grad=False)
else:
batch_converted[k] = v # Keep strings, lists, etc. as-is
batch = batch_converted
batch = policy.normalize_batch_inputs_and_targets(batch)
print(f'batch: {batch}')
info = train_step(
batch["observation.state"].realize(),
batch["observation.images"].realize(),
#batch["observation.environment_state"].realize() if "observation.environment_state" in batch else None,
batch["action"].realize(),
batch["action_is_pad"].realize()
)
loss = info[0]
grad_norm_backbone = info[1]
grad_norm_not_backbone = info[2]
if step % log_freq == 0:
print(f"step: {step} loss: {loss.numpy():.3f}")
print(f"grad_norm_backbone: {grad_norm_backbone.numpy():.3f}")
print(f"grad_norm_not_backbone: {grad_norm_not_backbone.numpy():.3f}")
step += 1
if step % 5000 == 0:
try:
state_dict = get_state_dict(policy)
safe_save(state_dict, f'{output_directory}/model_{step}.safetensors')
except:
print(f'Exception with safe save occured')
if step >= training_steps:
done = True
break
# Save a policy checkpoint.
state_dict = get_state_dict(policy)
safe_save(state_dict, f'{output_directory}/model_final.safetensors')