atomic resource collection

This commit is contained in:
Kevin Gasperich 2023-10-10 16:28:32 -05:00
parent aec55c0f42
commit cc55cc2e03
6 changed files with 168 additions and 5 deletions

View File

@ -176,6 +176,22 @@ struct SoaBasisSetBase
/// Determine which orbitals are S-type. Used for cusp correction.
virtual void queryOrbitalsForSType(const std::vector<bool>& corrCenter, std::vector<bool>& is_s_orbital) const {}
/** initialize a shared resource and hand it to collection
*/
virtual void createResource(ResourceCollection& collection) const {}
/** acquire a shared resource from collection
*/
virtual void acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase>& bset_list) const
{}
/** return a shared resource to collection
*/
virtual void releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase>& bset_list) const
{}
};
} // namespace qmcplusplus

View File

@ -109,13 +109,18 @@ void LCAOrbitalSet::checkObject() const
void LCAOrbitalSet::createResource(ResourceCollection& collection) const
{
myBasisSet->createResource(collection);
auto resource_index = collection.addResource(std::make_unique<LCAOMultiWalkerMem>());
}
void LCAOrbitalSet::acquireResource(ResourceCollection& collection, const RefVectorWithLeader<SPOSet>& spo_list) const
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
spo_leader.myBasisSet->acquireResource(collection, extractBasRefList(spo_list));
spo_leader.mw_mem_handle_ = collection.lendResource<LCAOMultiWalkerMem>();
}
@ -123,9 +128,25 @@ void LCAOrbitalSet::releaseResource(ResourceCollection& collection, const RefVec
{
assert(this == &spo_list.getLeader());
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
spo_leader.myBasisSet->releaseResource(collection, extractBasRefList(spo_list));
collection.takebackResource(spo_leader.mw_mem_handle_);
}
RefVectorWithLeader<typename LCAOrbitalSet::basis_type> LCAOrbitalSet::extractBasRefList(
const RefVectorWithLeader<SPOSet>& spo_list) const
{
auto& spo_leader = spo_list.getCastedLeader<LCAOrbitalSet>();
RefVectorWithLeader<basis_type> bas_list(*spo_leader.myBasisSet);
bas_list.reserve(spo_list.size());
for (size_t iw = 0; iw < spo_list.size(); iw++)
{
auto& spo_i = spo_list.getCastedElement<LCAOrbitalSet>(iw);
bas_list.push_back(*spo_i.myBasisSet);
}
return bas_list;
}
std::unique_ptr<SPOSet> LCAOrbitalSet::makeClone() const { return std::make_unique<LCAOrbitalSet>(*this); }
void LCAOrbitalSet::evaluateValue(const ParticleSet& P, int iat, ValueVector& psi)

View File

@ -308,6 +308,8 @@ private:
void mw_evaluateValueVPsImplGEMM(const RefVectorWithLeader<SPOSet>& spo_list,
const RefVectorWithLeader<const VirtualParticleSet>& vp_list,
OffloadMWVArray& phi_v) const;
RefVectorWithLeader<basis_type> extractBasRefList(const RefVectorWithLeader<SPOSet>& spo_list) const;
struct LCAOMultiWalkerMem;
ResourceHandle<LCAOMultiWalkerMem> mw_mem_handle_;
/// timer for basis set

View File

@ -16,6 +16,7 @@
#include "CPU/math.hpp"
#include "OptimizableObject.h"
#include <ResourceCollection.h>
namespace qmcplusplus
{
@ -29,10 +30,15 @@ namespace qmcplusplus
template<typename ROT, typename SH>
struct SoaAtomicBasisSet
{
using RadialOrbital_t = ROT;
using RealType = typename ROT::RealType;
using GridType = typename ROT::GridType;
using ValueType = typename QMCTraits::ValueType;
using RadialOrbital_t = ROT;
using RealType = typename ROT::RealType;
using GridType = typename ROT::GridType;
using ValueType = typename QMCTraits::ValueType;
using OffloadNelecVGLPBCArray = Array<ValueType, 4, OffloadPinnedAllocator<ValueType>>; // [VGL, elec, PBC, Rnl/Ylm]
using OffloadNelecVPBCArray = Array<ValueType, 3, OffloadPinnedAllocator<ValueType>>; // [elec, PBC, Rnl/Ylm/xyz]
using OffloadNelecPBCArray = Array<ValueType, 2, OffloadPinnedAllocator<ValueType>>; // [elec, PBC]
using OffloadRPBCArray = Array<ValueType, 2, OffloadPinnedAllocator<ValueType>>; // [xyz, PBC]
using OffloadVector = Vector<ValueType, OffloadPinnedAllocator<ValueType>>;
///size of the basis set
int BasisSetSize;
@ -56,6 +62,8 @@ struct SoaAtomicBasisSet
std::vector<QuantumNumberType> RnlID;
///temporary storage
VectorSoaContainer<RealType, 4> tempS;
struct SoaAtomicBSetMultiWalkerMem;
ResourceHandle<SoaAtomicBSetMultiWalkerMem> mw_mem_handle_;
///the constructor
explicit SoaAtomicBasisSet(int lmax, bool addsignforM = false) : Ylm(lmax, addsignforM) {}
@ -656,6 +664,54 @@ struct SoaAtomicBasisSet
}
}
}
void createResource(ResourceCollection& collection) const
{
// Ylm.createResource(collection);
// MultiRnl.createResource(collection);
auto resource_index = collection.addResource(std::make_unique<SoaAtomicBSetMultiWalkerMem>());
}
void acquireResource(ResourceCollection& collection, const RefVectorWithLeader<SoaAtomicBasisSet>& atom_bs_list) const
{
assert(this == &atom_bs_list.getLeader());
// auto& atom_bs_leader = atom_bs_list.getCastedLeader<SoaAtomicBasisSet>();
// auto& atom_bs_leader = atom_bs_list.getCastedLeader();
// SoaAtomicBasisSet& atom_bs_leader = atom_bs_list.getCastedLeader();
// const SoaAtomicBasisSet& atom_bs_leader = atom_bs_list.getCastedLeader();
// const auto ylm_list(extractYlmRefList(atom_bs_list));
auto& atom_bs_leader = atom_bs_list.template getCastedLeader<SoaAtomicBasisSet>();
atom_bs_leader.mw_mem_handle_ = collection.lendResource<SoaAtomicBSetMultiWalkerMem>();
}
void releaseResource(ResourceCollection& collection, const RefVectorWithLeader<SoaAtomicBasisSet>& atom_bs_list) const
{
assert(this == &atom_bs_list.getLeader());
// auto& atom_bs_leader = atom_bs_list.getCastedLeader();
// const SoaAtomicBasisSet& atom_bs_leader = atom_bs_list.getCastedLeader();
auto& atom_bs_leader = atom_bs_list.template getCastedLeader<SoaAtomicBasisSet>();
collection.takebackResource(atom_bs_leader.mw_mem_handle_);
}
struct SoaAtomicBSetMultiWalkerMem : public Resource
{
SoaAtomicBSetMultiWalkerMem() : Resource("SoaAtomicBasisSet") {}
SoaAtomicBSetMultiWalkerMem(const SoaAtomicBSetMultiWalkerMem&) : SoaAtomicBSetMultiWalkerMem() {}
std::unique_ptr<Resource> makeClone() const override
{
return std::make_unique<SoaAtomicBSetMultiWalkerMem>(*this);
}
OffloadNelecVPBCArray ylm_v; // [Nelec][PBC][NYlm]
OffloadNelecVPBCArray rnl_v; // [Nelec][PBC][NRnl]
OffloadNelecVGLPBCArray ylm_vgl; // [5][Nelec][PBC][NYlm]
OffloadNelecVGLPBCArray rnl_vgl; // [5][Nelec][PBC][NRnl]
OffloadRPBCArray dr_pbc; // [PBC][xyz]
OffloadNelecVPBCArray dr; // [Nelec][PBC][xyz]
OffloadNelecPBCArray r; // [Nelec][PBC]
OffloadVector correctphase; // [Nelec]
};
};
} // namespace qmcplusplus

View File

@ -21,6 +21,56 @@
namespace qmcplusplus
{
template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::createResource(ResourceCollection& collection) const
{
for (int i = 0; i < LOBasisSet.size(); i++)
LOBasisSet[i]->createResource(collection);
}
template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& bs_list) const
{
auto& loc_bs_leader = bs_list.template getCastedLeader<SoaLocalizedBasisSet<COT, ORBT>>();
auto& atom_bs_leader = loc_bs_leader.LOBasisSet;
const int num_ctr = loc_bs_leader.LOBasisSet.size();
for (int i = 0; i < num_ctr; i++)
{
const auto atom_bs_list(extractLOBasisRefList(bs_list, i));
atom_bs_leader[i]->acquireResource(collection, atom_bs_list);
}
}
template<class COT, typename ORBT>
void SoaLocalizedBasisSet<COT, ORBT>::releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& bs_list) const
{
auto& loc_bs_leader = bs_list.template getCastedLeader<SoaLocalizedBasisSet<COT, ORBT>>();
auto& atom_bs_leader = loc_bs_leader.LOBasisSet;
const int num_ctr = loc_bs_leader.LOBasisSet.size();
for (int i = 0; i < num_ctr; i++)
{
const auto atom_bs_list(extractLOBasisRefList(bs_list, i));
atom_bs_leader[i]->releaseResource(collection, atom_bs_list);
}
}
template<class COT, typename ORBT>
RefVectorWithLeader<COT> SoaLocalizedBasisSet<COT, ORBT>::extractLOBasisRefList(
const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& bs_list,
int id)
{
auto& bs_leader = bs_list.template getCastedLeader<SoaLocalizedBasisSet<COT, ORBT>>();
RefVectorWithLeader<COT> atom_bs_list(*bs_leader.LOBasisSet[id]);
atom_bs_list.reserve(bs_list.size());
for (size_t iw = 0; iw < bs_list.size(); iw++)
{
auto& bs_i = bs_list.template getCastedElement<SoaLocalizedBasisSet<COT, ORBT>>(iw);
atom_bs_list.push_back(*bs_i.LOBasisSet[id]);
}
return atom_bs_list;
}
template<class COT, typename ORBT>
SoaLocalizedBasisSet<COT, ORBT>::SoaLocalizedBasisSet(ParticleSet& ions, ParticleSet& els)
: ions_(ions),

View File

@ -172,6 +172,24 @@ public:
* @param aos a set of Centered Atomic Orbitals
*/
void add(int icenter, std::unique_ptr<COT> aos);
/** initialize a shared resource and hand it to collection
*/
void createResource(ResourceCollection& collection) const override;
/** acquire a shared resource from collection
*/
void acquireResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& spo_list) const override;
/** return a shared resource to collection
*/
void releaseResource(ResourceCollection& collection,
const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& spo_list) const override;
static RefVectorWithLeader<COT> extractLOBasisRefList(const RefVectorWithLeader<SoaBasisSetBase<ORBT>>& bs_list,
int id);
};
} // namespace qmcplusplus
#endif