Remove OncvPlotter.plotly_ methods, use standard mpl methods with obj.plot(plotly=True)

This commit is contained in:
Matteo Giantomassi 2024-07-23 14:16:11 +02:00
parent f080687b7c
commit 090fca1cbe
3 changed files with 44 additions and 91 deletions

View File

@ -470,7 +470,7 @@ class Function1D:
# return self.__class__(self.mesh, -(2 / np.pi) * wmesh * kk_values)
def plot_ax(self, ax, exchange_xy=False, xfactor=1, yfactor=1, *args, **kwargs) -> list:
def plot_ax(self, ax, exchange_xy=False, normalize=False, xfactor=1, yfactor=1, *args, **kwargs) -> list:
"""
Helper function to plot self on axis ax.
@ -478,6 +478,7 @@ class Function1D:
ax: |matplotlib-Axes|.
exchange_xy: True to exchange the axis in the plot.
args: Positional arguments passed to ax.plot
normalize: Normalize the ydata to 1.
xfactor, yfactor: xvalues and yvalues are multiplied by this factor before plotting.
kwargs: Keyword arguments passed to ``matplotlib``. Accepts
@ -504,6 +505,7 @@ class Function1D:
xx, yy = self.mesh, data_from_cplx_mode(c, self.values)
if xfactor != 1: xx = xx * xfactor
if yfactor != 1: yy = yy * yfactor
if normalize: yy = np.max(yy)
if exchange_xy:
xx, yy = yy, xx

View File

@ -77,7 +77,7 @@ class FrohlichAnalyzer:
#def analyze(self)
ITER_LABELS = [
_LABELS = [
r'$E_{pol}$',
r'$E_{el}$',
r'$E_{ph}$',
@ -185,7 +185,7 @@ class VarpeqFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter)
nstep2cv = nstep2cv_spin[spin]
last_iteration = iter_rec_spin[spin, nstep2cv-1, :] * abu.Ha_eV
return dict(zip(ITER_LABELS, last_iteration))
return dict(zip(_LABELS, last_iteration))
@add_fig_kwargs
def plot_scf_cycle(self, ax_mat=None, fontsize=8, **kwargs) -> Figure:
@ -211,7 +211,7 @@ class VarpeqFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter)
xs = np.arange(1, nstep2cv + 1)
for iax, ax in enumerate(ax_mat[spin]):
for ilab, label in enumerate(ITER_LABELS):
for ilab, label in enumerate(_LABELS):
ys = iterations[:,ilab]
if iax == 0:
# Plot energies in linear scale.
@ -389,15 +389,16 @@ class Polaron:
return self.structure.plot_bz(show=False, **kws)
@add_fig_kwargs
def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 20,
def plot_ank_with_ebands(self, ebands_kpath, ebands_kmesh=None, nksmall: int = 20, normalize: bool = False,
lpratio: int = 5, step: float = 0.1, width: float = 0.2, method: str = "linear",
ax_list=None, ylims=None, scale=10, fontsize=8, **kwargs) -> Figure:
ax_list=None, ylims=None, scale=10, fontsize=12, **kwargs) -> Figure:
"""
Plot electronic energies with markers whose size is proportional to |A_nk|^2.
Args:
ebands_kpath: ElectronBands or Abipy file providing an electronic band structure along a path.
ebands_kmesh: ElectronBands or Abipy file providing an electronic band structure in the IBZ.
normalize: Rescale the two DOS to plot them on the same scale.
method=Interpolation method.
ax_list: List of |matplotlib-Axes| or None if a new figure should be created.
scale: Scaling factor for |A_nk|^2.
@ -427,11 +428,12 @@ class Polaron:
ymin = min(ymin, e)
ymax = max(ymax, e)
points = Marker(x, y, s, color="y")
points = Marker(x, y, s, color="orange")
nrows, ncols = 1, 2
gridspec_kw = {'width_ratios': [2, 1]}
ax_list, fig, plt = get_axarray_fig_plt(ax_list, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False)
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
ax_list = ax_list.ravel()
ax = ax_list[0]
@ -462,22 +464,20 @@ class Polaron:
enes_n = ebands_kmesh.eigens[self.spin, ik, self.bstart:self.bstop]
a2_n = a2_interp.eval_kpoint(kpoint)
for e, a2 in zip(enes_n, a2_n):
ank_dos += weight * a2 * gaussian(mesh, width, center=e - e0)
ank_dos += weight * a2 * gaussian(mesh, width, center=e-e0)
ank_dos = Function1D(mesh, ank_dos)
print("A2(E) integrates to:", ank_dos.integral_value, " Ideally, it should be 1.")
# Rescale the two DOS to plot them on the same scale.
ank_dos = ank_dos / ank_dos.max
ax = ax_list[1]
edos.plot_ax(ax, e0, spin=self.spin, normalize=True, exchange_xy=True, label="eDOS(E)")
edos.plot_ax(ax, e0, spin=self.spin, normalize=normalize, exchange_xy=True, label="eDOS(E)")
ank_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$A^2$(E)", color=points.color)
ax.set_xlabel("Arbitrary units", fontsize=fontsize)
ank_dos.plot_ax(ax, exchange_xy=True, label=r"$A^2$(E)", color=points.color)
ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
if ylims is None:
# Automatic ylims.
ymin -= 0.1 * abs(ymin)
ymin -= e0
ymax += 0.1 * abs(ymax)
@ -512,7 +512,7 @@ class Polaron:
**kwargs)
@add_fig_kwargs
def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width = 0.001,
def plot_bqnu_with_phbands(self, phbands_qpath, phdos_file=None, ddb=None, width = 0.001, normalize: bool=False,
method="linear", verbose=0, anaddb_kwargs=None,
ax=None, scale=10, fontsize=12, **kwargs) -> Figure:
"""
@ -520,17 +520,21 @@ class Polaron:
Args:
phbands_qpath: PhononBands or Abipy file providing a phonon band structure.
normalize: Rescale the two DOS to plot them on the same scale.
phdos_file:
method=Interpolation method.
ax: |matplotlib-Axes| or None if a new figure should be created.
scale: Scaling factor for |B_qnu|^2.
"""
with_phdos = phdos_file is not None and ddb is not None
nrows, ncols = 1, 2 if with_phdos else 1
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False)
ax_list = ax_list.ravel()
nrows, ncols = 1, 1
gridspec_kw = None
if with_phdos:
ncols, gridspec_kw = 2, {'width_ratios': [2, 1]}
ax_list, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=False, sharey=True, squeeze=False, gridspec_kw=gridspec_kw)
ax_list = ax_list.ravel()
phbands_qpath = PhononBands.as_phbands(phbands_qpath)
b2_interp = self.get_b2_interpolator(method)
@ -549,8 +553,7 @@ class Polaron:
for w, b2 in zip(omegas_nu, b2_nu):
x.append(iq); y.append(w); s.append(scale * b2)
points = Marker(x, y, s, color="yellow")
points = Marker(x, y, s, color="orange")
phbands_qpath.plot(ax=ax_list[0], points=points, show=False)
if not with_phdos:
@ -575,13 +578,10 @@ class Polaron:
bqnu_dos = Function1D(mesh, bqnu_dos)
# Rescale the two DOS to plot them on the same scale.
phdos = phdos / phdos.max
bqnu_dos = bqnu_dos / bqnu_dos.max
ax = ax_list[1]
phdos.plot_ax(ax, exchange_xy=True, label="phDOS(E)")
bqnu_dos.plot_ax(ax, exchange_xy=True, label=r"$B^2$(E)", color=points.color)
phdos.plot_ax(ax, exchange_xy=True, normalize=normalize, label="phDOS(E)")
bqnu_dos.plot_ax(ax, exchange_xy=True, normalize=normalize, label=r"$B^2$(E)", color=points.color)
ax.set_xlabel("Arbitrary units", fontsize=fontsize)
ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
@ -771,7 +771,7 @@ class VarpeqRobot(Robot, RobotWithEbands):
data = defaultdict(list)
# Now loop over the sorted files and extract the results of the final iteration.
for i, (label, abifile, nktot) in zip(labels, abifiles, nktot_list):
for i, (label, abifile, nktot) in enumerate(zip(labels, abifiles, nktot_list)):
for k, v in abifile.get_last_iteration_dict_ev(spin).items():
data[k].append(v)
@ -807,13 +807,14 @@ class VarpeqRobot(Robot, RobotWithEbands):
return fig
@add_fig_kwargs
def plot_kdata(self, fontsie=12, **kwargs) -> Figure:
def plot_kconv(self, fontsize=12, **kwargs) -> Figure:
"""
Plot the convergence of the data wrt to the k-point sampling.
"""
nsppol = self.getattr_alleq("nsppol")
# Build grid of plots.
nrows, ncols = len(ITER_LABELS), nsppol
nrows, ncols = len(_LABELS), nsppol
ax_mat, fig, plt = get_axarray_fig_plt(None, nrows=nrows, ncols=ncols,
sharex=True, sharey=True, squeeze=False)
deg = 1
@ -821,16 +822,21 @@ class VarpeqRobot(Robot, RobotWithEbands):
kdata = self.get_kdata_spin(spin)
xs = kdata["minibz_vol"]
xvals = np.linspace(0, 1.1 * xs.max(), 100)
for ix, label in enumerate(ITER_LABELS):
for ix, label in enumerate(_LABELS):
ax = ax_mat[ix, spin]
color = "k"
ys = kdata[label]
# plot ab-initio points.
ax.scatter(xs, ys, color=color, marker="o")
# plot fit.
p = np.poly1d(np.polyfit(xs, ys, deg))
ax = ax_mat[ix,spin]
ax.scatter(xs, ys, marker="o")
ax.plot(xvals, p[xvals], style="k--")
ax.plot(xvals, p(xvals), color=color, ls="--")
ax.grid(True)
ax.legend(loc="best", shadow=True, fontsize=fontsize)
#ax.set_xlabel("Iteration", fontsize=fontsize)
#ax.set_ylabel("Energy (eV)" if iax == 0 else r"$|\Delta|$ Energy (eV)", fontsize=fontsize)
ax.set_ylabel(label, fontsize=fontsize)
#ax.legend(loc="right", shadow=True, fontsize=fontsize)
#print([(0, p(0)), (xs[0], ys[0]), (xs[1], ys[1])])
return fig

View File

@ -100,11 +100,6 @@ class OncvPlotter(NotebookWriter):
ax.axvline(self.parser.rc5, lw=2, color=color, ls="--")
ax._custom_rc_lines.append((self.parser.rc5, color))
def plotly_atan_logders(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_atan_logders(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_atan_logders(self, ax=None, with_xlabel=True,
fontsize: int = 8, **kwargs) -> Figure:
@ -156,11 +151,6 @@ class OncvPlotter(NotebookWriter):
raise ValueError(f"Invalid value for {what=}")
return ae_wfs, ps_wfs
def plotly_radial_wfs(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_radial_wfs(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_radial_wfs(self, ax=None, what="bound_states",
fontsize: int = 8, **kwargs) -> Figure:
@ -201,11 +191,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_projectors(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_projects(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_projectors(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
@ -235,11 +220,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_densities(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_densities(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_densities(self, ax=None, timesr2=False, fontsize: int = 8, **kwargs) -> Figure:
"""
@ -262,11 +242,6 @@ class OncvPlotter(NotebookWriter):
)
return fig
def plotly_der_densities(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_der_densities(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_der_densities(self, ax=None, order=1, acc=4, fontsize=8, **kwargs) -> Figure:
"""
@ -295,11 +270,6 @@ class OncvPlotter(NotebookWriter):
)
return fig
def plotly_potentials(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_potentials(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_potentials(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
@ -322,11 +292,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_vtau(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_vtau(*args, show=False, **kwargs))
def plot_vtau(self, xscale="log", ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
Plot v_tau and v_tau(model+pseudo) potentials on axis ax.
@ -352,11 +317,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_tau(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_tau(*args, show=False, **kwargs))
def plot_tau(self, ax=None, yscale="log", fontsize: int = 8, **kwargs) -> Figure:
"""
Plot kinetic energy densities tauPS and tau(M+PS) on axis ax.
@ -381,11 +341,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_der_potentials(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_der_potentials(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_der_potentials(self, ax=None, order=1, acc=4, fontsize: int = 8, **kwargs) -> Figure:
"""
@ -417,11 +372,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_kene_vs_ecut(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_kene_vs_ecut(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_kene_vs_ecut(self, ax=None, fontsize: int = 8, **kwargs) -> Figure:
"""
@ -449,11 +399,6 @@ class OncvPlotter(NotebookWriter):
return fig
def plotly_atanlogder_econv(self, *args, **kwargs):
"""Generate plotly figure from matplotly."""
from plotly.tools import mpl_to_plotly
return mpl_to_plotly(self.plot_atan_logder_econv(*args, show=False, **kwargs))
@add_fig_kwargs
def plot_atanlogder_econv(self, ax_list=None, fontsize: int = 6, **kwargs) -> Figure:
"""