Add robot.plot_egaps

This commit is contained in:
Matteo Giantomassi 2018-01-19 14:44:05 +01:00
parent e16ef710fa
commit d6e5425caf
5 changed files with 124 additions and 16 deletions

View File

@ -14,7 +14,8 @@ from functools import wraps
from monty.string import is_string, list_strings
from monty.termcolor import cprint
from abipy.core.mixins import NotebookWriter
from abipy.tools.plotting import plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt
from abipy.tools.plotting import (plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
rotate_ticklabels)
class Robot(NotebookWriter):
@ -755,6 +756,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
ax.grid(True)
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
ax.set_ylabel("%s" % self._get_label(item))
ax.legend(loc="best", fontsize=fontsize, shadow=True)
@ -778,6 +780,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
If callable, the output of hue(abifile) is used.
fontsize: legend and label fontsize.
kwargs: keyword arguments are passed to ax.plot
Returns: |matplotlib-Figure|
"""
@ -790,12 +793,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
sharex=True, sharey=False, squeeze=False)
ax_list = ax_list.ravel()
# Sort and read QP data.
# Sort and group files if hue.
if hue is None:
labels, ncfiles, params = self.sortby(sortby, unpack=True)
else:
groups = self.group_and_sortby(hue, sortby)
marker = kwargs.pop("marker", "o")
for i, (ax, item) in enumerate(zip(ax_list, items)):
if hue is None:
# Extract data.
@ -803,7 +807,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
yvals = [float(item(gsr)) for gsr in self.abifiles]
else:
yvals = [getattr(gsr, item) for gsr in self.abifiles]
ax.plot(params, yvals, marker="o")
ax.plot(params, yvals, marker=marker, **kwargs)
else:
for g in groups:
# Extract data.
@ -812,12 +816,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
else:
yvals = [getattr(gsr, item) for gsr in g.abifiles]
label = "%s: %s" % (self._get_label(hue), g.hvalue)
ax.plot(g.xvalues, yvals, label=label, marker="o")
ax.plot(g.xvalues, yvals, label=label, marker=marker, **kwargs)
ax.grid(True)
ax.set_ylabel(self._get_label(item))
if i == len(items) - 1:
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
if i == 0:
ax.legend(loc="best", fontsize=fontsize, shadow=True)
@ -895,9 +900,10 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=False, squeeze=False)
marker = kwargs.pop("marker", "o")
for i, (ax, item) in enumerate(zip(ax_list.ravel(), items)):
self.plot_convergence(item, sortby=sortby, hue=hue, ax=ax, fontsize=fontsize,
marker="o", show=False)
marker=marker, show=False)
if i != 0 and ax.legend():
ax.legend().set_visible(False)
if i != len(items) - 1 and ax.xaxis.label:

View File

@ -31,7 +31,8 @@ from abipy.core.kpoints import (Kpoint, KpointList, Kpath, IrredZone, KSamplingI
from abipy.core.structure import Structure
from abipy.iotools import ETSF_Reader
from abipy.tools import gaussian, duck
from abipy.tools.plotting import set_axlims, add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt
from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt,
get_ax3d_fig_plt, rotate_ticklabels)
import logging
logger = logging.getLogger(__name__)
@ -3762,8 +3763,94 @@ class RobotWithEbands(object):
#def get_ebands_dataframe(self, with_spglib=True):
# return dataframe_from_ebands(self.ncfiles, index=list(self.keys()), with_spglib=with_spglib)
@add_fig_kwargs
def plot_egaps(self, sortby=None, hue=None, fontsize=6, **kwargs):
"""
Plot the convergence of the direct and fundamental gaps
wrt to the ``sortby`` parameter. Values can optionally be grouped by ``hue``.
Args:
sortby: Define the convergence parameter, sort files and produce plot labels.
Can be None, string or function. If None, no sorting is performed.
If string and not empty it's assumed that the abifile has an attribute
with the same name and `getattr` is invoked.
If callable, the output of sortby(abifile) is used.
hue: Variable that define subsets of the data, which will be drawn on separate lines.
Accepts callable or string
If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked.
If callable, the output of hue(abifile) is used.
fontsize: legend and label fontsize.
Returns: |matplotlib-Figure|
"""
# Note: Handling nsppol > 1 and the case in which we have abifiles with different nsppol is a bit tricky
# hence we have to handle the different cases explicitly (see get_xy)
if len(self.abifiles) == 0: return None
max_nsppol = max(f.nsppol for f in self.abifiles)
items = ["fundamental_gaps", "direct_gaps", "bandwidths"]
def get_xy(item, spin, all_xvals, all_abifiles):
"""
Extract (xvals, yvals) from all_abifiles for given (item, spin) and initial all_xvals.
Here we handle the case in which we have files with different nsppol.
"""
xvals, yvals = [], []
for i, af in enumerate(all_abifiles):
if spin > af.nsppol - 1: continue
xvals.append(all_xvals[i])
if callable(item):
yy = float(item(af.ebands))
else:
yy = getattr(af.ebands, item)
if item in ("fundamental_gaps", "direct_gaps"):
yy = yy[spin].energy
else:
yy = yy[spin]
yvals.append(yy)
return xvals, yvals
# Build grid plot.
nrows, ncols = len(items), 1
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=False, squeeze=False)
ax_list = ax_list.ravel()
# Sort and group files if hue.
if hue is None:
labels, ncfiles, params = self.sortby(sortby, unpack=True)
else:
groups = self.group_and_sortby(hue, sortby)
marker_spin = {0: "^", 1: "v"}
for i, (ax, item) in enumerate(zip(ax_list, items)):
for spin in range(max_nsppol):
if hue is None:
# Extract data.
xvals, yvals = get_xy(item, spin, params, self.abifiles)
ax.plot(xvals, yvals, marker=marker_spin[spin], **kwargs)
else:
for g in groups:
# Extract data.
xvals, yvals = get_xy(item, spin, g.xvalues, g.abifiles)
label = "%s: %s" % (self._get_label(hue), g.hvalue)
ax.plot(xvals, yvals, label=label, marker=marker_spin[spin], **kwargs)
ax.grid(True)
ax.set_ylabel(self._get_label(item))
if i == len(items) - 1:
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
if i == 0:
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
def get_ebands_code_cells(self, title=None):
"""Return list of notebook cells. """
"""Return list of notebook cells."""
nbformat, nbv = self.get_nbformat_nbv()
title = "## Code to compare multiple ElectronBands objects" if title is None else str(title)
# Try not pollute namespace with lots of variables.
@ -3771,4 +3858,5 @@ class RobotWithEbands(object):
nbv.new_markdown_cell(title),
nbv.new_code_cell("robot.get_ebands_plotter().ipw_select_plot();"),
nbv.new_code_cell("robot.get_edos_plotter().ipw_select_plot();"),
nbv.new_code_cell("#robot.plot_egaps(sorby=None, hue=None);"),
]

View File

@ -185,8 +185,12 @@ class GstRobotTest(AbipyTest):
assert robot.combiplot_edos(show=False)
assert robot.gridplot_edos(show=False)
assert robot.plot_gsr_convergence(show=False)
assert robot.plot_gsr_convergence(sortby="nkpt", hue="tsmear", show=False)
assert robot.plot_egaps(show=False)
assert robot.plot_egaps(sortby="nkpt", hue="tsmear")
# Get pandas dataframe.
df = robot.get_dataframe()
assert "energy" in df

View File

@ -22,7 +22,7 @@ from monty.functools import lazy_property
from monty.termcolor import cprint
from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter
from abipy.core.kpoints import KpointList
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, rotate_ticklabels
from abipy.tools import duck
from abipy.electrons.ebands import ElectronsReader, RobotWithEbands
#from abipy.dfpt.phonons import PhononBands, RobotWithPhbands, factor_ev2units, unit_tag, dos_label_from_units
@ -1065,6 +1065,7 @@ class SigEPhRobot(Robot, RobotWithEbands):
ax.grid(True)
if ik == len(sigma_kpoints) - 1:
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
ax.set_ylabel("QP Direct gap [eV]")
ax.legend(loc="best", fontsize=fontsize, shadow=True)
@ -1139,6 +1140,7 @@ class SigEPhRobot(Robot, RobotWithEbands):
ax.set_ylabel(what)
if i == len(what_list) - 1:
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
if i == 0:
ax.legend(loc="best", fontsize=fontsize, shadow=True)

View File

@ -60,9 +60,19 @@ def set_axlims(ax, lims, axname):
return left, right
def rotate_ticklabels(ax, rotation, axname="x"):
"""Rotate the ticklables of axis ``ax``"""
if "x" in axname:
for tick in ax.get_xticklabels():
tick.set_rotation(rotation)
if "y" in axname:
for tick in ax.get_yticklabels():
tick.set_rotation(rotation)
def data_from_cplx_mode(cplx_mode, arr):
"""
Extract the data from the numpy array `arr` depending on the values of `cplx_mode`.
Extract the data from the numpy array ``arr`` depending on the values of ``cplx_mode``.
Args:
cplx_mode: Possible values in ("re", "im", "abs", "angle")
@ -86,19 +96,18 @@ def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None,
Useful for convergence tests done wrt to two parameters.
Args:
data: DataFrame containing columns `x`, `y`, and `hue`.
data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`.
x: Name of the column used as x-value
y: Name of the column used as y-value
hue: Variable that define subsets of the data, which will be drawn on separate lines
decimals: Number of decimal places to round `hue` columns. Ignore if None
ax: matplotlib :class:`Axes` or None if a new figure should be created.
ax: |matplotlib-Axes| or None if a new figure should be created.
xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)`
or scalar e.g. `left`. If left (right) is None, default values are used
or scalar e.g. `left`. If left (right) is None, default values are used
fontsize: Legend fontsize.
kwargs: Keywork arguments are passed to ax.plot method.
Returns:
`matplotlib` figure.
Returns: |matplotlib-Figure|
"""
# Check here because pandas messages are a bit criptic.
miss = [k for k in (x, y, hue) if k not in data]
@ -148,8 +157,7 @@ def plot_array(array, color_map=None, cplx_mode="abs", **kwargs):
"abs" means that the absolute value of the complex number is shown.
"angle" will display the phase of the complex number in radians.
Returns:
`matplotlib` figure.
Returns: |matplotlib-Figure|
"""
# Handle vectors
array = np.atleast_2d(array)