@@ -104,6 +104,7 @@ def main(
104104 n_spherical_harmonics : int = 1 ,
105105 theta : float = 0.0 ,
106106 leaf_size : int = 1 ,
107+ tree_build_device : Literal ["cpu" , "cuda" ] | None = None ,
107108 airfrans_task : Literal ["full" , "scarce" , "reynolds" , "aoa" ] = "full" ,
108109 patience_steps : int = 1600 ,
109110 use_profiler : bool = True ,
@@ -140,6 +141,8 @@ def main(
140141 n_spherical_harmonics: Number of Legendre polynomial terms for angle features.
141142 theta: Barnes-Hut opening angle. Larger = more aggressive approximation.
142143 leaf_size: Maximum sources per leaf node in the Barnes-Hut tree.
144+ tree_build_device: Device on which to build cluster trees and run the
145+ dual-tree Barnes-Hut traversal. ``None`` (default) uses the input's device.
143146 airfrans_task: Which AirFRANS dataset task to train on.
144147 patience_steps: ReduceLROnPlateau patience expressed in gradient
145148 steps (world-size independent). Converted to epochs internally.
@@ -270,6 +273,7 @@ def main(
270273 self_regularization_beta = self_regularization_beta ,
271274 latent_compression_scale = latent_compression_scale ,
272275 expand_far_targets = expand_far_targets ,
276+ tree_build_device = tree_build_device ,
273277 ).to (device )
274278
275279 logger0 .info (f"{ output_dir .name = !r} " )
@@ -347,16 +351,13 @@ def main(
347351 min_lr = learning_rate / 64 ,
348352 threshold = 1e-3 ,
349353 )
350- scaler = torch .amp .GradScaler (device = device .type , enabled = amp )
351-
352354 ### [Checkpoint Save/Load]
353355 metadata_dict : dict [str , Any ] = {}
354356 epoch = load_checkpoint (
355357 checkpoint_dir ,
356358 models = base_model ,
357359 optimizer = optimizer ,
358360 scheduler = scheduler ,
359- scaler = scaler ,
360361 metadata_dict = metadata_dict ,
361362 device = dist .device ,
362363 )
@@ -430,7 +431,6 @@ def main(
430431 ** config_settings ,
431432 "optimizer" : optimizer .__class__ .__name__ ,
432433 "scheduler" : scheduler .__class__ .__name__ ,
433- "scaler" : scaler .__class__ .__name__ ,
434434 "physicsnemo_pkg_info" : get_physicsnemo_pkg_info (),
435435 "world_size" : dist .world_size ,
436436 ** {f"n_{ split } _samples" : len (sample_paths [split ]) for split in splits },
@@ -526,15 +526,13 @@ def prepare_sample(sample: AirFRANSSample) -> AirFRANSSample:
526526 if torch .isnan (batch_loss ):
527527 warnings .warn (f"{ batch_loss = } at: { dist .rank = } , { epoch = } " )
528528 with record_function ("backward" ):
529- scaler . scale ( batch_loss ) .backward ()
529+ batch_loss .backward ()
530530 if gradient_clip_norm is not None :
531- scaler .unscale_ (optimizer )
532531 torch .nn .utils .clip_grad_norm_ (
533532 model .parameters (), max_norm = gradient_clip_norm
534533 )
535534 with record_function ("optimizer_step" ):
536- scaler .step (optimizer )
537- scaler .update ()
535+ optimizer .step ()
538536 all_batch_losses .append (batch_loss .detach ().clone ())
539537 for k , v in batch_loss_components .items ():
540538 all_batch_loss_components [k ].append (v .detach ().clone ())
@@ -614,7 +612,6 @@ def save_ckpt() -> None:
614612 models = base_model ,
615613 optimizer = optimizer ,
616614 scheduler = scheduler ,
617- scaler = scaler ,
618615 epoch = epoch ,
619616 metadata = checkpoint_metadata (),
620617 )
0 commit comments