Merge pull request #7 from abinit/gp_qha

Test (i,j) and (j, i) if select=all
This commit is contained in:
Guido Petretto 2018-08-01 09:33:30 +02:00 committed by GitHub
commit babcebc0fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 16 deletions

View File

@ -1209,7 +1209,7 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbiAbstractInput, MSONable, Ha
return ph_inputs
def make_ddk_inputs(self, tolerance=None, kptopt=2):
def make_ddk_inputs(self, tolerance=None, kptopt=2, manager=None):
"""
Return inputs for performing DDK calculations.
This functions should be called with an input the represents a GS run.
@ -1218,6 +1218,7 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbiAbstractInput, MSONable, Ha
kptopt: 2 to take into account time-reversal symmetry. note that kptopt 1 is not available.
tolerance: dict {varname: value} with the tolerance to be used in the DFPT run.
Defaults to {"tolwfr": 1.0e-22}.
manager: |TaskManager| of the task. If None, the manager is initialized from the config file.
Return:
List of |AbinitInput| objects for DFPT runs.
@ -1243,6 +1244,7 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbiAbstractInput, MSONable, Ha
for rfdir, ddk_input in zip(ddk_rfdirs, ddk_inputs):
ddk_input.set_vars(
rfelfd=2, # Activate the calculation of the d/dk perturbation
# only the derivative of ground-state wavefunctions with respect to k
rfdir=rfdir, # Direction of the per ddk.
nqpt=1, # One wavevector is to be considered
qpt=(0, 0, 0), # q-wavevector.
@ -1425,7 +1427,7 @@ class AbinitInput(six.with_metaclass(abc.ABCMeta, AbiAbstractInput, MSONable, Ha
return multi
def make_strain_perts_inputs(self, tolerance=None, manager=None, phonon_pert=True, kptopt=2):
def make_strain_perts_inputs(self, tolerance=None, phonon_pert=True, kptopt=2, manager=None):
"""
Return inputs for the strain perturbation calculation.
This functions should be called with an input that represents a GS run.

View File

@ -152,6 +152,9 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
app("")
app("Number of q-points in DDB: %d" % len(self.qpoints))
app("guessed_ngqpt: %s (guess for the q-mesh divisions made by AbiPy)" % self.guessed_ngqpt)
app("Has total energy: %s" % (self.total_energy is not None))
app("Has forces: %s" % (self.cart_forces is not None))
app("Has stress tensor: %s" % (self.cart_stress_tensor is not None))
app("Has (at least one) atomic pertubation: %s" % self.has_at_least_one_atomic_perturbation())
app("Has (at least one) electric-field perturbation: %s" % self.has_emacro_terms(select="at_least_one"))
app("Has (at least one) Born effective charge: %s" % self.has_bec_terms(select="at_least_one"))
@ -614,10 +617,12 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
for p1 in ep_list:
for p2 in ep_list:
p12 = p1 + p2
p21 = p2 + p1
if select == "at_least_one":
if p12 in index_set: return True
elif select == "all":
if p12 not in index_set: return False
if p12 not in index_set and p21 not in index_set:
return False
else:
raise ValueError("Wrong select %s" % str(select))
@ -645,10 +650,12 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
for ap1 in ap_list:
for ep2 in ep_list:
p12 = ap1 + ep2
p21 = ep2 + ap1
if select == "at_least_one":
if p12 in index_set: return True
elif select == "all":
if p12 not in index_set: return False
if p12 not in index_set and p21 not in index_set:
return False
else:
raise ValueError("Wrong select %s" % str(select))
@ -682,10 +689,11 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
for p1 in sp_list:
for p2 in sp_list:
p12 = p1 + p2
p21 = p2 + p1
if select == "at_least_one":
if p12 in index_set: return True
elif select == "all":
if p12 not in index_set:
if p12 not in index_set and p21 not in index_set:
#print("p12", p12, "not in index_set")
return False
else:
@ -722,10 +730,13 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
for p1 in sp_list:
for p2 in ap_list:
p12 = p1 + p2
p21 = p2 + p1
if select == "at_least_one":
if p12 in index_set: return True
elif select == "all":
if p12 not in index_set: return False
if p12 not in index_set and p21 not in index_set:
#print("p12", p12, "non in index")
return False
else:
raise ValueError("Wrong select %s" % str(select))
@ -760,10 +771,11 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
for p1 in sp_list:
for p2 in ep_list:
p12 = p1 + p2
p21 = p2 + p1
if select == "at_least_one":
if p12 in index_set: return True
elif select == "all":
if p12 not in index_set: return False
if p12 not in index_set and p21 not in index_set: return False
else:
raise ValueError("Wrong select %s" % str(select))
@ -1932,6 +1944,8 @@ class DdbRobot(Robot):
# from abipy.dfpt.anaddbnc import AnaddbNcRobot
# return AnaddbNcRobot.from_files(anaddbnc_paths)
#def compare_computed_dynmat(self):
def yield_figs(self, **kwargs): # pragma: no cover
"""
This function *generates* a predefined list of matplotlib figures with minimal input from the user.

View File

@ -24,28 +24,35 @@ class ElasticWork(Work, MergeDdb):
# Register task for SCF calculation.
scf_task = new.register_scf_task(scf_input)
multi = scf_task.input.make_strain_perts_inputs(tolerance=tolerance, manager=manager, phonon_pert=False, kptopt=2)
multi = scf_task.input.make_strain_perts_inputs(tolerance=tolerance, manager=manager,
phonon_pert=with_internal_strain, kptopt=2)
ddk_tasks = []
if with_piezoelectric:
#sfc_task.input.make_ddk_inputs(tolerance=None, manager=None):
for inp in multi.filter_by_tags(tags=tags.DDK):
ddk_multi = scf_task.input.make_ddk_inputs(tolerance=None, manager=manager)
#for inp in multi.filter_by_tags(tags=tags.DDK):
for inp in ddk_multi:
ddk_task = new.register_ddk_task(inp, deps={scf_task: "WFK"})
ddk_tasks.append(ddk_task)
assert len(ddk_tasks) == 3
ddk_deps = {ddk_task: "DDK" for ddk_task in ddk_tasks}
if with_internal_strain:
ph_deps = {scf_task: "WFK"}
#if with_piezoelectric: ph_deps.update()
for inp in multi.filter_by_tags(tags=tags.PHONON):
new.register_phonon_task(inp, deps=ph_deps)
if with_piezoelectric: ph_deps.update(ddk_deps)
#for inp in multi.filter_by_tags(tags=tags.PHONON):
for inp in multi:
if inp.get("rfphon", 0) == 1:
#new.register_phonon_task(inp, deps=ph_deps)
new.register_bec_task(inp, deps=ph_deps)
#bec_deps = {ddk_task: "DDK" for ddk_task in ddk_tasks}
elast_deps = {scf_task: "WFK"}
#if with_piezoelectric: ph_deps.update()
if with_piezoelectric: elast_deps.update(ddk_deps)
#for inp in multi.filter_by_tags(tags=tags.STRAIN)
for inp in multi:
new.register_elastic_task(inp, deps=elast_deps)
# FIXME
#new.register_elastic_task(inp, deps=elast_deps)
new.register_bec_task(inp, deps=elast_deps)
return new