mirror of https://github.com/abinit/abipy.git
Fix unit tests
This commit is contained in:
parent
9399078334
commit
469d7c7cf1
|
@ -30,7 +30,5 @@ omit =
|
|||
abipy/examples/*
|
||||
abipy/gui/*
|
||||
abipy/gw/*
|
||||
abipy/extensions/*
|
||||
abipy/gw/*
|
||||
abipy/scripts/*
|
||||
./docs/
|
||||
|
|
|
@ -258,6 +258,32 @@ class AbstractInput(six.with_metaclass(abc.ABCMeta, MutableMapping, object)):
|
|||
task: Task object
|
||||
"""
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
This function generates new inputs by replacing the variables specified in kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: keyword arguments with the values used for each variable.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
gs_inp = call_function_to_generate_initial_template()
|
||||
|
||||
# To generate two input files with different values of ecut:
|
||||
for inp_ecut in gs_inp.generate(ecut=[10, 20]):
|
||||
print("do something with inp_ecut %s" % inp_ecut)
|
||||
|
||||
# To generate four input files with all the possible combinations of ecut and nsppol:
|
||||
for inp_ecut in gs_inpt.generate(ecut=[10, 20], nsppol=[1, 2]):
|
||||
print("do something with inp_ecut %s" % inp_ecut)
|
||||
"""
|
||||
for new_vars in product_dict(kwargs):
|
||||
new_inp = self.deepcopy()
|
||||
# Remove the variable names to avoid annoying warnings if the variable is overwritten.
|
||||
new_inp.remove_vars(new_vars.keys())
|
||||
new_inp.set_vars(**new_vars)
|
||||
yield new_inp
|
||||
|
||||
|
||||
class AbinitInputError(Exception):
|
||||
"""Base error class for exceptions raised by ``AbinitInput``."""
|
||||
|
@ -876,10 +902,8 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbstractInput, MSONable, Has_S
|
|||
In that case, the sequence consists of all but the last of ``ndtset + 1``
|
||||
evenly spaced samples, so that `stop` is excluded. Note that the step
|
||||
size changes when `endpoint` is False.
|
||||
num : int, optional
|
||||
Number of samples to generate. Default is 50.
|
||||
endpoint : bool, optional
|
||||
If True, `stop` is the last sample. Otherwise, it is not included.
|
||||
num (int): Number of samples to generate. Default is 50.
|
||||
endpoint (bool): optional. If True, `stop` is the last sample. Otherwise, it is not included.
|
||||
Default is True.
|
||||
"""
|
||||
inps = []
|
||||
|
@ -3149,3 +3173,45 @@ class Cut3DInput(MSONable, object):
|
|||
"""
|
||||
return cls(infile_path=d.get('infile_path', None), output_filepath=d.get('output_filepath', None),
|
||||
options=d.get('options', None))
|
||||
|
||||
|
||||
def product_dict(d):
|
||||
"""
|
||||
This function receives a dictionary d where each key defines a list of items or a simple scalar.
|
||||
It constructs the Cartesian product of the values (equivalent to nested for-loops),
|
||||
and returns a list of dictionaries with the values that would be used inside the loop.
|
||||
|
||||
>>> d = OrderedDict([("foo", [2, 4]), ("bar", 1)])
|
||||
>>> product_dict(d) == [OrderedDict([('foo', 2), ('bar', 1)]), OrderedDict([('foo', 4), ('bar', 1)])]
|
||||
True
|
||||
>>> d = OrderedDict([("bar", [1,2]), ('foo', [3,4])])
|
||||
>>> product_dict(d) == [{'bar': 1, 'foo': 3},
|
||||
... {'bar': 1, 'foo': 4},
|
||||
... {'bar': 2, 'foo': 3},
|
||||
... {'bar': 2, 'foo': 4}]
|
||||
True
|
||||
|
||||
.. warning:
|
||||
|
||||
Dictionaries are not ordered, therefore one cannot assume that
|
||||
the order of the keys in the output equals the one used to loop.
|
||||
If the order is important, one should pass a :class:`OrderedDict` in input.
|
||||
"""
|
||||
keys, vals = d.keys(), d.values()
|
||||
|
||||
# Each item in vals must be iterable.
|
||||
values = []
|
||||
|
||||
for v in vals:
|
||||
if not isinstance(v, collections.Iterable): v = [v]
|
||||
values.append(v)
|
||||
|
||||
# Build list of dictionaries. Use ordered dicts so that
|
||||
# we preserve the order when d is an OrderedDict.
|
||||
vars_prod = []
|
||||
|
||||
for prod_values in itertools.product(*values):
|
||||
dprod = OrderedDict(zip(keys, prod_values))
|
||||
vars_prod.append(dprod)
|
||||
|
||||
return vars_prod
|
||||
|
|
|
@ -122,6 +122,14 @@ class TestAbinitInput(AbipyTest):
|
|||
self.serialize_with_pickle(inp, test_eq=False)
|
||||
self.assertMSONable(inp)
|
||||
|
||||
# Test generate method.
|
||||
ecut_list = [10, 20]
|
||||
for i, ginp in enumerate(inp.generate(ecut=ecut_list)):
|
||||
assert ginp["ecut"] == ecut_list[i]
|
||||
|
||||
inp_list = list(inp.generate(ecut=[10, 20], nsppol=[1, 2]))
|
||||
assert len(inp_list) == 4
|
||||
|
||||
# Test tags
|
||||
assert isinstance(inp.tags, set)
|
||||
assert len(inp.tags) == 0
|
||||
|
|
|
@ -21,7 +21,7 @@ def make_input(paw=False):
|
|||
structure = abidata.structure_from_ucell("SiO2-alpha")
|
||||
|
||||
inp = abilab.AbinitInput(structure, pseudos)
|
||||
inp.set_kmesh(ngkpt=[1,1,1], shiftk=[0,0,0])
|
||||
inp.set_kmesh(ngkpt=[1, 1, 1], shiftk=[0, 0, 0])
|
||||
|
||||
# Global variables
|
||||
ecut = 24
|
||||
|
@ -64,7 +64,7 @@ def build_flow(options):
|
|||
for npfft in mpi_list:
|
||||
if not options.accept_mpi_omp(npfft, omp_threads): continue
|
||||
manager = options.manager.new_with_fixed_mpi_omp(npfft, omp_threads)
|
||||
for inp in abilab.input_gen(template, fftalg=fftalg, npfft=npfft, ecut=ecut_list):
|
||||
for inp in template.generate(fftalg=fftalg, npfft=npfft, ecut=ecut_list):
|
||||
work.register_scf_task(inp, manager=manager)
|
||||
flow.register_work(work)
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ root = os.path.dirname(__file__)
|
|||
|
||||
__all__ = [
|
||||
"AbipyTest",
|
||||
"AbipyFileTest",
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ def build_flow(options):
|
|||
# read a submatrix when we test the convergence wrt to ecuteps.
|
||||
scr_work = flowtk.Work()
|
||||
|
||||
for inp in abilab.input_gen(scr_inp, nband=[10, 15]):
|
||||
for inp in scr_inp.generate(nband=[10, 15]):
|
||||
inp.set_vars(ecuteps=max_ecuteps)
|
||||
scr_work.register_scr_task(inp, deps={bands.nscf_task: "WFK"})
|
||||
|
||||
|
@ -112,7 +112,7 @@ def build_flow(options):
|
|||
# different SCR file computed with a different value of nband.
|
||||
|
||||
# Build a list of sigma inputs with different ecuteps
|
||||
sigma_inputs = list(abilab.input_gen(sig_inp, ecuteps=ecuteps_list))
|
||||
sigma_inputs = list(sig_inp.generate(ecuteps=ecuteps_list))
|
||||
|
||||
for scr_task in scr_work:
|
||||
sigma_conv = flowtk.SigmaConvWork(wfk_node=bands.nscf_task, scr_node=scr_task, sigma_inputs=sigma_inputs)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import print_function, division
|
||||
|
||||
from abipy.tools.numtools import *
|
||||
from abipy.core.testing import *
|
||||
from abipy.core.testing import AbiPyTest
|
||||
|
||||
|
||||
class TestTools(AbipyTest):
|
||||
|
|
Loading…
Reference in New Issue