mirror of https://github.com/abinit/abipy.git
Add Robot.getattrs_alleq method
This commit is contained in:
parent
abc2c9af93
commit
7df0777550
|
@ -75,6 +75,7 @@ from abipy.eph.v1qavg import V1qAvgFile
|
|||
from abipy.eph.rta import RtaFile, RtaRobot
|
||||
from abipy.eph.transportfile import TransportFile
|
||||
from abipy.eph.gstore import GstoreFile
|
||||
from abipy.eph.gpath import GpathFile
|
||||
from abipy.wannier90 import WoutFile, AbiwanFile, AbiwanRobot
|
||||
from abipy.electrons.lobster import CoxpFile, ICoxpFile, LobsterDoscarFile, LobsterInput, LobsterAnalyzer
|
||||
|
||||
|
@ -173,6 +174,7 @@ abiext2ncfile = collections.OrderedDict([
|
|||
("A2F.nc", A2fFile),
|
||||
("SIGEPH.nc", SigEPhFile),
|
||||
("GSTORE.nc", GstoreFile),
|
||||
("GPATH.nc", GpathFile),
|
||||
("TRANSPORT.nc",TransportFile),
|
||||
("RTA.nc",RtaFile),
|
||||
("V1SYM.nc", V1symFile),
|
||||
|
|
|
@ -613,23 +613,50 @@ class Robot(NotebookWriter):
|
|||
"""Integration with jupyter_ notebooks."""
|
||||
return '<ol start="0">\n{}\n</ol>'.format("\n".join("<li>%s</li>" % label for label, abifile in self.items()))
|
||||
|
||||
def getattr_alleq(self, aname : str):
|
||||
def getattrs_alleq(self, *aname_args) -> list:
|
||||
"""
|
||||
Return the value of attribute aname.
|
||||
Return list of attribute values for each attribute name in *aname_args.
|
||||
"""
|
||||
return [self.getattr_alleq(aname) for aname in aname_args]
|
||||
|
||||
def getattr_alleq(self, aname: str):
|
||||
"""
|
||||
Return the value of attribute aname. Try firs in self then in self.r
|
||||
Raises ValueError if value is not the same across all the files in the robot.
|
||||
"""
|
||||
val1 = getattr(self.abifiles[0], aname)
|
||||
|
||||
for abifile in self.abifiles[1:]:
|
||||
val2 = getattr(abifile, aname)
|
||||
if isinstance(val1, (str, int, float)):
|
||||
eq = val1 == val2
|
||||
elif isinstance(val1, np.ndarray):
|
||||
eq = np.allclose(val1, val2)
|
||||
if not eq:
|
||||
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")
|
||||
def get_obj_list(what: str):
|
||||
if what == "abifiles":
|
||||
return self.abifiles
|
||||
elif what == "r":
|
||||
return [abifile.r for abifile in self.abifiles]
|
||||
|
||||
return val1
|
||||
raise ValueError(f"Invalid {what=}")
|
||||
|
||||
err_msg = []
|
||||
|
||||
for what in ["abifiles", "r"]:
|
||||
objs = get_obj_list(what)
|
||||
|
||||
try:
|
||||
val1 = getattr(objs[0], aname)
|
||||
except AttributeError as exc:
|
||||
err_msg.append(str(exc))
|
||||
continue
|
||||
|
||||
for obj in objs[1:]:
|
||||
val2 = getattr(obj, aname)
|
||||
if isinstance(val1, (str, int, float)):
|
||||
eq = val1 == val2
|
||||
elif isinstance(val1, np.ndarray):
|
||||
eq = np.allclose(val1, val2)
|
||||
if not eq:
|
||||
raise ValueError(f"Different values of {aname=}, {val1=}, {val2=}")
|
||||
|
||||
return val1
|
||||
|
||||
if err_msg:
|
||||
raise ValueError("\n".join(err_msg))
|
||||
|
||||
@property
|
||||
def abifiles(self) -> list:
|
||||
|
|
|
@ -18,7 +18,7 @@ from abipy.core.mixins import AbinitNcFile, Has_Structure, NotebookWriter
|
|||
from abipy.tools.typing import PathLike
|
||||
#from abipy.tools.numtools import nparr_to_df
|
||||
from abipy.tools.plotting import (add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, set_visible,
|
||||
rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend)
|
||||
rotate_ticklabels, ax_append_title, set_ax_xylabels, linestyles, Marker, set_grid_legend, set_axlims)
|
||||
from abipy.electrons.ebands import ElectronBands, RobotWithEbands
|
||||
from abipy.dfpt.phonons import PhononBands
|
||||
from abipy.dfpt.phtk import NonAnalyticalPh
|
||||
|
@ -27,6 +27,13 @@ from abipy.abio.robots import Robot
|
|||
from abipy.eph.common import BaseEphReader
|
||||
|
||||
|
||||
def k2s(k_vector, fmt=".3f", threshold = 1e-8) -> str:
|
||||
k_vector = np.asarray(k_vector)
|
||||
k_vector[np.abs(k_vector) < threshold] = 0
|
||||
|
||||
return "[" + ", ".join(f"{x:.3f}" for x in k_vector) + "]"
|
||||
|
||||
|
||||
class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
||||
"""
|
||||
This file stores the e-ph matrix elements along a k/q path
|
||||
|
@ -112,9 +119,23 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _get_which_g_list(which_g: str) -> list[str]:
|
||||
all_choices = ["avg", "raw"]
|
||||
if which_g == "all":
|
||||
return all_choices
|
||||
|
||||
if which_g not in all_choices:
|
||||
raise ValueError(f"Invalid {which=}, should be in {all_choices=}")
|
||||
|
||||
return [which_g]
|
||||
|
||||
def _get_band_range(self, band_range):
|
||||
return (self.r.bstart, self.r.bstop) if band_range is None else band_range
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1,
|
||||
with_phbands=True, with_ebands=False,
|
||||
def plot_g_qpath(self, band_range=None, which_g="avg", with_qexp: int=0, scale=1, gmax_mev=250,
|
||||
ph_modes=None, with_phbands=True, with_ebands=False,
|
||||
ax_mat=None, fontsize=8, **kwargs) -> Figure:
|
||||
"""
|
||||
Plot the averaged |g(k,q)| in meV units along the q-path
|
||||
|
@ -124,21 +145,22 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
|
||||
with_qexp: Multiply |g(q)| by |q|^{with_qexp}.
|
||||
scale: Scaling factor for the marker size used when with_phbands is True.
|
||||
gmax_mev: Show results up to gmax in meV.
|
||||
ph_modes: List of ph branch indices to show (start from 0). If None all modes are shown.
|
||||
with_phbands: False if phonon bands should now be displayed.
|
||||
with_ebands: False if electron bands should now be displayed.
|
||||
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
|
||||
fontsize: fontsize for legends and titles
|
||||
"""
|
||||
nrows, ncols = 1 + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol
|
||||
which_g_list = [which_g]
|
||||
if which_g == "all":
|
||||
which_g_list = ["avg", "raw"]
|
||||
nrows += 1
|
||||
which_g_list = self._get_which_g_list(which_g)
|
||||
nrows, ncols = len(which_g_list) + int((np.array([with_ebands, with_phbands]) == True).sum()), self.r.nsppol
|
||||
|
||||
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
|
||||
sharex=False, sharey=False, squeeze=False)
|
||||
marker_color = "gold"
|
||||
band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range
|
||||
band_range = self._get_band_range(band_range)
|
||||
|
||||
#facts_q, g_label, g_units = self.get_info(which_g, with_qexp)
|
||||
facts_q = np.ones(len(self.phbands.qpoints)) if with_qexp == 0 else \
|
||||
np.array([qpt.norm for qpt in self.phbands.qpoints]) ** with_qexp
|
||||
|
||||
|
@ -150,18 +172,22 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
ax_cnt = -1
|
||||
|
||||
for which_g in which_g_list:
|
||||
# Select data according to which_g and multiply by facts_q
|
||||
# Select ys according to which_g and multiply by facts_q
|
||||
g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] * facts_q[None,:]
|
||||
|
||||
# Plot g_nu(q)
|
||||
ax_cnt += 1
|
||||
ax = ax_mat[ax_cnt, spin]
|
||||
for nu in range(self.r.natom3):
|
||||
if ph_modes is not None and nu not in ph_modes: continue
|
||||
ax.plot(g_nuq[nu], label=f"{nu=}")
|
||||
|
||||
self.phbands.decorate_ax(ax, units="meV")
|
||||
g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label)
|
||||
set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units))
|
||||
|
||||
if gmax_mev is not None and with_qexp == 0:
|
||||
set_axlims(ax, [0, gmax_mev], "y")
|
||||
|
||||
if with_phbands:
|
||||
# Plot phonons bands + averaged g(q) as markers
|
||||
ax_cnt += 1
|
||||
|
@ -188,15 +214,15 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
|
||||
# Add title.
|
||||
if (kpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None:
|
||||
kpt_name = str(self.r.eph_fix_wavec)
|
||||
qpt_name = k2s(self.r.eph_fix_wavec)
|
||||
|
||||
fig.suptitle(f"k = {kpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}")
|
||||
|
||||
return fig
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_g_kpath(self, band_range=None, which_g="sym", scale=1, with_ebands=True,
|
||||
ax_mat=None, fontsize=8, **kwargs) -> Figure:
|
||||
def plot_g_kpath(self, band_range=None, which_g="avg", scale=1, gmax_mev=250, ph_modes=None,
|
||||
with_ebands=True, ax_mat=None, fontsize=8, **kwargs) -> Figure:
|
||||
"""
|
||||
Plot the averaged |g(k,q)| in meV units along the k-path
|
||||
|
||||
|
@ -204,60 +230,49 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
band_range: Band range that will be averaged over (python convention).
|
||||
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
|
||||
scale: Scaling factor for the marker size used when with_phbands is True.
|
||||
gmax_mev: Show results up to gmax in meV.
|
||||
ph_modes: List of ph branch indices to show (start from 0). If None all modes are show.
|
||||
with_ebands: False if electron bands should now be displayed.
|
||||
ax_mat: List of |matplotlib-Axes| or None if a new figure should be created.
|
||||
fontsize: fontsize for legends and titles
|
||||
"""
|
||||
nrows, ncols = 1 + int((np.array([with_ebands]) == True).sum()), self.r.nsppol
|
||||
which_g_list = [which_g]
|
||||
if which_g == "all":
|
||||
which_g_list = ["avg", "raw"]
|
||||
nrows += 1
|
||||
which_g_list = self._get_which_g_list(which_g)
|
||||
nrows, ncols = len(which_g_list) + int((np.array([with_ebands]) == True).sum()), self.r.nsppol
|
||||
|
||||
ax_mat, fig, plt = get_axarray_fig_plt(ax_mat, nrows=nrows, ncols=ncols,
|
||||
sharex=False, sharey=False, squeeze=False)
|
||||
|
||||
marker_color = "gold"
|
||||
band_range = (self.r.bstart, self.r.bstop) if band_range is None else band_range
|
||||
band_range = self._get_band_range(band_range)
|
||||
|
||||
for spin in range(self.r.nsppol):
|
||||
g_nuk_avg, g_nuk_raw = self.r.get_gnuk_average_spin(spin, band_range)
|
||||
ax_cnt = -1
|
||||
|
||||
for which_g in which_g_list:
|
||||
# Select data according to which_g
|
||||
# Select ys according to which_g
|
||||
g_nuk = dict(avg=g_nuk_avg, raw=g_nuk_raw)[which_g]
|
||||
|
||||
# Plot g_nu(q)
|
||||
ax_cnt += 1
|
||||
ax = ax_mat[ax_cnt, spin]
|
||||
for nu in range(self.r.natom3):
|
||||
if ph_modes is not None and nu not in ph_modes: continue
|
||||
ax.plot(g_nuk[nu], label=f"{which_g} {nu=}")
|
||||
|
||||
# Plot g(k)
|
||||
self.ebands_k.decorate_ax(ax, units="meV")
|
||||
set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}(\mathbf{k})|$ (meV)" % (which_g))
|
||||
set_grid_legend(ax, fontsize, ylabel=r"$|g^{\text{%s}}_{\mathbf{k}}|$ (meV)" % (which_g))
|
||||
if gmax_mev is not None:
|
||||
set_axlims(ax, [0, gmax_mev], "y")
|
||||
|
||||
if with_ebands:
|
||||
# Plot electron bands + averaged g(k) as markers
|
||||
# Plot electron bands
|
||||
ax_cnt += 1
|
||||
points = None
|
||||
#x, y, s = [], [], []
|
||||
#for ik, kpoint in enumerate(self.ebands_k.kpoints):
|
||||
# omegas_nu = self.phbands.phfreqs[iq,:]
|
||||
# for w, g2 in zip(omegas_nu, g_nuk[:,iq], strict=True):
|
||||
# x.append(iq); y.append(w); s.append(scale * g2)
|
||||
|
||||
#points = Marker(x, y, s, color=marker_color, edgecolors='gray', alpha=0.8,
|
||||
# label=r'$|g^{\text{avg}}(\mathbf{k})|$ (meV)')
|
||||
|
||||
ax = ax_mat[ax_cnt, spin]
|
||||
self.ebands_k.plot(ax=ax, spin=spin, band_range=band_range, with_gaps=False, show=False)
|
||||
set_grid_legend(ax, fontsize) #, xlabel=r"Wavevector $\mathbf{q}$")
|
||||
|
||||
#self.phbands.plot(ax=ax, points=points, show=False)
|
||||
|
||||
if (qpt_name := self.structure.findname_in_hsym_stars(self.r.eph_fix_wavec)) is None:
|
||||
qpt_name = str(self.r.eph_fix_wavec)
|
||||
qpt_name = k2s(self.r.eph_fix_wavec)
|
||||
|
||||
fig.suptitle(f"q = {qpt_name}" + f" m, n = {band_range[0]} - {band_range[1] - 1}")
|
||||
|
||||
|
@ -269,7 +284,7 @@ class GpathFile(AbinitNcFile, Has_Structure, NotebookWriter):
|
|||
"""
|
||||
if self.r.eph_fix_korq == "k":
|
||||
#yield self.ebands_kq.plot(show=False)
|
||||
yield self.phbands.plot(show=False)
|
||||
#yield self.phbands.plot(show=False)
|
||||
yield self.plot_g_qpath()
|
||||
|
||||
if self.r.eph_fix_korq == "q":
|
||||
|
@ -320,7 +335,11 @@ class GpathReader(BaseEphReader):
|
|||
|
||||
# Read important variables.
|
||||
self.eph_fix_korq = self.read_string("eph_fix_korq")
|
||||
if self.eph_fix_korq not in {"k", "q"}:
|
||||
raise ValueError(f"Invalid value for {self.eph_fix_korq=}")
|
||||
self.eph_fix_wavec = self.read_value("eph_fix_wavevec")
|
||||
self.dbdb_add_lr = self.read_value("dvdb_add_lr")
|
||||
#self.used_ftinterp = self.read_value("used_ftinterp")
|
||||
#self.completed = self.read_value("gstore_completed")
|
||||
|
||||
# Note conversion Fortran --> C for the bstart index.
|
||||
|
@ -594,57 +613,68 @@ class GpathRobot(Robot, RobotWithEbands):
|
|||
"t05o_GPATH.nc",
|
||||
])
|
||||
|
||||
|
||||
.. rubric:: Inheritance Diagram
|
||||
.. inheritance-diagram:: GstoreRobot
|
||||
.. inheritance-diagram:: GpathRobot
|
||||
"""
|
||||
EXT = "GPATH"
|
||||
|
||||
#def neq(self, ref_basename: str | None = None, verbose: int = 0) -> int:
|
||||
# """
|
||||
# Compare all GPATHE.nc files stored in the robot
|
||||
# """
|
||||
# # Find reference gstore. By default the first file in the robot is used.
|
||||
# ref_gstore = self._get_ref_abifile_from_basename(ref_basename)
|
||||
@add_fig_kwargs
|
||||
def plot_g_qpath(self, which_g="avg", gmax_mev=250, ph_modes=None,
|
||||
colormap="jet", **kwargs) -> Figure:
|
||||
"""
|
||||
Compare the g-matrix along a q-path.
|
||||
|
||||
# exc_list = []
|
||||
# ierr = 0
|
||||
# for other_gstore in self.abifiles:
|
||||
# if ref_gstore.filepath == other_gstore.filepath:
|
||||
# continue
|
||||
# print("Comparing: ", ref_gstore.basename, " with: ", other_gstore.basename)
|
||||
# try:
|
||||
# ierr += self._neq_two_gstores(ref_gstore, other_gstore, verbose)
|
||||
# cprint("EQUAL", color="green")
|
||||
# except Exception as exc:
|
||||
# exc_list.append(str(exc))
|
||||
Args
|
||||
which_g: "avg" to plot the symmetrized |g|, "raw" for unsymmetrized |g|."all" for both.
|
||||
gmax_mev: Show results up to gmax in me
|
||||
ph_modes: List of ph branch indices to show (start from 0). If None all modes are show.
|
||||
colormap: Color map. Have a look at the colormaps here and decide which one you like:
|
||||
http://matplotlib.sourceforge.net/examples/pylab_examples/show_colormaps.html
|
||||
"""
|
||||
nsppol, nq_path, natom3, eph_fix_wavec, eph_fix_korq = self.getattrs_alleq(
|
||||
"nsppol", "nq_path", "natom3", "eph_fix_wavec", "eph_fix_korq"
|
||||
)
|
||||
xs = np.arange(nq_path)
|
||||
|
||||
# for exc in exc_list:
|
||||
# cprint(exc, color="red")
|
||||
nrows, ncols = 1, nsppol
|
||||
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
|
||||
sharex=False, sharey=False, squeeze=False)
|
||||
cmap = plt.get_cmap(colormap)
|
||||
|
||||
# return ierr
|
||||
# TODO: Compute common band range.
|
||||
band_range = None
|
||||
ref_ifile= 0
|
||||
#q_label = r"$|q|^{%d}$" % with_qexp if with_qexp else ""
|
||||
#g_units = "(meV)" if with_qexp == 0 else r"(meV $\AA^-{%s}$)" % with_qexp
|
||||
|
||||
#@staticmethod
|
||||
#def _neq_two_gstores(gstore1: GstoreFile, gstore2: GstoreFile, verbose: int) -> int:
|
||||
# """
|
||||
# Helper function to compare two GSTORE files.
|
||||
# """
|
||||
# # These quantities must be the same to have a meaningfull comparison.
|
||||
# aname_list = ["structure", "nsppol", "cplex", "nkbz", "nkibz",
|
||||
# "nqbz", "nqibz", "completed", "kzone", "qzone", "kfilter", "gmode",
|
||||
# "brange_spin", "erange_spin", "glob_spin_nq", "glob_nk_spin",
|
||||
# ]
|
||||
for spin in range(nsppol):
|
||||
ax_cnt = 0
|
||||
ax = ax_mat[ax_cnt, spin]
|
||||
|
||||
# for aname in aname_list:
|
||||
# self._compare_attr_name(aname, gstore1, gstore2)
|
||||
for ifile, gpath in enumerate(self.abifiles):
|
||||
g_nuq_avg, g_nuq_raw = gpath.r.get_gnuq_average_spin(spin, band_range)
|
||||
# Select ys according to which_g and multiply by facts_q
|
||||
g_nuq = dict(avg=g_nuq_avg, raw=g_nuq_raw)[which_g] # * facts_q[None,:]
|
||||
|
||||
# # Now compare the gkq objects for each spin.
|
||||
# ierr = 0
|
||||
# for spin in range(gstore1.nsppol):
|
||||
# gqk1, gqk2 = gstore1.gqk_spin[spin], gstore2.gqk_spin[spin]
|
||||
# ierr += gqk1.neq(gqk2, verbose)
|
||||
for nu in range(natom3):
|
||||
if ph_modes is not None and nu not in ph_modes: continue
|
||||
color = cmap(nu / natom3)
|
||||
if ifile == ref_ifile:
|
||||
ax.scatter(xs, g_nuq[nu], color=color, label=f"{nu=}", marker="o")
|
||||
gpath.phbands.decorate_ax(ax, units="meV")
|
||||
#g_label = r"$|g^{\text{%s}}_{\mathbf{q}}|$ %s" % (which_g, q_label)
|
||||
#set_grid_legend(ax, fontsize, ylabel="%s %s" % (g_label, g_units))
|
||||
else:
|
||||
ax.plot(g_nuq[nu], color=color, label=f"{nu=}")
|
||||
|
||||
# return ierr
|
||||
#if gmax_mev is not None and with_qexp == 0:
|
||||
if gmax_mev is not None:
|
||||
set_axlims(ax, [0, gmax_mev], "y")
|
||||
|
||||
return fig
|
||||
|
||||
#@add_fig_kwargs
|
||||
#def plot_g_kpath(self, **kwargs) --> Figure
|
||||
|
||||
def yield_figs(self, **kwargs): # pragma: no cover
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue