mirror of https://github.com/QMCPACK/qmcpack.git
1570 lines
48 KiB
Python
1570 lines
48 KiB
Python
|
|
|
|
# Python standard library imports
|
|
import os
|
|
import inspect
|
|
from time import process_time
|
|
from copy import deepcopy
|
|
|
|
# Non-standard Python imports
|
|
import numpy as np
|
|
|
|
from developer import unavailable # Nexus unavailable module guard
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except:
|
|
plt = unavailable('matplotlib','pyplot')
|
|
#end try
|
|
try:
|
|
import h5py
|
|
except:
|
|
h5py = unavailable('h5py')
|
|
#end try
|
|
|
|
# Nexus imports
|
|
import memory
|
|
from unit_converter import convert
|
|
from generic import obj
|
|
from developer import DevBase,log,error,ci
|
|
from numerics import simstats
|
|
from grid_functions import grid_function,read_grid,StructuredGrid,grid as generate_grid
|
|
from grid_functions import SpheroidGrid
|
|
from structure import Structure,read_structure
|
|
from fileio import XsfFile
|
|
|
|
|
|
|
|
|
|
class VLog(DevBase):
|
|
|
|
verbosity_levels = obj(
|
|
none = 0,
|
|
low = 1,
|
|
high = 2,
|
|
)
|
|
|
|
def __init__(self):
|
|
self.tstart = process_time()
|
|
self.tlast = self.tstart
|
|
self.mstart = memory.resident(children=True)
|
|
self.mlast = self.mstart
|
|
self.verbosity = self.verbosity_levels.low
|
|
self.indent = 0
|
|
#end def __init__
|
|
|
|
|
|
def __call__(self,msg,level='low',n=0,time=False,mem=False,width=75):
|
|
if self.verbosity==self.verbosity_levels.none:
|
|
return
|
|
elif self.verbosity >= self.verbosity_levels[level]:
|
|
if mem or time:
|
|
npad = max(0,width-2*(n+self.indent)-len(msg)-36)
|
|
if npad>0:
|
|
msg += npad*' '
|
|
#end if
|
|
if mem:
|
|
dm = 1e6 # MB
|
|
mnow = memory.resident(children=True)
|
|
msg += ' (mem add {:6.2f}, tot {:6.2f})'.format((mnow-self.mlast)/dm,(mnow-self.mstart)/dm)
|
|
self.mlast = mnow
|
|
#end if
|
|
if time:
|
|
tnow = process_time()
|
|
msg += ' (t elap {:7.3f}, tot {:7.3f})'.format(tnow-self.tlast,tnow-self.tstart)
|
|
self.tlast = tnow
|
|
#end if
|
|
#end if
|
|
log(msg,n=n+self.indent)
|
|
#end if
|
|
#end def __init__
|
|
|
|
def increment(self,n=1):
|
|
self.indent += n
|
|
#end def increment
|
|
|
|
def decrement(self,n=1):
|
|
self.indent -= n
|
|
#end def decrement
|
|
|
|
def set_none(self):
|
|
self.verbosity = self.verbosity_levels.none
|
|
#end def set_none
|
|
|
|
def set_low(self):
|
|
self.verbosity = self.verbosity_levels.low
|
|
#end def set_low
|
|
|
|
def set_high(self):
|
|
self.verbosity = self.verbosity_levels.high
|
|
#end def set_high
|
|
|
|
def set_verbosity(self,level):
|
|
if level not in self.verbosity_levels:
|
|
vlinv = self.verbosity_levels.inverse()
|
|
error('Cannot set verbosity level to "{}".\nValid options are: {}'.format(level,[vlinv[i] for i in sorted(vlinv.keys())]))
|
|
#end if
|
|
self.verbosity = self.verbosity_levels[level]
|
|
#end def set_verbosity
|
|
#end class VLog
|
|
vlog = VLog()
|
|
|
|
|
|
|
|
def set_verbosity(level):
|
|
vlog.set_verbosity(level)
|
|
#end def set_verbosity
|
|
|
|
|
|
class Missing:
|
|
def __call__(self,value):
|
|
return isinstance(value,Missing)
|
|
#end def __call__
|
|
#end class Missing
|
|
missing = Missing()
|
|
|
|
|
|
|
|
class AttributeProperties(DevBase):
|
|
def __init__(self,**kwargs):
|
|
self.assigned = set(kwargs.keys())
|
|
self.name = kwargs.pop('name' , None )
|
|
self.dest = kwargs.pop('dest' , None )
|
|
self.type = kwargs.pop('type' , None )
|
|
self.default = kwargs.pop('default' , None )
|
|
self.no_default = kwargs.pop('no_default', False)
|
|
self.deepcopy = kwargs.pop('deepcopy' , False)
|
|
self.required = kwargs.pop('required' , False)
|
|
if len(kwargs)>0:
|
|
self.error('Invalid init variable attributes received.\nInvalid attributes:\n{}\nThis is a developer error.'.format(obj(kwargs)))
|
|
#end if
|
|
#end def __init__
|
|
#end class AttributeProperties
|
|
|
|
|
|
|
|
class DefinedAttributeBase(DevBase):
|
|
|
|
@classmethod
|
|
def set_unassigned_default(cls,default):
|
|
cls.unassigned_default = default
|
|
#end def set_unassigned_default
|
|
|
|
@classmethod
|
|
def define_attributes(cls,*other_cls,**attribute_properties):
|
|
if len(other_cls)==1 and issubclass(other_cls[0],DefinedAttributeBase):
|
|
cls.obtain_attributes(other_cls[0])
|
|
#end if
|
|
if cls.class_has('attribute_definitions'):
|
|
attr_defs = cls.attribute_definitions
|
|
else:
|
|
attr_defs = obj()
|
|
cls.class_set(
|
|
attribute_definitions = attr_defs
|
|
)
|
|
#end if
|
|
for name,attr_props in attribute_properties.items():
|
|
attr_props = AttributeProperties(**attr_props)
|
|
attr_props.name = name
|
|
if name not in attr_defs:
|
|
attr_defs[name] = attr_props
|
|
else:
|
|
p = attr_defs[name]
|
|
for n in attr_props.assigned:
|
|
p[n] = attr_props[n]
|
|
#end for
|
|
#end if
|
|
#end for
|
|
if cls.class_has('unassigned_default'):
|
|
for p in attr_defs:
|
|
if 'default' not in p.assigned:
|
|
p.default = cls.unassigned_default
|
|
#end if
|
|
#end for
|
|
#end if
|
|
required_attributes = set()
|
|
deepcopy_attributes = set()
|
|
typed_attributes = set()
|
|
toplevel_attributes = set()
|
|
sublevel_attributes = set()
|
|
for name,props in attr_defs.items():
|
|
if props.required:
|
|
required_attributes.add(name)
|
|
#end if
|
|
if props.deepcopy:
|
|
deepcopy_attributes.add(name)
|
|
#end if
|
|
if props.type is not None:
|
|
typed_attributes.add(name)
|
|
#end if
|
|
if props.dest is None:
|
|
toplevel_attributes.add(name)
|
|
else:
|
|
sublevel_attributes.add(name)
|
|
#end if
|
|
#end for
|
|
cls.class_set(
|
|
required_attributes = required_attributes,
|
|
deepcopy_attributes = deepcopy_attributes,
|
|
typed_attributes = typed_attributes,
|
|
toplevel_attributes = toplevel_attributes,
|
|
sublevel_attributes = sublevel_attributes,
|
|
)
|
|
#end def define_attributes
|
|
|
|
|
|
@classmethod
|
|
def obtain_attributes(cls,super_cls):
|
|
cls.class_set(
|
|
attribute_definitions = super_cls.attribute_definitions.copy()
|
|
)
|
|
#end def obtain_attributes
|
|
|
|
|
|
def __init__(self,**values):
|
|
if len(values)>0:
|
|
self.set_default_attributes()
|
|
self.set_attributes(**values)
|
|
#end if
|
|
#end def __init__
|
|
|
|
|
|
def initialize(self,**values):
|
|
self.set_default_attributes()
|
|
if len(values)>0:
|
|
self.set_attributes(**values)
|
|
#end if
|
|
#end def initialize
|
|
|
|
|
|
def set_default_attributes(self):
|
|
cls = self.__class__
|
|
props = cls.attribute_definitions
|
|
for name in cls.toplevel_attributes:
|
|
self._set_default_attribute(name,props[name])
|
|
#end for
|
|
for name in cls.sublevel_attributes:
|
|
self._set_default_attribute(name,props[name])
|
|
#end for
|
|
#end def set_default_attributes
|
|
|
|
|
|
def set_attributes(self,**values):
|
|
cls = self.__class__
|
|
value_names = set(values.keys())
|
|
attr_names = set(cls.attribute_definitions.keys())
|
|
invalid = value_names - attr_names
|
|
if len(invalid)>0:
|
|
v = obj()
|
|
v.transfer_from(values,invalid)
|
|
self.error('Attempted to set unrecognized attributes\nUnrecognized attributes:\n{}'.format(v))
|
|
#end if
|
|
missing = set(cls.required_attributes) - value_names
|
|
if len(missing)>0:
|
|
msg = ''
|
|
for n in sorted(missing):
|
|
msg += '\n '+n
|
|
#end for
|
|
self.error('Required attributes are missing.\nPlease provide the following attributes during initialization:{}'.format(msg))
|
|
#end if
|
|
props = cls.attribute_definitions
|
|
toplevel_names = value_names & cls.toplevel_attributes
|
|
for name in toplevel_names:
|
|
self._set_attribute(self,name,values[name],props[name])
|
|
#end for
|
|
sublevel_names = value_names - toplevel_names
|
|
for name in sublevel_names:
|
|
p = props[name]
|
|
if p.dest not in self:
|
|
self.error('Attribute destination "{}" does not exist at the top level.\nThis is a developer error.'.format(p.dest))
|
|
#end if
|
|
self._set_attribute(self[p.dest],name,values[name],p)
|
|
#end for
|
|
#end def set_attributes
|
|
|
|
|
|
def check_attributes(self,exit=False):
|
|
msg = ''
|
|
cls = self.__class__
|
|
a = obj()
|
|
for name in cls.toplevel_attributes:
|
|
if name in self:
|
|
a[name] = self[name]
|
|
#end if
|
|
#end for
|
|
props = cls.attribute_definitions
|
|
for name in cls.sublevel_attributes:
|
|
p = props[name]
|
|
if p.dest in self:
|
|
sub = self[p.dest]
|
|
if name in sub:
|
|
a[name] = sub[name]
|
|
#end if
|
|
#end if
|
|
#end for
|
|
present = set(a.keys())
|
|
missing = cls.required_attributes - present
|
|
if len(missing)>0:
|
|
m = ''
|
|
for n in sorted(missing):
|
|
m += '\n '+n
|
|
#end for
|
|
msg += 'Required attributes are missing.\nPlease provide the following attributes during initialization:{}\n'.format(m)
|
|
#end if
|
|
for name in cls.typed_attributes:
|
|
if name in a:
|
|
p = props[name]
|
|
v = a[name]
|
|
if not isinstance(v,p.type):
|
|
msg += 'Attribute "{}" has invalid type.\n Type expected: {}\n Type present: {}\n'.format(name,p.type.__name__,v.__class__.__name__)
|
|
#end if
|
|
#end if
|
|
#end for
|
|
valid = len(msg)==0
|
|
if not valid and exit:
|
|
self.error(msg)
|
|
#end if
|
|
return valid
|
|
#end def check_attributes
|
|
|
|
|
|
def check_unassigned(self,value):
|
|
cls = self.__class__
|
|
unassigned = cls.class_has('unassigned_default') and value is cls.unassigned_default
|
|
return unassigned
|
|
#end def check_unassigned
|
|
|
|
|
|
def set_attribute(self,name,value):
|
|
cls = self.__class__
|
|
props = cls.attribute_definitions
|
|
if name not in props:
|
|
self.error('Cannot set unrecognized attribute "{}".\nValid options are: {}'.format(name,sorted(props.keys())))
|
|
#end if
|
|
p = props[name]
|
|
if p.type is not None and not isinstance(value,p.type):
|
|
self.error('Cannot set attribute "{}".\nExpected value with type: {}\nReceived value with type: {}'.format(name,p.type.__name__,value.__class__.__name__))
|
|
#end if
|
|
if p.deepcopy:
|
|
value = deepcopy(value)
|
|
#end if
|
|
if p.dest is None:
|
|
self[name] = value
|
|
elif p.dest not in self:
|
|
self.error('Cannot set attribute "{}".\nAttribute destination "{}" does not exist.'.format(name,p.dest))
|
|
else:
|
|
self[p.dest][name] = value
|
|
#end if
|
|
#end def set_attribute
|
|
|
|
|
|
def get_attribute(self,name,value=missing,assigned=True):
|
|
default_value = value
|
|
default_provided = not missing(default_value)
|
|
require_assigned = assigned and not default_provided
|
|
cls = self.__class__
|
|
props = cls.attribute_definitions
|
|
if name not in props:
|
|
self.error('Cannot get unrecognized attribute "{}".\nValid options are: {}'.format(name,sorted(props.keys())))
|
|
#end if
|
|
p = props[name]
|
|
value = missing
|
|
if p.dest is None:
|
|
if name in self:
|
|
value = self[name]
|
|
#end if
|
|
elif p.dest in self and name in self[p.dest]:
|
|
value = self[p.dest][name]
|
|
#end if
|
|
present = not missing(value)
|
|
if not present and default_provided:
|
|
return default_value
|
|
else:
|
|
unassigned = True
|
|
if present:
|
|
unassigned = self.check_unassigned(value)
|
|
#end if
|
|
if not present or (unassigned and require_assigned):
|
|
extra = ''
|
|
if p.dest is not None:
|
|
extra = ' at location "{}"'.format(p.dest)
|
|
#end if
|
|
if not present:
|
|
msg = 'Cannot get attribute "{}"{}.\nAttribute does not exist.'.format(name,extra)
|
|
else:
|
|
msg = 'Cannot get attribute "{}"{}.\nAttribute has not been assigned.'.format(name,extra)
|
|
#end if
|
|
self.error(msg)
|
|
#end if
|
|
#end if
|
|
return value
|
|
#end def get_attribute
|
|
|
|
|
|
def has_attribute(self,name):
|
|
return not (name not in self or self.check_unassigned(self[name]))
|
|
#end def has_attribute
|
|
|
|
|
|
def _set_default_attribute(self,name,props):
|
|
p = props
|
|
if p.no_default:
|
|
return
|
|
#end if
|
|
value = p.default
|
|
if inspect.isclass(value) or inspect.isfunction(value):
|
|
value = value()
|
|
#end if
|
|
if p.dest is None:
|
|
self[name] = value
|
|
elif p.dest not in self:
|
|
self.error('Attribute destination "{}" does not exist at the top level.\nThis is a developer error.'.format(p.dest))
|
|
else:
|
|
self[p.dest][name] = value
|
|
#end if
|
|
#end def _set_default_attribute
|
|
|
|
|
|
def _set_attribute(self,container,name,value,props):
|
|
p = props
|
|
if p.type is not None and not isinstance(value,p.type):
|
|
self.error('Cannot set attribute "{}".\nExpected value with type: {}\nReceived value with type: {}'.format(name,p.type.__name__,value.__class__.__name__))
|
|
#end if
|
|
if p.deepcopy:
|
|
value = deepcopy(value)
|
|
#end if
|
|
container[name] = value
|
|
#end def _set_attribute
|
|
|
|
#end class DefinedAttributeBase
|
|
|
|
|
|
|
|
class Observable(DefinedAttributeBase):
|
|
def __init__(self,**values):
|
|
self.initialize(**values)
|
|
#end def __init__
|
|
|
|
def initialize(self,**values):
|
|
DefinedAttributeBase.initialize(self,**values)
|
|
if len(values)>0:
|
|
self.info.initialized = True
|
|
#end if
|
|
#end def initialize
|
|
#end class Observable
|
|
|
|
Observable.set_unassigned_default(None)
|
|
|
|
Observable.define_attributes(
|
|
info = obj(
|
|
type = obj,
|
|
default = obj,
|
|
),
|
|
initialized = obj(
|
|
dest = 'info',
|
|
type = bool,
|
|
default = False,
|
|
),
|
|
structure = obj(
|
|
dest = 'info',
|
|
type = Structure,
|
|
default = None,
|
|
deepcopy = True,
|
|
),
|
|
)
|
|
|
|
|
|
|
|
class ObservableWithComponents(Observable):
|
|
|
|
component_names = None
|
|
default_component_name = None
|
|
|
|
|
|
def process_component_name(self,name):
|
|
if name is None:
|
|
name = self.default_component_name
|
|
elif name not in self.components:
|
|
self.error('"{}" is not a known component.\nValid options are: {}'.format(name,self.component_names))
|
|
#end if
|
|
return name
|
|
#end def process_component_name
|
|
|
|
|
|
def default_component(self):
|
|
return self.component(self.default_component_name)
|
|
#end def default_component
|
|
|
|
|
|
def component(self,name):
|
|
if name is None:
|
|
return self.default_component()
|
|
#end if
|
|
if name not in self.component_names:
|
|
self.error('"{}" is not a known component.\nValid options are: {}'.format(name,self.component_names))
|
|
elif name not in self:
|
|
self.error('Component "{}" not found.'.format(name))
|
|
#end if
|
|
comp = self.get_attribute(name)
|
|
return comp
|
|
#end def component
|
|
|
|
|
|
def components(self,names=None):
|
|
comps = obj()
|
|
if names is None:
|
|
for c in self.component_names:
|
|
if c in self:
|
|
comps[c] = self[c]
|
|
#end if
|
|
#end for
|
|
if len(comps)==0:
|
|
self.error('No components found.')
|
|
#end if
|
|
else:
|
|
if isinstance(names,str):
|
|
names = [names]
|
|
#end if
|
|
for name in names:
|
|
if name not in self.component_names:
|
|
self.error('"{}" is not a known component.\nValid options are: {}'.format(name,self.component_names))
|
|
elif name not in self:
|
|
self.error('Component "{}" not found.'.format(name))
|
|
#end if
|
|
comps[name] = self[name]
|
|
#end for
|
|
#end if
|
|
return comps
|
|
#end def components
|
|
|
|
#end class ObservableWithComponents
|
|
|
|
|
|
|
|
|
|
def read_eshdf_nofk_data(filename,Ef):
|
|
from numpy import array,pi,dot,sqrt,abs,zeros
|
|
from numpy.linalg import inv,det
|
|
from hdfreader import read_hdf
|
|
|
|
def h5int(i):
|
|
return array(i,dtype=int)[0]
|
|
#end def h5int
|
|
|
|
# Use slightly shifted Fermi energy
|
|
E_fermi = Ef + 1e-8
|
|
|
|
# Open the HDF file w/o loading the arrays into memory (view mode)
|
|
vlog('Reading '+filename)
|
|
h = read_hdf(filename,view=True)
|
|
|
|
# Get the G-vectors in cell coordinates
|
|
gvu = array(h.electrons.kpoint_0.gvectors)
|
|
|
|
# Get the untiled cell axes
|
|
axes = array(h.supercell.primitive_vectors)
|
|
|
|
# Compute the k-space cell axes
|
|
kaxes = 2*pi*inv(axes).T
|
|
|
|
# Convert G-vectors from cell coordinates to atomic units
|
|
gv = dot(gvu,kaxes)
|
|
|
|
# Get number of kpoints/twists, spins, and G-vectors
|
|
nkpoints = h5int(h.electrons.number_of_kpoints)
|
|
nspins = h5int(h.electrons.number_of_spins)
|
|
ngvecs = len(gv)
|
|
|
|
# Process the orbital data
|
|
data = obj()
|
|
for k in range(nkpoints):
|
|
vlog('Processing k-point {:>3}'.format(k),n=1,time=True)
|
|
kin_k = obj()
|
|
eig_k = obj()
|
|
k_k = obj()
|
|
nk_k = obj()
|
|
nelec_k = zeros((nspins,),dtype=float)
|
|
kp = h.electrons['kpoint_'+str(k)]
|
|
gvs = dot(array(kp.reduced_k),kaxes)
|
|
gvk = gv.copy()
|
|
for d in range(3):
|
|
gvk[:,d] += gvs[d]
|
|
#end for
|
|
kinetic=(gvk**2).sum(1)/2 # Hartree units
|
|
for s in range(nspins):
|
|
kin_s = []
|
|
eig_s = []
|
|
k_s = gvk
|
|
nk_s = zeros((ngvecs,),dtype=float)
|
|
nelec_s = 0
|
|
path = 'electrons/kpoint_{0}/spin_{1}'.format(k,s)
|
|
spin = h.get_path(path)
|
|
eigs = convert(array(spin.eigenvalues),'Ha','eV')
|
|
nstates = h5int(spin.number_of_states)
|
|
for st in range(nstates):
|
|
eig = eigs[st]
|
|
if eig<E_fermi:
|
|
stpath = path+'/state_{0}/psi_g'.format(st)
|
|
psi = array(h.get_path(stpath))
|
|
nk_orb = (psi**2).sum(1)
|
|
kin_orb = (kinetic*nk_orb).sum()
|
|
nelec_s += nk_orb.sum()
|
|
nk_s += nk_orb
|
|
kin_s.append(kin_orb)
|
|
eig_s.append(eig)
|
|
#end if
|
|
#end for
|
|
data[k,s] = obj(
|
|
kpoint = array(kp.reduced_k),
|
|
kin = array(kin_s),
|
|
eig = array(eig_s),
|
|
k = k_s,
|
|
nk = nk_s,
|
|
ne = nelec_s,
|
|
)
|
|
#end for
|
|
#end for
|
|
res = obj(
|
|
orbfile = filename,
|
|
E_fermi = E_fermi,
|
|
axes = axes,
|
|
kaxes = kaxes,
|
|
nkpoints = nkpoints,
|
|
nspins = nspins,
|
|
data = data,
|
|
)
|
|
|
|
return res
|
|
#end def read_eshdf_nofk_data
|
|
|
|
|
|
|
|
class MomentumDistribution(ObservableWithComponents):
|
|
component_names = ('tot','pol','u','d')
|
|
|
|
default_component_name = 'tot'
|
|
|
|
def get_raw_data(self):
|
|
data = self.get_attribute('raw')
|
|
if len(data)==0:
|
|
self.error('Raw n(k) data is not present.')
|
|
#end if
|
|
return data
|
|
#end def get_raw_data
|
|
|
|
|
|
def filter_raw_data(self,filter_tol=1e-5,store=True):
|
|
vlog('Filtering raw n(k) data with tolerance {:6.4e}'.format(filter_tol))
|
|
prior_tol = self.get_attribute('raw_filter_tol',assigned=False)
|
|
data = self.get_raw_data()
|
|
if prior_tol is not None and prior_tol<=filter_tol:
|
|
vlog('Filtering applied previously with tolerance {:6.4e}, skipping.'.format(prior_tol))
|
|
return data
|
|
#end if
|
|
k = data.first().k
|
|
km = np.linalg.norm(k,axis=1)
|
|
kmax = 0.
|
|
order = km.argsort()
|
|
for s,sdata in data.items():
|
|
vlog('Finding kmax for {} data'.format(s),n=1,time=True)
|
|
nk = sdata.nk
|
|
for n in reversed(order):
|
|
if nk[n]>filter_tol:
|
|
break
|
|
#end if
|
|
#end for
|
|
kmax = max(km[n],kmax)
|
|
#end for
|
|
vlog('Original kmax: {:8.4f}'.format(km.max()),n=2)
|
|
vlog('Filtered kmax: {:8.4f}'.format(kmax),n=2)
|
|
vlog('Applying kmax filter to data',n=1,time=True)
|
|
keep = km<kmax
|
|
k = k[keep]
|
|
vlog('size before filter: {}'.format(len(keep)),n=2)
|
|
vlog('size after filter: {}'.format(len(k)),n=2)
|
|
vlog('fraction: {:6.4e}'.format(len(k)/len(keep)),n=2)
|
|
if store:
|
|
new_data = data
|
|
self.set_attribute('raw_filter_tol',filter_tol)
|
|
else:
|
|
new_data = obj()
|
|
#end if
|
|
for s in data.keys():
|
|
if s not in new_data:
|
|
new_data[s] = obj()
|
|
#end if
|
|
sdata = new_data[s]
|
|
sdata.k = k
|
|
sdata.nk = data[s].nk[keep]
|
|
#end for
|
|
if store:
|
|
vlog('Overwriting original raw n(k) with filtered data',n=1)
|
|
#end if
|
|
vlog('Filtering complete',n=1,time=True)
|
|
return new_data
|
|
#end def filter_raw_data
|
|
|
|
|
|
def map_raw_data_onto_grid(self,unfold=False,filter_tol=1e-5):
|
|
vlog('\nMapping raw n(k) data onto regular grid')
|
|
data = self.get_raw_data()
|
|
structure = self.get_attribute('structure',assigned=unfold)
|
|
if structure is not None:
|
|
kaxes = structure.kaxes
|
|
else:
|
|
kaxes = self.get_attribute('kaxes')
|
|
#end if
|
|
if filter_tol is not None:
|
|
vlog.increment()
|
|
data = self.filter_raw_data(filter_tol,store=False)
|
|
vlog.decrement()
|
|
#end if
|
|
if not unfold:
|
|
for s,sdata in data.items():
|
|
vlog('Mapping {} data onto grid'.format(s),n=1,time=True)
|
|
self[s] = grid_function(
|
|
points = sdata.k,
|
|
values = sdata.nk,
|
|
axes = kaxes,
|
|
)
|
|
#end for
|
|
else:
|
|
rotations = structure.point_group_operations()
|
|
for s,sdata in data.items():
|
|
if s=='d' and 'u' in data and id(sdata)==id(data.u):
|
|
continue
|
|
#end if
|
|
vlog('Unfolding {} data'.format(s),n=1,time=True)
|
|
k = []
|
|
nk = []
|
|
ks = sdata.k
|
|
nks = sdata.nk
|
|
for n,R in enumerate(rotations):
|
|
vlog('Processing rotation {:<3}'.format(n),n=2,mem=True)
|
|
k.extend(np.dot(ks,R))
|
|
nk.extend(nks)
|
|
#end for
|
|
k = np.array(k ,dtype=float)
|
|
nk = np.array(nk,dtype=float)
|
|
vlog('Unfolding finished',n=2,time=True)
|
|
|
|
vlog('Mapping {} data onto grid'.format(s),n=1,time=True)
|
|
vlog.increment(2)
|
|
self[s] = grid_function(
|
|
points = k,
|
|
values = nk,
|
|
axes = kaxes,
|
|
average = True,
|
|
)
|
|
vlog.decrement(2)
|
|
#end for
|
|
#end if
|
|
if 'd' not in self and 'u' in self:
|
|
self.d = self.u
|
|
#end if
|
|
vlog('Mapping complete',n=1,time=True)
|
|
vlog('Current memory: ',n=1,mem=True)
|
|
#end def map_raw_data_onto_grid
|
|
|
|
|
|
def backfold(self):
|
|
structure = self.get_attribute('structure',assigned=True)
|
|
kaxes = structure.kaxes
|
|
c = self.default_component()
|
|
dk = c.grid.dr
|
|
print(kaxes)
|
|
print(dk)
|
|
print(np.diag(kaxes)/np.diag(dk))
|
|
print(c.grid.cell_grid_shape)
|
|
ci()
|
|
exit()
|
|
#end def backfold
|
|
|
|
|
|
def plot_plane_contours(self,
|
|
quantity = None,
|
|
origin = None,
|
|
a1 = None,
|
|
a2 = None,
|
|
a1_range = (0,1),
|
|
a2_range = (0,1),
|
|
grid_spacing = 0.3,
|
|
unit_in = False,
|
|
unit_out = False,
|
|
boundary = True,
|
|
):
|
|
c = self.component(quantity)
|
|
o = np.asarray(origin)
|
|
a1 = np.asarray(a1)
|
|
a2 = np.asarray(a2)
|
|
if unit_in:
|
|
s = self.get_attribute('structure',assigned=True)
|
|
from structure import get_seekpath_full
|
|
skp = get_seekpath_full(structure=s,primitive=True)
|
|
kaxes = np.asarray(skp.reciprocal_primitive_lattice)
|
|
o_in = o
|
|
a1_in = a1
|
|
a2_in = a2
|
|
o = np.dot(o_in ,kaxes)
|
|
a1 = np.dot(a1_in,kaxes)
|
|
a2 = np.dot(a2_in,kaxes)
|
|
special_kpoints = skp.point_coords
|
|
#end if
|
|
a1 -= o
|
|
a2 -= o
|
|
corner = o + a1_range[0]*a1 + a2_range[0]*a2
|
|
a1 *= a1_range[1] - a1_range[0]
|
|
a2 *= a2_range[1] - a2_range[0]
|
|
g = generate_grid(
|
|
type = 'parallelotope',
|
|
corner = corner,
|
|
axes = [a1,a2],
|
|
dr = (grid_spacing,grid_spacing),
|
|
)
|
|
gf = c.interpolate(g)
|
|
gf.plot_contours(boundary=boundary)
|
|
#end def plot_plane_contours
|
|
|
|
|
|
def plot_radial_raw(self,quants='all',kmax=None,fmt='b.',fig=True,show=True):
|
|
data = self.get_raw_data()
|
|
if quants=='all':
|
|
quants = list(data.keys())
|
|
#end if
|
|
for q in quants:
|
|
d = data[q]
|
|
k = np.linalg.norm(d.k,axis=1)
|
|
nk = d.nk
|
|
has_error = 'nk_err' in d
|
|
if has_error:
|
|
nke = d.nk_err
|
|
#end if
|
|
if kmax is not None:
|
|
rng = k<kmax
|
|
k = k[rng]
|
|
nk = nk[rng]
|
|
if has_error:
|
|
nke = nke[rng]
|
|
#end if
|
|
#end if
|
|
if fig:
|
|
plt.figure()
|
|
#end if
|
|
if not has_error:
|
|
plt.plot(k,nk,fmt)
|
|
else:
|
|
plt.errorbar(k,nk,nke,fmt=fmt)
|
|
#end if
|
|
plt.xlabel('k (a.u.)')
|
|
plt.ylabel('n(k) {}'.format(q))
|
|
#end for
|
|
if show:
|
|
plt.show()
|
|
#end if
|
|
#end def plot_radial_raw
|
|
|
|
|
|
def plot_directional_raw(self,kdir,quants='all',kmax=None,fmt='b.',fig=True,show=True,reflect=False):
|
|
data = self.get_raw_data()
|
|
kdir = np.array(kdir,dtype=float)
|
|
kdir /= np.linalg.norm(kdir)
|
|
if quants=='all':
|
|
quants = list(data.keys())
|
|
#end if
|
|
for q in quants:
|
|
d = data[q]
|
|
k = d.k
|
|
nk = d.nk
|
|
has_error = 'nk_err' in d
|
|
if has_error:
|
|
nke = d.nk_err
|
|
#end if
|
|
km = np.linalg.norm(d.k,axis=1)
|
|
if kmax is not None:
|
|
rng = km<kmax
|
|
km = km[rng]
|
|
k = k[rng]
|
|
nk = nk[rng]
|
|
if has_error:
|
|
nke = nke[rng]
|
|
#end if
|
|
#end if
|
|
kd = np.dot(k,kdir)
|
|
along_dir = (np.abs(km-np.abs(kd)) < 1e-8*km) | (km<1e-8)
|
|
kd = kd[along_dir]
|
|
nk = nk[along_dir]
|
|
if has_error:
|
|
nke = nke[along_dir]
|
|
#end if
|
|
if fig:
|
|
plt.figure()
|
|
#end if
|
|
if not has_error:
|
|
plt.plot(kd,nk,fmt)
|
|
if reflect:
|
|
plt.plot(-kd,nk,fmt)
|
|
#end if
|
|
else:
|
|
plt.errorbar(kd,nk,nke,fmt=fmt)
|
|
if reflect:
|
|
plt.errorbar(-kd,nk,nke,fmt=fmt)
|
|
#end if
|
|
#end if
|
|
plt.xlabel('k (a.u.)')
|
|
plt.ylabel('directional n(k) {}'.format(q))
|
|
#end for
|
|
#end def plot_directional_raw
|
|
#end class MomentumDistribution
|
|
|
|
MomentumDistribution.define_attributes(
|
|
Observable,
|
|
raw = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
u = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
d = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
tot = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
pol = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
kaxes = obj(
|
|
dest = 'info',
|
|
type = np.ndarray,
|
|
no_default = True,
|
|
),
|
|
raw_filter_tol = obj(
|
|
dest = 'info',
|
|
type = float,
|
|
default = None,
|
|
),
|
|
)
|
|
|
|
|
|
|
|
class MomentumDistributionDFT(MomentumDistribution):
|
|
|
|
def read_eshdf(self,filepath,E_fermi=None,savefile=None,unfold=False,grid=True):
|
|
|
|
save = False
|
|
if savefile is not None:
|
|
if os.path.exists(savefile):
|
|
vlog('\nLoading from save file {}'.format(savefile))
|
|
self.load(savefile)
|
|
vlog('Done',n=1,time=True)
|
|
return
|
|
else:
|
|
save = True
|
|
#end if
|
|
#end if
|
|
|
|
vlog('\nExtracting n(k) data from {}'.format(filepath))
|
|
|
|
if E_fermi is None:
|
|
E_fermi = self.info.E_fermi
|
|
else:
|
|
self.info.E_fermi = E_fermi
|
|
#end if
|
|
if E_fermi is None:
|
|
self.error('Cannot read n(k) from ESHDF file. Fermi energy (eV) is required to populate n(k) from ESHDF data.\nFile being read: {}'.format(filepath))
|
|
#end if
|
|
|
|
vlog.increment()
|
|
d = read_eshdf_nofk_data(filepath,E_fermi)
|
|
vlog.decrement()
|
|
|
|
spins = {0:'u',1:'d'}
|
|
|
|
spin_data = obj()
|
|
for (ki,si) in sorted(d.data.keys()):
|
|
vlog('Appending data for k-point {:>3} and spin {}'.format(ki,si),n=1,time=True)
|
|
data = d.data[ki,si]
|
|
s = spins[si]
|
|
if s not in spin_data:
|
|
spin_data[s] = obj(k=[],nk=[])
|
|
#end if
|
|
sdata = spin_data[s]
|
|
sdata.k.extend(data.k)
|
|
sdata.nk.extend(data.nk)
|
|
#end for
|
|
for sdata in spin_data:
|
|
sdata.k = np.array(sdata.k)
|
|
sdata.nk = np.array(sdata.nk)
|
|
#end for
|
|
if 'd' not in spin_data:
|
|
spin_data.d = spin_data.u
|
|
#end if
|
|
spin_data.tot = obj(
|
|
k = spin_data.u.k,
|
|
nk = spin_data.u.nk + spin_data.d.nk,
|
|
)
|
|
|
|
self.set_attribute('raw' ,spin_data)
|
|
self.set_attribute('kaxes',d.kaxes )
|
|
|
|
if grid:
|
|
self.map_raw_data_onto_grid(unfold=unfold)
|
|
#end if
|
|
|
|
if save:
|
|
vlog('Saving to file {}'.format(savefile),n=1)
|
|
self.save(savefile)
|
|
#end if
|
|
|
|
vlog('n(k) data extraction complete',n=1,time=True)
|
|
|
|
#end def read_eshdf
|
|
#end class MomentumDistributionDFT
|
|
|
|
MomentumDistributionDFT.define_attributes(
|
|
MomentumDistribution,
|
|
E_fermi = obj(
|
|
dest = 'info',
|
|
type = float,
|
|
default = None,
|
|
)
|
|
)
|
|
|
|
|
|
class MomentumDistributionQMC(MomentumDistribution):
|
|
def read_stat_h5(self,*files,equil=0,savefile=None):
|
|
|
|
save = False
|
|
if savefile is not None:
|
|
if os.path.exists(savefile):
|
|
vlog('\nLoading from save file {}'.format(savefile))
|
|
self.load(savefile)
|
|
vlog('Done',n=1,time=True)
|
|
return
|
|
else:
|
|
save = True
|
|
#end if
|
|
#end if
|
|
|
|
vlog('\nReading n(k) data from stat.h5 files',time=True)
|
|
k = []
|
|
nk = []
|
|
nke = []
|
|
if len(files)==1 and isinstance(files[0],(list,tuple)):
|
|
files = files[0]
|
|
#end if
|
|
for file in files:
|
|
if isinstance(file,StatFile):
|
|
stat = file
|
|
else:
|
|
vlog('Reading stat.h5 file',n=1,time=True)
|
|
stat = StatFile(file,observables=['momentum_distribution'])
|
|
#end if
|
|
vlog('Processing n(k) data from stat.h5 file',n=1,time=True)
|
|
vlog('filename = {}'.format(stat.filepath),n=2)
|
|
group = stat.observable_groups(self,single=True)
|
|
|
|
kpoints = np.array(group['kpoints'])
|
|
nofk = np.array(group['value'])
|
|
|
|
nk_mean,nk_var,nk_error,nk_kappa = simstats(nofk[equil:],dim=0)
|
|
|
|
k.extend(kpoints)
|
|
nk.extend(nk_mean)
|
|
nke.extend(nk_error)
|
|
#end for
|
|
vlog('Converting concatenated lists to arrays',n=1,time=True)
|
|
data = obj(
|
|
tot = obj(
|
|
k = np.array(k),
|
|
nk = np.array(nk),
|
|
nk_err = np.array(nke),
|
|
)
|
|
)
|
|
self.set_attribute('raw',data)
|
|
|
|
if save:
|
|
vlog('Saving to file {}'.format(savefile),n=1)
|
|
self.save(savefile)
|
|
#end if
|
|
|
|
vlog('stat.h5 file read finished',n=1,time=True)
|
|
#end def read_stat_h5
|
|
#end class MomentumDistributionQMC
|
|
|
|
|
|
|
|
class Density(ObservableWithComponents):
|
|
component_names = ('tot','pol','u','d')
|
|
|
|
default_component_name = 'tot'
|
|
|
|
|
|
def read_xsf(self,filepath,component=None):
|
|
component = self.process_component_name(component)
|
|
|
|
vlog('Reading density data from XSF file for component "{}"'.format(component),time=True)
|
|
|
|
if isinstance(filepath,XsfFile):
|
|
vlog('XSF file already loaded, reusing data.')
|
|
xsf = filepath
|
|
copy_values = True
|
|
else:
|
|
vlog('Loading data from file',n=1,time=True)
|
|
vlog('file location: {}'.format(filepath),n=2)
|
|
vlog('memory before: ',n=2,mem=True)
|
|
xsf = XsfFile(filepath)
|
|
vlog('load complete',n=2,time=True)
|
|
vlog('memory after: ',n=2,mem=True)
|
|
copy_values = False
|
|
#end if
|
|
|
|
# read structure
|
|
if not self.has_attribute('structure'):
|
|
vlog('Reading structure from XSF data',n=1,time=True)
|
|
s = Structure()
|
|
s.read_xsf(xsf)
|
|
self.set_attribute('structure',s)
|
|
#end if
|
|
|
|
# read grid
|
|
if not self.has_attribute('grid'):
|
|
vlog('Reading grid from XSF data',n=1,time=True)
|
|
g = read_grid(xsf)
|
|
self.set_attribute('grid',g)
|
|
self.set_attribute('distance_units','B')
|
|
#end if
|
|
|
|
# read values
|
|
xsf.remove_ghost()
|
|
d = xsf.get_density()
|
|
values = d.values_noghost.ravel()
|
|
if copy_values:
|
|
values = values.copy()
|
|
#end if
|
|
|
|
# create grid function for component
|
|
vlog('Constructing grid function from XSF data',n=1,time=True)
|
|
f = grid_function(
|
|
type = 'parallelotope',
|
|
grid = self.grid,
|
|
values = values,
|
|
copy = False,
|
|
)
|
|
|
|
self.set_attribute(component,f)
|
|
self.set_attribute('distance_units','A')
|
|
|
|
vlog('Read complete',n=1,time=True)
|
|
vlog('Current memory:',n=1,mem=True)
|
|
#end def read_xsf
|
|
|
|
|
|
def volume_normalize(self):
|
|
g = self.get_attribute('grid')
|
|
dV = g.volume()/g.ncells
|
|
for c in self.components():
|
|
c.values /= dV
|
|
#end for
|
|
#end def volume_normalize
|
|
|
|
|
|
def norm(self,component=None):
|
|
norms = obj()
|
|
comps = self.components(component)
|
|
for name,d in comps.items():
|
|
g = d.grid
|
|
dV = g.volume()/g.ncells
|
|
norms[name] = d.values.sum()*dV
|
|
#end if
|
|
if isinstance(component,str):
|
|
return norms[component]
|
|
else:
|
|
return norms
|
|
#end if
|
|
#end def norm
|
|
|
|
|
|
def change_distance_units(self,units):
|
|
units_old = self.get_attribute('distance_units')
|
|
rscale = convert(1.0,units_old,units)
|
|
grid = self.get_attribute('grid')
|
|
grid.points *= rscale
|
|
self.set_attribute('distance_units',units) # Update the object info to reflect the conversion
|
|
#end def change_distance_units
|
|
|
|
|
|
def change_density_units(self,units):
|
|
units_old = self.get_attribute('density_units')
|
|
dscale = 1.0/convert(1.0,units_old,units)
|
|
for c in self.components():
|
|
c.values *= dscale**3
|
|
#end for
|
|
self.set_attribute('density_units',units) # Update the object info to reflect the conversion
|
|
#end def change_density_units
|
|
|
|
|
|
def radial_density(self,component=None,dr=0.01,ntheta=100,rmax=None,single=False,interp_kwargs=None,comps_return=False,species=None):
|
|
|
|
vlog('Computing radial density',time=True)
|
|
vlog('Current memory:',n=1,mem=True)
|
|
if interp_kwargs is None:
|
|
interp_kwargs = obj()
|
|
#end if
|
|
s = self.get_attribute('structure')
|
|
struct = s
|
|
if rmax is None:
|
|
rmax = s.voronoi_species_radii()
|
|
#end if
|
|
vlog('Finding equivalent atomic sites',n=1,time=True)
|
|
equiv_atoms = s.equivalent_atoms()
|
|
species_rmax = obj()
|
|
if isinstance(rmax,float):
|
|
if species is None:
|
|
species = list(equiv_atoms.keys())
|
|
#end if
|
|
for s in species:
|
|
species_rmax[s] = rmax
|
|
#end for
|
|
elif isinstance(rmax,list):
|
|
if species is None:
|
|
species = list(equiv_atoms.keys())
|
|
#end if
|
|
for si,s in enumerate(species):
|
|
if len(rmax)>1:
|
|
species_rmax[s] = rmax[si]
|
|
else:
|
|
species_rmax[s] = rmax[0]
|
|
#end if
|
|
#end for
|
|
else:
|
|
species = list(rmax.keys())
|
|
species_rmax.transfer_from(rmax)
|
|
#end if
|
|
vlog('Constructing spherical grid for each species',n=1,time=True)
|
|
species_grids = obj()
|
|
for s in species:
|
|
srmax = species_rmax[s]
|
|
if srmax<1e-3:
|
|
self.error('Cannot compute radial density.\n"rmax" must be set to a finite value.\nrmax provided for species "{}": {}'.format(s,srmax))
|
|
#end if
|
|
nr = int(np.ceil(srmax/dr))
|
|
species_grids[s] = SpheroidGrid(
|
|
axes = srmax*np.eye(3),
|
|
cells = (nr,ntheta,2*ntheta),
|
|
centered = True,
|
|
)
|
|
#end for
|
|
|
|
rdfs = obj()
|
|
for cname,d in self.components(component).items():
|
|
vlog('Processing radial density for component "{}"'.format(cname),n=1,time=True)
|
|
rdf = obj()
|
|
rdfs[cname] = rdf
|
|
for s,sgrid in species_grids.items():
|
|
rrad = sgrid.radii()
|
|
rsphere = sgrid.r
|
|
drad = np.zeros(rrad.shape,dtype=d.dtype)
|
|
if single:
|
|
atom_indices = [equiv_atoms[s][0]]
|
|
else:
|
|
atom_indices = equiv_atoms[s]
|
|
#end if
|
|
vlog('Averaging radial data for species "{}" over {} sites'.format(s,len(atom_indices)),n=2,time=True)
|
|
rcenter = np.zeros((3,),dtype=float)
|
|
for i in atom_indices:
|
|
new_center = struct.pos[i]
|
|
dr = new_center-rcenter
|
|
rsphere += dr
|
|
rcenter = new_center
|
|
dsphere = d.interpolate(rsphere,**interp_kwargs)
|
|
dsphere.shape = sgrid.shape
|
|
dsphere.shape = len(dsphere),dsphere.size//len(dsphere)
|
|
drad += dsphere.mean(axis=1)*4*np.pi*rrad**2
|
|
#end for
|
|
drad /= len(atom_indices)
|
|
rdf[s] = obj(
|
|
radius = rrad,
|
|
density = drad,
|
|
)
|
|
#end for
|
|
d.clear_ghost()
|
|
vlog('Current memory:',n=2,mem=True)
|
|
#end if
|
|
if isinstance(component,str) and not comps_return:
|
|
return rdfs[component]
|
|
else:
|
|
return rdfs
|
|
#end if
|
|
#end def radial_density
|
|
|
|
|
|
def cumulative_radial_density(self,rdfs=None,comps_return=False,**kwargs):
|
|
component = kwargs.get('component',None)
|
|
if rdfs is None:
|
|
kwargs['comps_return'] = True
|
|
crdfs = self.radial_density(**kwargs)
|
|
else:
|
|
crdfs = rdfs.copy()
|
|
#end if
|
|
for crdf in crdfs:
|
|
for d in crdf:
|
|
dr = d.radius[1]-d.radius[0]
|
|
d.density = d.density.cumsum()*dr
|
|
#end for
|
|
#end if
|
|
if isinstance(component,str) and not comps_return:
|
|
return crdfs[component]
|
|
else:
|
|
return crdfs
|
|
#end if
|
|
#end def cumulative_radial_density
|
|
|
|
|
|
def plot_radial_density(self,component=None,show=True,cumulative=False,**kwargs):
|
|
vlog('Plotting radial density')
|
|
kwargs['comps_return'] = True
|
|
if not cumulative:
|
|
rdfs = self.radial_density(component=component,**kwargs)
|
|
else:
|
|
rdfs = self.cumulative_radial_density(component=component,**kwargs)
|
|
#end if
|
|
rdf = rdfs.first()
|
|
species = list(rdf.keys())
|
|
|
|
dist_units = self.get_attribute('distance_units',None)
|
|
density_units = self.get_attribute('density_units',None)
|
|
|
|
for cname in self.component_names:
|
|
if cname in rdfs:
|
|
rdf = rdfs[cname]
|
|
for s in sorted(rdf.keys()):
|
|
srdf = rdf[s]
|
|
plt.figure()
|
|
plt.plot(srdf.radius,srdf.density,'b.-')
|
|
xlabel = 'Radius'
|
|
if dist_units is not None:
|
|
xlabel += ' ({})'.format(dist_units)
|
|
#end if
|
|
plt.xlabel(xlabel)
|
|
if not cumulative:
|
|
ylabel = 'Radial density'
|
|
else:
|
|
ylabel = 'Cumulative radial density'
|
|
#end if
|
|
if density_units is not None:
|
|
ylabel += ' (e/{}^3)'.format(density_units)
|
|
#end if
|
|
plt.ylabel(ylabel)
|
|
plt.title('{} {} density'.format(s,cname))
|
|
#end for
|
|
#end if
|
|
#end for
|
|
if show:
|
|
plt.show()
|
|
#end if
|
|
#end def plot_radial_density
|
|
|
|
|
|
def save_radial_density(self,prefix,rdfs=None,**kwargs):
|
|
path = ''
|
|
if '/' in prefix:
|
|
path,prefix = os.path.split(prefix)
|
|
#end if
|
|
vlog('Saving radial density with file prefix "{}"'.format(prefix))
|
|
vlog.increment()
|
|
kwargs['comps_return'] = True
|
|
if rdfs is None:
|
|
rdfs = self.radial_density(**kwargs)
|
|
#end if
|
|
crdfs = self.cumulative_radial_density(rdfs)
|
|
vlog.decrement()
|
|
groups = obj(
|
|
rad_dens = rdfs,
|
|
rad_dens_cum = crdfs,
|
|
)
|
|
for gname,dfs in groups.items():
|
|
for cname,rdf in dfs.items():
|
|
for sname,srdf in rdf.items():
|
|
filename = '{}.{}.{}_{}.dat'.format(prefix,gname,sname,cname)
|
|
filepath = os.path.join(path,filename)
|
|
vlog('Saving file '+filepath,n=1)
|
|
f = open(filepath,'w')
|
|
for r,d in zip(srdf.radius,srdf.density):
|
|
f.write('{: 16.8e} {: 16.8e}\n'.format(r,d))
|
|
#end for
|
|
f.close()
|
|
#end for
|
|
#end for
|
|
#end for
|
|
#end def save_radial_density
|
|
#end class Density
|
|
|
|
Density.define_attributes(
|
|
Observable,
|
|
raw = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
u = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
d = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
tot = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
pol = obj(
|
|
type = obj,
|
|
no_default = True,
|
|
),
|
|
grid = obj(
|
|
type = StructuredGrid,
|
|
no_default = True,
|
|
),
|
|
distance_units = obj(
|
|
dest = 'info',
|
|
type = str,
|
|
),
|
|
density_units = obj(
|
|
dest = 'info',
|
|
type = str,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
class ChargeDensity(Density):
|
|
None
|
|
#end class ChargeDensity
|
|
|
|
|
|
class EnergyDensity(Density):
|
|
None
|
|
#end class EnergyDensity
|
|
|
|
|
|
|
|
|
|
class StatFile(DevBase):
|
|
|
|
scalars = set('''
|
|
LocalEnergy
|
|
LocalEnergy_sq
|
|
Kinetic
|
|
LocalPotential
|
|
ElecElec
|
|
IonIon
|
|
LocalECP
|
|
NonLocalECP
|
|
KEcorr
|
|
MPC
|
|
'''.split())
|
|
|
|
observable_aliases = obj(
|
|
momentum_distribution = ['nofk'],
|
|
)
|
|
for observable in list(observable_aliases.keys()):
|
|
for alias in observable_aliases[observable]:
|
|
observable_aliases[alias] = observable
|
|
#end for
|
|
observable_aliases[observable] = observable
|
|
#end for
|
|
|
|
observable_classes = obj(
|
|
momentum_distribution = MomentumDistributionQMC,
|
|
)
|
|
|
|
observable_class_to_stat_group = obj()
|
|
for name,cls in observable_classes.items():
|
|
observable_class_to_stat_group[cls.__name__] = name
|
|
#end for
|
|
|
|
|
|
def __init__(self,filepath=None,**read_kwargs):
|
|
self.filepath = None
|
|
|
|
if filepath is not None:
|
|
self.filepath = filepath
|
|
self.read(filepath,**read_kwargs)
|
|
#end if
|
|
#end def __init__
|
|
|
|
|
|
def read(self,filepath,observables='all'):
|
|
if not os.path.exists(filepath):
|
|
self.error('Cannot read file.\nFile path does not exist: {}'.format(filepath))
|
|
#end if
|
|
h5 = h5py.File(filepath,'r')
|
|
observable_groups = obj()
|
|
for name,group in h5.items():
|
|
# Skip scalar quantities
|
|
if name in self.scalars:
|
|
continue
|
|
#end if
|
|
# Identify observable type by name, for now
|
|
for alias,observable in self.observable_aliases.items():
|
|
cond_name = self.condense_name(name)
|
|
cond_alias = self.condense_name(alias)
|
|
if cond_name.startswith(cond_alias):
|
|
if observable not in observable_groups:
|
|
observable_groups[observable] = obj()
|
|
#end if
|
|
observable_groups[observable][name] = group
|
|
#end if
|
|
#end for
|
|
#end for
|
|
if isinstance(observables,str):
|
|
if observables=='all':
|
|
self.transfer_from(observable_groups)
|
|
#end if
|
|
else:
|
|
for obs in observables:
|
|
if obs in observable_groups:
|
|
self[obs] = observable_groups[obs]
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end def read
|
|
|
|
|
|
def condense_name(self,name):
|
|
return name.lower().replace('_','')
|
|
#end def condenst_name
|
|
|
|
|
|
def observable_groups(self,observable,single=False):
|
|
if inspect.isclass(observable):
|
|
observable = observable.__name__
|
|
elif isinstance(observable,Observable):
|
|
observable = observable.__class__.__name__
|
|
#end if
|
|
groups = None
|
|
if observable in self:
|
|
groups = self[observable]
|
|
elif observable in self.observable_class_to_stat_group:
|
|
observable = self.observable_class_to_stat_group[observable]
|
|
if observable in self:
|
|
groups = self[observable]
|
|
#end if
|
|
#end if
|
|
if single and groups is not None:
|
|
if len(groups)==1:
|
|
return groups.first()
|
|
else:
|
|
self.error('Single stat.h5 observable group requested, but multiple are present.\nGroups present: {}'.format(sorted(groups.keys())))
|
|
#end if
|
|
else:
|
|
return groups
|
|
#end if
|
|
#end def observable_groups
|
|
#end class StatFile
|