Skip to content

Commit 7632db8

Browse files
author
Han Wang
committed
fix(pt_expt): fail fast on with-comm artifact errors instead of silently zeroing
Address @iProzd review on PR #5430: - border_op_export: throw on empty output list rather than returning empty_like(g1), which masked internal kernel bugs as zero outputs. - DeepPotPTExpt / DeepSpinPTExpt: if the with-comm artifact is declared in metadata but fails to load, keep has_comm_artifact_=true so multi-rank dispatch (nswap>0) throws explicitly. Previously has_comm_artifact_ was reset to false on load failure, making multi-rank silently fall through to the single-rank artifact and skip the MPI ghost-embedding exchange.
1 parent 4f8240e commit 7632db8

3 files changed

Lines changed: 42 additions & 14 deletions

File tree

source/api_cc/src/DeepPotPTExpt.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ void DeepPotPTExpt::init(const std::string& model,
166166
// inference. Pre-Phase-3 .pt2 files lack ``has_comm_artifact``;
167167
// default to false so old artifacts keep working. If the metadata
168168
// flag is set but the nested artifact fails to extract or compile,
169-
// fall back to single-rank mode rather than aborting init -- the
170-
// hard error then surfaces in ``run_model_with_comm()`` only when
171-
// multi-rank actually needs it.
169+
// keep ``has_comm_artifact_=true`` and let single-rank dispatch
170+
// continue working; multi-rank dispatch then fails fast at
171+
// ``run_model_with_comm()`` rather than silently dropping the MPI
172+
// exchange and producing wrong results.
172173
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
173174
metadata["has_comm_artifact"].as_bool();
174175
if (has_comm_artifact_) {
@@ -186,11 +187,12 @@ void DeepPotPTExpt::init(const std::string& model,
186187
: static_cast<c10::DeviceIndex>(-1));
187188
} catch (const std::exception& e) {
188189
std::cerr << "DeepPotPTExpt: failed to load with-comm artifact ("
189-
<< e.what() << "); falling back to single-rank-only dispatch."
190+
<< e.what()
191+
<< "); single-rank inference will still work, but multi-rank "
192+
"LAMMPS dispatch will throw."
190193
<< std::endl;
191194
with_comm_tempfile_.reset();
192195
with_comm_loader.reset();
193-
has_comm_artifact_ = false;
194196
}
195197
}
196198

@@ -244,9 +246,12 @@ std::vector<torch::Tensor> DeepPotPTExpt::run_model_with_comm(
244246
const std::vector<at::Tensor>& comm_tensors) {
245247
if (!with_comm_loader) {
246248
throw deepmd::deepmd_exception(
247-
"run_model_with_comm called but the .pt2 file has no with-comm "
248-
"artifact. This is a programming error: the caller should check "
249-
"has_comm_artifact_ before invoking this path.");
249+
"run_model_with_comm called but the with-comm artifact is not "
250+
"available. Either the .pt2 file has no with-comm artifact compiled "
251+
"(programming error: the caller should check has_comm_artifact_ "
252+
"before invoking this path), or the artifact was present in the "
253+
".pt2 metadata but failed to load at init time (see earlier stderr "
254+
"log). Multi-rank LAMMPS requires a working with-comm artifact.");
250255
}
251256
if (comm_tensors.size() != 8) {
252257
throw deepmd::deepmd_exception(
@@ -431,6 +436,12 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
431436
// tensor to gather ghost embeddings from local atoms.
432437
std::vector<torch::Tensor> flat_outputs;
433438
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
439+
if (use_with_comm && !with_comm_loader) {
440+
throw deepmd::deepmd_exception(
441+
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
442+
"to load at init time. See the earlier stderr log for the underlying "
443+
"error.");
444+
}
434445
// When NULL-type atoms exist, remapped storage must outlive comm
435446
// tensors (the int** pointer-array tensor references it).
436447
std::vector<std::vector<int>> remapped_sendlist;

source/api_cc/src/DeepSpinPTExpt.cc

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,9 @@ void DeepSpinPTExpt::init(const std::string& model,
174174

175175
// Phase 4: load the optional with-comm artifact for multi-rank GNN
176176
// spin inference. Mirrors DeepPotPTExpt; see its init() comment for
177-
// the rationale on the try/catch fallback.
177+
// the rationale on keeping ``has_comm_artifact_=true`` on load
178+
// failure so multi-rank dispatch fails fast rather than silently
179+
// dropping the MPI exchange.
178180
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
179181
metadata["has_comm_artifact"].as_bool();
180182
if (has_comm_artifact_) {
@@ -189,11 +191,12 @@ void DeepSpinPTExpt::init(const std::string& model,
189191
: static_cast<c10::DeviceIndex>(-1));
190192
} catch (const std::exception& e) {
191193
std::cerr << "DeepSpinPTExpt: failed to load with-comm artifact ("
192-
<< e.what() << "); falling back to single-rank-only dispatch."
194+
<< e.what()
195+
<< "); single-rank inference will still work, but multi-rank "
196+
"LAMMPS dispatch will throw."
193197
<< std::endl;
194198
with_comm_tempfile_.reset();
195199
with_comm_loader.reset();
196-
has_comm_artifact_ = false;
197200
}
198201
}
199202

@@ -249,8 +252,11 @@ std::vector<torch::Tensor> DeepSpinPTExpt::run_model_with_comm(
249252
const std::vector<at::Tensor>& comm_tensors) {
250253
if (!with_comm_loader) {
251254
throw deepmd::deepmd_exception(
252-
"DeepSpinPTExpt::run_model_with_comm called but the .pt2 has no "
253-
"with-comm artifact.");
255+
"DeepSpinPTExpt::run_model_with_comm called but the with-comm "
256+
"artifact is not available. Either the .pt2 file has no with-comm "
257+
"artifact compiled, or the artifact was present in the .pt2 metadata "
258+
"but failed to load at init time (see earlier stderr log). Multi-rank "
259+
"LAMMPS requires a working with-comm artifact.");
254260
}
255261
if (comm_tensors.size() != 8) {
256262
throw deepmd::deepmd_exception(
@@ -448,6 +454,12 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
448454
// (pre atom-doubling); the spin override halves them internally.
449455
std::vector<torch::Tensor> flat_outputs;
450456
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
457+
if (use_with_comm && !with_comm_loader) {
458+
throw deepmd::deepmd_exception(
459+
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
460+
"to load at init time. See the earlier stderr log for the underlying "
461+
"error.");
462+
}
451463
std::vector<std::vector<int>> remapped_sendlist;
452464
std::vector<int*> remapped_sendlist_ptrs;
453465
std::vector<int> remapped_sendnum, remapped_recvnum;

source/op/pt/comm.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,12 @@ torch::Tensor border_op_export(const torch::Tensor& sendlist_tensor,
523523
communicator_tensor, nlocal_tensor, nghost_tensor);
524524
// border_op returns {g1_tensor} — a list whose first element aliases
525525
// g1_tensor. Clone for AOTI graph-output correctness.
526-
return out.empty() ? torch::empty_like(g1_tensor) : out[0].clone();
526+
if (out.empty()) {
527+
throw std::runtime_error(
528+
"border_op_export: border_op returned an empty output list, which "
529+
"indicates an internal error in the underlying border_op kernel.");
530+
}
531+
return out[0].clone();
527532
}
528533

529534
torch::Tensor border_op_backward_export(

0 commit comments

Comments
 (0)