Merge pull request #909 from markdewing/extract_min

Extract 1D minimization routines from CuspCorr.cpp
This commit is contained in:
Ye Luo 2018-07-02 16:50:01 -05:00 committed by GitHub
commit af170d59bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 260 additions and 1 deletions

View File

@ -0,0 +1,137 @@
//////////////////////////////////////////////////////////////////////////////////////
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2018 Jeongnim Kim and QMCPACK developers.
//
// File developed by: Mark Dewing, mdewing@anl.gov, Argonne National Laboratory
//
// File created by: Mark Dewing, mdewing@anl.gov, Argonne National Laboratory
//////////////////////////////////////////////////////////////////////////////////////
#ifndef QMCPLUSPLUS_MINIMIZE_ONED_H
#define QMCPLUSPLUS_MINIMIZE_ONED_H
#include <algorithm>
#include <stdexcept>
#include <tuple>
// Storage for bracketed minimum.
template<typename T>
struct Bracket_min_t {
T a;
T b;
T c;
bool success;
Bracket_min_t(T a1, T b1, T c1, bool okay=true) : a(a1), b(b1), c(c1), success(okay) {}
};
// Minimize a function in one dimension
// Bracket a minimum in preparation for minimization
// If 'bound_max' is a positive number, the search range is bounded between zero and 'bound_max'.
// If the search falls outside that range, the function returns with bracket.success set to 'false',
// and the position in bracket.a. This means the minimum occurs at the edge of the boundary, and
// there is no need to call 'find_minimum' (nor should it be called).
template <class F, typename T> Bracket_min_t<T> bracket_minimum(const F &f, T initial_value, T bound_max = -1.0)
{
T xa = initial_value;
T fa = f(xa);
T xb = xa + 0.005;
T fb = f(xb);
// swap a and b
if (fb > fa) {
std::swap(xa, xb);
std::swap(fa, fb);
}
bool check_bound = false;
if (bound_max > 0.0) {
check_bound = true;
}
T best_val = xb;
T delx = 1.61*(xb - xa);
T xd = xb + delx;
T fd = f(xd);
int cnt = 0;
while (fb > fd) {
T xtmp = xb; T ftmp = fb;
xb = xd; fb = fd;
xa = xtmp; fa = ftmp;
xd += delx;
if (check_bound && (xd < 0.0 || xd > bound_max)) {
// minimum occurs at the boundary of the range
return Bracket_min_t<T>(best_val, 0.0, 0.0, false);
}
fd = f(xd);
if (cnt == 50) {
delx *= 5;
}
if (cnt == 100) {
delx *= 5;
}
cnt++;
if (cnt == 1000) {
throw std::runtime_error("Failed to bracket minimum");
}
}
if (xa > xd) std::swap(xa, xd);
return Bracket_min_t<T>(xa, xb, xd);
}
// Use a golden-section search
// Returns a pair with the location of the minimum and the value of the function.
template <class F, typename T> std::pair<T, T> find_minimum(const F &f, Bracket_min_t<T> &bracket)
{
// assert(bracket.success == true);
T xa = bracket.a;
T xb = bracket.b;
T xd = bracket.c;
T fa = f(xa);
T fb = f(xb);
T xc = xb + 0.4*(xd - xb);
T fc = f(xc);
T tol = 1e-5;
while (std::abs(xa-xd) > tol*(std::abs(xb) + std::abs(xc)))
{
if (fb > fc) {
xa = xb;
xb = xa + 0.4*(xd-xa);
fb = f(xb);
xc = xa + 0.6*(xd-xa);
fc = f(xc);
} else {
xd = xc;
xb = xa + 0.4*(xd-xa);
fb = f(xb);
xc = xa + 0.6*(xd-xa);
fc = f(xc);
}
}
T final_value;
T final_x;
if (fb < fc) {
final_x = xb;
} else {
final_x = xc;
}
final_value = f(final_x);
return std::pair<T, T>(final_x, final_value);
}
#endif

View File

@ -22,7 +22,7 @@ SET(UTEST_NAME unit_test_${SRC_DIR})
SET(UTEST_SRCS test_grid_functor.cpp test_nr_spline.cpp test_stdlib.cpp test_bessel.cpp
test_ylm.cpp test_e2iphi.cpp test_aligned_allocator.cpp
test_gaussian_basis.cpp test_cartesian_tensor.cpp test_soa_cartesian_tensor.cpp
test_transform.cpp)
test_transform.cpp test_min_oned.cpp)
# Run gen_gto.py to create these files. They may take a long time to compile.
#SET(UTEST_SRCS ${UTEST_SRCS} test_full_cartesian_tensor.cpp test_full_soa_cartesian_tensor.cpp)

View File

@ -0,0 +1,122 @@
//////////////////////////////////////////////////////////////////////////////////////
// This file is distributed under the University of Illinois/NCSA Open Source License.
// See LICENSE file in top directory for details.
//
// Copyright (c) 2018 Jeongnim Kim and QMCPACK developers.
//
// File developed by: Mark Dewing, mdewing@anl.gov, Argonne National Laboratory
//
// File created by: Mark Dewing, mewing@anl.gov, Argonne National Laboratory
//////////////////////////////////////////////////////////////////////////////////////
#include "catch.hpp"
#include <iostream>
#include "Numerics/MinimizeOneDim.h"
namespace qmcplusplus
{
typedef double RealType;
class MinTest
{
public:
MinTest(double value=0.0) : min_value(value) {}
RealType min_value;
RealType one_cycle(RealType x)
{
return (x-min_value)*(x-min_value);
}
void find_bracket(RealType x0)
{
auto bracket = bracket_minimum([this](RealType x) -> RealType{return one_cycle(x);}, x0);
REQUIRE(bracket.success == true);
RealType xa = bracket.a;
RealType xb = bracket.b;
RealType xc = bracket.c;
//std::cout << " xa = " << xa;
//std::cout << " xb = " << xb;
//std::cout << " xc = " << xc;
//std::cout << std::endl;
REQUIRE(xa < xb);
REQUIRE(xb < xc);
// For a starting point of 1.3
//REQUIRE(xa == Approx(-0.0041));
//REQUIRE(xb == Approx( 0.03615));
//REQUIRE(xc == Approx(-0.04435));
RealType fa = one_cycle(xa);
RealType fb = one_cycle(xb);
RealType fc = one_cycle(xc);
REQUIRE(fa > fb);
REQUIRE(fc > fb);
}
// ensure the bracket search will find a minimum at the edge of the bound
void find_bracket_bound(RealType x0, RealType bound)
{
auto bracket = bracket_minimum([this](RealType x) -> RealType{return one_cycle(x);}, x0, bound);
REQUIRE(bracket.success == false);
}
void find_min(RealType x0)
{
auto bracket = bracket_minimum([this](RealType x) -> RealType{return one_cycle(x);}, x0);
auto m = find_minimum([this](RealType x) -> RealType{return one_cycle(x);}, bracket);
REQUIRE(m.first == Approx(min_value));
REQUIRE(m.second == Approx(0.0));
}
};
TEST_CASE("bracket minimum", "[numerics]")
{
MinTest min_test;
min_test.find_bracket(1.3);
min_test.find_bracket(-1.3);
min_test.find_bracket(10.0);
MinTest min_test2(1.5);
min_test2.find_bracket(1.3);
min_test2.find_bracket(-1.3);
min_test2.find_bracket(10.0);
min_test2.find_bracket_bound(1.2, 1.4);
MinTest min_test3(-0.5);
min_test3.find_bracket(1.3);
min_test3.find_bracket(-1.3);
min_test3.find_bracket(10.0);
min_test3.find_bracket_bound(1.0, 2.0);
}
TEST_CASE("find minimum", "[numerics]")
{
MinTest min_test;
min_test.find_min(1.3);
min_test.find_min(-1.3);
min_test.find_min(10.0);
MinTest min_test2(1.5);
min_test2.find_min(1.3);
min_test2.find_min(-1.3);
min_test2.find_min(10.0);
MinTest min_test3(-0.5);
min_test3.find_min(1.3);
min_test3.find_min(-1.3);
min_test3.find_min(10.0);
}
}