Fix sphinx.util.compat ImportError

This commit is contained in:
Matteo Giantomassi 2018-02-14 12:48:34 +01:00
parent 48b609b110
commit 28e96c43a7
8 changed files with 44 additions and 29 deletions

View File

@ -858,7 +858,9 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
ax.set_ylabel("%s" % self._get_label(item))
ax.legend(loc="best", fontsize=fontsize, shadow=True)
if hue is not None:
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
@ -907,7 +909,7 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
if callable(item):
yvals = [float(item(gsr)) for gsr in self.abifiles]
else:
yvals = [getattr(gsr, item) for gsr in self.abifiles]
yvals = [getattrd(gsr, item) for gsr in self.abifiles]
ax.plot(params, yvals, marker=marker, **kwargs)
else:
for g in groups:
@ -915,7 +917,7 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
if callable(item):
yvals = [float(item(gsr)) for gsr in g.abifiles]
else:
yvals = [getattr(gsr, item) for gsr in g.abifiles]
yvals = [getattrd(gsr, item) for gsr in g.abifiles]
label = "%s: %s" % (self._get_label(hue), g.hvalue)
ax.plot(g.xvalues, yvals, label=label, marker=marker, **kwargs)
@ -924,18 +926,20 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
if i == len(items) - 1:
ax.set_xlabel("%s" % self._get_label(sortby))
if sortby is None: rotate_ticklabels(ax, 15)
if i == 0:
if i == 0 and hue is not None:
ax.legend(loc="best", fontsize=fontsize, shadow=True)
return fig
@add_fig_kwargs
def plot_lattice_convergence(self, sortby=None, hue=None, fontsize=8, **kwargs):
def plot_lattice_convergence(self, what_list=None, sortby=None, hue=None, fontsize=8, **kwargs):
"""
Plot the convergence of the lattice parameters (a, b, c, alpha, beta, gamma).
wrt the``sortby`` parameter. Values can optionally be grouped by ``hue``.
Args:
what_list: List of strings with the quantities to plot e.g. ["a", "alpha", "beta"].
None means all.
item: Define the quantity to plot. Accepts callable or string
If string, it's assumed that the abifile has an attribute
with the same name and `getattr` is invoked.
@ -985,6 +989,9 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
def c(afile):
"c [Ang]"
return getattr(afile, key).lattice.c
def volume(afile):
r"$V$"
return getattr(afile, key).lattice.volume
def alpha(afile):
r"$\alpha$"
return getattr(afile, key).lattice.alpha
@ -995,7 +1002,10 @@ Expecting callable or attribute name or key in abifile.params""" % (type(hue), s
r"$\gamma$"
return getattr(afile, key).lattice.gamma
items = [a, b, c, alpha, beta, gamma]
items = [a, b, c, volume, alpha, beta, gamma]
if what_list is not None:
locs = locals()
items = [locs[what] for what in list_strings(what_list)]
# Build plot grid.
nrows, ncols = len(items), 1

View File

@ -100,7 +100,8 @@ class HistFileTest(AbipyTest):
for what in robot.what_list:
assert robot.gridplot(what=what, show=False)
assert robot.combiplot(show=False)
assert robot.plot_lattice_convergence(show=False)
assert robot.plot_lattice_convergence(fontsize=10, show=False)
assert robot.plot_lattice_convergence(what_list=("a", "alpha"), show=False)
if self.has_nbformat():
robot.write_notebook(nbpath=self.get_tmpname(text=True))

View File

@ -707,7 +707,7 @@ class MultipleMdfPlotter(object):
self._mdfs[label][mdf_type] = obj.get_mdf(mdf_type=mdf_type)
@add_fig_kwargs
def plot(self, mdf_type="exc", qview="avg", xlims=None, ylims=None, **kwargs):
def plot(self, mdf_type="exc", qview="avg", xlims=None, ylims=None, fontsize=8, **kwargs):
"""
Plot all macroscopic dielectric functions (MDF) stored in the plotter
@ -720,6 +720,7 @@ class MultipleMdfPlotter(object):
xlims: Set the data limits for the y-axis. Accept tuple e.g. `(left, right)`
or scalar e.g. `left`. If left (right) is None, default values are used
ylims: Same meaning as `ylims` but for the y-axis
fontsize: fontsize for titles and legend.
Return: |matplotlib-Figure|
"""
@ -738,18 +739,18 @@ class MultipleMdfPlotter(object):
if qview == "avg":
# Plot averaged values
self.plot_mdftype_cplx(mdf_type, "Re", ax=ax_mat[0, 0], xlims=xlims, ylims=ylims,
with_legend=True, show=False)
fontsize=fontsize, with_legend=True, show=False)
self.plot_mdftype_cplx(mdf_type, "Im", ax=ax_mat[0, 1], xlims=xlims, ylims=ylims,
with_legend=False, show=False)
fontsize=fontsize, with_legend=False, show=False)
elif qview == "all":
# Plot MDF(q)
nqpt = len(qpoints)
for iq, qpt in enumerate(qpoints):
islast = (iq == nqpt - 1)
self.plot_mdftype_cplx(mdf_type, "Re", qpoint=qpt, ax=ax_mat[iq, 0], xlims=xlims, ylims=ylims,
with_legend=(iq == 0), with_xlabel=islast, with_ylabel=islast, show=False)
fontsize=fontsize, with_legend=(iq == 0), with_xlabel=islast, with_ylabel=islast, show=False)
self.plot_mdftype_cplx(mdf_type, "Im", qpoint=qpt, ax=ax_mat[iq, 1], xlims=xlims, ylims=ylims,
with_legend=False, with_xlabel=islast, with_ylabel=islast, show=False)
fontsize=fontsize, with_legend=False, with_xlabel=islast, with_ylabel=islast, show=False)
else:
raise ValueError("Invalid value of qview: %s" % str(qview))
@ -762,12 +763,10 @@ class MultipleMdfPlotter(object):
#@add_fig_kwargs
#def plot_mdftypes(self, qview="avg", xlims=None, ylims=None, **kwargs):
# """
# Args:
# qview:
# xlims
# ylims
# Return: matplotlib figure
# """
# # Build plot grid.
@ -778,10 +777,8 @@ class MultipleMdfPlotter(object):
# ncols, nrows = 2, len(qpoints)
# else:
# raise ValueError("Invalid value of qview: %s" % str(qview))
# import matplotlib.pyplot as plt
# fig, ax_mat = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True, squeeze=False)
# if qview == "avg":
# # Plot averaged values
# for mdf_type in self.MDF_TYPES:
@ -799,18 +796,15 @@ class MultipleMdfPlotter(object):
# with_legend=(iq == 0), with_xlabel=islast, with_ylabel=islast, show=False)
# self.plot_mdftype_cplx(mdf_type, "Im", qpoint=qpt, ax=ax_mat[iq, 1], xlims=xlims, ylims=ylims,
# with_legend=False, with_xlabel=islast, with_ylabel=islast, show=False)
# else:
# raise ValueError("Invalid value of qview: %s" % str(qview))
# #ax_mat[0, 0].legend(loc="best", fontsize=fontsize, shadow=True)
# #fig.tight_layout()
# return fig
@add_fig_kwargs
def plot_mdftype_cplx(self, mdf_type, cplx_mode, qpoint=None, ax=None, xlims=None, ylims=None,
with_legend=True, with_xlabel=True, with_ylabel=True, fontsize=12, **kwargs):
with_legend=True, with_xlabel=True, with_ylabel=True, fontsize=8, **kwargs):
"""
Helper function to plot data corresponds to ``mdf_type``, ``cplx_mode``, ``qpoint``.
@ -930,7 +924,6 @@ class MdfRobot(Robot, RobotWithEbands):
Return an instance of :class:`MultipleMdfPlotter` to compare multiple dielectric functions.
"""
plotter = MultipleMdfPlotter() if cls is None else cls()
for label, mdf in self.items():
plotter.add_mdf_file(label, mdf)

View File

@ -1431,13 +1431,15 @@ class ElectronBands(Has_Structure):
spin: Spin index.
valence: Int or iterable with the valence indices.
conduction: Int or iterable with the conduction indices.
method: String defining the method.
method (str): String defining the integraion method.
step: Energy step (eV) of the linear mesh.
width: Standard deviation (eV) of the gaussian.
mesh: Frequency mesh to use. If None, the mesh is computed automatically from the eigenvalues.
Returns: |Function1D| object.
"""
# TODO: Generalize to k+q with
# k2kqg = self.kpoints.get_k2kqg_map(qpt, atol_kdiff=atol_kdiff)
self.kpoints.check_weights()
if not isinstance(valence, Iterable): valence = [valence]
if not isinstance(conduction, Iterable): conduction = [conduction]
@ -1481,7 +1483,7 @@ class ElectronBands(Has_Structure):
jdos += fact * gaussian(mesh, width, center=ec-ev)
else:
raise NotImplementedError("Method %s is not supported" % method)
raise NotImplementedError("Method %s is not supported" % str(method))
return Function1D(mesh, jdos)
@ -1517,7 +1519,7 @@ class ElectronBands(Has_Structure):
ax.grid(True)
ax.set_xlabel('Energy [eV]')
cmap = plt.get_cmap(colormap)
lw = 1.0
lw = kwargs.pop("lw", 1.0)
for s in self.spins:
spin_sign = +1 if s == 0 else -1
@ -1538,7 +1540,7 @@ class ElectronBands(Has_Structure):
num_plots, i = len(jdos_vc), 0
for (v, c), jdos in jdos_vc.items():
label = r"$v=%s \rightarrow c=%s, \sigma=%s$" % (v, c, s)
color = cmap(float(i)/num_plots)
color = cmap(float(i) / num_plots)
x, y = jdos.mesh, jdos.values
ax.plot(x, cumulative + y, lw=lw, label=label, color=color)
ax.fill_between(x, cumulative, cumulative + y, facecolor=color, alpha=alpha)
@ -1548,12 +1550,14 @@ class ElectronBands(Has_Structure):
num_plots, i = len(jdos_vc), 0
for (v, c), jdos in jdos_vc.items():
color = cmap(float(i)/num_plots)
jdos.plot_ax(ax, color=color, lw=lw, label=r"$v=%s \rightarrow c=%s, \sigma=%s$" % (v, c, s))
jdos.plot_ax(ax, color=color, lw=lw,
label=r"$v=%s \rightarrow c=%s, \sigma=%s$" % (v, c, s))
i += 1
tot_jdos.plot_ax(ax, color="k", lw=lw, label=r"Total JDOS, $\sigma=%s$" % s)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
return fig
def apply_scissors(self, scissors):

View File

@ -193,6 +193,8 @@ class GsrRobotTest(AbipyTest):
assert robot.plot_gsr_convergence(show=False)
assert robot.plot_gsr_convergence(sortby="nkpt", hue="tsmear", show=False)
y_vars = ["energy", "structure.lattice.a", "structure.volume"]
assert robot.plot_convergence_items(y_vars, sortby="nkpt", hue="tsmear", show=False)
assert robot.plot_egaps(show=False)
assert robot.plot_egaps(sortby="nkpt", hue="tsmear")

View File

@ -5,10 +5,10 @@ MgB2 Fermi surface
This example shows how to plot the Fermi surface with matplotlib
"""
from abipy.abilab import abiopen
from abipy import abilab
import abipy.data as abidata
with abiopen(abidata.ref_file("mgb2_kmesh181818_FATBANDS.nc")) as fbnc_kmesh:
with abilab.abiopen(abidata.ref_file("mgb2_kmesh181818_FATBANDS.nc")) as fbnc_kmesh:
ebands = fbnc_kmesh.ebands
# Build ebands in full BZ.

View File

@ -172,3 +172,5 @@ TODO list:
* Remove Prettytable
* context manager to change variables (e.g. autoparal)
* Cleanup and refactoring in OpticTask

View File

@ -7,7 +7,10 @@ from __future__ import division
import re
from docutils import nodes
from docutils.parsers.rst import directives
from sphinx.util.compat import Directive
try:
from sphinx.util.compat import Directive
except ImportError:
from docutils.parsers.rst import Directive
CONTROL_HEIGHT = 30