mirror of https://github.com/abinit/abipy.git
Merge pull request #186 from henriquemiranda/workflow_refactoring
Workflow refactoring
This commit is contained in:
commit
64d28941aa
|
@ -1,2 +0,0 @@
|
|||
""" This module provides class and methods to interface Abipy with Boltztrap calculations"""
|
||||
from .boltztrap import *
|
|
@ -1,877 +0,0 @@
|
|||
# coding: utf-8
|
||||
"""
|
||||
This module containes a Bolztrap2 class to interpolate and analyse the results
|
||||
It also provides interfaces with Abipy objects allowing to
|
||||
initialize the Boltztrap2 calculation from Abinit files
|
||||
|
||||
Warning:
|
||||
|
||||
Work in progress
|
||||
"""
|
||||
import pickle
|
||||
import numpy as np
|
||||
import abipy.core.abinit_units as abu
|
||||
from monty.string import marquee
|
||||
from monty.termcolor import cprint
|
||||
from monty.dev import deprecated
|
||||
from abipy.tools.plotting import add_fig_kwargs
|
||||
from abipy.tools import duck
|
||||
from abipy.electrons.ebands import ElectronBands
|
||||
from abipy.core.kpoints import Kpath
|
||||
from abipy.core.structure import Structure
|
||||
from abipy.tools.plotting import add_fig_kwargs, get_ax_fig_plt, get_axarray_fig_plt #, set_axlims, set_visible, set_ax_xylabels
|
||||
from abipy.tools.decorators import timeit
|
||||
|
||||
class AbipyBoltztrap():
|
||||
"""
|
||||
Wrapper to Boltztrap2 interpolator
|
||||
This class contains the same quantities as the Loader classes from dft.py in Boltztrap2
|
||||
Additionally it has methods to call the Boltztrap2 interpolator.
|
||||
It creates multiple instances of BolztrapResult storing the results of the interpolation
|
||||
Enter with quantities in the IBZ and interpolate to a fine BZ mesh
|
||||
"""
|
||||
def __init__(self,fermi,structure,nelect,kpoints,eig,volume,linewidths=None,tmesh=None,
|
||||
mommat=None,magmom=None,lpratio=5):
|
||||
#data needed by boltztrap
|
||||
self.fermi = fermi
|
||||
self.atoms = structure.to_ase_atoms()
|
||||
self.nelect = nelect
|
||||
self.kpoints = np.array(kpoints)
|
||||
self.volume = volume
|
||||
self.mommat = mommat
|
||||
self.magmom = magmom
|
||||
|
||||
#additional parameters
|
||||
self.eig = eig
|
||||
self.structure = structure
|
||||
self.linewidths = linewidths
|
||||
self.tmesh = tmesh
|
||||
self.lpratio = lpratio
|
||||
|
||||
@property
|
||||
def nkpoints(self):
|
||||
return len(self.kpoints)
|
||||
|
||||
@property
|
||||
def equivalences(self):
|
||||
if not hasattr(self,'_equivalences'):
|
||||
self.compute_equivalences()
|
||||
return self._equivalences
|
||||
|
||||
@property
|
||||
def coefficients(self):
|
||||
if not hasattr(self,'_coefficients'):
|
||||
self.compute_coefficients()
|
||||
return self._coefficients
|
||||
|
||||
@property
|
||||
def linewidth_coefficients(self):
|
||||
if not hasattr(self,'_linewidth_coefficients'):
|
||||
self.compute_coefficients()
|
||||
return self._linewidth_coefficients
|
||||
|
||||
@property
|
||||
def linewidth_coefficients(self):
|
||||
if not hasattr(self,'_linewidth_coefficients'):
|
||||
self.compute_coefficients()
|
||||
return self._linewidth_coefficients
|
||||
|
||||
@property
|
||||
def rmesh(self):
|
||||
if not hasattr(self,'_rmesh'):
|
||||
self.get_interpolation_mesh()
|
||||
return self._rmesh
|
||||
|
||||
@property
|
||||
def nequivalences(self):
|
||||
return len(self.equivalences)
|
||||
|
||||
@property
|
||||
def ncoefficients(self):
|
||||
return len(self.coefficients)
|
||||
|
||||
@property
|
||||
def ntemps(self):
|
||||
return len(self.linewidths)
|
||||
|
||||
def pickle(self,filename):
|
||||
with open(filename,'wb') as f:
|
||||
pickle.dump(self,f)
|
||||
|
||||
@classmethod
|
||||
def from_pickle(cls,filename):
|
||||
with open(filename,'rb') as f:
|
||||
cls = pickle.load(f)
|
||||
return cls
|
||||
|
||||
@classmethod
|
||||
def from_ebands(cls):
|
||||
"""Initialize from an ebands object"""
|
||||
raise NotImplementedError('TODO')
|
||||
|
||||
@classmethod
|
||||
def from_evk(cls):
|
||||
"""Intialize from a EVK file"""
|
||||
raise NotImplementedError('TODO')
|
||||
|
||||
@classmethod
|
||||
def from_dftdata(cls,dftdata,tmesh,lpratio=5):
|
||||
"""
|
||||
Initialize an instance of this class from a DFTData instance from Boltztrap
|
||||
|
||||
Args:
|
||||
dftdata: DFTData
|
||||
tmesh: a list of temperatures to use in the fermi integrations
|
||||
lpratio: ratio to multiply by the number of k-points in the IBZ and give the
|
||||
number of real space points inside a sphere
|
||||
"""
|
||||
structure = Structure.from_ase_atoms(dftdata.atoms)
|
||||
return cls(dftdata.fermi,structure,dftdata.nelect,dftdata.kpoints,dftdata.ebands,
|
||||
dftdata.get_volume(),linewidths=None,tmesh=tmesh,
|
||||
mommat=dftdata.mommat,magmom=None,lpratio=lpratio)
|
||||
|
||||
@classmethod
|
||||
def from_sigeph(cls, sigeph, itemp_list=None, bstart=None, bstop=None, lpratio=5):
|
||||
"""
|
||||
Initialize interpolation of the bands and lifetimes from a SigEphFile object
|
||||
|
||||
Args:
|
||||
sigeph: |SigEphFile| instance
|
||||
itemp_list: list of the temperature indexes to consider
|
||||
bstart, bstop: only consider bands between bstart and bstop
|
||||
lpratio: ratio to multiply by the number of k-points in the IBZ and give the
|
||||
number of real space points inside a sphere
|
||||
"""
|
||||
#get the lifetimes as an array
|
||||
qpes = sigeph.get_qp_array(mode='ks+lifetimes')
|
||||
|
||||
#get other dimensions
|
||||
if bstart is None: bstart = sigeph.reader.max_bstart
|
||||
if bstop is None: bstop = sigeph.reader.min_bstop
|
||||
fermi = sigeph.ebands.fermie*abu.eV_Ha
|
||||
structure = sigeph.ebands.structure
|
||||
volume = sigeph.ebands.structure.volume*abu.Ang_Bohr**3
|
||||
nelect = sigeph.ebands.nelect
|
||||
kpoints = [k.frac_coords for k in sigeph.sigma_kpoints]
|
||||
|
||||
if sigeph.nsppol == 2:
|
||||
raise NotImplementedError("nsppol 2 not implemented")
|
||||
|
||||
#TODO handle spin
|
||||
eig = qpes[0,:,bstart:bstop,0].real.T*abu.eV_Ha
|
||||
|
||||
itemp_list = list(range(sigeph.ntemp)) if itemp_list is None else duck.list_ints(itemp_list)
|
||||
linewidths = []
|
||||
tmesh = []
|
||||
for itemp in itemp_list:
|
||||
tmesh.append(sigeph.tmesh[itemp])
|
||||
fermi = sigeph.mu_e[itemp]*abu.eV_Ha
|
||||
#TODO handle spin
|
||||
linewidth = qpes[0, :, bstart:bstop, itemp].imag.T*abu.eV_Ha
|
||||
linewidths.append(linewidth)
|
||||
|
||||
return cls(fermi, structure, nelect, kpoints, eig, volume, linewidths=linewidths,
|
||||
tmesh=tmesh, lpratio=lpratio)
|
||||
|
||||
def get_lattvec(self):
|
||||
"""this method is required by Bolztrap"""
|
||||
return self.lattvec
|
||||
|
||||
@property
|
||||
def nbands(self):
|
||||
nbands, rpoints = self.coefficients.shape
|
||||
return nbands
|
||||
|
||||
@property
|
||||
def lattvec(self):
|
||||
if not hasattr(self,"_lattvec"):
|
||||
self._lattvec = self.atoms.get_cell().T / abu.Bohr_Ang
|
||||
return self._lattvec
|
||||
|
||||
def get_ebands(self,kpath=None,line_density=20,vertices_names=None,linewidth_itemp=False):
|
||||
"""
|
||||
Compute the band-structure using the computed coefficients
|
||||
|
||||
Args:
|
||||
kpath: |Kpath| instance where to interpolate the eigenvalues and linewidths
|
||||
line_density: Number of points used to sample the smallest segment of the path
|
||||
vertices_names: List of tuple, each tuple is of the form (kfrac_coords, kname) where
|
||||
kfrac_coords are the reduced coordinates of the k-point and kname is a string with the name of
|
||||
the k-point. Each point represents a vertex of the k-path.
|
||||
linewith_itemp: list of indexes refering to the temperatures where the linewidth will be interpolated
|
||||
"""
|
||||
from BoltzTraP2 import fite
|
||||
|
||||
if kpath is None:
|
||||
if vertices_names is None:
|
||||
vertices_names = [(k.frac_coords, k.name) for k in self.structure.hsym_kpoints]
|
||||
|
||||
kpath = Kpath.from_vertices_and_names(self.structure, vertices_names, line_density=line_density)
|
||||
|
||||
#call boltztrap to interpolate
|
||||
coeffs = self.coefficients
|
||||
eigens_kpath, vvband = fite.getBands(kpath.frac_coords, self.equivalences, self.lattvec, coeffs)
|
||||
|
||||
linewidths_kpath = None
|
||||
if linewidth_itemp is not False:
|
||||
coeffs = self.linewidth_coefficients[linewidth_itemp]
|
||||
linewidths_kpath, vvband = fite.getBands(kpath.frac_coords, self.equivalences, self.lattvec, coeffs)
|
||||
linewidths_kpath = linewidths_kpath.T[np.newaxis,:,:]*abu.Ha_eV
|
||||
|
||||
#convert units and shape
|
||||
eigens_kpath = eigens_kpath.T[np.newaxis,:,:]*abu.Ha_eV
|
||||
occfacts_kpath = np.zeros_like(eigens_kpath)
|
||||
nspinor1 = 1
|
||||
nspden1 = 1
|
||||
|
||||
#return a ebands object
|
||||
return ElectronBands(self.structure, kpath, eigens_kpath, self.fermi*abu.Ha_eV, occfacts_kpath,
|
||||
self.nelect, nspinor1, nspden1, linewidths=linewidths_kpath)
|
||||
|
||||
@deprecated(message="get_bands is deprecated, use get_ebands")
|
||||
def get_bands(self, **kwargs):
|
||||
return self.get_ebands(**kwargs)
|
||||
|
||||
def get_interpolation_mesh(self):
|
||||
"""From the array of equivalences determine the mesh that was used"""
|
||||
max1, max2, max3 = 0,0,0
|
||||
for equiv in self.equivalences:
|
||||
max1 = max(np.max(equiv[:,0]),max1)
|
||||
max2 = max(np.max(equiv[:,1]),max2)
|
||||
max3 = max(np.max(equiv[:,2]),max3)
|
||||
self._rmesh = (2*max1+1,2*max2+1,2*max3+1)
|
||||
return self._rmesh
|
||||
|
||||
def dump_rsphere(self,filename):
|
||||
""" Write a file with the real space points"""
|
||||
with open(filename, 'wt') as f:
|
||||
for iband in range(self.nbands):
|
||||
for ie,equivalence in enumerate(self.equivalences):
|
||||
coeff = self.coefficients[iband,ie]
|
||||
for ip,point in enumerate(equivalence):
|
||||
f.write("%5d %5d %5d "%tuple(point)+"%lf\n"%((abs(coeff))**(1./3)))
|
||||
f.write("\n\n")
|
||||
|
||||
@timeit
|
||||
def compute_equivalences(self):
|
||||
"""Compute equivalent k-points"""
|
||||
from BoltzTraP2 import sphere
|
||||
try:
|
||||
self._equivalences = sphere.get_equivalences(self.atoms, self.magmom, self.lpratio*self.nkpoints)
|
||||
except TypeError:
|
||||
self._equivalences = sphere.get_equivalences(self.atoms, self.lpratio*self.nkpoints)
|
||||
|
||||
@timeit
|
||||
def compute_coefficients(self):
|
||||
"""Call fitde3D routine from Boltztrap2"""
|
||||
from BoltzTraP2 import fite
|
||||
#we will set ebands to compute teh coefficients
|
||||
self.ebands = self.eig
|
||||
self._coefficients = fite.fitde3D(self, self.equivalences)
|
||||
|
||||
if self.linewidths:
|
||||
self._linewidth_coefficients = []
|
||||
for itemp in range(self.ntemps):
|
||||
self.ebands = self.linewidths[itemp]
|
||||
coeffs = fite.fitde3D(self, self.equivalences)
|
||||
self._linewidth_coefficients.append(coeffs)
|
||||
|
||||
#at the end we always unset ebands
|
||||
delattr(self,"ebands")
|
||||
|
||||
@timeit
|
||||
def run(self,npts=500,dos_method='gaussian:0.05 eV',erange=None,margin=0.1,nworkers=1,verbose=0):
|
||||
"""
|
||||
Interpolate the eingenvalues to compute dos and vvdos
|
||||
This part is quite memory intensive
|
||||
|
||||
Args:
|
||||
npts: number of frequency points
|
||||
dos_method: when using a patched version of Boltztrap
|
||||
"""
|
||||
boltztrap_results = []; app = boltztrap_results.append
|
||||
import inspect
|
||||
from BoltzTraP2 import fite
|
||||
import BoltzTraP2.bandlib as BL
|
||||
|
||||
def BTPDOS(eband,vvband,cband=None,erange=None,npts=None,scattering_model="uniform_tau",mode=dos_method):
|
||||
"""
|
||||
This is a small wrapper for Boltztrap2 to use the official version or a modified
|
||||
verison using gaussian or lorentzian smearing
|
||||
"""
|
||||
try:
|
||||
return BL.BTPDOS(eband, vvband, erange=erange, npts=npts, scattering_model=scattering_model, mode=dos_method)
|
||||
except TypeError:
|
||||
return BL.BTPDOS(eband, vvband, erange=erange, npts=npts, scattering_model=scattering_model)
|
||||
|
||||
|
||||
#TODO change this!
|
||||
if erange is None: erange = (np.min(self.eig),np.max(self.eig))
|
||||
else: erange = np.array(erange)/abu.Ha_eV+self.fermi
|
||||
|
||||
#interpolate the electronic structure
|
||||
if verbose: print('interpolating bands')
|
||||
results = fite.getBTPbands(self.equivalences, self.coefficients,
|
||||
self.lattvec, nworkers=nworkers)
|
||||
eig_fine, vvband, cband = results
|
||||
|
||||
#calculate DOS and VDOS without lifetimes
|
||||
if verbose: print('calculating dos and vvdos without lifetimes')
|
||||
wmesh,dos,vvdos,_ = BTPDOS(eig_fine, vvband, erange=erange, npts=npts, mode=dos_method)
|
||||
app(BoltztrapResult(self,wmesh,dos,vvdos,self.fermi,self.tmesh,self.volume,margin=margin))
|
||||
|
||||
#if we have linewidths
|
||||
if self.linewidths:
|
||||
for itemp in range(self.ntemps):
|
||||
if verbose: print('itemp %d\ninterpolating bands')
|
||||
#calculate the lifetimes on the fine grid
|
||||
results = fite.getBTPbands(self.equivalences, self._linewidth_coefficients[itemp],
|
||||
self.lattvec, nworkers=nworkers)
|
||||
linewidth_fine, vvband, cband = results
|
||||
tau_fine = 1.0/np.abs(2*linewidth_fine*abu.eV_s)
|
||||
|
||||
#calculate vvdos with the lifetimes
|
||||
if verbose: print('calculating dos and vvdos with lifetimes')
|
||||
wmesh, dos_tau, vvdos_tau, _ = BTPDOS(eig_fine, vvband, erange=erange, npts=npts,
|
||||
scattering_model=tau_fine, mode=dos_method)
|
||||
#store results
|
||||
app(BoltztrapResult(self,wmesh,dos_tau,vvdos_tau,self.fermi,self.tmesh,
|
||||
self.volume,tau_temp=self.tmesh[itemp],margin=margin))
|
||||
|
||||
return BoltztrapResultRobot(boltztrap_results)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
def to_string(self, verbose=2):
|
||||
lines = []; app = lines.append
|
||||
app(marquee(self.__class__.__name__,mark="="))
|
||||
app("equivalent points: {}".format(self.nequivalences))
|
||||
app("real space mesh: {}".format(self.rmesh))
|
||||
app("lpratio: {}".format(self.lpratio))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class BoltztrapResult():
|
||||
"""
|
||||
Container for BoltztraP2 results
|
||||
Provides a object oriented interface to BoltztraP2
|
||||
for plotting, storing and analysing the results
|
||||
"""
|
||||
_attrs = ['_L0','_L1','_L2','_sigma','_seebeck','_kappa']
|
||||
|
||||
def __init__(self,abipyboltztrap,wmesh,dos,vvdos,fermi,tmesh,volume,tau_temp=None,margin=0.1):
|
||||
self.abipyboltztrap = abipyboltztrap
|
||||
|
||||
self.fermi = fermi
|
||||
self.volume = volume
|
||||
self.wmesh = np.array(wmesh)
|
||||
idx_margin = int(margin*len(wmesh))
|
||||
self.mumesh = self.wmesh[idx_margin:-(idx_margin+1)]
|
||||
self.tmesh = np.array(tmesh)
|
||||
|
||||
#Temperature fix
|
||||
if any(self.tmesh < 1):
|
||||
cprint("Boltztrap does not handle 0K well.\n"
|
||||
"I avoid potential problems by setting all T<1K to T=1K",color="yellow")
|
||||
self.tmesh[self.tmesh < 1] = 1
|
||||
|
||||
self.tau_temp = tau_temp
|
||||
|
||||
self.dos = dos
|
||||
self.vvdos = vvdos
|
||||
|
||||
@property
|
||||
def has_tau(self):
|
||||
return self.tau_temp is not None
|
||||
|
||||
@property
|
||||
def ntemp(self):
|
||||
return len(self.tmesh)
|
||||
|
||||
@property
|
||||
def L0(self):
|
||||
if not hasattr(self,'_L0'):
|
||||
self.compute_fermiintegrals()
|
||||
return self._L0
|
||||
|
||||
@property
|
||||
def L1(self):
|
||||
if not hasattr(self,'_L1'):
|
||||
self.compute_fermiintegrals()
|
||||
return self._L1
|
||||
|
||||
@property
|
||||
def L2(self):
|
||||
if not hasattr(self,'_L2'):
|
||||
self.compute_fermiintegrals()
|
||||
return self._L2
|
||||
|
||||
@property
|
||||
def sigma(self):
|
||||
if not hasattr(self,'_sigma'):
|
||||
self.compute_onsager_coefficients()
|
||||
return self._sigma
|
||||
|
||||
@property
|
||||
def seebeck(self):
|
||||
if not hasattr(self,'_seebeck'):
|
||||
self.compute_onsager_coefficients()
|
||||
return self._seebeck
|
||||
|
||||
@property
|
||||
def powerfactor(self):
|
||||
return self.sigma * self.seebeck**2
|
||||
|
||||
@property
|
||||
def kappa(self):
|
||||
if not hasattr(self,'_kappa'):
|
||||
self.compute_onsager_coefficients()
|
||||
return self._kappa
|
||||
|
||||
def set_tmesh(self,tmesh):
|
||||
""" Set the temperature mesh"""
|
||||
self.tmesh = tmesh
|
||||
|
||||
def set_tmesh(self,tmesh):
|
||||
""" Set the temperature mesh"""
|
||||
self.tmesh = tmesh
|
||||
|
||||
def del_attrs(self):
|
||||
""" Remove all the atributes so they are recomputed """
|
||||
for attr in self._attrs:
|
||||
delattr(attr)
|
||||
|
||||
def set_mumesh(self,emin,emax):
|
||||
"""
|
||||
Set the range in which to plot the change of the doping
|
||||
|
||||
Args:
|
||||
emin: minimun energy in eV
|
||||
emax: maximum energy in eV
|
||||
"""
|
||||
start_idx = np.abs(self.wmesh - emin*abu.eV_Ha - self.fermi).argmin()
|
||||
stop_idx = np.abs(self.wmesh - emax*abu.eV_Ha - self.fermi).argmin()
|
||||
self.mumesh = self.wmesh[start_idx:stop_idx]
|
||||
|
||||
def compute_fermiintegrals(self):
|
||||
"""Compute and store the results of the Fermi integrals"""
|
||||
import BoltzTraP2.bandlib as BL
|
||||
results = BL.fermiintegrals(self.wmesh, self.dos, self.vvdos, mur=self.mumesh, Tr=self.tmesh)
|
||||
_, self._L0, self._L1, self._L2, self._Lm11 = results
|
||||
|
||||
def compute_onsager_coefficients(self):
|
||||
"""Compute Onsager coefficients"""
|
||||
import BoltzTraP2.bandlib as BL
|
||||
L0,L1,L2 = self.L0,self.L1,self.L2
|
||||
results = BL.calc_Onsager_coefficients(L0,L1,L2,mur=self.mumesh,Tr=self.tmesh,vuc=self.volume)
|
||||
self._sigma, self._seebeck, self._kappa, self._hall = results
|
||||
|
||||
@staticmethod
|
||||
def from_pickle(filename):
|
||||
"""Load BoltztrapResult from a pickle file"""
|
||||
with open(filename,'rb') as f:
|
||||
instance = pickle.load(f)
|
||||
return instance
|
||||
|
||||
def pickle(self,filename):
|
||||
"""Write a file with the results from the calculation"""
|
||||
with open(filename,'wb') as f:
|
||||
pickle.dump(self,f)
|
||||
|
||||
def istensor(self,what):
|
||||
"""Check if a certain quantity is a tensor"""
|
||||
if not hasattr(self,what): return None
|
||||
return len(getattr(self,what).shape) > 2
|
||||
|
||||
def get_component(self,what,component,itemp):
|
||||
i,j = abu.s2itup(component)
|
||||
return getattr(self,what)[itemp,:,i,j]
|
||||
|
||||
def plot_dos_ax(self,ax,fontsize=8,**kwargs):
|
||||
"""
|
||||
Plot the density of states on axis ax.
|
||||
|
||||
Args:
|
||||
ax: |matplotlib-Axes|.
|
||||
kwargs: Passed to ax.plot
|
||||
"""
|
||||
wmesh = (self.wmesh-self.fermi) * abu.Ha_eV
|
||||
ax.plot(wmesh,self.dos,label=self.get_letter('dos'),**kwargs)
|
||||
ax.set_xlabel('Energy (eV)',fontsize=fontsize)
|
||||
|
||||
def plot_vvdos_ax(self,ax,components=('xx',),fontsize=8,**kwargs):
|
||||
"""
|
||||
Plot components of vvdos on the axis ax.
|
||||
|
||||
Args:
|
||||
ax: |matplotlib-Axes|.
|
||||
components: Choose the components of the tensor to plot ['xx','xy','xz','yy',(...)]
|
||||
kwargs: Passed to ax.plot
|
||||
"""
|
||||
wmesh = (self.wmesh-self.fermi) * abu.Ha_eV
|
||||
|
||||
for component in components:
|
||||
i,j = abu.s2itup(component)
|
||||
label = "%s $_{%s}$" % (self.get_letter('vvdos'),component)
|
||||
if self.tau_temp: label += r" $\tau_T$ = %dK" % self.tau_temp
|
||||
ax.plot(wmesh,self.vvdos[i,j,:],label=label,**kwargs)
|
||||
ax.set_xlabel('Energy (eV)',fontsize=fontsize)
|
||||
|
||||
def plot_ax(self, ax, what, components=('xx',), itemp_list=None, fontsize=8, **kwargs):
|
||||
"""
|
||||
Plot a quantity for all the dopings as a function of temperature on the axis ax.
|
||||
|
||||
Args:
|
||||
ax: |matplotlib-Axes|.
|
||||
what: choose the quantity to plot can be: ['sigma','kappa','powerfactor']
|
||||
components: Choose the components of the tensor to plot ['xx','xy','xz','yy',(...)]
|
||||
itemp_list: list of indexes of the tempratures to plot
|
||||
colormap: Colormap used to plot the results
|
||||
kwargs: Passed to ax.plot
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
colormap = kwargs.pop('colormap','plasma')
|
||||
cmap = plt.get_cmap(colormap)
|
||||
color = None
|
||||
|
||||
itemp_list = list(range(self.ntemp)) if itemp_list is None else duck.list_ints(itemp_list)
|
||||
maxitemp = max(itemp_list)
|
||||
minitemp = min(itemp_list)
|
||||
if maxitemp > self.ntemp or minitemp < 0:
|
||||
raise ValueError('Invalid itemp_list, should be between 0 and %d. Got %d.'%(self.ntemp,maxitemp))
|
||||
|
||||
mumesh = (self.mumesh-self.fermi) * abu.Ha_eV
|
||||
|
||||
if self.istensor(what):
|
||||
kwargs.pop('c',None)
|
||||
for itemp in itemp_list:
|
||||
for component in components:
|
||||
y = self.get_component(what,component,itemp)
|
||||
if len(itemp_list) > 1: color=cmap(itemp/len(itemp_list))
|
||||
label = "%s $_{%s}$ $b_T$ = %dK" % (self.get_letter(what),component,self.tmesh[itemp])
|
||||
if self.has_tau: label += r" $\tau_T$ = %dK" % self.tau_temp
|
||||
ax.plot(mumesh,y,label=label,c=color,**kwargs)
|
||||
else:
|
||||
ax.plot(mumesh,getattr(self,what), label=what, **kwargs)
|
||||
|
||||
ax.set_ylabel(self.get_ylabel(what), fontsize=fontsize)
|
||||
ax.set_xlabel('Energy (eV)', fontsize=fontsize)
|
||||
|
||||
def get_ylabel(self,what):
|
||||
"""
|
||||
Get a label with units for the quntities stores in this object.
|
||||
"""
|
||||
if self.has_tau: tau = ''
|
||||
else: tau = 's^{-1}'
|
||||
if what == 'sigma': return r'$\sigma$ [$Sm^{-1}%s$]'%tau
|
||||
if what == 'seebeck': return r'$S$ [$VSm^{-1}%s$]'%tau
|
||||
if what == 'kappa': return r'$\kappa_e$ [$VJSm^{-1}%s$]'%tau
|
||||
if what == 'powerfactor': return r'$S^2\sigma$ [$VJSm^{-1}%s$]'%tau
|
||||
return ''
|
||||
|
||||
def get_letter(self,what):
|
||||
letters = {'sigma': r'$\sigma$',
|
||||
'seebeck': r'$S$',
|
||||
'kappa': r'$\kappa_e$',
|
||||
'powerfactor':r'$S^2\sigma$',
|
||||
'vvdos': r'$v\otimes v$',
|
||||
'dos': r'$n(\epsilon)$'}
|
||||
return letters[what]
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot(self, what, colormap='plasma', directions=('xx'), ax=None, fontsize=8, **kwargs):
|
||||
"""
|
||||
Plot the qantity for all the temperatures as a function of the doping
|
||||
"""
|
||||
ax, fig, plt = get_ax_fig_plt(ax=ax)
|
||||
self.plot_ax(ax, what, colormap=colormap, directions=directions, **kwargs)
|
||||
ax.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
|
||||
return fig
|
||||
|
||||
def to_string(self, title=None, mark="=", verbose=0):
|
||||
"""
|
||||
String representation of the class
|
||||
"""
|
||||
lines = []; app = lines.append
|
||||
if title is None: app(marquee(self.__class__.__name__,mark=mark))
|
||||
app("fermi: %8.5lf eV"%(self.fermi*abu.Ha_eV))
|
||||
app("mumesh: %8.5lf <-> %8.5lf eV"%(self.mumesh[0]*abu.Ha_eV,self.mumesh[-1]*abu.Ha_eV))
|
||||
app("tmesh: %s K"%self.tmesh)
|
||||
app("has_tau: %s"%self.has_tau)
|
||||
if self.tau_temp: app("tau_temp: %.1lf K"%self.tau_temp)
|
||||
return "\n".join(lines)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
||||
|
||||
class BoltztrapResultRobot():
|
||||
"""
|
||||
Robot to analyse multiple Boltztrap calculations
|
||||
Behaves as a list of BoltztrapResult
|
||||
Provides methods to plot multiple results on a single figure
|
||||
"""
|
||||
def __init__(self,results,erange=None):
|
||||
if not all([isinstance(r,BoltztrapResult) for r in results]):
|
||||
raise ValueError('Must provide BolztrapResult instances.')
|
||||
|
||||
#consistency check in the temperature meshes
|
||||
res0 = results[0]
|
||||
if np.any([res0.tmesh != res.tmesh for res in results]):
|
||||
cprint("Comparing BoltztrapResults with different temperature meshes.", color="yellow")
|
||||
|
||||
#consistency check in chemical potential meshes
|
||||
if np.any([res0.wmesh != res.wmesh for res in results]):
|
||||
cprint("Comparing BoltztrapResults with different energy meshes.", color="yellow")
|
||||
|
||||
#store the results
|
||||
self.results = results
|
||||
self.erange = erange
|
||||
|
||||
if not all([np.allclose(results[0].mumesh,result.mumesh) for result in results[1:]]):
|
||||
raise ValueError('The doping meshes of the results differ, cannot continue')
|
||||
self.mumesh = results[0].mumesh
|
||||
|
||||
if not all([np.allclose(results[0].tmesh,result.tmesh) for result in results[1:]]):
|
||||
raise ValueError('The temperature meshes of the results differ, cannot continue')
|
||||
self.tmesh = results[0].tmesh
|
||||
|
||||
def __getitem__(self,index):
|
||||
"""Access the results stored in the class as a list"""
|
||||
return self.results[index]
|
||||
|
||||
@property
|
||||
def ntemp(self):
|
||||
return len(self.tmesh)
|
||||
|
||||
@property
|
||||
def tau_list(self):
|
||||
"""Get all the results with tau included"""
|
||||
return [ res.tau_temp for res in self.results if res.tau_temp is not None ]
|
||||
|
||||
@property
|
||||
def notau_results(self):
|
||||
"""Get all the results without the tau included"""
|
||||
instance = self.__class__([ res for res in self.results if res.tau_temp is None ])
|
||||
if self.erange: instance.erange = self.erange
|
||||
return instance
|
||||
|
||||
@property
|
||||
def tau_results(self):
|
||||
"""Return all the results that have temperature dependence"""
|
||||
instance = self.__class__([ res for res in self.results if res.tau_temp ])
|
||||
if self.erange: instance.erange = self.erange
|
||||
return instance
|
||||
|
||||
@property
|
||||
def nresults(self):
|
||||
return len(self.results)
|
||||
|
||||
@staticmethod
|
||||
def from_pickle(filename):
|
||||
"""
|
||||
Load results from file
|
||||
"""
|
||||
with open(filename,'rb') as f:
|
||||
instance = pickle.load(f)
|
||||
return instance
|
||||
|
||||
def pickle(self,filename):
|
||||
"""
|
||||
Write a file with the results from the calculation
|
||||
"""
|
||||
with open(filename,'wb') as f:
|
||||
pickle.dump(self,f)
|
||||
|
||||
def plot_vvdos_ax(self,ax,legend=True,components=('xx',),itau_list=None,fontsize=8,erange=None,**kwargs):
|
||||
"""
|
||||
Plot the vvdos for all the results in the robot
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
colormap = kwargs.pop('colormap','plasma')
|
||||
cmap = plt.get_cmap(colormap)
|
||||
|
||||
#set erange
|
||||
erange = erange or self.erange
|
||||
if erange is not None: ax.set_xlim(erange)
|
||||
|
||||
if itau_list:
|
||||
#filter results by temperature
|
||||
tau_list = [self.tmesh[itau] for itau in itau_list]
|
||||
filtered_results = sorted([res for res in self.results if res.tau_temp in tau_list],key=lambda x: x.tau_temp)
|
||||
|
||||
for itemp,result in enumerate(filtered_results):
|
||||
color = kwargs.pop('c',cmap(itemp/len(filtered_results)))
|
||||
result.plot_vvdos_ax(ax,fontsize=fontsize,c=color,components=components,**kwargs)
|
||||
ax.set_ylabel(r'with $\tau$',fontsize=fontsize)
|
||||
if legend: ax.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
else:
|
||||
#results without temperature
|
||||
for result in self.notau_results:
|
||||
result.plot_vvdos_ax(ax,fontsize=fontsize,components=components,**kwargs)
|
||||
ax.set_ylabel(r'without $\tau$',fontsize=fontsize)
|
||||
if legend: ax.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
|
||||
def plot_dos_ax(self, ax1, legend=True, fontsize=8, erange=None, **kwargs):
|
||||
"""
|
||||
Plot the dos for all the results in the robot
|
||||
"""
|
||||
#set erange
|
||||
erange = erange or self.erange
|
||||
if erange is not None: ax1.set_xlim(erange)
|
||||
|
||||
for result in self.results:
|
||||
result.plot_dos_ax(ax1,fontsize=fontsize,**kwargs)
|
||||
if legend: ax1.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
|
||||
def plot_ax(self,ax1,what,components=('xx',),itemp_list=None,itau_list=None,fontsize=8,erange=None,**kwargs):
|
||||
"""
|
||||
Plot the same quantity for all the results on axis ax1
|
||||
|
||||
Args:
|
||||
ax1: |matplotlib-Axes|.
|
||||
what: choose the quantity to plot can be: ['sigma','kappa','powerfactor']
|
||||
itemp_list: list of indexes of the tempratures to plot
|
||||
itau_list: list of indexes of the tempratures at which the lifetimes were computed
|
||||
components: Choose the components of the tensor to plot ['xx','xy','xz','yy',(...)]
|
||||
erange: choose energy range of the plot
|
||||
kwargs: Passed to ax.plot
|
||||
"""
|
||||
from matplotlib import pyplot as plt
|
||||
colormap = kwargs.pop('colormap','plasma')
|
||||
cmap = plt.get_cmap(colormap)
|
||||
|
||||
#set erange
|
||||
erange = erange or self.erange
|
||||
if erange is not None: ax1.set_xlim(erange)
|
||||
|
||||
if itau_list:
|
||||
#filter results by temperature
|
||||
tau_list = self.tmesh if itau_list is None else [self.tmesh[itau] for itau in itau_list]
|
||||
filtered_results = [res for res in self.results if res.tau_temp in tau_list]
|
||||
|
||||
#plot the results
|
||||
for itemp,result in enumerate(filtered_results):
|
||||
color = kwargs.pop('c',cmap(itemp/len(filtered_results)))
|
||||
result.plot_ax(ax1,what,components,itemp_list,fontsize=fontsize,c=color,**kwargs)
|
||||
else:
|
||||
#plot result without tau
|
||||
for result in self.notau_results:
|
||||
result.plot_ax(ax1,what,components,itemp_list,fontsize=fontsize,**kwargs)
|
||||
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_transport(self, itemp_list=None, itau_list=None, components=('xx',),
|
||||
erange=None, ax_array=None, fontsize=8, legend=True, **kwargs):
|
||||
"""
|
||||
Plot the different quantities relevant for transport for all the results in the robot
|
||||
"""
|
||||
ax_array, fig, plt = get_axarray_fig_plt(ax_array,nrows=2,ncols=2)
|
||||
self.plot_ax(ax_array[0,0],'sigma', itemp_list=itemp_list,itau_list=itau_list,fontsize=fontsize,**kwargs)
|
||||
self.plot_ax(ax_array[0,1],'seebeck', itemp_list=itemp_list,itau_list=itau_list,fontsize=fontsize,**kwargs)
|
||||
self.plot_ax(ax_array[1,0],'kappa', itemp_list=itemp_list,itau_list=itau_list,fontsize=fontsize,**kwargs)
|
||||
self.plot_ax(ax_array[1,1],'powerfactor',itemp_list=itemp_list,itau_list=itau_list,fontsize=fontsize,**kwargs)
|
||||
|
||||
if legend:
|
||||
for ax in ax_array.flatten(): ax.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
|
||||
#fig.tight_layout()
|
||||
return fig
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot(self,what,itemp_list=None,itau_list=None,components=('xx',),
|
||||
erange=None,fontsize=8,legend=True,**kwargs):
|
||||
"""
|
||||
Plot all the boltztrap results in the Robot
|
||||
|
||||
Args:
|
||||
what: choose the quantity to plot can be: ['sigma','kappa','powerfactor']
|
||||
itemp_list: list of indexes of the tempratures to plot
|
||||
itau_list: list of indexes of the tempratures at which the lifetimes were computed
|
||||
components: Choose the components of the tensor to plot ['xx','xy','xz','yy',(...)]
|
||||
erange: choose energy range of the plot
|
||||
kwargs: Passed to ax.plot
|
||||
"""
|
||||
ax1, fig, plt = get_ax_fig_plt(ax=None)
|
||||
self.plot_ax(ax1,what,components=components,itemp_list=itemp_list,itau_list=itau_list,
|
||||
fontsize=fontsize,erange=erange,**kwargs)
|
||||
if legend: ax1.legend(loc="best", shadow=True, fontsize=fontsize)
|
||||
return fig
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_dos_vvdos(self,dos_color=None,erange=None,ax_array=None,components=('xx',),fontsize=8,legend=True,**kwargs):
|
||||
"""
|
||||
Plot dos and vvdos on the same figure
|
||||
"""
|
||||
ax_array, fig, plt = get_axarray_fig_plt(ax_array,nrows=3)
|
||||
self.plot_dos_ax(ax_array[0],erange=erange,legend=legend,fontsize=fontsize,**kwargs)
|
||||
self.plot_vvdos_ax(ax_array[1],components=components,erange=erange,fontsize=fontsize,legend=legend)
|
||||
self.plot_vvdos_ax(ax_array[2],itau_list=range(self.ntemp),components=components,erange=erange,
|
||||
fontsize=fontsize,legend=legend)
|
||||
return fig
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_dos(self,ax=None,erange=None,fontsize=8,legend=True,**kwargs):
|
||||
"""
|
||||
Plot dos for the results in the Robot
|
||||
"""
|
||||
ax1, fig, plt = get_ax_fig_plt(ax=ax)
|
||||
self.plot_dos_ax(ax1,erange=erange,legend=legend,fontsize=fontsize,**kwargs)
|
||||
return fig
|
||||
|
||||
@add_fig_kwargs
|
||||
def plot_vvdos(self,ax_array=None,itau_list=None,components=('xx',),erange=None,fontsize=8,legend=True,**kwargs):
|
||||
"""
|
||||
Plot vvdos for all the results in the Robot
|
||||
"""
|
||||
ax_array, fig, plt = get_axarray_fig_plt(ax_array=ax_array,sharex=True,nrows=2)
|
||||
|
||||
self.plot_vvdos_ax(ax_array[0],components=components,erange=erange,fontsize=fontsize,legend=legend)
|
||||
self.plot_vvdos_ax(ax_array[1],itau_list=range(self.ntemp),components=components,erange=erange,
|
||||
fontsize=fontsize,legend=legend)
|
||||
return fig
|
||||
|
||||
def set_erange(self,emin,emax):
|
||||
""" Get an energy range based on an energy margin above and bellow the fermi level"""
|
||||
self.erange = (emin,emax)
|
||||
|
||||
def unset_erange(self):
|
||||
""" Unset the energy range"""
|
||||
self.erange = None
|
||||
|
||||
def to_string(self, verbose=0):
|
||||
"""
|
||||
Return a string representation of the data in this class
|
||||
"""
|
||||
lines = []; app = lines.append
|
||||
app(marquee(self.__class__.__name__,mark="="))
|
||||
app('nresults: %d'%self.nresults)
|
||||
for result in self.results:
|
||||
app(result.to_string(mark='-'))
|
||||
return "\n".join(lines)
|
||||
|
||||
def set_mumesh(self,emin,emax):
|
||||
"""
|
||||
Set the range in which to plot the change of the doping
|
||||
for all the results
|
||||
|
||||
Args:
|
||||
emin: minimun energy in eV
|
||||
emax: maximum energy in eV
|
||||
"""
|
||||
for result in self.results:
|
||||
result.set_mumesh(emin,emax)
|
||||
|
||||
def set_tmesh(self,tmesh):
|
||||
"""
|
||||
Set the temperature mesh of all the results
|
||||
|
||||
Args:
|
||||
tmesh: array with temperatures at which to compute the Fermi integrals
|
||||
"""
|
||||
for result in self.results:
|
||||
result.set_tmesh(tmesh)
|
||||
|
||||
def __str__(self):
|
||||
return self.to_string()
|
|
@ -1,59 +0,0 @@
|
|||
"""Tests for boltztrap module."""
|
||||
from __future__ import print_function, division, unicode_literals, absolute_import
|
||||
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
import abipy.data as abidata
|
||||
|
||||
from abipy.core.testing import AbipyTest
|
||||
from abipy.boltztrap import AbipyBoltztrap, BoltztrapResult
|
||||
from abipy import abilab
|
||||
|
||||
|
||||
class AbipyBoltztrapTest(AbipyTest):
|
||||
|
||||
# TODO: Need new files with IBZ.
|
||||
def test_sigeph_boltztrap(self):
|
||||
"""Test boltztrap interpolation"""
|
||||
self.skip_if_not_bolztrap2()
|
||||
|
||||
with abilab.abiopen(abidata.ref_file("diamond_444q_full_SIGEPH.nc")) as sigeph:
|
||||
bt = AbipyBoltztrap.from_sigeph(sigeph)
|
||||
repr(bt); str(bt)
|
||||
assert bt.to_string(verbose=2)
|
||||
|
||||
# get equivalences
|
||||
assert bt.rmesh == (17, 17, 17)
|
||||
assert bt.nequivalences == 67
|
||||
|
||||
# get coefficients
|
||||
assert bt.ncoefficients == 53
|
||||
bt.dump_rsphere(self.get_tmpname(text=True))
|
||||
|
||||
# get ebands using boltztrap
|
||||
bt_ebands = bt.get_ebands()
|
||||
|
||||
# Get boltztrap results using different DOS methods
|
||||
btr = bt.run(dos_method="histogram")
|
||||
btr = bt.run(npts=500,dos_method="gaussian:0.5 eV")
|
||||
btr = bt.run(npts=500,dos_method="lorentzian:0.5 eV")
|
||||
repr(btr); str(btr)
|
||||
assert btr.to_string(verbose=2)
|
||||
|
||||
# Test pickle
|
||||
pickle_file = self.get_tmpname(suffix="diamond.npy")
|
||||
btr.pickle(pickle_file)
|
||||
same_result = BoltztrapResult.from_pickle(pickle_file)
|
||||
self.assert_equal(btr.tmesh, same_result.tmesh)
|
||||
|
||||
if self.has_matplotlib():
|
||||
# Plot the density of states and VVDOS for multiple temperatures
|
||||
assert btr.plot_dos_vvdos(show=False)
|
||||
|
||||
# Plot transport related quantities for different combinations of
|
||||
# tau temperature and boltztrap temperature
|
||||
assert btr.plot('sigma', itemp_list=None, itau_list=[3], show=False)
|
||||
assert btr.plot('seebeck', itemp_list=[3], itau_list=[1,2], show=False)
|
||||
assert btr.plot('powerfactor', itemp_list=[3], itau_list=None, show=False)
|
||||
assert btr.plot_transport(show=False)
|
|
@ -818,6 +818,11 @@ class PhononBands(object):
|
|||
|
||||
return h
|
||||
|
||||
def reasonable_repetitions(natoms):
|
||||
if (natoms < 4): return (3,3,3)
|
||||
if (4 < natoms < 50): return (2,2,2)
|
||||
if (50 < natoms): return (1,1,1)
|
||||
|
||||
# http://henriquemiranda.github.io/phononwebsite/index.html
|
||||
data = {}
|
||||
data["name"] = name or self.structure.composition.reduced_formula
|
||||
|
@ -826,7 +831,7 @@ class PhononBands(object):
|
|||
data["atom_types"] = [e.name for e in self.structure.species]
|
||||
data["atom_numbers"] = self.structure.atomic_numbers
|
||||
data["formula"] = self.structure.formula.replace(" ", "")
|
||||
data["repetitions"] = repetitions or (3, 3, 3)
|
||||
data["repetitions"] = repetitions or reasonable_repetitions(self.num_atoms)
|
||||
data["atom_pos_car"] = self.structure.cart_coords.tolist()
|
||||
data["atom_pos_red"] = self.structure.frac_coords.tolist()
|
||||
data["chemical_symbols"] = self.structure.symbol_set
|
||||
|
@ -870,7 +875,8 @@ class PhononBands(object):
|
|||
self.split_matched_indices[i][...,None],
|
||||
np.arange(vect.shape[2])[None, None,:]]
|
||||
v = vect.reshape((len(vect), self.num_branches,self.num_atoms, 3))
|
||||
v /= np.linalg.norm(v[0,0,0])
|
||||
norm = [np.linalg.norm(vi) for vi in v[0,0]]
|
||||
v /= max(norm)
|
||||
v = np.stack([v.real, v.imag], axis=-1)
|
||||
|
||||
vectors.extend(v.tolist())
|
||||
|
|
|
@ -1310,7 +1310,7 @@ class MergeDdb(object):
|
|||
|
||||
return ddk_tasks, bec_tasks
|
||||
|
||||
def merge_ddb_files(self, delete_source_ddbs=True, only_dfpt_tasks=True,
|
||||
def merge_ddb_files(self, delete_source_ddbs=False, only_dfpt_tasks=True,
|
||||
exclude_tasks=None, include_tasks=None):
|
||||
"""
|
||||
This method is called when all the q-points have been computed.
|
||||
|
|
Loading…
Reference in New Issue