@@ -82,8 +82,9 @@ def __init__(
8282
8383 # default policy
8484 if policy_type == "act" :
85+ # Map device to 'cpu' for SafeTensors compatibility
8586 self .policy_cfg = ACTConfig (
86- repo_id = "local_policy" , device = self . device , push_to_hub = False
87+ repo_id = "local_policy" , device = "cpu" , push_to_hub = False
8788 )
8889
8990 self .config_path = None
@@ -135,6 +136,11 @@ def __init__(
135136 ds_meta = self .dataset .meta ,
136137 rename_map = self .train_cfg .rename_map ,
137138 )
139+
140+ # Move policy to actual XPU device after loading
141+ if str (self .device ).startswith ('xpu' ):
142+ self .policy = self .policy .to (self .device )
143+
138144 self .accelerator .wait_for_everyone ()
139145
140146 processor_kwargs = {}
@@ -147,7 +153,7 @@ def __init__(
147153
148154 if self .train_cfg .policy .pretrained_path is not None :
149155 processor_kwargs ["preprocessor_overrides" ] = {
150- "device_processor" : {"device" : self . device . type },
156+ "device_processor" : {"device" : "cpu" }, # Map device for processor compatibility
151157 "normalizer_processor" : {
152158 "stats" : self .dataset .meta .stats ,
153159 "features" : {
@@ -227,13 +233,26 @@ def run(self):
227233 initial_step = self .step ,
228234 )
229235
236+ # Comprehensive device transfer for all tensor types
237+ def move_to_device (obj , device ):
238+ if isinstance (obj , torch .Tensor ):
239+ return obj .to (device , non_blocking = True )
240+ elif isinstance (obj , dict ):
241+ return {k : move_to_device (v , device ) for k , v in obj .items ()}
242+ elif isinstance (obj , list ):
243+ return [move_to_device (item , device ) for item in obj ]
244+ elif isinstance (obj , tuple ):
245+ return tuple (move_to_device (item , device ) for item in obj )
246+ return obj
247+
230248 for _ in range (self .step , self .train_cfg .steps ):
231249 if self .is_training_stopped .is_set ():
232250 break
233251
234252 start_time = time .perf_counter ()
235253 batch = next (dl_iter )
236254 batch = self .preprocessor (batch )
255+ batch = move_to_device (batch , self .device )
237256 train_tracker .dataloading_s = time .perf_counter () - start_time
238257
239258 train_tracker , output_dict = update_policy (
0 commit comments