From ae67ff5b096dcea7037056f3a41de793f9211050 Mon Sep 17 00:00:00 2001 From: Ye Luo Date: Tue, 23 Aug 2022 13:51:38 -0500 Subject: [PATCH] Add isRotationSupported. --- .../BsplineFactory/SplineR2R.h | 1 + src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h | 2 + src/QMCWaveFunctions/SPOSet.cpp | 98 ++++++++++++++++++- src/QMCWaveFunctions/SPOSet.h | 23 ++--- src/QMCWaveFunctions/SPOSetBuilder.cpp | 3 + 5 files changed, 110 insertions(+), 17 deletions(-) diff --git a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h index 06397cf78..71ce5458f 100644 --- a/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h +++ b/src/QMCWaveFunctions/BsplineFactory/SplineR2R.h @@ -79,6 +79,7 @@ public: virtual std::string getClassName() const override { return "SplineR2R"; } virtual std::string getKeyword() const override { return "SplineR2R"; } bool isComplex() const override { return false; }; + bool isRotationSupported() const override { return true; } std::unique_ptr makeClone() const override { return std::make_unique(*this); } diff --git a/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h b/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h index 799ae3a94..02945a6a5 100644 --- a/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h +++ b/src/QMCWaveFunctions/LCAO/LCAOrbitalSet.h @@ -49,6 +49,8 @@ public: virtual std::string getClassName() const override { return "LCAOrbitalSet"; } + bool isRotationSupported() const override { return true; } + bool hasIonDerivs() const override { return true; } std::unique_ptr makeClone() const override; diff --git a/src/QMCWaveFunctions/SPOSet.cpp b/src/QMCWaveFunctions/SPOSet.cpp index e1d8a6ed8..82815cf42 100644 --- a/src/QMCWaveFunctions/SPOSet.cpp +++ b/src/QMCWaveFunctions/SPOSet.cpp @@ -232,6 +232,89 @@ void SPOSet::evaluateVGHGH(const ParticleSet& P, "::evaluate(P,iat,psi,dpsi,dhpsi,dghpsi) (vector quantities)\n"); } +void SPOSet::applyRotation(const ValueMatrix& rot_mat, bool use_stored_copy) +{ + if (isRotationSupported()) + throw std::logic_error("Bug!! " + getClassName() + + "::applyRotation " + "must be overloaded when the SPOSet supports rotation."); +} + +void SPOSet::evaluateDerivatives(ParticleSet& P, + const opt_variables_type& optvars, + Vector& dlogpsi, + Vector& dhpsioverpsi, + const int& FirstIndex, + const int& LastIndex) +{ + if (isOptimizable()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateDerivatives " + "must be overloaded when the SPOSet is optimizable."); +} + +/** Evaluate the derivative of the optimized orbitals with respect to the parameters + * this is used only for MSD, to be refined for better serving both single and multi SD + */ +void SPOSet::evaluateDerivatives(ParticleSet& P, + const opt_variables_type& optvars, + Vector& dlogpsi, + Vector& dhpsioverpsi, + const ValueType& psiCurrent, + const std::vector& Coeff, + const std::vector& C2node_up, + const std::vector& C2node_dn, + const ValueVector& detValues_up, + const ValueVector& detValues_dn, + const GradMatrix& grads_up, + const GradMatrix& grads_dn, + const ValueMatrix& lapls_up, + const ValueMatrix& lapls_dn, + const ValueMatrix& M_up, + const ValueMatrix& M_dn, + const ValueMatrix& Minv_up, + const ValueMatrix& Minv_dn, + const GradMatrix& B_grad, + const ValueMatrix& B_lapl, + const std::vector& detData_up, + const size_t N1, + const size_t N2, + const size_t NP1, + const size_t NP2, + const std::vector>& lookup_tbl) +{ + if (isOptimizable()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateDerivatives " + "must be overloaded when the SPOSet is optimizable."); +} + +/** Evaluate the derivative of the optimized orbitals with respect to the parameters + * this is used only for MSD, to be refined for better serving both single and multi SD + */ +void SPOSet::evaluateDerivativesWF(ParticleSet& P, + const opt_variables_type& optvars, + Vector& dlogpsi, + const QTFull::ValueType& psiCurrent, + const std::vector& Coeff, + const std::vector& C2node_up, + const std::vector& C2node_dn, + const ValueVector& detValues_up, + const ValueVector& detValues_dn, + const ValueMatrix& M_up, + const ValueMatrix& M_dn, + const ValueMatrix& Minv_up, + const ValueMatrix& Minv_dn, + const std::vector& detData_up, + const std::vector>& lookup_tbl) +{ + if (isOptimizable()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateDerivativesWF " + "must be overloaded when the SPOSet is optimizable."); +} + + void SPOSet::evaluateGradSource(const ParticleSet& P, int first, int last, @@ -239,7 +322,10 @@ void SPOSet::evaluateGradSource(const ParticleSet& P, int iat_src, GradMatrix& gradphi) { - throw std::runtime_error("SPOSetBase::evalGradSource is not implemented"); + if (hasIonDerivs()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateGradSource " + "must be overloaded when the SPOSet has ion derivatives."); } void SPOSet::evaluateGradSource(const ParticleSet& P, @@ -251,7 +337,10 @@ void SPOSet::evaluateGradSource(const ParticleSet& P, HessMatrix& grad_grad_phi, GradMatrix& grad_lapl_phi) { - throw std::runtime_error("SPOSetBase::evalGradSource is not implemented"); + if (hasIonDerivs()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateGradSource " + "must be overloaded when the SPOSet has ion derivatives."); } void SPOSet::evaluateGradSourceRow(const ParticleSet& P, @@ -260,7 +349,10 @@ void SPOSet::evaluateGradSourceRow(const ParticleSet& P, int iat_src, GradVector& gradphi) { - throw std::runtime_error("SPOSetBase::evalGradSourceRow is not implemented"); + if (hasIonDerivs()) + throw std::logic_error("Bug!! " + getClassName() + + "::evaluateGradSourceRow " + "must be overloaded when the SPOSet has ion derivatives."); } void SPOSet::evaluate_spin(const ParticleSet& P, int iat, ValueVector& psi, ValueVector& dpsi) diff --git a/src/QMCWaveFunctions/SPOSet.h b/src/QMCWaveFunctions/SPOSet.h index d9a09b0da..064976b2c 100644 --- a/src/QMCWaveFunctions/SPOSet.h +++ b/src/QMCWaveFunctions/SPOSet.h @@ -115,23 +115,20 @@ public: virtual void buildOptVariables(const size_t nel) {} // For the MSD case rotations must be created in MultiSlaterDetTableMethod class virtual void buildOptVariables(const std::vector>& rotations) {} - // store parameters before getting destroyed by rotation. + /// return true if this SPOSet can be wrappered by RotatedSPO + virtual bool isRotationSupported() const { return false; } + /// store parameters before getting destroyed by rotation. virtual void storeParamsBeforeRotation() {} - // apply rotation to all the orbitals - virtual void applyRotation(const ValueMatrix& rot_mat, bool use_stored_copy = false) - { - std::ostringstream o; - o << "SPOSet::applyRotation is not implemented by " << getClassName() << std::endl; - APP_ABORT(o.str()); - } + /// apply rotation to all the orbitals + virtual void applyRotation(const ValueMatrix& rot_mat, bool use_stored_copy = false); virtual void evaluateDerivatives(ParticleSet& P, const opt_variables_type& optvars, Vector& dlogpsi, Vector& dhpsioverpsi, const int& FirstIndex, - const int& LastIndex) - {} + const int& LastIndex); + /** Evaluate the derivative of the optimized orbitals with respect to the parameters * this is used only for MSD, to be refined for better serving both single and multi SD */ @@ -160,8 +157,7 @@ public: const size_t N2, const size_t NP1, const size_t NP2, - const std::vector>& lookup_tbl) - {} + const std::vector>& lookup_tbl); /** Evaluate the derivative of the optimized orbitals with respect to the parameters * this is used only for MSD, to be refined for better serving both single and multi SD @@ -180,8 +176,7 @@ public: const ValueMatrix& Minv_up, const ValueMatrix& Minv_dn, const std::vector& detData_up, - const std::vector>& lookup_tbl) - {} + const std::vector>& lookup_tbl); /** set the OrbitalSetSize * @param norbs number of single-particle orbitals diff --git a/src/QMCWaveFunctions/SPOSetBuilder.cpp b/src/QMCWaveFunctions/SPOSetBuilder.cpp index fc66f0991..96d4c8520 100644 --- a/src/QMCWaveFunctions/SPOSetBuilder.cpp +++ b/src/QMCWaveFunctions/SPOSetBuilder.cpp @@ -98,6 +98,9 @@ std::unique_ptr SPOSetBuilder::createSPOSet(xmlNodePtr cur) // create sposet with rotation auto& sposet_ref = *sposet; app_log() << " SPOSet " << sposet_ref.getName() << " is optimizable\n"; + if (!sposet_ref.isRotationSupported()) + myComm->barrier_and_abort("Orbital rotation not supported with '" + sposet_ref.getName() + "' of type '" + + sposet_ref.getClassName() + "'."); auto rot_spo = std::make_unique(sposet_ref.getName(), std::move(sposet)); xmlNodePtr tcur = cur->xmlChildrenNode; while (tcur != NULL)