Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion source/api_c/include/c_api.h
Comment thread
njzjz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 25
#define DP_C_API_VERSION 26

/**
* @brief Neighbor list.
Expand Down Expand Up @@ -1447,6 +1447,16 @@ int DP_DeepBaseModelGetDimAParam(DP_DeepBaseModel* dpbase);
*/
bool DP_DeepBaseModelIsAParamNAll(DP_DeepBaseModel* dpbase);

/**
* @brief Check if the model has default frame parameters.
*
* @param[in] dpbase The DP to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepBaseModelHasDefaultFParam(DP_DeepBaseModel* dpbase);

/**
* @brief Get the type map of a DP.
* @param[in] dpbase The DP to use.
Expand Down Expand Up @@ -1490,6 +1500,16 @@ int DP_DeepBaseModelDeviGetDimAParam(DP_DeepBaseModelDevi* dpbase);
*/
bool DP_DeepBaseModelDeviIsAParamNAll(DP_DeepBaseModelDevi* dpbase);

/**
* @brief Check if the model deviation has default frame parameters.
*
* @param[in] dpbase The DP Model Deviation to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepBaseModelDeviHasDefaultFParam(DP_DeepBaseModelDevi* dpbase);

/**
* @brief Get the type map of a DP model deviation.
* @param[in] dpbase The DP model deviation to use.
Expand Down Expand Up @@ -1569,6 +1589,15 @@ int DP_DeepPotGetDimAParam(DP_DeepPot* dp);
*/
bool DP_DeepPotIsAParamNAll(DP_DeepPot* dp);

/**
* @brief Check if the DP has default frame parameters.
* @param[in] dp The DP to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepPotHasDefaultFParam(DP_DeepPot* dp);

/**
* @brief Get the type map of a DP.
* @param[in] dp The DP to use.
Expand Down Expand Up @@ -1607,6 +1636,15 @@ int DP_DeepPotModelDeviGetDimAParam(DP_DeepPotModelDevi* dp);
*/
bool DP_DeepPotModelDeviIsAParamNAll(DP_DeepPotModelDevi* dp);

/**
* @brief Check if the DP model deviation has default frame parameters.
* @param[in] dp The DP model deviation to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepPotModelDeviHasDefaultFParam(DP_DeepPotModelDevi* dp);

/**
* @brief Get the type map of a DP model deviation.
* @param[in] dp The DP model deviation to use.
Expand Down Expand Up @@ -1688,6 +1726,15 @@ int DP_DeepSpinGetDimAParam(DP_DeepSpin* dp);
*/
bool DP_DeepSpinIsAParamNAll(DP_DeepSpin* dp);

/**
* @brief Check if the DP Spin Model has default frame parameters.
* @param[in] dp The DP Spin Model to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepSpinHasDefaultFParam(DP_DeepSpin* dp);

/**
* @brief Get the type map of a DP Spin Model.
* @param[in] dp The DP Spin Model to use.
Expand Down Expand Up @@ -1731,6 +1778,15 @@ int DP_DeepSpinModelDeviGetDimAParam(DP_DeepSpinModelDevi* dp);
*/
bool DP_DeepSpinModelDeviIsAParamNAll(DP_DeepSpinModelDevi* dp);

/**
* @brief Check if the DP Spin Model Deviation has default frame parameters.
* @param[in] dp The DP Spin Model Deviation to use.
* @return true the model has default frame parameters
* @return false the model does not have default frame parameters
* @since API version 26
*/
bool DP_DeepSpinModelDeviHasDefaultFParam(DP_DeepSpinModelDevi* dp);

/**
* @brief Get the type map of a DP model deviation.
* @param[in] dp The DP model deviation to use.
Expand Down
2 changes: 2 additions & 0 deletions source/api_c/include/c_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct DP_DeepBaseModel {
int dfparam;
int daparam;
bool aparam_nall;
bool has_default_fparam;
};

struct DP_DeepBaseModelDevi {
Expand All @@ -57,6 +58,7 @@ struct DP_DeepBaseModelDevi {
int dfparam;
int daparam;
bool aparam_nall;
bool has_default_fparam;
};

struct DP_DeepPot : DP_DeepBaseModel {
Expand Down
29 changes: 26 additions & 3 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,19 +965,29 @@ class DeepBaseModel {
assert(dpbase);
return daparam;
}
/**
* @brief Check if the model has default frame parameters.
* @return true if the model has default frame parameters.
**/
bool has_default_fparam() const {
assert(dpbase);
return has_default_fparam_;
}

protected:
DP_DeepBaseModel* dpbase;
int dfparam;
int daparam;
bool aparam_nall;
bool has_default_fparam_;
template <typename VALUETYPE>
void validate_fparam_aparam(const int& nframes,
const int& nloc,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) const {
if (fparam.size() != dfparam &&
fparam.size() != static_cast<size_t>(nframes) * dfparam) {
fparam.size() != static_cast<size_t>(nframes) * dfparam &&
!(fparam.empty() && has_default_fparam_)) {
throw deepmd::hpp::deepmd_exception(
"the dim of frame parameter provided is not consistent with what the "
"model uses");
Expand Down Expand Up @@ -1058,6 +1068,7 @@ class DeepPot : public DeepBaseModel {
dfparam = DP_DeepPotGetDimFParam(dp);
daparam = DP_DeepPotGetDimAParam(dp);
aparam_nall = DP_DeepPotIsAParamNAll(dp);
has_default_fparam_ = DP_DeepPotHasDefaultFParam(dp);
dpbase = (DP_DeepBaseModel*)dp;
};

Expand Down Expand Up @@ -1502,6 +1513,7 @@ class DeepSpin : public DeepBaseModel {
dfparam = DP_DeepSpinGetDimFParam(dp);
daparam = DP_DeepSpinGetDimAParam(dp);
aparam_nall = DP_DeepSpinIsAParamNAll(dp);
has_default_fparam_ = DP_DeepSpinHasDefaultFParam(dp);
dpbase = (DP_DeepBaseModel*)dp;
};

Expand Down Expand Up @@ -1860,7 +1872,14 @@ class DeepBaseModelDevi {
return daparam;
}
/**
* @brief Compute the average of vectors.
* @brief Check if the model has default frame parameters.
* @return true if the model has default frame parameters.
**/
bool has_default_fparam() const {
assert(dpbase);
return has_default_fparam_;
}
/**
* @param[out] avg The average of vectors.
* @param[in] xx The vectors of all models.
**/
Expand Down Expand Up @@ -1981,13 +2000,15 @@ class DeepBaseModelDevi {
int dfparam;
int daparam;
bool aparam_nall;
bool has_default_fparam_;
template <typename VALUETYPE>
void validate_fparam_aparam(const int& nframes,
const int& nloc,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam) const {
if (fparam.size() != dfparam &&
fparam.size() != static_cast<size_t>(nframes) * dfparam) {
fparam.size() != static_cast<size_t>(nframes) * dfparam &&
!(fparam.empty() && has_default_fparam_)) {
throw deepmd::hpp::deepmd_exception(
"the dim of frame parameter provided is not consistent with what the "
"model uses");
Expand Down Expand Up @@ -2081,6 +2102,7 @@ class DeepPotModelDevi : public DeepBaseModelDevi {
dfparam = DP_DeepPotModelDeviGetDimFParam(dp);
daparam = DP_DeepPotModelDeviGetDimAParam(dp);
aparam_nall = DP_DeepPotModelDeviIsAParamNAll(dp);
has_default_fparam_ = DP_DeepPotModelDeviHasDefaultFParam(dp);
dpbase = (DP_DeepBaseModelDevi*)dp;
};

Expand Down Expand Up @@ -2513,6 +2535,7 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
dfparam = DP_DeepSpinModelDeviGetDimFParam(dp);
daparam = DP_DeepSpinModelDeviGetDimAParam(dp);
aparam_nall = DP_DeepSpinModelDeviIsAParamNAll(dp);
has_default_fparam_ = DP_DeepSpinModelDeviHasDefaultFParam(dp);
dpbase = (DP_DeepBaseModelDevi*)dp;
};

Expand Down
34 changes: 32 additions & 2 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,26 @@ void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

// DP Base Model
DP_DeepBaseModel::DP_DeepBaseModel() {}
DP_DeepBaseModel::DP_DeepBaseModel()
: dfparam(0), daparam(0), aparam_nall(false), has_default_fparam(false) {}
DP_DeepBaseModel::DP_DeepBaseModel(deepmd::DeepBaseModel& dpbase)
: dpbase(dpbase) {
dfparam = dpbase.dim_fparam();
daparam = dpbase.dim_aparam();
aparam_nall = dpbase.is_aparam_nall();
has_default_fparam = dpbase.has_default_fparam();
}
void DP_DeleteDeepBaseModel(DP_DeepBaseModel* dpbase) { delete dpbase; }

// DP Base Model Devi
DP_DeepBaseModelDevi::DP_DeepBaseModelDevi() {}
DP_DeepBaseModelDevi::DP_DeepBaseModelDevi()
: dfparam(0), daparam(0), aparam_nall(false), has_default_fparam(false) {}
DP_DeepBaseModelDevi::DP_DeepBaseModelDevi(deepmd::DeepBaseModelDevi& dpbase)
: dpbase(dpbase) {
dfparam = dpbase.dim_fparam();
daparam = dpbase.dim_aparam();
aparam_nall = dpbase.is_aparam_nall();
has_default_fparam = dpbase.has_default_fparam();
}
void DP_DeleteDeepBaseModelDevi(DP_DeepBaseModelDevi* dp) { delete dp; }

Expand Down Expand Up @@ -2019,6 +2023,10 @@ bool DP_DeepBaseModelIsAParamNAll(DP_DeepBaseModel* dpbase) {
return dpbase->aparam_nall;
}

bool DP_DeepBaseModelHasDefaultFParam(DP_DeepBaseModel* dpbase) {
return dpbase->has_default_fparam;
}

const char* DP_DeepBaseModelCheckOK(DP_DeepBaseModel* dpbase) {
return string_to_char(dpbase->exception);
}
Expand Down Expand Up @@ -2047,6 +2055,10 @@ bool DP_DeepBaseModelDeviIsAParamNAll(DP_DeepBaseModelDevi* dpbase) {
return dpbase->aparam_nall;
}

bool DP_DeepBaseModelDeviHasDefaultFParam(DP_DeepBaseModelDevi* dpbase) {
return dpbase->has_default_fparam;
}

const char* DP_DeepBaseModelDeviCheckOK(DP_DeepBaseModelDevi* dpbase) {
return string_to_char(dpbase->exception);
}
Expand Down Expand Up @@ -2080,6 +2092,10 @@ bool DP_DeepPotIsAParamNAll(DP_DeepPot* dp) {
return DP_DeepBaseModelIsAParamNAll(static_cast<DP_DeepBaseModel*>(dp));
}

bool DP_DeepPotHasDefaultFParam(DP_DeepPot* dp) {
return DP_DeepBaseModelHasDefaultFParam(static_cast<DP_DeepBaseModel*>(dp));
}

const char* DP_DeepPotCheckOK(DP_DeepPot* dp) {
return DP_DeepBaseModelCheckOK(static_cast<DP_DeepBaseModel*>(dp));
}
Expand Down Expand Up @@ -2113,6 +2129,11 @@ bool DP_DeepPotModelDeviIsAParamNAll(DP_DeepPotModelDevi* dp) {
static_cast<DP_DeepBaseModelDevi*>(dp));
}

bool DP_DeepPotModelDeviHasDefaultFParam(DP_DeepPotModelDevi* dp) {
return DP_DeepBaseModelDeviHasDefaultFParam(
static_cast<DP_DeepBaseModelDevi*>(dp));
}

const char* DP_DeepPotModelDeviCheckOK(DP_DeepPotModelDevi* dp) {
return DP_DeepBaseModelDeviCheckOK(static_cast<DP_DeepBaseModelDevi*>(dp));
}
Expand Down Expand Up @@ -2146,6 +2167,10 @@ bool DP_DeepSpinIsAParamNAll(DP_DeepSpin* dp) {
return DP_DeepBaseModelIsAParamNAll(static_cast<DP_DeepBaseModel*>(dp));
}

bool DP_DeepSpinHasDefaultFParam(DP_DeepSpin* dp) {
return DP_DeepBaseModelHasDefaultFParam(static_cast<DP_DeepBaseModel*>(dp));
}

const char* DP_DeepSpinCheckOK(DP_DeepSpin* dp) {
return DP_DeepBaseModelCheckOK(static_cast<DP_DeepBaseModel*>(dp));
}
Expand Down Expand Up @@ -2179,6 +2204,11 @@ bool DP_DeepSpinModelDeviIsAParamNAll(DP_DeepSpinModelDevi* dp) {
static_cast<DP_DeepBaseModelDevi*>(dp));
}

bool DP_DeepSpinModelDeviHasDefaultFParam(DP_DeepSpinModelDevi* dp) {
return DP_DeepBaseModelDeviHasDefaultFParam(
static_cast<DP_DeepBaseModelDevi*>(dp));
}

const char* DP_DeepSpinModelDeviCheckOK(DP_DeepSpinModelDevi* dp) {
return DP_DeepBaseModelDeviCheckOK(static_cast<DP_DeepBaseModelDevi*>(dp));
}
Expand Down
Loading
Loading