mirror of https://github.com/abinit/abipy.git
Add robot.plot_egaps
This commit is contained in:
parent
e16ef710fa
commit
d6e5425caf
|
@ -14,7 +14,8 @@ from functools import wraps
|
|||
from monty.string import is_string, list_strings
|
||||
from monty.termcolor import cprint
|
||||
from abipy.core.mixins import NotebookWriter
|
||||
from abipy.tools.plotting import plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
|
||||
from abipy.tools.plotting import (plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
|
||||
rotate_ticklabels)
|
||||
|
||||
|
||||
class Robot(NotebookWriter):
|
||||
|
@ -755,6 +756,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
|
||||
ax.grid(True)
|
||||
ax.set_xlabel("%s" % self._get_label(sortby))
|
||||
if sortby is None: rotate_ticklabels(ax, 15)
|
||||
ax.set_ylabel("%s" % self._get_label(item))
|
||||
ax.legend(loc="best", fontsize=fontsize, shadow=True)
|
||||
|
||||
|
@ -778,6 +780,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
|
||||
If callable, the output of hue(abifile) is used.
|
||||
fontsize: legend and label fontsize.
|
||||
kwargs: keyword arguments are passed to ax.plot
|
||||
|
||||
Returns: |matplotlib-Figure|
|
||||
"""
|
||||
|
@ -790,12 +793,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
sharex=True, sharey=False, squeeze=False)
|
||||
ax_list = ax_list.ravel()
|
||||
|
||||
# Sort and read QP data.
|
||||
# Sort and group files if hue.
|
||||
if hue is None:
|
||||
labels, ncfiles, params = self.sortby(sortby, unpack=True)
|
||||
else:
|
||||
groups = self.group_and_sortby(hue, sortby)
|
||||
|
||||
marker = kwargs.pop("marker", "o")
|
||||
for i, (ax, item) in enumerate(zip(ax_list, items)):
|
||||
if hue is None:
|
||||
# Extract data.
|
||||
|
@ -803,7 +807,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
yvals = [float(item(gsr)) for gsr in self.abifiles]
|
||||
else:
|
||||
yvals = [getattr(gsr, item) for gsr in self.abifiles]
|
||||
ax.plot(params, yvals, marker="o")
|
||||
ax.plot(params, yvals, marker=marker, **kwargs)
|
||||
else:
|
||||
for g in groups:
|
||||
# Extract data.
|
||||
|
@ -812,12 +816,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
else:
|
||||
yvals = [getattr(gsr, item) for gsr in g.abifiles]
|
||||
label = "%s: %s" % (self._get_label(hue), g.hvalue)
|
||||
ax.plot(g.xvalues, yvals, label=label, marker="o")
|
||||
ax.plot(g.xvalues, yvals, label=label, marker=marker, **kwargs)
|
||||
|
||||
ax.grid(True)
|
||||
ax.set_ylabel(self._get_label(item))
|
||||
if i == len(items) - 1:
|
||||
ax.set_xlabel("%s" % self._get_label(sortby))
|
||||
if sortby is None: rotate_ticklabels(ax, 15)
|
||||
if i == 0:
|
||||
ax.legend(loc="best", fontsize=fontsize, shadow=True)
|
||||
|
||||
|
@ -895,9 +900,10 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
|
|||
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
|
||||
sharex=True, sharey=False, squeeze=False)
|
||||
|
||||
marker = kwargs.pop("marker", "o")
|
||||
for i, (ax, item) in enumerate(zip(ax_list.ravel(), items)):
|
||||
self.plot_convergence(item, sortby=sortby, hue=hue, ax=ax, fontsize=fontsize,
|
||||
marker="o", show=False)
|
||||
marker=marker, show=False)
|
||||
if i != 0 and ax.legend():
|
||||
ax.legend().set_visible(False)
|
||||
if i != len(items) - 1 and ax.xaxis.label:
|
||||
|
|
|
@ -31,7 +31,8 @@ from abipy.core.kpoints import (Kpoint, KpointList, Kpath, IrredZone, KSamplingI
|
|||
from abipy.core.structure import Structure
|
||||
from abipy.iotools import ETSF_Reader
|
||||
from abipy.tools import gaussian, duck
|
||||
from abipy.tools.plotting import set_axlims, add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt
|
||||
from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
|
||||
get_ax3d_fig_plt, rotate_ticklabels)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -3762,8 +3763,94 @@ class RobotWithEbands(object):
|
|||
#def get_ebands_dataframe(self, with_spglib=True):
|
||||
# return dataframe_from_ebands(self.ncfiles, index=list(self.keys()), with_spglib=with_spglib)
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_egaps(self, sortby=None, hue=None, fontsize=6, **kwargs):
|
||||
"""
|
||||
Plot the convergence of the direct and fundamental gaps
|
||||
wrt to the ``sortby`` parameter. Values can optionally be grouped by ``hue``.
|
||||
|
||||
Args:
|
||||
sortby: Define the convergence parameter, sort files and produce plot labels.
|
||||
Can be None, string or function. If None, no sorting is performed.
|
||||
If string and not empty it's assumed that the abifile has an attribute
|
||||
with the same name and `getattr` is invoked.
|
||||
If callable, the output of sortby(abifile) is used.
|
||||
hue: Variable that define subsets of the data, which will be drawn on separate lines.
|
||||
Accepts callable or string
|
||||
If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
|
||||
If callable, the output of hue(abifile) is used.
|
||||
fontsize: legend and label fontsize.
|
||||
|
||||
Returns: |matplotlib-Figure|
|
||||
"""
|
||||
# Note: Handling nsppol > 1 and the case in which we have abifiles with different nsppol is a bit tricky
|
||||
# hence we have to handle the different cases explicitly (see get_xy)
|
||||
if len(self.abifiles) == 0: return None
|
||||
max_nsppol = max(f.nsppol for f in self.abifiles)
|
||||
|
||||
items = ["fundamental_gaps", "direct_gaps", "bandwidths"]
|
||||
|
||||
def get_xy(item, spin, all_xvals, all_abifiles):
|
||||
"""
|
||||
Extract (xvals, yvals) from all_abifiles for given (item, spin) and initial all_xvals.
|
||||
Here we handle the case in which we have files with different nsppol.
|
||||
"""
|
||||
xvals, yvals = [], []
|
||||
|
||||
for i, af in enumerate(all_abifiles):
|
||||
if spin > af.nsppol - 1: continue
|
||||
xvals.append(all_xvals[i])
|
||||
if callable(item):
|
||||
yy = float(item(af.ebands))
|
||||
else:
|
||||
yy = getattr(af.ebands, item)
|
||||
if item in ("fundamental_gaps", "direct_gaps"):
|
||||
yy = yy[spin].energy
|
||||
else:
|
||||
yy = yy[spin]
|
||||
|
||||
yvals.append(yy)
|
||||
|
||||
return xvals, yvals
|
||||
|
||||
# Build grid plot.
|
||||
nrows, ncols = len(items), 1
|
||||
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
|
||||
sharex=True, sharey=False, squeeze=False)
|
||||
ax_list = ax_list.ravel()
|
||||
|
||||
# Sort and group files if hue.
|
||||
if hue is None:
|
||||
labels, ncfiles, params = self.sortby(sortby, unpack=True)
|
||||
else:
|
||||
groups = self.group_and_sortby(hue, sortby)
|
||||
|
||||
marker_spin = {0: "^", 1: "v"}
|
||||
for i, (ax, item) in enumerate(zip(ax_list, items)):
|
||||
for spin in range(max_nsppol):
|
||||
if hue is None:
|
||||
# Extract data.
|
||||
xvals, yvals = get_xy(item, spin, params, self.abifiles)
|
||||
ax.plot(xvals, yvals, marker=marker_spin[spin], **kwargs)
|
||||
else:
|
||||
for g in groups:
|
||||
# Extract data.
|
||||
xvals, yvals = get_xy(item, spin, g.xvalues, g.abifiles)
|
||||
label = "%s: %s" % (self._get_label(hue), g.hvalue)
|
||||
ax.plot(xvals, yvals, label=label, marker=marker_spin[spin], **kwargs)
|
||||
|
||||
ax.grid(True)
|
||||
ax.set_ylabel(self._get_label(item))
|
||||
if i == len(items) - 1:
|
||||
ax.set_xlabel("%s" % self._get_label(sortby))
|
||||
if sortby is None: rotate_ticklabels(ax, 15)
|
||||
if i == 0:
|
||||
ax.legend(loc="best", fontsize=fontsize, shadow=True)
|
||||
|
||||
return fig
|
||||
|
||||
def get_ebands_code_cells(self, title=None):
|
||||
"""Return list of notebook cells. """
|
||||
"""Return list of notebook cells."""
|
||||
nbformat, nbv = self.get_nbformat_nbv()
|
||||
title = "## Code to compare multiple ElectronBands objects" if title is None else str(title)
|
||||
# Try not pollute namespace with lots of variables.
|
||||
|
@ -3771,4 +3858,5 @@ class RobotWithEbands(object):
|
|||
nbv.new_markdown_cell(title),
|
||||
nbv.new_code_cell("robot.get_ebands_plotter().ipw_select_plot();"),
|
||||
nbv.new_code_cell("robot.get_edos_plotter().ipw_select_plot();"),
|
||||
nbv.new_code_cell("#robot.plot_egaps(sorby=None, hue=None);"),
|
||||
]
|
||||
|
|
|
@ -185,8 +185,12 @@ class GstRobotTest(AbipyTest):
|
|||
assert robot.combiplot_edos(show=False)
|
||||
assert robot.gridplot_edos(show=False)
|
||||
|
||||
assert robot.plot_gsr_convergence(show=False)
|
||||
assert robot.plot_gsr_convergence(sortby="nkpt", hue="tsmear", show=False)
|
||||
|
||||
assert robot.plot_egaps(show=False)
|
||||
assert robot.plot_egaps(sortby="nkpt", hue="tsmear")
|
||||
|
||||
# Get pandas dataframe.
|
||||
df = robot.get_dataframe()
|
||||
assert "energy" in df
|
||||
|
|
|
@ -22,7 +22,7 @@ from monty.functools import lazy_property
|
|||
from monty.termcolor import cprint
|
||||
from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter
|
||||
from abipy.core.kpoints import KpointList
|
||||
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims
|
||||
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, rotate_ticklabels
|
||||
from abipy.tools import duck
|
||||
from abipy.electrons.ebands import ElectronsReader, RobotWithEbands
|
||||
#from abipy.dfpt.phonons import PhononBands, RobotWithPhbands, factor_ev2units, unit_tag, dos_label_from_units
|
||||
|
@ -1065,6 +1065,7 @@ class SigEPhRobot(Robot, RobotWithEbands):
|
|||
ax.grid(True)
|
||||
if ik == len(sigma_kpoints) - 1:
|
||||
ax.set_xlabel("%s" % self._get_label(sortby))
|
||||
if sortby is None: rotate_ticklabels(ax, 15)
|
||||
ax.set_ylabel("QP Direct gap [eV]")
|
||||
ax.legend(loc="best", fontsize=fontsize, shadow=True)
|
||||
|
||||
|
@ -1139,6 +1140,7 @@ class SigEPhRobot(Robot, RobotWithEbands):
|
|||
ax.set_ylabel(what)
|
||||
if i == len(what_list) - 1:
|
||||
ax.set_xlabel("%s" % self._get_label(sortby))
|
||||
if sortby is None: rotate_ticklabels(ax, 15)
|
||||
if i == 0:
|
||||
ax.legend(loc="best", fontsize=fontsize, shadow=True)
|
||||
|
||||
|
|
|
@ -60,9 +60,19 @@ def set_axlims(ax, lims, axname):
|
|||
return left, right
|
||||
|
||||
|
||||
def rotate_ticklabels(ax, rotation, axname="x"):
|
||||
"""Rotate the ticklables of axis ``ax``"""
|
||||
if "x" in axname:
|
||||
for tick in ax.get_xticklabels():
|
||||
tick.set_rotation(rotation)
|
||||
if "y" in axname:
|
||||
for tick in ax.get_yticklabels():
|
||||
tick.set_rotation(rotation)
|
||||
|
||||
|
||||
def data_from_cplx_mode(cplx_mode, arr):
|
||||
"""
|
||||
Extract the data from the numpy array `arr` depending on the values of `cplx_mode`.
|
||||
Extract the data from the numpy array ``arr`` depending on the values of ``cplx_mode``.
|
||||
|
||||
Args:
|
||||
cplx_mode: Possible values in ("re", "im", "abs", "angle")
|
||||
|
@ -86,19 +96,18 @@ def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
|
|||
Useful for convergence tests done wrt to two parameters.
|
||||
|
||||
Args:
|
||||
data: DataFrame containing columns `x`, `y`, and `hue`.
|
||||
data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
|
||||
x: Name of the column used as x-value
|
||||
y: Name of the column used as y-value
|
||||
hue: Variable that define subsets of the data, which will be drawn on separate lines
|
||||
decimals: Number of decimal places to round `hue` columns. Ignore if None
|
||||
ax: matplotlib :class:`Axes` or None if a new figure should be created.
|
||||
ax: |matplotlib-Axes| or None if a new figure should be created.
|
||||
xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
|
||||
or scalar e.g. `left`. If left (right) is None, default values are used
|
||||
or scalar e.g. `left`. If left (right) is None, default values are used
|
||||
fontsize: Legend fontsize.
|
||||
kwargs: Keywork arguments are passed to ax.plot method.
|
||||
|
||||
Returns:
|
||||
`matplotlib` figure.
|
||||
Returns: |matplotlib-Figure|
|
||||
"""
|
||||
# Check here because pandas messages are a bit criptic.
|
||||
miss = [k for k in (x, y, hue) if k not in data]
|
||||
|
@ -148,8 +157,7 @@ def plot_array(array, color_map=None, cplx_mode="abs", **kwargs):
|
|||
"abs" means that the absolute value of the complex number is shown.
|
||||
"angle" will display the phase of the complex number in radians.
|
||||
|
||||
Returns:
|
||||
`matplotlib` figure.
|
||||
Returns: |matplotlib-Figure|
|
||||
"""
|
||||
# Handle vectors
|
||||
array = np.atleast_2d(array)
|
||||
|
|
Loading…
Reference in New Issue