Skip to content

Commit f056888

Browse files
authored
Merge pull request #397 from open-edge-platform/update-branch
bug: robotic training tool resume training error (#918)
2 parents 70310f5 + c1627c4 commit f056888

6 files changed

Lines changed: 926 additions & 2679 deletions

File tree

usecases/robotic/training-ui/server/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ wheels/
99
# Virtual environments
1010
.venv
1111
data/
12-
output/
12+
output/
13+
outputs/

usecases/robotic/training-ui/server/modules/lerobot/finetune.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ def __init__(
8282

8383
# default policy
8484
if policy_type == "act":
85+
# Map device to 'cpu' for SafeTensors compatibility
8586
self.policy_cfg = ACTConfig(
86-
repo_id="local_policy", device=self.device, push_to_hub=False
87+
repo_id="local_policy", device="cpu", push_to_hub=False
8788
)
8889

8990
self.config_path = None
@@ -135,6 +136,11 @@ def __init__(
135136
ds_meta=self.dataset.meta,
136137
rename_map=self.train_cfg.rename_map,
137138
)
139+
140+
# Move policy to actual XPU device after loading
141+
if str(self.device).startswith('xpu'):
142+
self.policy = self.policy.to(self.device)
143+
138144
self.accelerator.wait_for_everyone()
139145

140146
processor_kwargs = {}
@@ -147,7 +153,7 @@ def __init__(
147153

148154
if self.train_cfg.policy.pretrained_path is not None:
149155
processor_kwargs["preprocessor_overrides"] = {
150-
"device_processor": {"device": self.device.type},
156+
"device_processor": {"device": "cpu"}, # Map device for processor compatibility
151157
"normalizer_processor": {
152158
"stats": self.dataset.meta.stats,
153159
"features": {
@@ -227,13 +233,26 @@ def run(self):
227233
initial_step=self.step,
228234
)
229235

236+
# Comprehensive device transfer for all tensor types
237+
def move_to_device(obj, device):
238+
if isinstance(obj, torch.Tensor):
239+
return obj.to(device, non_blocking=True)
240+
elif isinstance(obj, dict):
241+
return {k: move_to_device(v, device) for k, v in obj.items()}
242+
elif isinstance(obj, list):
243+
return [move_to_device(item, device) for item in obj]
244+
elif isinstance(obj, tuple):
245+
return tuple(move_to_device(item, device) for item in obj)
246+
return obj
247+
230248
for _ in range(self.step, self.train_cfg.steps):
231249
if self.is_training_stopped.is_set():
232250
break
233251

234252
start_time = time.perf_counter()
235253
batch = next(dl_iter)
236254
batch = self.preprocessor(batch)
255+
batch = move_to_device(batch, self.device)
237256
train_tracker.dataloading_s = time.perf_counter() - start_time
238257

239258
train_tracker, output_dict = update_policy(

0 commit comments

Comments
 (0)