File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -586,14 +586,18 @@ def _deserialize_to_file_pt2(
586586 data , model_json_override , do_atomic_virial
587587 )
588588
589- # Compile via AOTInductor into a .pt2 package.
590- # realize_opcount_threshold=0 prevents aggressive kernel fusion that
589+ # On CUDA, aggressive kernel fusion (default realize_opcount_threshold=30)
591590 # causes NaN in the backward pass (force/virial) of attention-based
592- # descriptors (DPA1, DPA2) on CUDA for certain coordinate patterns.
591+ # descriptors (DPA1, DPA2). Setting threshold=0 prevents fusion and
592+ # avoids the NaN. Only applied on CUDA; CPU compilation is unaffected.
593593 import torch ._inductor .config as _inductor_config
594594
595+ import deepmd .pt_expt .utils .env as _env
596+
597+ is_cuda = _env .DEVICE .type == "cuda"
595598 saved_threshold = _inductor_config .realize_opcount_threshold
596- _inductor_config .realize_opcount_threshold = 0
599+ if is_cuda :
600+ _inductor_config .realize_opcount_threshold = 0
597601 try :
598602 aoti_compile_and_package (exported , package_path = model_file )
599603 finally :
You can’t perform that action at this time.
0 commit comments