Import df_widgets

This commit is contained in:
Matteo Giantomassi 2017-07-11 00:51:47 +02:00
parent 7257f92fba
commit 6cf681b510
6 changed files with 775 additions and 19 deletions

View File

@ -81,7 +81,7 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
self._structure = Structure.from_abivars(**self.header)
# Add AbinitSpacegroup (needed in guessed_ngkpt)
# FIXME: has_timerev is always True
# FIXME: kptopt is not reported in the header --> has_timerev is always set to True
spgid, has_timerev, h = 0, True, self.header
self._structure.set_abi_spacegroup(AbinitSpaceGroup(spgid, h.symrel, h.tnons, h.symafm, has_timerev))
@ -96,13 +96,12 @@ class DdbFile(TextFile, Has_Structure, NotebookWriter):
"""String representation."""
lines = []
app, extend = lines.append, lines.extend
#extend(super(DdbFile, self).__str__().splitlines())
app(marquee("File Info", mark="="))
app(self.filestat(as_string=True))
app("")
app(marquee("Structure", mark="="))
app(str(self.structure))
app(str(self.structure.to_string(verbose=verbose)))
app(marquee("Q-points", mark="="))
app(str(self.qpoints))

44
abipy/display/pandasw.py Normal file
View File

@ -0,0 +1,44 @@
# coding: utf-8
"""Widgets for Pandas Dataframes."""
from __future__ import print_function, division, unicode_literals, absolute_import
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as ipw
import df_widgets.utils as ut
from functools import wraps
@wraps(pd.DataFrame.plot)
def dfw_plot(data, **kwargs):
def plot_dataframe(x, y, kind, sharex, sharey, subplots, grid, legend,
logx, logy, loglog, colorbar, sort_columns):
x, y = ut.widget2py(x, y)
sharex, colorbar = ut.str2bool_or_none(sharex, colorbar)
data.plot(x=x, y=y, kind=kind, subplots=subplots, sharex=None, sharey=sharey,
layout=None, figsize=None, use_index=True, title=None, grid=grid, legend=legend, style=None,
logx=logx, logy=logy, loglog=loglog, xticks=None, yticks=None, xlim=None, ylim=None,
rot=None, fontsize=None, colormap=colorbar, table=False, yerr=None, xerr=None, secondary_y=False,
sort_columns=sort_columns, **kwargs)
# There's a typo in the documentation (colorbar/colormap!)
return plt.gcf()
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
plot_dataframe,
x=allcols,
y=allcols,
sharex=["None", "True", "False"],
sharey=False,
kind=["line", "bar", "barh", "hist", "box", "kde", "density", "area", "pie", "scatter", "hexbin"],
subplots=False,
grid=True,
legend=True,
logx=False,
logy=False,
loglog=False,
colorbar=["None", "True", "False"],
sort_columns=False,
)

472
abipy/display/seabornw.py Normal file
View File

@ -0,0 +1,472 @@
# coding: utf-8
"""
Widgets for Pandas Dataframes based on seaborn API.
API reference
Distribution plots
jointplot(x, y[, data, kind, stat_func, ...]) Draw a plot of two variables with bivariate and univariate graphs.
pairplot(data[, hue, hue_order, palette, ...]) Plot pairwise relationships in a dataset.
distplot(a[, bins, hist, kde, rug, fit, ...]) Flexibly plot a univariate distribution of observations.
kdeplot(data[, data2, shade, vertical, ...]) Fit and plot a univariate or bivariate kernel density estimate.
rugplot(a[, height, axis, ax]) Plot datapoints in an array as sticks on an axis.
Regression plots
lmplot(x, y, data[, hue, col, row, palette, ...]) Plot data and regression model fits across a FacetGrid.
regplot(x, y[, data, x_estimator, x_bins, ...]) Plot data and a linear regression model fit.
residplot(x, y[, data, lowess, x_partial, ...]) Plot the residuals of a linear regression.
interactplot(x1, x2, y[, data, filled, ...]) Visualize a continuous two-way interaction with a contour plot.
coefplot(formula, data[, groupby, ...]) Plot the coefficients from a linear model.
# Categorical plots
factorplot([x, y, hue, data, row, col, ...]) Draw a categorical plot onto a FacetGrid.
boxplot([x, y, hue, data, order, hue_order, ...]) Draw a box plot to show distributions with respect to categories.
violinplot([x, y, hue, data, order, ...]) Draw a combination of boxplot and kernel density estimate.
stripplot([x, y, hue, data, order, ...]) Draw a scatterplot where one variable is categorical.
swarmplot([x, y, hue, data, order, ...]) Draw a categorical scatterplot with non-overlapping points.
pointplot([x, y, hue, data, order, ...]) Show point estimates and confidence intervals using scatter plot glyphs.
barplot([x, y, hue, data, order, hue_order, ...]) Show point estimates and confidence intervals as rectangular bars.
countplot([x, y, hue, data, order, ...]) Show the counts of observations in each categorical bin using bars.
Matrix plots
heatmap(data[, vmin, vmax, cmap, center, ...]) Plot rectangular data as a color-encoded matrix.
clustermap(data[, pivot_kws, method, ...]) Plot a hierarchically clustered heatmap of a pandas DataFrame
Timeseries plots
tsplot(data[, time, unit, condition, value, ...]) Plot one or more timeseries with flexible representation of uncertainty.
Miscellaneous plots
palplot(pal[, size]) Plot the values in a color palette as a horizontal array.
Axis grids
FacetGrid(data[, row, col, hue, col_wrap, ...]) Subplot grid for plotting conditional relationships.
PairGrid(data[, hue, hue_order, palette, ...]) Subplot grid for plotting pairwise relationships in a dataset.
JointGrid(x, y[, data, size, ratio, space, ...]) Grid for drawing a bivariate plot with marginal univariate plots.
"""
from __future__ import print_function, division, unicode_literals, absolute_import
import sys
import ipywidgets as ipw
import seaborn as sns
import df_widgets.utils as ut
from functools import wraps
from collections import OrderedDict
from IPython.display import display, clear_output
__all__ = [
"api_selector",
# Distribution plots
"jointplot",
"pairplot",
#"distplot",
#"kdeplot",
#"rugplot",
# Regression plots
"lmplot",
#"regplot",
#"residplot",
#"interactplot",
#"coefplot",
# Categorical plots
"factorplot",
"boxplot",
"violinplot",
"stripplot",
"swarmplot",
"pointplot",
"barplot",
"countplot",
# Matrix plots
#"heatmap",
#"clustermap",
# Timeseries plots
#"tsplot",
# Miscellaneous plots
#"palplot",
]
def api_selector(data, funcname="countplot"):
"""
A widgets with ToogleButtons that allow the user to select and display
the widget associated to the different seaborn functions.
"""
this_module = sys.modules[__name__]
name2wfunc = OrderedDict()
for a in __all__:
if a == "api_selector": continue
func = this_module.__dict__.get(a)
if not callable(func): continue
name2wfunc[func.__name__] = func
w1 = ipw.ToggleButtons(description='seaborn API', options=list(name2wfunc.keys()))
w1.value = funcname
w2 = name2wfunc[funcname](data)
box = ipw.VBox(children=[w1, w2])
def on_value_change(change):
#print(change)
box.close()
clear_output()
api_selector(data, funcname=change["new"])
w1.observe(on_value_change, names='value')
return display(box)
@wraps(sns.jointplot)
def joinplot(data, joint_kws=None, marginal_kws=None, annot_kws=None, **kwargs):
def sns_joinplot(x, y, kind, color):
x, y, color = ut.widget2py(x, y, color)
# TODO: stat_func
return sns.jointplot(x, y, data=data, kind=kind, # stat_func=<function pearsonr>,
color=color, size=6, ratio=5, space=0.2, dropna=True, xlim=None, ylim=None,
joint_kws=joint_kws, marginal_kws=marginal_kws, annot_kws=annot_kws, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_joinplot,
x=allcols,
y=allcols,
kind=["scatter", "reg", "resid", "kde", "hex"],
color=ut.colors_dropdow(),
)
@wraps(sns.pairplot)
def pairplot(data, plot_kws=None, diag_kws=None, grid_kws=None):
# TODO: Write widget with multiple checkboxes to implement lists.
def sns_pairplot(x_vars, y_vars, hue, kind, diag_kind):
x_vars, y_vars, hue = ut.widget2py(x_vars, y_vars, hue)
return sns.pairplot(data, hue=hue, hue_order=None, palette=None, vars=None, x_vars=x_vars, y_vars=y_vars,
kind=kind, diag_kind=diag_kind, markers=None, size=2.5, aspect=1, dropna=True,
plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_pairplot,
x_vars=allcols,
y_vars=allcols,
hue=allcols,
kind=["scatter", "ref"],
diag_kind=["hist", "kde"],
)
"""
@wraps(sns.distplot)
def distplot(data, fit=None, hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None):
def sns_distplot(hist, kde, rug, color, vertical, norm_hist):
color = ut.widget2py(color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.distplot(a, bins=None, hist=hist, kde=kde, rug=rug, fit=fit,
hist_kws=hist_kws, kde_kws=kde_kws, rug_kws=rug_kws, fit_kws=fit_kws,
color=clor, vertical=vertical, norm_hist=norm_hist, axlabel=None, label=None, ax=ax)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_distplot,
hist=True,
kde=True,
rug=False,
color=ut.colors_dropdow(),
vertical=False,
norm_hist=False,
)
@wraps(sns.kdeplot)
def kdeplot(data, **kwargs):
def sns_kdeplot()
color = ut.widget2py(color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.kdeplot(data, data2=None, shade=False, vertical=False, kernel='gau', bw='scott',
gridsize=100, cut=3, clip=None, legend=True, cumulative=False, shade_lowest=True, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_kdeplot,
color=ut.colors_dropdow(),
)
"""
####################
# Regression plots #
####################
@wraps(sns.lmplot)
def lmplot(data, scatter_kws=None, line_kws=None):
def sns_lmplot(x, y, hue, col, row, legend, size):
x, y, hue, col, row = ut.widget2py(x, y, hue, col, row)
return sns.lmplot(x, y, data, hue=hue, col=col, row=row, palette=None, col_wrap=None,
size=size, aspect=1, markers='o', sharex=True, sharey=True, hue_order=None,
col_order=None, row_order=None, legend=legend, legend_out=True,
x_estimator=None, x_bins=None, x_ci='ci', scatter=True, fit_reg=True,
ci=95, n_boot=1000, units=None, order=1, logistic=False, lowess=False, robust=False,
logx=False, x_partial=None, y_partial=None, truncate=False, x_jitter=None, y_jitter=None,
scatter_kws=scatter_kws, line_kws=line_kws)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_lmplot,
x=allcols,
y=allcols,
hue=allcols,
col=allcols,
row=allcols,
legend=True,
size=ut.size_slider(default=5),
)
@wraps(sns.interactplot)
def interactplot(data, contour_kws=None, scatter_kws=None, **kwargs):
def sns_interactplot(x1, x2, y, filled, colorbar, logistic):
ax, fig, _ = ut.get_ax_fig_plt()
return sns.interactplot(x1, x2, y, data=data, filled=filled, cmap='RdBu_r', colorbar=colorbar,
levels=30, logistic=logistic, contour_kws=contour_kws, scatter_kws=scatter_kws,
ax=ax, **kwargs)
allcols = list(data.keys())
return ipw.interact_manual(
sns_interactplot,
x1=allcols,
x2=allcols,
y=allcols,
filled=False,
colorbar=True,
logistic=False,
)
#####################
# Categorical plots #
#####################
@wraps(sns.factorplot)
def factorplot(data, facet_kws=None, **kwargs):
def sns_factorplot(x, y, hue, color, kind, size, legend):
x, y, hue, color = ut.widget2py(x, y, hue, color)
return sns.factorplot(x=x, y=y, hue=hue, data=data, row=None, col=None, col_wrap=None, # estimator=<function mean>,
ci=95, n_boot=1000, units=None, order=None, hue_order=None, row_order=None, col_order=None,
kind=kind, size=size, aspect=1, orient=None, color=color, palette=None,
legend=legend, legend_out=True, sharex=True, sharey=True, margin_titles=False,
facet_kws=facet_kws, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_factorplot,
x=allcols,
y=allcols,
hue=allcols,
color=ut.colors_dropdow(),
kind=["point", "bar", "count", "box", "violin", "strip"],
size=ut.size_slider(default=4),
legend=True,
)
@wraps(sns.boxplot)
def boxplot(data, **kwargs):
def sns_boxplot(x, y, hue, orient, color, saturation, notch):
x, y, hue, orient, color = ut.widget2py(x, y, hue, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.boxplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None, orient=orient,
color=color, palette=None, saturation=saturation, width=0.8, fliersize=5, linewidth=None,
whis=1.5, notch=notch, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_boxplot,
x=allcols,
y=allcols,
hue=allcols,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
saturation=ut.saturation_slider(default=0.75),
notch=False,
)
@wraps(sns.violinplot)
def violinplot(data, **kwargs):
def sns_violinplot(x, y, hue, bw, scale, inner, split, orient, color, saturation):
x, y, hue, inner, orient, color = ut.widget2py(x, y, hue, inner, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
sns.violinplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None,
bw=bw, cut=2, scale=scale, scale_hue=True,
gridsize=100, width=0.8, inner=inner, split=split, orient=orient,
linewidth=None, color=color, palette=None, saturation=saturation, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_violinplot,
x=allcols,
y=allcols,
hue=allcols,
bw=["scott", "silverman", "float"],
scale=["area", "count", "width"],
inner=["box", "quartile", "point", "stick", "None"],
split=False,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
saturation=ut.saturation_slider(default=0.75),
)
@wraps(sns.stripplot)
def stripplot(data, **kwargs):
def sns_stripplot(x, y, hue, split, orient, color, size, linewidth):
x, y, hue, orient, color = ut.widget2py(x, y, hue, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.stripplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None, jitter=False,
split=split, orient=orient, color=color, palette=None, size=size, edgecolor='gray',
linewidth=linewidth, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_stripplot,
x=allcols,
y=allcols,
hue=allcols,
split=False,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
size=ut.size_slider(default=5),
linewidth=ut.linewidth_slider(default=0),
)
@wraps(sns.swarmplot)
def swarmplot(data, **kwargs):
def sns_swarmplot(x, y, hue, split, orient, color, size, linewidth):
x, y, hue, orient, color = ut.widget2py(x, y, hue, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.swarmplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None,
split=split, orient=orient, color=color, palette=None, size=size,
edgecolor='gray', linewidth=linewidth, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_swarmplot,
x=allcols,
y=allcols,
hue=allcols,
split=False,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
size=ut.size_slider(default=5),
linewidth=ut.linewidth_slider(default=0),
)
@wraps(sns.pointplot)
def pointplot(data, **kwargs):
def sns_pointplot(x, y, hue, split, join, orient, color, linewidth):
x, y, hue, orient, color = ut.widget2py(x, y, hue, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.pointplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None, # estimator=<function mean>,
ci=95, n_boot=1000, units=None, markers='o', linestyles='-', dodge=False, join=join, scale=1,
orient=orient, color=color, palette=None, ax=ax, errwidth=None, capsize=None, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_pointplot,
x=allcols,
y=allcols,
hue=allcols,
split=False,
join=True,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
linewidth=ut.linewidth_slider(default=0),
)
@wraps(sns.barplot)
def barplot(data, **kwargs):
def sns_barplot(x, y, hue, orient, color, saturation):
x, y, hue, orient, color = ut.widget2py(x, y, hue, orient, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.barplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None, # estimator=<function mean>,
ci=95, n_boot=1000, units=None, orient=orient, color=color, palette=None,
saturation=saturation, errcolor='.26', ax=ax, **kwargs) # errwidth=None, capsize=None, # New args added in ??
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_barplot,
x=allcols,
y=allcols,
hue=allcols,
orient=["None", "v", "h"],
color=ut.colors_dropdow(),
saturation=ut.saturation_slider(default=0.75),
)
@wraps(sns.countplot)
def countplot(data, **kwargs):
def sns_countplot(x, y, hue, color, saturation):
x, y, hue, color = ut.widget2py(x, y, hue, color)
ax, fig, _ = ut.get_ax_fig_plt()
return sns.countplot(x=x, y=y, hue=hue, data=data, order=None, hue_order=None, orient=None,
color=color, palette=None, saturation=saturation, ax=ax, **kwargs)
allcols = ["None"] + list(data.keys())
return ipw.interact_manual(
sns_countplot,
x=allcols,
y=allcols,
hue=allcols,
color=ut.colors_dropdow(),
saturation=ut.saturation_slider(default=0.75),
)
################
# Matrix plots #
################
@wraps(sns.heatmap)
def heatmap(data, annot_kws=None, cbar_kws=None, **kwargs):
def sns_heatmap():
ax, fig, _ = ut.get_ax_fig_plt()
return sns.heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None,
fmt='.2g', annot_kws=annot_kws, linewidths=0, linecolor='white', cbar=True,
cbar_kws=cbar_kws, cbar_ax=None, square=False, ax=ax,
xticklabels=True, yticklabels=True, mask=None, **kwargs)
return ipw.interact_manual(
sns_heatmap,
)
@wraps(sns.clustermap)
def clustermap(data, pivot_kws=None, cbar_kws=None, **kwargs):
def sns_clustermap():
return sns.clustermap(data, pivot_kws=pivot_kws, method='average', metric='euclidean',
z_score=None, standard_scale=None, figsize=None, cbar_kws=cbar_kws,
row_cluster=True, col_cluster=True, row_linkage=None, col_linkage=None,
row_colors=None, col_colors=None, mask=None, **kwargs)
return ipw.interact_manual(
sns_clustermap,
)

236
abipy/display/utils.py Normal file
View File

@ -0,0 +1,236 @@
# coding: utf-8
"""Widgets for Pandas Dataframes."""
from __future__ import print_function, division, unicode_literals, absolute_import
import ipywidgets as ipw
from collections import OrderedDict
def add_docstrings(*tuples):
"""
This decorator adds to the docstring the documentation for functions.
When writing high-level API, it's quite common to call thirdy-party functions
with a restricted set of arguments while optional keyword arguments are
collected in an optional dictionary.
The first item of the tuple contains the function (python object) wrapped by the code.
The second item is list of strings with the name of the actual arguments passed to function.
"""
from functools import wraps
def wrapper(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
# Add docstrings for the functions that will be called by func.
lines = []
app = lines.append
for t in tuples:
fname = t[0].__name__
# List of strings or string.
if isinstance(t[1], (list, tuple)):
fargs = ",".join("`%s`" % a for a in t[1])
else:
fargs = "`%s`" % t[1]
app("\n%s are passed to function :func:`%s` in module :mod:`%s`" % (fargs, fname, t[0].__module__))
app("Docstring of `%s`:" % fname)
app(t[0].__doc__)
s = "\n".join(lines)
if wrapped_func.__doc__ is not None:
# Add s at the end of the docstring.
wrapped_func.__doc__ += "\n" + s
else:
# Use s
wrapped_func.__doc__ = s
return wrapped_func
return wrapper
def widget2py(*args):
l = [None if a == "None" else a for a in args]
return l[0] if len(l) == 1 else l
def str2bool_or_none(*args):
d = {"None": None, "True": True, "False": False}
l = [d[a] for a in args]
return l[0] if len(l) == 1 else l
def get_ax_fig_plt(ax=None):
"""
Helper function used in plot functions supporting an optional Axes argument.
If ax is None, we build the `matplotlib` figure and create the Axes else
we return the current active figure.
Returns:
ax: :class:`Axes` object
figure: matplotlib figure
plt: matplotlib pyplot module.
"""
import matplotlib.pyplot as plt
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
else:
fig = plt.gcf()
return ax, fig, plt
# Taken from matplotlib.markers.MarkerStyle (replaced dict with OrderedDict).
_mpl_markers = OrderedDict([
('.', 'point'),
(',', 'pixel'),
('o', 'circle'),
('v', 'triangle_down'),
('^', 'triangle_up'),
('<', 'triangle_left'),
('>', 'triangle_right'),
('1', 'tri_down'),
('2', 'tri_up'),
('3', 'tri_left'),
('4', 'tri_right'),
('8', 'octagon'),
('s', 'square'),
('p', 'pentagon'),
('*', 'star'),
('h', 'hexagon1'),
('H', 'hexagon2'),
('+', 'plus'),
('x', 'x'),
('D', 'diamond'),
('d', 'thin_diamond'),
('|', 'vline'),
('_', 'hline'),
#(TICKLEFT: 'tickleft',
#(TICKRIGHT: 'tickright',
#(TICKUP: 'tickup',
#(TICKDOWN: 'tickdown',
#(CARETLEFT: 'caretleft',
#(CARETRIGHT: 'caretright',
#(CARETUP: 'caretup',
#(CARETDOWN: 'caretdown',
("None", 'nothing'),
(None, 'nothing'),
(' ', 'nothing'),
('', 'nothing'),
])
def markers_dropdown(default="o"):
return ipw.Dropdown(
options={name: key for key, name in _mpl_markers.items()},
value=default,
description='marker',
)
_mpl_colors = OrderedDict([
("None", "None"),
("blue", "b"),
("green", "g"),
("red", "r"),
("cyan", "c"),
("magenta", "m"),
("yellow", "y"),
("black", "k"),
("white", "w"),
])
def colors_dropdow(default="None"):
return ipw.Dropdown(
options=_mpl_colors,
value=default,
description='color',
)
def linewidth_slider(default=1, orientation="horizontal"):
return ipw.FloatSlider(
value=default,
min=0,
max=10,
step=0.5,
description='linewidth',
orientation=orientation,
readout_format='.1f'
)
def size_slider(default=5, orientation="horizontal"):
return ipw.FloatSlider(
value=default,
min=0,
max=20,
step=0.5,
description='size',
orientation=orientation,
readout_format='.1f'
)
def saturation_slider(default=0.75, orientation="horizontal"):
return ipw.FloatSlider(
value=default,
min=0,
max=1,
step=0.05,
description='saturation',
orientation=orientation,
readout_format='.1f'
)
# Have colormaps separated into categories:
# http://matplotlib.org/examples/color/colormaps_reference.html
_mpl_categ_cmaps = OrderedDict([
#('Perceptually Uniform Sequential',
('Uniform', ['viridis', 'inferno', 'plasma', 'magma']),
('Sequential', ['Blues', 'BuGn', 'BuPu',
'GnBu', 'Greens', 'Greys', 'Oranges', 'OrRd',
'PuBu', 'PuBuGn', 'PuRd', 'Purples', 'RdPu',
'Reds', 'YlGn', 'YlGnBu', 'YlOrBr', 'YlOrRd']),
('Sequential(2)', ['afmhot', 'autumn', 'bone', 'cool',
'copper', 'gist_heat', 'gray', 'hot',
'pink', 'spring', 'summer', 'winter']),
('Diverging', ['BrBG', 'bwr', 'coolwarm', 'PiYG', 'PRGn', 'PuOr',
'RdBu', 'RdGy', 'RdYlBu', 'RdYlGn', 'Spectral',
'seismic']),
('Qualitative', ['Accent', 'Dark2', 'Paired', 'Pastel1',
'Pastel2', 'Set1', 'Set2', 'Set3']),
('Miscellaneous', ['gist_earth', 'terrain', 'ocean', 'gist_stern',
'brg', 'CMRmap', 'cubehelix',
'gnuplot', 'gnuplot2', 'gist_ncar',
'nipy_spectral', 'jet', 'rainbow',
'gist_rainbow', 'hsv', 'flag', 'prism'])
])
# flat list.
_mpl_cmaps = [cm for sublist in _mpl_categ_cmaps.values() for cm in sublist]
def colormap_widget(default=None):
options = _mpl_cmaps
value = options[0]
if default is not None:
value = default
if default not in _mpl_cmaps: options[:].insert(0, value)
return ipw.Dropdown(options=options, value=value, description='colormap')
#def colormap_widget():
# from IPython.display import display, clear_output
# w_type = ipw.Dropdown(options=list(_mpl_categ_cmaps.keys()), description='colormap category')
# w_cmap = ipw.Dropdown(options=_mpl_categ_cmaps["Uniform"], description='colormap name')
#
# def on_value_change(change):
# print(change)
# print(w_cmap.value)
# w_cmap.options = _mpl_categ_cmaps[w_type.value]
# print(w_cmap.value)
#
# w_type.observe(on_value_change, names='value')
# box = ipw.HBox(children=[w_type, w_cmap])
# return display(box)

View File

@ -56,6 +56,7 @@ class HistFile(AbinitNcFile, NotebookWriter):
app("")
app(marquee("Initial structure", mark="="))
app(str(self.initial_structure))
app("")
app(marquee("Final structure", mark="="))
app("Number of relaxation steps performed: %d" % self.num_steps)
app(str(self.final_structure))
@ -201,30 +202,31 @@ class HistFile(AbinitNcFile, NotebookWriter):
return fig
def mvplot_trajectories(self, colormap="hot", figure=None, show=True, with_forces=True, **kwargs):
def mvplot_trajectories(self, colormap="hot", sampling=1, figure=None, show=True, with_forces=True, **kwargs):
"""
Call mayavi to plot atomic trajectories and variation of the unit cell.
Call mayavi to plot atomic trajectories and the variation of the unit cell.
"""
from abipy.display import mvtk
figure, mlab = mvtk.get_fig_mlab(figure=figure)
style = "labels"
line_width = 2
line_width = 100
mvtk.plot_structure(self.initial_structure, style=style, unit_cell_color=(1, 0, 0), figure=figure)
mvtk.plot_structure(self.final_structure, style=style, unit_cell_color=(0, 0, 0), figure=figure)
steps = np.arange(start=0, stop=self.num_steps, step=sampling)
xcart_list = self.reader.read_value("xcart") * units.bohr_to_ang
t = np.arange(self.num_steps)
for iatom in range(self.reader.natom):
x, y, z = xcart_list[:, iatom, :].T
trajectory = mlab.plot3d(x, y, z, t, colormap=colormap, tube_radius=None,
x, y, z = xcart_list[::sampling, iatom, :].T
#for i in zip(x, y, z): print(i)
trajectory = mlab.plot3d(x, y, z, steps, colormap=colormap, tube_radius=None,
line_width=line_width, figure=figure)
mlab.colorbar(trajectory, title='Iteration', orientation='vertical')
if with_forces:
fcart_list = self.reader.read_cart_forces(unit="eV ang^-1")
for iatom in range(self.reader.natom):
x, y, z = xcart_list[:, iatom, :].T
u, v, w = fcart_list[:, iatom, :].T
x, y, z = xcart_list[::sampling, iatom, :].T
u, v, w = fcart_list[::sampling, iatom, :].T
q = mlab.quiver3d(x, y, z, u, v, w, figure=figure, colormap=colormap,
line_width=line_width, scale_factor=10)
#mlab.colorbar(q, title='Forces [eV/Ang]', orientation='vertical')
@ -232,12 +234,12 @@ class HistFile(AbinitNcFile, NotebookWriter):
if show: mlab.show()
return figure
def mvanimate(self, to_unit_cell=False):
def mvanimate(self, delay=500):
from abipy.display import mvtk
figure, mlab = mvtk.get_fig_mlab(figure=None)
style = "points"
#mvtk.plot_structure(self.initial_structure, to_unit_cell=to_unit_cell, style=style, figure=figure)
#mvtk.plot_structure(self.final_structure, to_unit_cell=to_unit_cell, style=style, figure=figure)
#mvtk.plot_structure(self.initial_structure, style=style, figure=figure)
#mvtk.plot_structure(self.final_structure, style=style, figure=figure)
xcart_list = self.reader.read_value("xcart") * units.bohr_to_ang
#t = np.arange(self.num_steps)
@ -255,14 +257,14 @@ class HistFile(AbinitNcFile, NotebookWriter):
#nodes.mlab_source.dataset.point_data.scalars = np.random.random((5000,))
@mlab.show
@mlab.animate(delay=1000, ui=True)
@mlab.animate(delay=delay, ui=True)
def anim():
"""Animate."""
for it, structure in enumerate(self.structures):
#for it in range(self.num_steps):
print('Updating scene for iteration:', it)
#mlab.clf(figure=figure)
mvtk.plot_structure(structure, to_unit_cell=to_unit_cell, style=style, figure=figure)
mvtk.plot_structure(structure, style=style, figure=figure)
#x, y, z = xcart_list[it, :, :].T
#nodes.mlab_source.set(x=x, y=y, z=z)
#figure.scene.render()
@ -270,7 +272,6 @@ class HistFile(AbinitNcFile, NotebookWriter):
yield
anim()
#mlab.close(figure)
def write_notebook(self, nbpath=None):
"""

View File

@ -46,6 +46,9 @@ def abimovie_hist(options):
for path in options.paths:
with abilab.abiopen(path) as hist:
print(hist)
if options.trajectories:
hist.mvplot_trajectories()
else:
hist.mvanimate()
return 0
@ -153,6 +156,7 @@ Use `-v` to increase verbosity level (can be supplied multiple times e.g -vv).
# Subparser for hist command.
p_hist = subparsers.add_parser('hist', parents=[copts_parser], help=abimovie_hist.__doc__)
p_hist.add_argument("-t", "--trajectories", default=False, action="store_true", help="Plot trajectories.")
# Subparser for ebands command.
p_ebands = subparsers.add_parser('ebands', parents=[copts_parser], help=abimovie_ebands.__doc__)