diff --git a/abipy/abio/robots.py b/abipy/abio/robots.py index 15284431..c8cced3f 100644 --- a/abipy/abio/robots.py +++ b/abipy/abio/robots.py @@ -14,7 +14,8 @@ from functools import wraps from monty.string import is_string, list_strings from monty.termcolor import cprint from abipy.core.mixins import NotebookWriter -from abipy.tools.plotting import plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt +from abipy.tools.plotting import (plot_xy_with_hue, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, + rotate_ticklabels) class Robot(NotebookWriter): @@ -755,6 +756,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ ax.grid(True) ax.set_xlabel("%s" % self._get_label(sortby)) + if sortby is None: rotate_ticklabels(ax, 15) ax.set_ylabel("%s" % self._get_label(item)) ax.legend(loc="best", fontsize=fontsize, shadow=True) @@ -778,6 +780,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked. If callable, the output of hue(abifile) is used. fontsize: legend and label fontsize. + kwargs: keyword arguments are passed to ax.plot Returns: |matplotlib-Figure| """ @@ -790,12 +793,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ sharex=True, sharey=False, squeeze=False) ax_list = ax_list.ravel() - # Sort and read QP data. + # Sort and group files if hue. if hue is None: labels, ncfiles, params = self.sortby(sortby, unpack=True) else: groups = self.group_and_sortby(hue, sortby) + marker = kwargs.pop("marker", "o") for i, (ax, item) in enumerate(zip(ax_list, items)): if hue is None: # Extract data. @@ -803,7 +807,7 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ yvals = [float(item(gsr)) for gsr in self.abifiles] else: yvals = [getattr(gsr, item) for gsr in self.abifiles] - ax.plot(params, yvals, marker="o") + ax.plot(params, yvals, marker=marker, **kwargs) else: for g in groups: # Extract data. @@ -812,12 +816,13 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ else: yvals = [getattr(gsr, item) for gsr in g.abifiles] label = "%s: %s" % (self._get_label(hue), g.hvalue) - ax.plot(g.xvalues, yvals, label=label, marker="o") + ax.plot(g.xvalues, yvals, label=label, marker=marker, **kwargs) ax.grid(True) ax.set_ylabel(self._get_label(item)) if i == len(items) - 1: ax.set_xlabel("%s" % self._get_label(sortby)) + if sortby is None: rotate_ticklabels(ax, 15) if i == 0: ax.legend(loc="best", fontsize=fontsize, shadow=True) @@ -895,9 +900,10 @@ Not all entries are sortable (Please select number-like quantities)""" % (self._ ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, sharex=True, sharey=False, squeeze=False) + marker = kwargs.pop("marker", "o") for i, (ax, item) in enumerate(zip(ax_list.ravel(), items)): self.plot_convergence(item, sortby=sortby, hue=hue, ax=ax, fontsize=fontsize, - marker="o", show=False) + marker=marker, show=False) if i != 0 and ax.legend(): ax.legend().set_visible(False) if i != len(items) - 1 and ax.xaxis.label: diff --git a/abipy/electrons/ebands.py b/abipy/electrons/ebands.py index ce6478b5..80ca9251 100644 --- a/abipy/electrons/ebands.py +++ b/abipy/electrons/ebands.py @@ -31,7 +31,8 @@ from abipy.core.kpoints import (Kpoint, KpointList, Kpath, IrredZone, KSamplingI from abipy.core.structure import Structure from abipy.iotools import ETSF_Reader from abipy.tools import gaussian, duck -from abipy.tools.plotting import set_axlims, add_fig_kwargs, get_ax_fig_plt, get_ax3d_fig_plt +from abipy.tools.plotting import (set_axlims, add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, + get_ax3d_fig_plt, rotate_ticklabels) import logging logger = logging.getLogger(__name__) @@ -3762,8 +3763,94 @@ class RobotWithEbands(object): #def get_ebands_dataframe(self, with_spglib=True): # return dataframe_from_ebands(self.ncfiles, index=list(self.keys()), with_spglib=with_spglib) + @add_fig_kwargs + def plot_egaps(self, sortby=None, hue=None, fontsize=6, **kwargs): + """ + Plot the convergence of the direct and fundamental gaps + wrt to the ``sortby`` parameter. Values can optionally be grouped by ``hue``. + + Args: + sortby: Define the convergence parameter, sort files and produce plot labels. + Can be None, string or function. If None, no sorting is performed. + If string and not empty it's assumed that the abifile has an attribute + with the same name and `getattr` is invoked. + If callable, the output of sortby(abifile) is used. + hue: Variable that define subsets of the data, which will be drawn on separate lines. + Accepts callable or string + If string, it's assumed that the abifile has an attribute with the same name and getattr is invoked. + If callable, the output of hue(abifile) is used. + fontsize: legend and label fontsize. + + Returns: |matplotlib-Figure| + """ + # Note: Handling nsppol > 1 and the case in which we have abifiles with different nsppol is a bit tricky + # hence we have to handle the different cases explicitly (see get_xy) + if len(self.abifiles) == 0: return None + max_nsppol = max(f.nsppol for f in self.abifiles) + + items = ["fundamental_gaps", "direct_gaps", "bandwidths"] + + def get_xy(item, spin, all_xvals, all_abifiles): + """ + Extract (xvals, yvals) from all_abifiles for given (item, spin) and initial all_xvals. + Here we handle the case in which we have files with different nsppol. + """ + xvals, yvals = [], [] + + for i, af in enumerate(all_abifiles): + if spin > af.nsppol - 1: continue + xvals.append(all_xvals[i]) + if callable(item): + yy = float(item(af.ebands)) + else: + yy = getattr(af.ebands, item) + if item in ("fundamental_gaps", "direct_gaps"): + yy = yy[spin].energy + else: + yy = yy[spin] + + yvals.append(yy) + + return xvals, yvals + + # Build grid plot. + nrows, ncols = len(items), 1 + ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols, + sharex=True, sharey=False, squeeze=False) + ax_list = ax_list.ravel() + + # Sort and group files if hue. + if hue is None: + labels, ncfiles, params = self.sortby(sortby, unpack=True) + else: + groups = self.group_and_sortby(hue, sortby) + + marker_spin = {0: "^", 1: "v"} + for i, (ax, item) in enumerate(zip(ax_list, items)): + for spin in range(max_nsppol): + if hue is None: + # Extract data. + xvals, yvals = get_xy(item, spin, params, self.abifiles) + ax.plot(xvals, yvals, marker=marker_spin[spin], **kwargs) + else: + for g in groups: + # Extract data. + xvals, yvals = get_xy(item, spin, g.xvalues, g.abifiles) + label = "%s: %s" % (self._get_label(hue), g.hvalue) + ax.plot(xvals, yvals, label=label, marker=marker_spin[spin], **kwargs) + + ax.grid(True) + ax.set_ylabel(self._get_label(item)) + if i == len(items) - 1: + ax.set_xlabel("%s" % self._get_label(sortby)) + if sortby is None: rotate_ticklabels(ax, 15) + if i == 0: + ax.legend(loc="best", fontsize=fontsize, shadow=True) + + return fig + def get_ebands_code_cells(self, title=None): - """Return list of notebook cells. """ + """Return list of notebook cells.""" nbformat, nbv = self.get_nbformat_nbv() title = "## Code to compare multiple ElectronBands objects" if title is None else str(title) # Try not pollute namespace with lots of variables. @@ -3771,4 +3858,5 @@ class RobotWithEbands(object): nbv.new_markdown_cell(title), nbv.new_code_cell("robot.get_ebands_plotter().ipw_select_plot();"), nbv.new_code_cell("robot.get_edos_plotter().ipw_select_plot();"), + nbv.new_code_cell("#robot.plot_egaps(sorby=None, hue=None);"), ] diff --git a/abipy/electrons/tests/test_gsr.py b/abipy/electrons/tests/test_gsr.py index 45ddc189..1d2f7342 100644 --- a/abipy/electrons/tests/test_gsr.py +++ b/abipy/electrons/tests/test_gsr.py @@ -185,8 +185,12 @@ class GstRobotTest(AbipyTest): assert robot.combiplot_edos(show=False) assert robot.gridplot_edos(show=False) + assert robot.plot_gsr_convergence(show=False) assert robot.plot_gsr_convergence(sortby="nkpt", hue="tsmear", show=False) + assert robot.plot_egaps(show=False) + assert robot.plot_egaps(sortby="nkpt", hue="tsmear") + # Get pandas dataframe. df = robot.get_dataframe() assert "energy" in df diff --git a/abipy/eph/sigeph.py b/abipy/eph/sigeph.py index ef340e3a..2eb4ee15 100644 --- a/abipy/eph/sigeph.py +++ b/abipy/eph/sigeph.py @@ -22,7 +22,7 @@ from monty.functools import lazy_property from monty.termcolor import cprint from abipy.core.mixins import AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter from abipy.core.kpoints import KpointList -from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims +from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt, set_axlims, rotate_ticklabels from abipy.tools import duck from abipy.electrons.ebands import ElectronsReader, RobotWithEbands #from abipy.dfpt.phonons import PhononBands, RobotWithPhbands, factor_ev2units, unit_tag, dos_label_from_units @@ -1065,6 +1065,7 @@ class SigEPhRobot(Robot, RobotWithEbands): ax.grid(True) if ik == len(sigma_kpoints) - 1: ax.set_xlabel("%s" % self._get_label(sortby)) + if sortby is None: rotate_ticklabels(ax, 15) ax.set_ylabel("QP Direct gap [eV]") ax.legend(loc="best", fontsize=fontsize, shadow=True) @@ -1139,6 +1140,7 @@ class SigEPhRobot(Robot, RobotWithEbands): ax.set_ylabel(what) if i == len(what_list) - 1: ax.set_xlabel("%s" % self._get_label(sortby)) + if sortby is None: rotate_ticklabels(ax, 15) if i == 0: ax.legend(loc="best", fontsize=fontsize, shadow=True) diff --git a/abipy/tools/plotting.py b/abipy/tools/plotting.py index 4152f4e4..e5ab6ca2 100644 --- a/abipy/tools/plotting.py +++ b/abipy/tools/plotting.py @@ -60,9 +60,19 @@ def set_axlims(ax, lims, axname): return left, right +def rotate_ticklabels(ax, rotation, axname="x"): + """Rotate the ticklables of axis ``ax``""" + if "x" in axname: + for tick in ax.get_xticklabels(): + tick.set_rotation(rotation) + if "y" in axname: + for tick in ax.get_yticklabels(): + tick.set_rotation(rotation) + + def data_from_cplx_mode(cplx_mode, arr): """ - Extract the data from the numpy array `arr` depending on the values of `cplx_mode`. + Extract the data from the numpy array ``arr`` depending on the values of ``cplx_mode``. Args: cplx_mode: Possible values in ("re", "im", "abs", "angle") @@ -86,19 +96,18 @@ def plot_xy_with_hue(data, x, y, hue, decimals=None, ax=None, Useful for convergence tests done wrt to two parameters. Args: - data: DataFrame containing columns `x`, `y`, and `hue`. + data: |pandas-DataFrame| containing columns `x`, `y`, and `hue`. x: Name of the column used as x-value y: Name of the column used as y-value hue: Variable that define subsets of the data, which will be drawn on separate lines decimals: Number of decimal places to round `hue` columns. Ignore if None - ax: matplotlib :class:`Axes` or None if a new figure should be created. + ax: |matplotlib-Axes| or None if a new figure should be created. xlims ylims: Set the data limits for the x(y)-axis. Accept tuple e.g. `(left, right)` - or scalar e.g. `left`. If left (right) is None, default values are used + or scalar e.g. `left`. If left (right) is None, default values are used fontsize: Legend fontsize. kwargs: Keywork arguments are passed to ax.plot method. - Returns: - `matplotlib` figure. + Returns: |matplotlib-Figure| """ # Check here because pandas messages are a bit criptic. miss = [k for k in (x, y, hue) if k not in data] @@ -148,8 +157,7 @@ def plot_array(array, color_map=None, cplx_mode="abs", **kwargs): "abs" means that the absolute value of the complex number is shown. "angle" will display the phase of the complex number in radians. - Returns: - `matplotlib` figure. + Returns: |matplotlib-Figure| """ # Handle vectors array = np.atleast_2d(array)