22// This file is distributed under the University of Illinois/NCSA Open Source License.
33// See LICENSE file in top directory for details.
44//
5- // Copyright (c) 2016 Jeongnim Kim and QMCPACK developers.
5+ // Copyright (c) 2025 QMCPACK developers.
66//
77// File developed by: Jeremy McMinnis, jmcminis@gmail.com, University of Illinois at Urbana-Champaign
88// Jeongnim Kim, jeongnim.kim@gmail.com, University of Illinois at Urbana-Champaign
99// Mark A. Berrill, berrillma@ornl.gov, Oak Ridge National Laboratory
10+ // Peter W. Doak, doakpw@ornl.gov, Oak Ridge National Laboratory
1011//
1112// File created by: Jeongnim Kim, jeongnim.kim@gmail.com, University of Illinois at Urbana-Champaign
1213// ////////////////////////////////////////////////////////////////////////////////////
2122
2223namespace qmcplusplus
2324{
24- void KContainer::updateKLists (const ParticleLayout& lattice,
25- RealType kc,
26- unsigned ndim,
27- const PosType& twist,
28- bool useSphere)
25+
26+ template <typename REAL>
27+ const std::vector<typename KContainerT<REAL>::AppPosition>& KContainerT<
28+ REAL>::getKptsCartWorking() const
29+ {
30+ // This is an `if constexpr` so it should not cost a branch at runtime.
31+ if constexpr (std::is_same_v<decltype (kpts_cart_), decltype (kpts_cart_working_)>)
32+ return kpts_cart_;
33+ else
34+ return kpts_cart_working_;
35+ }
36+
37+ template <typename REAL>
38+ const std::vector<REAL>& KContainerT<REAL>::getKSQWorking() const
39+ {
40+ // This is an `if constexpr` so it should not cost a branch at runtime.
41+ if constexpr (std::is_same<decltype (ksq_), decltype (ksq_working_)>::value)
42+ return ksq_;
43+ else
44+ return ksq_working_;
45+ }
46+
47+ template <typename REAL>
48+ int KContainerT<REAL>::getMinusK(int k) const
49+ {
50+ assert (k < minusk.size ());
51+ return minusk[k];
52+ }
53+
54+ template <typename REAL>
55+ void KContainerT<REAL>::updateKLists(const Lattice& lattice,
56+ FullPrecReal kc,
57+ unsigned ndim,
58+ const Position& twist,
59+ bool useSphere)
2960{
3061 kcutoff = kc;
3162 if (kcutoff <= 0.0 )
@@ -37,11 +68,12 @@ void KContainer::updateKLists(const ParticleLayout& lattice,
3768
3869 app_log () << " KContainer initialised with cutoff " << kcutoff << std::endl;
3970 app_log () << " # of K-shell = " << kshell.size () << std::endl;
40- app_log () << " # of K points = " << kpts .size () << std::endl;
71+ app_log () << " # of K points = " << kpts_ .size () << std::endl;
4172 app_log () << std::endl;
4273}
4374
44- void KContainer::findApproxMMax (const ParticleLayout& lattice, unsigned ndim)
75+ template <typename REAL>
76+ void KContainerT<REAL>::findApproxMMax(const Lattice& lattice, unsigned ndim)
4577{
4678 // Estimate the size of the parallelpiped that encompasses a sphere of kcutoff.
4779 // mmax is stored as integer translations of the reciprocal cell vectors.
@@ -102,19 +134,22 @@ void KContainer::findApproxMMax(const ParticleLayout& lattice, unsigned ndim)
102134 mmax[1 ] = 0 ;
103135}
104136
105- void KContainer::BuildKLists (const ParticleLayout& lattice, const PosType& twist, bool useSphere)
137+ template <typename REAL>
138+ void KContainerT<REAL>::BuildKLists(const Lattice& lattice,
139+ const Position& twist,
140+ bool useSphere)
106141{
107142 TinyVector<int , DIM + 1 > TempActualMax;
108143 TinyVector<int , DIM> kvec;
109- TinyVector<RealType , DIM> kvec_cart;
110- RealType modk2;
144+ TinyVector<FullPrecReal , DIM> kvec_cart;
145+ FullPrecReal modk2;
111146 std::vector<TinyVector<int , DIM>> kpts_tmp;
112- std::vector<PosType > kpts_cart_tmp;
113- std::vector<RealType > ksq_tmp;
147+ std::vector<PositionFull > kpts_cart_tmp;
148+ std::vector<FullPrecReal > ksq_tmp;
114149 // reserve the space for memory efficiency
115150 if (useSphere)
116151 {
117- const RealType kcut2 = kcutoff * kcutoff;
152+ const FullPrecReal kcut2 = kcutoff * kcutoff;
118153 // Loop over guesses for valid k-points.
119154 for (int i = -mmax[0 ]; i <= mmax[0 ]; i++)
120155 {
@@ -208,10 +243,10 @@ void KContainer::BuildKLists(const ParticleLayout& lattice, const PosType& twist
208243 }
209244 }
210245 std::map<int64_t , std::vector<int >*>::iterator it (kpts_sorted.begin ());
211- kpts .resize (numk);
212- kpts_cart .resize (numk);
246+ kpts_ .resize (numk);
247+ kpts_cart_ .resize (numk);
213248 kpts_cart_soa_.resize (numk);
214- ksq .resize (numk);
249+ ksq_ .resize (numk);
215250 kshell.resize (kpts_sorted.size () + 1 , 0 );
216251 int ok = 0 , ish = 0 ;
217252 while (it != kpts_sorted.end ())
@@ -220,10 +255,10 @@ void KContainer::BuildKLists(const ParticleLayout& lattice, const PosType& twist
220255 while (vit != (*it).second ->end ())
221256 {
222257 int ik = (*vit);
223- kpts [ok] = kpts_tmp[ik];
224- kpts_cart [ok] = kpts_cart_tmp[ik];
258+ kpts_ [ok] = kpts_tmp[ik];
259+ kpts_cart_ [ok] = kpts_cart_tmp[ik];
225260 kpts_cart_soa_ (ok) = kpts_cart_tmp[ik];
226- ksq [ok] = ksq_tmp[ik];
261+ ksq_ [ok] = ksq_tmp[ik];
227262 ++vit;
228263 ++ok;
229264 }
@@ -232,6 +267,13 @@ void KContainer::BuildKLists(const ParticleLayout& lattice, const PosType& twist
232267 ++ish;
233268 }
234269 kpts_cart_soa_.updateTo ();
270+ if constexpr (!std::is_same<Real, FullPrecReal>::value)
271+ {
272+ // This copy implicity does the precision reduction.
273+ // the working vectors are not used or initialized for full precision builds.
274+ std::copy (kpts_cart_.begin (), kpts_cart_.end (), std::back_inserter (kpts_cart_working_));
275+ std::copy (ksq_.begin (), ksq_.end (), std::back_inserter (ksq_working_));
276+ }
235277 it = kpts_sorted.begin ();
236278 std::map<int64_t , std::vector<int >*>::iterator e_it (kpts_sorted.end ());
237279 while (it != e_it)
@@ -262,13 +304,19 @@ void KContainer::BuildKLists(const ParticleLayout& lattice, const PosType& twist
262304 std::map<int64_t , int > hashToIndex;
263305 for (int ki = 0 ; ki < numk; ki++)
264306 {
265- hashToIndex[getHashOfVec (kpts [ki], numk)] = ki;
307+ hashToIndex[getHashOfVec (kpts_ [ki], numk)] = ki;
266308 }
267309 // Use the map to find the index of -k from the index of k
268310 for (int ki = 0 ; ki < numk; ki++)
269311 {
270- minusk[ki] = hashToIndex[getHashOfVec (-1 * kpts [ki], numk)];
312+ minusk[ki] = hashToIndex[getHashOfVec (-1 * kpts_ [ki], numk)];
271313 }
272314}
273315
316+ #ifdef MIXED_PRECISION
317+ template class KContainerT <float >;
318+ template class KContainerT <double >;
319+ #else
320+ template class KContainerT <double >;
321+ #endif
274322} // namespace qmcplusplus
0 commit comments