Add Robot.getattrs_alleq method

This commit is contained in:
Matteo Giantomassi 2024-09-15 11:58:22 +02:00
parent abc2c9af93
commit 7df0777550
3 changed files with 150 additions and 91 deletions

View File

@ -75,6 +75,7 @@ from abipy.eph.v1qavg import V1qAvgFile
from abipy.eph.rta import RtaFile, RtaRobot
from abipy.eph.transportfile import TransportFile
from abipy.eph.gstore import GstoreFile
from abipy.eph.gpath import GpathFile
from abipy.wannier90 import WoutFile, AbiwanFile, AbiwanRobot
from abipy.electrons.lobster import CoxpFile, ICoxpFile, LobsterDoscarFile, LobsterInput, LobsterAnalyzer
@ -173,6 +174,7 @@ abiext2ncfile = collections.OrderedDict([
("A2F.nc", A2fFile),
("SIGEPH.nc", SigEPhFile),
("GSTORE.nc", GstoreFile),
("GPATH.nc", GpathFile),
("TRANSPORT.nc",TransportFile),
("RTA.nc",RtaFile),
("V1SYM.nc", V1symFile),

View File

@ -613,23 +613,50 @@ class Robot(NotebookWriter):
"""Integration with jupyter_ notebooks."""
return '<ol start="0">\n{}\n</ol>'.format("\n".join("<li>%s</li>" % label for label, abifile in self.items()))
def getattr_alleq(self, aname : str):
def getattrs_alleq(self, *aname_args) -> list:
"""
Return the value of attribute aname.
Return list of attribute values for each attribute name in *aname_args.
"""
return [self.getattr_alleq(aname) for aname in aname_args]
def getattr_alleq(self, aname: str):
"""
Return the value of attribute aname. Try firs in self then in self.r
Raises ValueError if value is not the same across all the files in the robot.
"""
val1 = getattr(self.abifiles[0], aname)
for abifile in self.abifiles[1:]:
val2 = getattr(abifile, aname)
if isinstance(val1, (str, int, float)):
eq = val1 == val2
elif isinstance(val1, np.ndarray):
eq = np.allclose(val1, val2)
if not eq:
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")
def get_obj_list(what: str):
if what == "abifiles":
return self.abifiles
elif what == "r":
return [abifile.r for abifile in self.abifiles]
return val1
raise ValueError(f"Invalid {what=}")
err_msg = []
for what in ["abifiles", "r"]:
objs = get_obj_list(what)
try:
val1 = getattr(objs[0], aname)
except AttributeError as exc:
err_msg.append(str(exc))
continue
for obj in objs[1:]:
val2 = getattr(obj, aname)
if isinstance(val1, (str, int, float)):
eq = val1 == val2
elif isinstance(val1, np.ndarray):
eq = np.allclose(val1, val2)
if not eq:
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")
return val1
if err_msg:
raise ValueError("\n".join(err_msg))
@property
def abifiles(self) -> list:

View File

@ -18,7 +18,7 @@ from abipy.core.mixins import AbinitNcFile, Has_Structure, NotebookWriter
from abipy.tools.typing import PathLike
#from abipy.tools.numtools import nparr_to_df
from abipy.tools.plotting import (add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, set_visible,
rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend)
rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend, set_axlims)
from abipy.electrons.ebands import ElectronBands, RobotWithEbands
from abipy.dfpt.phonons import PhononBands
from abipy.dfpt.phtk import NonAnalyticalPh
@ -27,6 +27,13 @@ from abipy.abio.robots import Robot
from abipy.eph.common import BaseEphReader
def k2s(k_vector, fmt=".3f", threshold = 1e-8) -> str:
k_vector = np.asarray(k_vector)
k_vector[np.abs(k_vector) < threshold] = 0
return "[" + ", ".join(f"{x:.3f}" for x in k_vector) + "]"
class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
"""
This file stores the e-ph matrix elements along a k/q path
@ -112,9 +119,23 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
return "\n".join(lines)
@staticmethod
def _get_which_g_list(which_g: str) -> list[str]:
all_choices = ["avg", "raw"]
if which_g == "all":
return all_choices
if which_g not in all_choices:
raise ValueError(f"Invalid {which=}, should be in {all_choices=}")
return [which_g]
def _get_band_range(self, band_range):
return (self.r.bstart, self.r.bstop) if band_range is None else band_range
@add_fig_kwargs
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1,
with_phbands=True, with_ebands=False,
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1, gmax_mev=250,
ph_modes=None, with_phbands=True, with_ebands=False,
ax_mat=None, fontsize=8, **kwargs) -> Figure:
"""
Plot the averaged |g(k,q)| in meV units along the q-path
@ -124,21 +145,22 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
with_qexp: Multiply |g(q)| by |q|^{with_qexp}.
scale: Scaling factor for the marker size used when with_phbands is True.
gmax_mev: Show results up to gmax in meV.
ph_modes: List of ph branch indices to show (start from 0). If None all modes are shown.
with_phbands: False if phonon bands should now be displayed.
with_ebands: False if electron bands should now be displayed.
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
fontsize: fontsize for legends and titles
"""
nrows, ncols = 1 + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol
which_g_list = [which_g]
if which_g == "all":
which_g_list = ["avg", "raw"]
nrows += 1
which_g_list = self._get_which_g_list(which_g)
nrows, ncols = len(which_g_list) + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=False, sharey=False, squeeze=False)
marker_color = "gold"
band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range
band_range = self._get_band_range(band_range)
#facts_q, g_label, g_units = self.get_info(which_g, with_qexp)
facts_q = np.ones(len(self.phbands.qpoints)) if with_qexp == 0 else \
np.array([qpt.norm for qpt in self.phbands.qpoints]) ** with_qexp
@ -150,18 +172,22 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
ax_cnt = -1
for which_g in which_g_list:
# Select data according to which_g and multiply by facts_q
# Select ys according to which_g and multiply by facts_q
g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] * facts_q[None,:]
# Plot g_nu(q)
ax_cnt += 1
ax = ax_mat[ax_cnt, spin]
for nu in range(self.r.natom3):
if ph_modes is not None and nu not in ph_modes: continue
ax.plot(g_nuq[nu], label=f"{nu=}")
self.phbands.decorate_ax(ax, units="meV")
g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label)
set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units))
if gmax_mev is not None and with_qexp == 0:
set_axlims(ax, [0, gmax_mev], "y")
if with_phbands:
# Plot phonons bands + averaged g(q) as markers
ax_cnt += 1
@ -188,15 +214,15 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
# Add title.
if (kpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None:
kpt_name = str(self.r.eph_fix_wavec)
qpt_name = k2s(self.r.eph_fix_wavec)
fig.suptitle(f"k = {kpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}")
return fig
@add_fig_kwargs
def plot_g_kpath(self, band_range=None, which_g="sym", scale=1, with_ebands=True,
ax_mat=None, fontsize=8, **kwargs) -> Figure:
def plot_g_kpath(self, band_range=None, which_g="avg", scale=1, gmax_mev=250, ph_modes=None,
with_ebands=True, ax_mat=None, fontsize=8, **kwargs) -> Figure:
"""
Plot the averaged |g(k,q)| in meV units along the k-path
@ -204,60 +230,49 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
band_range: Band range that will be averaged over (python convention).
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
scale: Scaling factor for the marker size used when with_phbands is True.
gmax_mev: Show results up to gmax in meV.
ph_modes: List of ph branch indices to show (start from 0). If None all modes are show.
with_ebands: False if electron bands should now be displayed.
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
fontsize: fontsize for legends and titles
"""
nrows, ncols = 1 + int((np.array([with_ebands]) == True).sum()), self.r.nsppol
which_g_list = [which_g]
if which_g == "all":
which_g_list = ["avg", "raw"]
nrows += 1
which_g_list = self._get_which_g_list(which_g)
nrows, ncols = len(which_g_list) + int((np.array([with_ebands]) == True).sum()), self.r.nsppol
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
sharex=False, sharey=False, squeeze=False)
marker_color = "gold"
band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range
band_range = self._get_band_range(band_range)
for spin in range(self.r.nsppol):
g_nuk_avg, g_nuk_raw = self.r.get_gnuk_average_spin(spin, band_range)
ax_cnt = -1
for which_g in which_g_list:
# Select data according to which_g
# Select ys according to which_g
g_nuk = dict(avg=g_nuk_avg, raw=g_nuk_raw)[which_g]
# Plot g_nu(q)
ax_cnt += 1
ax = ax_mat[ax_cnt, spin]
for nu in range(self.r.natom3):
if ph_modes is not None and nu not in ph_modes: continue
ax.plot(g_nuk[nu], label=f"{which_g} {nu=}")
# Plot g(k)
self.ebands_k.decorate_ax(ax, units="meV")
set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}(\mathbf{k})|$ (meV)" % (which_g))
set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}_{\mathbf{k}}|$ (meV)" % (which_g))
if gmax_mev is not None:
set_axlims(ax, [0, gmax_mev], "y")
if with_ebands:
# Plot electron bands + averaged g(k) as markers
# Plot electron bands
ax_cnt += 1
points = None
#x, y, s = [], [], []
#for ik, kpoint in enumerate(self.ebands_k.kpoints):
# omegas_nu = self.phbands.phfreqs[iq,:]
# for w, g2 in zip(omegas_nu, g_nuk[:,iq], strict=True):
# x.append(iq); y.append(w); s.append(scale * g2)
#points = Marker(x, y, s, color=marker_color, edgecolors='gray', alpha=0.8,
# label=r'$|g^{\text{avg}}(\mathbf{k})|$ (meV)')
ax = ax_mat[ax_cnt, spin]
self.ebands_k.plot(ax=ax, spin=spin, band_range=band_range, with_gaps=False, show=False)
set_grid_legend(ax, fontsize) #, xlabel=r"Wavevector $\mathbf{q}$")
#self.phbands.plot(ax=ax, points=points, show=False)
if (qpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None:
qpt_name = str(self.r.eph_fix_wavec)
qpt_name = k2s(self.r.eph_fix_wavec)
fig.suptitle(f"q = {qpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}")
@ -269,7 +284,7 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
"""
if self.r.eph_fix_korq == "k":
#yield self.ebands_kq.plot(show=False)
yield self.phbands.plot(show=False)
#yield self.phbands.plot(show=False)
yield self.plot_g_qpath()
if self.r.eph_fix_korq == "q":
@ -320,7 +335,11 @@ class GpathReader(BaseEphReader):
# Read important variables.
self.eph_fix_korq = self.read_string("eph_fix_korq")
if self.eph_fix_korq not in {"k", "q"}:
raise ValueError(f"Invalid value for {self.eph_fix_korq=}")
self.eph_fix_wavec = self.read_value("eph_fix_wavevec")
self.dbdb_add_lr = self.read_value("dvdb_add_lr")
#self.used_ftinterp = self.read_value("used_ftinterp")
#self.completed = self.read_value("gstore_completed")
# Note conversion Fortran --> C for the bstart index.
@ -594,57 +613,68 @@ class GpathRobot(Robot, RobotWithEbands):
"t05o_GPATH.nc",
])
.. rubric:: Inheritance Diagram
.. inheritance-diagram:: GstoreRobot
.. inheritance-diagram:: GpathRobot
"""
EXT = "GPATH"
#def neq(self, ref_basename: str | None = None, verbose: int = 0) -> int:
# """
# Compare all GPATHE.nc files stored in the robot
# """
# # Find reference gstore. By default the first file in the robot is used.
# ref_gstore = self._get_ref_abifile_from_basename(ref_basename)
@add_fig_kwargs
def plot_g_qpath(self, which_g="avg", gmax_mev=250, ph_modes=None,
colormap="jet", **kwargs) -> Figure:
"""
Compare the g-matrix along a q-path.
# exc_list = []
# ierr = 0
# for other_gstore in self.abifiles:
# if ref_gstore.filepath == other_gstore.filepath:
# continue
# print("Comparing: ", ref_gstore.basename, " with: ", other_gstore.basename)
# try:
# ierr += self._neq_two_gstores(ref_gstore, other_gstore, verbose)
# cprint("EQUAL", color="green")
# except Exception as exc:
# exc_list.append(str(exc))
Args
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
gmax_mev: Show results up to gmax in me
ph_modes: List of ph branch indices to show (start from 0). If None all modes are show.
colormap: Color map. Have a look at the colormaps here and decide which one you like:
http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
"""
nsppol, nq_path, natom3, eph_fix_wavec, eph_fix_korq = self.getattrs_alleq(
"nsppol", "nq_path", "natom3", "eph_fix_wavec", "eph_fix_korq"
)
xs = np.arange(nq_path)
# for exc in exc_list:
# cprint(exc, color="red")
nrows, ncols = 1, nsppol
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=False, sharey=False, squeeze=False)
cmap = plt.get_cmap(colormap)
# return ierr
# TODO: Compute common band range.
band_range = None
ref_ifile= 0
#q_label = r"$|q|^{%d}$" % with_qexp if with_qexp else ""
#g_units = "(meV)" if with_qexp == 0 else r"(meV $\AA^-{%s}$)" % with_qexp
#@staticmethod
#def _neq_two_gstores(gstore1: GstoreFile, gstore2: GstoreFile, verbose: int) -> int:
# """
# Helper function to compare two GSTORE files.
# """
# # These quantities must be the same to have a meaningfull comparison.
# aname_list = ["structure", "nsppol", "cplex", "nkbz", "nkibz",
# "nqbz", "nqibz", "completed", "kzone", "qzone", "kfilter", "gmode",
# "brange_spin", "erange_spin", "glob_spin_nq", "glob_nk_spin",
# ]
for spin in range(nsppol):
ax_cnt = 0
ax = ax_mat[ax_cnt, spin]
# for aname in aname_list:
# self._compare_attr_name(aname, gstore1, gstore2)
for ifile, gpath in enumerate(self.abifiles):
g_nuq_avg, g_nuq_raw = gpath.r.get_gnuq_average_spin(spin, band_range)
# Select ys according to which_g and multiply by facts_q
g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] # * facts_q[None,:]
# # Now compare the gkq objects for each spin.
# ierr = 0
# for spin in range(gstore1.nsppol):
# gqk1, gqk2 = gstore1.gqk_spin[spin], gstore2.gqk_spin[spin]
# ierr += gqk1.neq(gqk2, verbose)
for nu in range(natom3):
if ph_modes is not None and nu not in ph_modes: continue
color = cmap(nu / natom3)
if ifile == ref_ifile:
ax.scatter(xs, g_nuq[nu], color=color, label=f"{nu=}", marker="o")
gpath.phbands.decorate_ax(ax, units="meV")
#g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label)
#set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units))
else:
ax.plot(g_nuq[nu], color=color, label=f"{nu=}")
# return ierr
#if gmax_mev is not None and with_qexp == 0:
if gmax_mev is not None:
set_axlims(ax, [0, gmax_mev], "y")
return fig
#@add_fig_kwargs
#def plot_g_kpath(self, **kwargs) --> Figure
def yield_figs(self, **kwargs): # pragma: no cover
"""