Skip to content

Commit 93b5b61

Browse files
authored
Add SAD guess (#30)
* Add SAD guess * remove unused typedef
1 parent 395872f commit 93b5b61

4 files changed

Lines changed: 166 additions & 0 deletions

File tree

src/scf/guess/guess.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,23 @@
2020
namespace scf::guess {
2121

2222
DECLARE_MODULE(Core);
23+
DECLARE_MODULE(SAD);
2324

2425
inline void load_modules(pluginplay::ModuleManager& mm) {
2526
mm.add_module<Core>("Core guess");
27+
mm.add_module<SAD>("SAD guess");
2628
}
2729

2830
inline void set_defaults(pluginplay::ModuleManager& mm) {
2931
mm.change_submod("Core guess", "Build Fock operator",
3032
"Restricted One-Electron Fock Op");
3133
mm.change_submod("Core guess", "Guess updater",
3234
"Diagonalization Fock update");
35+
36+
mm.change_submod("SAD guess", "Build Fock operator",
37+
"Restricted One-Electron Fock Op");
38+
mm.change_submod("SAD guess", "Guess updater",
39+
"Diagonalization Fock update");
3340
}
3441

3542
} // namespace scf::guess

src/scf/guess/sad.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright 2025 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "guess.hpp"
18+
19+
namespace scf::guess {
20+
namespace {
21+
const auto desc = R"(
22+
SAD Guess
23+
---------
24+
25+
TODO: Write me!!!
26+
)";
27+
}
28+
29+
using rscf_wf = simde::type::rscf_wf;
30+
using density_t = simde::type::decomposable_e_density;
31+
using pt = simde::InitialGuess<rscf_wf>;
32+
using fock_op_pt = simde::FockOperator<density_t>;
33+
using update_pt = simde::UpdateGuess<rscf_wf>;
34+
using initial_rho_pt = simde::InitialDensity;
35+
36+
using simde::type::tensor;
37+
38+
// TODO: move to chemist?
39+
struct NElectronCounter : public chemist::qm_operator::OperatorVisitor {
40+
NElectronCounter() : chemist::qm_operator::OperatorVisitor(false) {}
41+
42+
void run(const simde::type::T_e_type& T_e) { set_n(T_e.particle().size()); }
43+
44+
void run(const simde::type::V_en_type& V_en) {
45+
set_n(V_en.lhs_particle().size());
46+
}
47+
48+
void run(const simde::type::V_ee_type& V_ee) {
49+
set_n(V_ee.lhs_particle().size());
50+
set_n(V_ee.rhs_particle().size());
51+
}
52+
53+
void set_n(unsigned int n) {
54+
if(n_electrons == 0)
55+
n_electrons = n;
56+
else if(n_electrons != n) {
57+
throw std::runtime_error("Deduced a different number of electrons");
58+
}
59+
}
60+
61+
unsigned int n_electrons = 0;
62+
};
63+
64+
MODULE_CTOR(SAD) {
65+
description(desc);
66+
satisfies_property_type<pt>();
67+
add_submodule<fock_op_pt>("Build Fock operator");
68+
add_submodule<update_pt>("Guess updater");
69+
add_submodule<initial_rho_pt>("SAD Density");
70+
}
71+
72+
MODULE_RUN(SAD) {
73+
const auto&& [H, aos] = pt::unwrap_inputs(inputs);
74+
75+
// Step 1: Build Fock Operator with zero density
76+
auto& initial_rho_mod = submods.at("SAD Density");
77+
const auto& rho = initial_rho_mod.run_as<initial_rho_pt>(H);
78+
auto& fock_op_mod = submods.at("Build Fock operator");
79+
const auto& f = fock_op_mod.run_as<fock_op_pt>(H, rho);
80+
81+
// Step 2: Get number of electrons and occupations
82+
simde::type::cmos cmos(tensor{}, aos, tensor{});
83+
NElectronCounter visitor;
84+
H.visit(visitor);
85+
auto n_electrons = visitor.n_electrons;
86+
if(n_electrons % 2 != 0)
87+
throw std::runtime_error("Assumed even number of electrons");
88+
89+
typename rscf_wf::orbital_index_set_type occs;
90+
using value_type = typename rscf_wf::orbital_index_set_type::value_type;
91+
for(value_type i = 0; i < n_electrons / 2; ++i) occs.insert(i);
92+
93+
rscf_wf zero_guess(occs, cmos);
94+
auto& update_mod = submods.at("Guess updater");
95+
const auto& Psi0 = update_mod.run_as<update_pt>(f, zero_guess);
96+
97+
auto rv = results();
98+
return pt::wrap_results(rv, Psi0);
99+
}
100+
101+
} // namespace scf::guess
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright 2024 NWChemEx-Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "../integration_tests.hpp"
18+
19+
using simde::type::tensor;
20+
using shape_type = tensorwrapper::shape::Smooth;
21+
using cmos_type = simde::type::cmos;
22+
using density_type = simde::type::decomposable_e_density;
23+
using rscf_wf = simde::type::rscf_wf;
24+
using occ_index = typename rscf_wf::orbital_index_set_type;
25+
26+
using pt = simde::InitialGuess<rscf_wf>;
27+
using initial_rho_pt = simde::InitialDensity;
28+
29+
using tensorwrapper::operations::approximately_equal;
30+
31+
TEMPLATE_LIST_TEST_CASE("SAD", "", test_scf::float_types) {
32+
using float_type = TestType;
33+
using allocator_type = tensorwrapper::allocator::Eigen<float_type>;
34+
35+
auto mm = test_scf::load_modules<float_type>();
36+
auto aos = test_scf::h2_aos();
37+
auto H = test_scf::h2_hamiltonian();
38+
auto rt = mm.get_runtime();
39+
40+
auto mod = mm.at("SAD guess");
41+
auto psi = mod.template run_as<pt>(H, aos);
42+
const auto& evals = psi.orbitals().diagonalized_matrix();
43+
44+
occ_index occs{0};
45+
allocator_type alloc(rt);
46+
shape_type shape_corr{2};
47+
auto pbuffer = alloc.construct({-0.498376, 0.594858});
48+
tensor corr(shape_corr, std::move(pbuffer));
49+
50+
REQUIRE(psi.orbital_indices() == occs);
51+
REQUIRE(psi.orbitals().from_space() == aos);
52+
REQUIRE(approximately_equal(corr, evals, 1E-6));
53+
}

tests/cxx/integration_tests/integration_tests.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818
#include "../test_scf.hpp"
19+
#include <chemcache/chemcache.hpp>
1920
#include <integrals/integrals.hpp>
2021
#include <nux/nux.hpp>
2122
#include <scf/scf.hpp>
@@ -31,6 +32,7 @@ pluginplay::ModuleManager load_modules() {
3132
scf::load_modules(mm);
3233
integrals::load_modules(mm);
3334
nux::load_modules(mm);
35+
chemcache::load_modules(mm);
3436

3537
mm.change_submod("SCF Driver", "Hamiltonian",
3638
"Born-Oppenheimer approximation");
@@ -44,13 +46,16 @@ pluginplay::ModuleManager load_modules() {
4446

4547
mm.change_submod("Loop", "Overlap matrix builder", "Overlap");
4648

49+
mm.change_submod("SAD guess", "SAD Density", "sto-3g SAD density");
50+
4751
if constexpr(!std::is_same_v<FloatType, double>) {
4852
mm.change_input("Evaluate 2-Index BraKet", "With UQ?", true);
4953
mm.change_input("Evaluate 4-Index BraKet", "With UQ?", true);
5054
mm.change_input("Overlap", "With UQ?", true);
5155
mm.change_input("ERI4", "With UQ?", true);
5256
mm.change_input("Kinetic", "With UQ?", true);
5357
mm.change_input("Nuclear", "With UQ?", true);
58+
mm.change_input("sto-3g atomic density matrix", "With UQ?", true);
5459
}
5560

5661
return mm;

0 commit comments

Comments
 (0)