Fix unit tests

This commit is contained in:
Matteo Giantomassi 2018-01-12 14:29:34 +01:00
parent 9399078334
commit 469d7c7cf1
7 changed files with 83 additions and 12 deletions

View File

@ -30,7 +30,5 @@ omit =
abipy/examples/*
abipy/gui/*
abipy/gw/*
abipy/extensions/*
abipy/gw/*
abipy/scripts/*
./docs/

View File

@ -258,6 +258,32 @@ class AbstractInput(six.with_metaclass(abc.ABCMeta, MutableMapping, object)):
task: Task object
"""
def generate(self, **kwargs):
"""
This function generates new inputs by replacing the variables specified in kwargs.
Args:
kwargs: keyword arguments with the values used for each variable.
.. code-block:: python
gs_inp = call_function_to_generate_initial_template()
# To generate two input files with different values of ecut:
for inp_ecut in gs_inp.generate(ecut=[10, 20]):
print("do something with inp_ecut %s" % inp_ecut)
# To generate four input files with all the possible combinations of ecut and nsppol:
for inp_ecut in gs_inpt.generate(ecut=[10, 20], nsppol=[1, 2]):
print("do something with inp_ecut %s" % inp_ecut)
"""
for new_vars in product_dict(kwargs):
new_inp = self.deepcopy()
# Remove the variable names to avoid annoying warnings if the variable is overwritten.
new_inp.remove_vars(new_vars.keys())
new_inp.set_vars(**new_vars)
yield new_inp
class AbinitInputError(Exception):
"""Base error class for exceptions raised by ``AbinitInput``."""
@ -876,10 +902,8 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbstractInput, MSONable, Has_S
In that case, the sequence consists of all but the last of ``ndtset + 1``
evenly spaced samples, so that `stop` is excluded. Note that the step
size changes when `endpoint` is False.
num : int, optional
Number of samples to generate. Default is 50.
endpoint : bool, optional
If True, `stop` is the last sample. Otherwise, it is not included.
num (int): Number of samples to generate. Default is 50.
endpoint (bool): optional. If True, `stop` is the last sample. Otherwise, it is not included.
Default is True.
"""
inps = []
@ -3149,3 +3173,45 @@ class Cut3DInput(MSONable, object):
"""
return cls(infile_path=d.get('infile_path', None), output_filepath=d.get('output_filepath', None),
options=d.get('options', None))
def product_dict(d):
"""
This function receives a dictionary d where each key defines a list of items or a simple scalar.
It constructs the Cartesian product of the values (equivalent to nested for-loops),
and returns a list of dictionaries with the values that would be used inside the loop.
>>> d = OrderedDict([("foo", [2, 4]), ("bar", 1)])
>>> product_dict(d) == [OrderedDict([('foo', 2), ('bar', 1)]), OrderedDict([('foo', 4), ('bar', 1)])]
True
>>> d = OrderedDict([("bar", [1,2]), ('foo', [3,4])])
>>> product_dict(d) == [{'bar': 1, 'foo': 3},
... {'bar': 1, 'foo': 4},
... {'bar': 2, 'foo': 3},
... {'bar': 2, 'foo': 4}]
True
.. warning:
Dictionaries are not ordered, therefore one cannot assume that
the order of the keys in the output equals the one used to loop.
If the order is important, one should pass a :class:`OrderedDict` in input.
"""
keys, vals = d.keys(), d.values()
# Each item in vals must be iterable.
values = []
for v in vals:
if not isinstance(v, collections.Iterable): v = [v]
values.append(v)
# Build list of dictionaries. Use ordered dicts so that
# we preserve the order when d is an OrderedDict.
vars_prod = []
for prod_values in itertools.product(*values):
dprod = OrderedDict(zip(keys, prod_values))
vars_prod.append(dprod)
return vars_prod

View File

@ -122,6 +122,14 @@ class TestAbinitInput(AbipyTest):
self.serialize_with_pickle(inp, test_eq=False)
self.assertMSONable(inp)
# Test generate method.
ecut_list = [10, 20]
for i, ginp in enumerate(inp.generate(ecut=ecut_list)):
assert ginp["ecut"] == ecut_list[i]
inp_list = list(inp.generate(ecut=[10, 20], nsppol=[1, 2]))
assert len(inp_list) == 4
# Test tags
assert isinstance(inp.tags, set)
assert len(inp.tags) == 0

View File

@ -21,7 +21,7 @@ def make_input(paw=False):
structure = abidata.structure_from_ucell("SiO2-alpha")
inp = abilab.AbinitInput(structure, pseudos)
inp.set_kmesh(ngkpt=[1,1,1], shiftk=[0,0,0])
inp.set_kmesh(ngkpt=[1, 1, 1], shiftk=[0, 0, 0])
# Global variables
ecut = 24
@ -64,7 +64,7 @@ def build_flow(options):
for npfft in mpi_list:
if not options.accept_mpi_omp(npfft, omp_threads): continue
manager = options.manager.new_with_fixed_mpi_omp(npfft, omp_threads)
for inp in abilab.input_gen(template, fftalg=fftalg, npfft=npfft, ecut=ecut_list):
for inp in template.generate(fftalg=fftalg, npfft=npfft, ecut=ecut_list):
work.register_scf_task(inp, manager=manager)
flow.register_work(work)

View File

@ -29,7 +29,6 @@ root = os.path.dirname(__file__)
__all__ = [
"AbipyTest",
"AbipyFileTest",
]

View File

@ -102,7 +102,7 @@ def build_flow(options):
# read a submatrix when we test the convergence wrt to ecuteps.
scr_work = flowtk.Work()
for inp in abilab.input_gen(scr_inp, nband=[10, 15]):
for inp in scr_inp.generate(nband=[10, 15]):
inp.set_vars(ecuteps=max_ecuteps)
scr_work.register_scr_task(inp, deps={bands.nscf_task: "WFK"})
@ -112,7 +112,7 @@ def build_flow(options):
# different SCR file computed with a different value of nband.
# Build a list of sigma inputs with different ecuteps
sigma_inputs = list(abilab.input_gen(sig_inp, ecuteps=ecuteps_list))
sigma_inputs = list(sig_inp.generate(ecuteps=ecuteps_list))
for scr_task in scr_work:
sigma_conv = flowtk.SigmaConvWork(wfk_node=bands.nscf_task, scr_node=scr_task, sigma_inputs=sigma_inputs)

View File

@ -1,7 +1,7 @@
from __future__ import print_function, division
from abipy.tools.numtools import *
from abipy.core.testing import *
from abipy.core.testing import AbiPyTest
class TestTools(AbipyTest):