qmcpack/tests/scripts/check_stats.py

2553 lines
79 KiB
Python
Executable File

#! /usr/bin/env python3
from __future__ import print_function, division
# Statical error checking code for use by testing framework (stat.h5 files)
# Jaron Krogel/ORNL
# check_stats.py packages obj and HDFreader classes from Nexus.
# Note that h5py is required (which depends on numpy).
######################################################################
# from generic.py
######################################################################
import sys
import traceback
from copy import deepcopy
import pickle
from random import randint
class generic_settings:
devlog = sys.stdout
raise_error = False
#end class generic_settings
class NexusError(Exception):
None
#end class NexusError
exit_call = sys.exit
def nocopy(value):
return value
#end def nocopy
def log(*items,**kwargs):
indent=None
logfile=generic_settings.devlog
if len(kwargs)>0:
n=0
if 'indent' in kwargs:
indent = kwargs['indent']
n+=1
#end if
if 'logfile' in kwargs:
logfile = kwargs['logfile']
n+=1
#end if
if n!=len(kwargs):
valid = 'indent logfile'.split()
invalid = set(kwargs.keys())-set(valid)
error('invalid keyword arguments provided\ninvalid keywords: {0}\nvalid options are: {1}'.format(sorted(invalid),valid))
#end if
#end if
if len(items)==1 and isinstance(items[0],str):
s = items[0]
else:
s=''
for item in items:
s+=str(item)+' '
#end for
#end if
if len(s)>0:
if isinstance(indent,str):
s=indent+s.replace('\n','\n'+indent)
#end if
s += '\n'
#end if
logfile.write(s)
#end def log
def message(msg,header=None,post_header=' message:',indent=' ',logfile=None):
if logfile is None:
logfile = generic_settings.devlog
#end if
if header is None:
header = post_header.lstrip()
else:
header += post_header
#end if
log('\n '+header,logfile=logfile)
log(msg.rstrip(),indent=indent,logfile=logfile)
#end def message
def warn(msg,header=None,indent=' ',logfile=None):
if logfile is None:
logfile = generic_settings.devlog
#end if
post_header=' warning:'
message(msg,header,post_header,indent,logfile)
#end def warn
def error(msg,header=None,exit=True,trace=True,indent=' ',logfile=None):
if generic_settings.raise_error:
raise NexusError(msg)
#end if
if logfile is None:
logfile = generic_settings.devlog
#end if
post_header=' error:'
message(msg,header,post_header,indent,logfile)
if exit:
log(' exiting.\n')
if trace:
traceback.print_stack()
#end if
exit_call()
#end if
#end def error
class object_interface(object):
_logfile = sys.stdout
def __len__(self):
return len(self.__dict__)
#end def __len__
def __contains__(self,name):
return name in self.__dict__
#end def
def __getitem__(self,name):
return self.__dict__[name]
#end def __getitem__
def __setitem__(self,name,value):
self.__dict__[name]=value
#end def __setitem__
def __delitem__(self,name):
del self.__dict__[name]
#end def __delitem__
def __iter__(self):
for item in self.__dict__:
yield self.__dict__[item]
#end for
#end def __iter__
def __repr__(self):
s=''
for k in sorted(self._keys()):
if not isinstance(k,str) or k[0]!='_':
v=self.__dict__[k]
if hasattr(v,'__class__'):
s+=' {0:<20} {1:<20}\n'.format(k,v.__class__.__name__)
else:
s+=' {0:<20} {1:<20}\n'.format(k,type(v))
#end if
#end if
#end for
return s
#end def __repr__
def __str__(self,nindent=1):
pad = ' '
npad = nindent*pad
s=''
normal = []
qable = []
for k,v in self._items():
if not isinstance(k,str) or k[0]!='_':
if isinstance(v,object_interface):
qable.append(k)
else:
normal.append(k)
#end if
#end if
#end for
normal.sort()
qable.sort()
indent = npad+18*' '
for k in normal:
v = self[k]
vstr = str(v).replace('\n','\n'+indent)
s+=npad+'{0:<15} = '.format(k)+vstr+'\n'
#end for
for k in qable:
v = self[k]
s+=npad+str(k)+'\n'
s+=v.__str__(nindent+1)
if isinstance(k,str):
s+=npad+'end '+k+'\n'
#end if
#end for
return s
#end def __str__
def __eq__(self,other):
if not hasattr(other,'__dict__'):
return False
#end if
eq = True
for sname in self.__dict__:
if sname not in other.__dict__:
return False
#end if
svar = self.__dict__[sname]
ovar = other.__dict__[sname]
stype = type(svar)
otype = type(ovar)
if stype!=otype:
return False
#end if
eqval = svar==ovar
if isinstance(eqval,bool):
eq &= eqval
else:
try: # accommodate numpy arrays implicitly
eq &= eqval.all()
except:
return False
#end try
#end if
#end for
return eq
#end def __eq__
def tree(self,depth=None,all=False,types=False,nindent=1):
if depth==nindent-1:
return ''
#end if
pad = ' '
npad = nindent*pad
s=''
normal = []
qable = []
for k,v in self._items():
if not isinstance(k,str) or k[0]!='_':
if isinstance(v,object_interface):
qable.append(k)
else:
normal.append(k)
#end if
#end if
#end for
normal.sort()
qable.sort()
indent = npad+18*' '
if all:
for k in normal:
v = self[k]
if types:
s+=npad+'{0:<15} = '.format(k)
if hasattr(v,'__class__'):
s+='{0:<20}'.format(v.__class__.__name__)
else:
s+='{0:<20}'.format(type(v))
#end if
else:
s+=npad+str(k)
#end if
s+='\n'
#end for
#end if
if all and depth!=nindent:
for k in qable:
v = self[k]
s+=npad+str(k)+'\n'
s+=v.tree(depth,all,types,nindent+1)
if isinstance(k,str):
s+=npad+'end '+k+'\n'
#end if
#end for
else:
for k in qable:
v = self[k]
if types:
s+=npad+'{0:<15} = '.format(k)
if hasattr(v,'__class__'):
s+='{0:<20}'.format(v.__class__.__name__)
else:
s+='{0:<20}'.format(type(v))
#end if
else:
s+=npad+str(k)
#end if
s+='\n'
s+=v.tree(depth,all,types,nindent+1)
#end for
#end if
return s
#end def tree
# dict interface
def keys(self):
return self.__dict__.keys()
#end def keys
def values(self):
return self.__dict__.values()
#end def values
def items(self):
return self.__dict__.items()
#end def items
def copy(self):
return deepcopy(self)
#end def copy
def clear(self):
self.__dict__.clear()
#end def clear
# save/load
def save(self,fpath=None):
if fpath==None:
fpath='./'+self.__class__.__name__+'.p'
#end if
fobj = open(fpath,'w')
binary = pickle.HIGHEST_PROTOCOL
pickle.dump(self,fobj,binary)
fobj.close()
del fobj
del binary
return
#end def save
def load(self,fpath=None):
if fpath==None:
fpath='./'+self.__class__.__name__+'.p'
#end if
fobj = open(fpath,'r')
tmp = pickle.load(fobj)
fobj.close()
d = self.__dict__
d.clear()
for k,v in tmp.__dict__.items():
d[k] = v
#end for
del fobj
del tmp
return
#end def load
# log, warning, and error messages
def open_log(self,filepath):
self._logfile = open(filepath,'w')
#end def open_log
def close_log(self):
self._logfile.close()
#end def close_log
def write(self,s):
self._logfile.write(s)
#end def write
def log(self,*items,**kwargs):
if 'logfile' not in kwargs:
kwargs['logfile'] = self._logfile
#end if
log(*items,**kwargs)
#end def log
def warn(self,message,header=None):
if header is None:
header=self.__class__.__name__
#end if
warn(message,header,logfile=self._logfile)
#end def warn
def error(self,message,header=None,exit=True,trace=True):
if header==None:
header = self.__class__.__name__
#end if
error(message,header,exit,trace,logfile=self._logfile)
#end def error
@classmethod
def class_log(cls,message):
log(message,logfile=cls._logfile)
#end def class_log
@classmethod
def class_warn(cls,message,header=None,post_header=' Warning:'):
if header==None:
header=cls.__name__
#end if
warn(message,header,logfile=cls._logfile)
#end def class_warn
@classmethod
def class_error(cls,message,header=None,exit=True,trace=True,post_header=' Error:'):
if header==None:
header = cls.__name__
#end if
error(message,header,exit,trace,logfile=cls._logfile)
#end def class_error
@classmethod
def class_has(cls,k):
return hasattr(cls,k)
#end def classmethod
@classmethod
def class_keys(cls):
return cls.__dict__.keys()
#end def class_keys
@classmethod
def class_get(cls,k):
return getattr(cls,k)
#end def class_set
@classmethod
def class_set(cls,**kwargs):
for k,v in kwargs.items():
setattr(cls,k,v)
#end for
#end def class_set
@classmethod
def class_set_single(cls,k,v):
setattr(cls,k,v)
#end def class_set_single
@classmethod
def class_set_optional(cls,**kwargs):
for k,v in kwargs.items():
if not hasattr(cls,k):
setattr(cls,k,v)
#end if
#end for
#end def class_set_optional
# access preserving functions
# dict interface
def _keys(self,*args,**kwargs):
return object_interface.keys(self,*args,**kwargs)
def _values(self,*args,**kwargs):
object_interface.values(self,*args,**kwargs)
def _items(self,*args,**kwargs):
return object_interface.items(self,*args,**kwargs)
def _copy(self,*args,**kwargs):
return object_interface.copy(self,*args,**kwargs)
def _clear(self,*args,**kwargs):
object_interface.clear(self,*args,**kwargs)
# save/load
def _save(self,*args,**kwargs):
object_interface.save(self,*args,**kwargs)
def _load(self,*args,**kwargs):
object_interface.load(self,*args,**kwargs)
# log, warning, and error messages
def _open_log(self,*args,**kwargs):
object_interface.open_log(self,*args,**kwargs)
def _close_log(self,*args,**kwargs):
object_interface.close_log(self,*args,**kwargs)
def _write(self,*args,**kwargs):
object_interface.write(self,*args,**kwargs)
def _log(self,*args,**kwargs):
object_interface.log(self,*args,**kwargs)
def _error(self,*args,**kwargs):
object_interface.error(self,*args,**kwargs)
def _warn(self,*args,**kwargs):
object_interface.warn(self,*args,**kwargs)
#end class object_interface
class obj(object_interface):
def __init__(self,*vars,**kwargs):
for var in vars:
if isinstance(var,(dict,object_interface)):
for k,v in var.items():
self[k] = v
#end for
else:
self[var] = None
#end if
#end for
for k,v in kwargs.items():
self[k] = v
#end for
#end def __init__
# list interface
def append(self,value):
self[len(self)] = value
#end def append
# return representations
def list(self,*keys):
nkeys = len(keys)
if nkeys==0:
keys = sorted(self._keys())
elif nkeys==1 and isinstance(keys[0],(list,tuple)):
keys = keys[0]
#end if
values = []
for key in keys:
values.append(self[key])
#end if
return values
#end def list
def list_optional(self,*keys):
nkeys = len(keys)
if nkeys==0:
keys = sorted(self._keys())
elif nkeys==1 and isinstance(keys[0],(list,tuple)):
keys = keys[0]
#end if
values = []
for key in keys:
if key in self:
values.append(self[key])
else:
values.append(None)
#end if
#end if
return values
#end def list_optional
def tuple(self,*keys):
return tuple(obj.list(self,*keys))
#end def tuple
def dict(self,*keys):
nkeys = len(keys)
if nkeys==0:
keys = sorted(self._keys())
elif nkeys==1 and isinstance(keys[0],(list,tuple)):
keys = keys[0]
#end if
d = dict()
for k in keys:
d[k] = self[k]
#end for
return d
#end def dict
def to_dict(self):
d = dict()
for k,v in self._items():
if isinstance(v,obj):
d[k] = v._to_dict()
else:
d[k] = v
#end if
#end for
return d
#end def to_dict
def obj(self,*keys):
nkeys = len(keys)
if nkeys==0:
keys = sorted(self._keys())
elif nkeys==1 and isinstance(keys[0],(list,tuple)):
keys = keys[0]
#end if
o = obj()
for k in keys:
o[k] = self[k]
#end for
return o
#end def obj
# list extensions
def first(self):
return self[min(self._keys())]
#end def first
def last(self):
return self[max(self._keys())]
#end def last
def select_random(self):
return self[randint(0,len(self)-1)]
#end def select_random
# dict extensions
def random_key(self):
key = None
nkeys = len(self)
if nkeys>0:
key = self._keys()[randint(0,nkeys-1)]
#end if
return key
#end def random_key
def set(self,*objs,**kwargs):
for key,value in kwargs.items():
self[key]=value
#end for
if len(objs)>0:
for o in objs:
for k,v in o.items():
self[k] = v
#end for
#end for
#end if
return self
#end def set
def set_optional(self,*objs,**kwargs):
for key,value in kwargs.items():
if key not in self:
self[key]=value
#end if
#end for
if len(objs)>0:
for o in objs:
for k,v in o.items():
if k not in self:
self[k] = v
#end if
#end for
#end for
#end if
return self
#end def set_optional
def get(self,key,value=None): # follow dict interface, no plural
if key in self:
value = self[key]
#end if
return value
#end def get
def get_optional(self,key,value=None):
if key in self:
value = self[key]
#end if
return value
#end def get_optional
def get_required(self,key):
if key in self:
value = self[key]
else:
obj.error(self,'a required key is not present\nkey required: {0}\nkeys present: {1}'.format(key,sorted(self._keys())))
#end if
return value
#end def get_required
def delete(self,*keys):
nkeys = len(keys)
single = False
if nkeys==0:
keys = sorted(self._keys())
elif nkeys==1 and isinstance(keys[0],(list,tuple)):
keys = keys[0]
elif nkeys==1:
single = True
#end if
values = []
for key in keys:
values.append(self[key])
del self[key]
#end for
if single:
return values[0]
else:
return values
#end if
#end def delete
def delete_optional(self,key,value=None):
if key in self:
value = self[key]
del self[key]
#end if
return value
#end def delete_optional
def delete_required(self,key):
if key in self:
value = self[key]
del self[key]
else:
obj.error(self,'a required key is not present\nkey required: {0}\nkeys present: {1}'.format(key,sorted(self._keys())))
#end if
return value
#end def delete_required
def add(self,key,value):
self[key] = value
#end def add
def add_optional(self,key,value):
if key not in self:
self[key] = value
#end if
#end def add_optional
def transfer_from(self,other,keys=None,copy=False,overwrite=True):
if keys==None:
if isinstance(other,object_interface):
keys = other._keys()
else:
keys = other.keys()
#end if
#end if
if copy:
copier = deepcopy
else:
copier = nocopy
#end if
if overwrite:
for k in keys:
self[k]=copier(other[k])
#end for
else:
for k in keys:
if k not in self:
self[k]=copier(other[k])
#end if
#end for
#end if
#end def transfer_from
def transfer_to(self,other,keys=None,copy=False,overwrite=True):
if keys==None:
keys = self._keys()
#end if
if copy:
copier = deepcopy
else:
copier = nocopy
#end if
if overwrite:
for k in keys:
other[k]=copier(self[k])
#end for
else:
for k in keys:
if k not in self:
other[k]=copier(self[k])
#end if
#end for
#end if
#end def transfer_to
def move_from(self,other,keys=None):
if keys==None:
if isinstance(other,object_interface):
keys = other._keys()
else:
keys = other.keys()
#end if
#end if
for k in keys:
self[k]=other[k]
del other[k]
#end for
#end def move_from
def move_to(self,other,keys=None):
if keys==None:
keys = self._keys()
#end if
for k in keys:
other[k]=self[k]
del self[k]
#end for
#end def move_to
def copy_from(self,other,keys=None,deep=True):
obj.transfer_from(self,other,keys,copy=deep)
#end def copy_from
def copy_to(self,other,keys=None,deep=True):
obj.transfer_to(self,other,keys,copy=deep)
#end def copy_to
def shallow_copy(self):
new = self.__class__()
for k,v in self._items():
new[k] = v
#end for
return new
#end def shallow_copy
def inverse(self):
new = self.__class__()
for k,v in self._items():
new[v] = k
#end for
return new
#end def inverse
def path_exists(self,path):
o = self
if isinstance(path,str):
path = path.split('/')
#end if
for p in path:
if not p in o:
return False
#end if
o = o[p]
#end for
return True
#end def path_exists
def set_path(self,path,value=None):
o = self
cls = self.__class__
if isinstance(path,str):
path = path.split('/')
#end if
for p in path[0:-1]:
if not p in o:
o[p] = cls()
#end if
o = o[p]
#end for
o[path[-1]] = value
#end def set_path
def get_path(self,path,value=None):
o = self
if isinstance(path,str):
path = path.split('/')
#end if
for p in path[0:-1]:
if not p in o:
return value
#end if
o = o[p]
#end for
lp = path[-1]
if lp not in o:
return value
else:
return o[lp]
#end if
#end def get_path
def serial(self,s=None,path=None):
first = s is None
if first:
s = obj()
path = ''
#end if
for k,v in self._items():
p = path+str(k)
if isinstance(v,obj):
if len(v)==0:
s[p]=v
else:
v._serial(s,p+'/')
#end if
else:
s[p]=v
#end if
#end for
if first:
return s
#end if
#end def serial
# access preserving functions
# list interface
def _append(self,*args,**kwargs):
obj.append(self,*args,**kwargs)
# return representations
def _list(self,*args,**kwargs):
return obj.list(self,*args,**kwargs)
def _list_optional(self,*args,**kwargs):
return obj.list_optional(self,*args,**kwargs)
def _tuple(self,*args,**kwargs):
return obj.tuple(self,*args,**kwargs)
def _dict(self,*args,**kwargs):
return obj.dict(self,*args,**kwargs)
def _to_dict(self,*args,**kwargs):
return obj.to_dict(self,*args,**kwargs)
def _obj(self,*args,**kwargs):
return obj.obj(self,*args,**kwargs)
# list extensions
def _first(self,*args,**kwargs):
return obj.first(self,*args,**kwargs)
def _last(self,*args,**kwargs):
return obj.last(self,*args,**kwargs)
def _select_random(self,*args,**kwargs):
return obj.select_random(self,*args,**kwargs)
# dict extensions
def _random_key(self,*args,**kwargs):
obj.random_key(self,*args,**kwargs)
def _set(self,*args,**kwargs):
obj.set(self,*args,**kwargs)
def _set_optional(self,*args,**kwargs):
obj.set_optional(self,*args,**kwargs)
def _get(self,*args,**kwargs):
obj.get(self,*args,**kwargs)
def _get_optional(self,*args,**kwargs):
obj.get_optional(self,*args,**kwargs)
def _get_required(self,*args,**kwargs):
obj.get_required(self,*args,**kwargs)
def _delete(self,*args,**kwargs):
obj.delete(self,*args,**kwargs)
def _delete_optional(self,*args,**kwargs):
obj.delete_optional(self,*args,**kwargs)
def _delete_required(self,*args,**kwargs):
obj.delete_required(self,*args,**kwargs)
def _add(self,*args,**kwargs):
obj.add(self,*args,**kwargs)
def _add_optional(self,*args,**kwargs):
obj.add_optional(self,*args,**kwargs)
def _transfer_from(self,*args,**kwargs):
obj.transfer_from(self,*args,**kwargs)
def _transfer_to(self,*args,**kwargs):
obj.transfer_to(self,*args,**kwargs)
def _move_from(self,*args,**kwargs):
obj.move_from(self,*args,**kwargs)
def _move_to(self,*args,**kwargs):
obj.move_to(self,*args,**kwargs)
def _copy_from(self,*args,**kwargs):
obj.copy_from(self,*args,**kwargs)
def _copy_to(self,*args,**kwargs):
obj.copy_to(self,*args,**kwargs)
def _shallow_copy(self,*args,**kwargs):
obj.shallow_copy(self,*args,**kwargs)
def _inverse(self,*args,**kwargs):
return obj.inverse(self,*args,**kwargs)
def _path_exists(self,*args,**kwargs):
obj.path_exists(self,*args,**kwargs)
def _set_path(self,*args,**kwargs):
obj.set_path(self,*args,**kwargs)
def _get_path(self,*args,**kwargs):
obj.get_path(self,*args,**kwargs)
def _serial(self,*args,**kwargs):
return obj.serial(self,*args,**kwargs)
#end class obj
######################################################################
# end from generic.py
######################################################################
######################################################################
# from superstring.py
######################################################################
import string
def contains_any(str, set):
for c in set:
if c in str: return 1;
return 0;
#end def contains_any
invalid_variable_name_chars=set('!"#$%&\'()*+,-./:;<=>?@[\\]^`{|}-\n\t ')
def valid_variable_name(s):
return not contains_any(s,invalid_variable_name_chars)
#end def valid_variable_name
######################################################################
# end from superstring.py
######################################################################
######################################################################
# from debug.py
######################################################################
import code
import inspect
def ci(locs=None,globs=None):
if locs is None or globs is None:
cur_frame = inspect.currentframe()
caller_frame = cur_frame.f_back
locs = caller_frame.f_locals
globs = caller_frame.f_globals
#end if
code.interact(local=dict(globs,**locs))
#end def ci
ls = locals
gs = globals
interact = ci
######################################################################
# end from debug.py
######################################################################
######################################################################
# from developer.py
######################################################################
class DevBase(obj):
def not_implemented(self):
self.error('a base class function has not been implemented',trace=True)
#end def not_implemented
#end class DevBase
######################################################################
# end from developer.py
######################################################################
######################################################################
# from hdfreader.py
######################################################################
from numpy import array,ndarray,minimum,abs,ix_,resize
import sys
import keyword
from inspect import getmembers
import h5py
class HDFglobals(DevBase):
view = False
#end class HDFglobals
class HDFgroup(DevBase):
def _escape_name(self,name):
if name in self._escape_names:
name=name+'_'
#end if
return name
#end def escape_name
def _set_parent(self,parent):
self._parent=parent
return
#end def set_parent
def _add_dataset(self,name,dataset):
self._datasets[name]=dataset
return
#end def add_dataset
def _add_group(self,name,group):
group._name=name
self._groups[name]=group
return
#end def add_group
def _contains_group(self,name):
return name in self._groups.keys()
#end def _contains_group
def _contains_dataset(self,name):
return name in self._datasets.keys()
#end def _contains_dataset
def _to_string(self):
s=''
if len(self._datasets)>0:
s+=' datasets:\n'
for k,v in self._datasets.items():
s+= ' '+k+'\n'
#end for
#end if
if len(self._groups)>0:
s+= ' groups:\n'
for k,v in self._groups.items():
s+= ' '+k+'\n'
#end for
#end if
return s
#end def list
# def __str__(self):
# return self._to_string()
# #end def __str__
#
# def __repr__(self):
# return self._to_string()
# #end def __repr__
def __init__(self):
self._name=''
self._parent=None
self._groups={};
self._datasets={};
self._group_counts={}
self._escape_names=None
self._escape_names=set(dict(getmembers(self)).keys()) | set(keyword.kwlist)
return
#end def __init__
def _remove_hidden(self,deep=True):
if '_parent' in self:
del self._parent
#end if
if deep:
for name,value in self.items():
if isinstance(value,HDFgroup):
value._remove_hidden()
#end if
#end for
#end if
for name in list(self.keys()):
if name[0]=='_':
del self[name]
#end if
#end for
#end def _remove_hidden
# read in all data views (h5py datasets) into arrays
# useful for converting a single group read in view form to full arrays
def read_arrays(self):
self._remove_hidden()
for k,v in self.items():
if isinstance(v,HDFgroup):
v.read_arrays()
else:
self[k] = array(v)
#end if
#end for
#end def read_arrays
def get_keys(self):
if '_groups' in self:
keys = list(self._groups.keys())
else:
keys = list(self.keys())
#end if
return keys
#end def get_keys
#end class HDFgroup
class HDFreader(DevBase):
datasets = set(["<class 'h5py.highlevel.Dataset'>","<class 'h5py._hl.dataset.Dataset'>","<class 'h5py._debian_h5py_serial._hl.dataset.Dataset'>"])
groups = set(["<class 'h5py.highlevel.Group'>","<class 'h5py._hl.group.Group'>","<class 'h5py._debian_h5py_serial._hl.group.Group'>"])
def __init__(self,fpath,verbose=False,view=False):
HDFglobals.view = view
if verbose:
print(' Initializing HDFreader')
self.fpath=fpath
if verbose:
print(' loading h5 file')
try:
self.hdf = h5py.File(fpath,'r')
except IOError:
self._success = False
self.hdf = obj(obj=obj())
else:
self._success = True
#end if
if verbose:
print(' converting h5 file to dynamic object')
#convert the hdf 'dict' into a dynamic object
self.nlevels=1
self.ilevel=0
# Set the current hdf group
self.obj = HDFgroup()
self.cur=[self.obj]
self.hcur=[self.hdf]
if self._success:
cur = self.cur[self.ilevel]
hcur = self.hcur[self.ilevel]
for kr,v in hcur.items():
k=cur._escape_name(kr)
if valid_variable_name(k):
vtype = str(type(v))
if vtype in HDFreader.datasets:
self.add_dataset(cur,k,v)
elif vtype in HDFreader.groups:
self.add_group(hcur,cur,k,v)
else:
raise Exception('hdfreader error: encountered invalid type: '+vtype)
#end if
else:
print('hdfreader warning: attribute '+k+' is not a valid variable name and has been ignored')
#end if
#end for
#end if
if verbose:
print(' end HDFreader Initialization')
return
#end def __init__
def increment_level(self):
self.ilevel+=1
self.nlevels = max(self.ilevel+1,self.nlevels)
if self.ilevel+1==self.nlevels:
self.cur.append(None)
self.hcur.append(None)
#end if
self.pad = self.ilevel*' '
return
#end def increment_level
def decrement_level(self):
self.ilevel-=1
self.pad = self.ilevel*' '
return
#end def decrement_level
def add_dataset(self,cur,k,v):
if not HDFglobals.view:
cur[k]=array(v)
else:
cur[k] = v
#end if
cur._add_dataset(k,cur[k])
return
#end def add_dataset
def add_group(self,hcur,cur,k,v):
cur[k] = HDFgroup()
cur._add_group(k,cur[k])
cur._groups[k]._parent = cur
self.increment_level()
self.cur[self.ilevel] = cur._groups[k]
self.hcur[self.ilevel] = hcur[k]
cur = self.cur[self.ilevel]
hcur = self.hcur[self.ilevel]
for kr,v in hcur.items():
k=cur._escape_name(kr)
if valid_variable_name(k):
vtype = str(type(v))
if vtype in HDFreader.datasets:
self.add_dataset(cur,k,v)
elif vtype in HDFreader.groups:
self.add_group(hcur,cur,k,v)
#end if
else:
print('hdfreader warning: attribute '+k+' is not a valid variable name and has been ignored')
#end if
#end for
return
#end def add_group
#end class HDFreader
def read_hdf(fpath,verbose=False,view=False):
return HDFreader(fpath=fpath,verbose=verbose,view=view).obj
#end def read_hdf
######################################################################
# end from hdfreader.py
######################################################################
import os
import sys
from optparse import OptionParser
from numpy import zeros,sqrt,longdouble,loadtxt
can_plot = False
try:
import matplotlib
gui_envs = ['GTKAgg','TKAgg','Qt4Agg','WXAgg']
for gui in gui_envs:
try:
matplotlib.use(gui,warn=False, force=True)
from matplotlib import pyplot
can_plot = True
break
except:
continue
#end try
#end for
from matplotlib.pyplot import figure,plot,xlabel,ylabel,title,show,ylim,legend,xlim,rcParams,savefig,bar,xticks,subplot,grid,setp,errorbar,loglog,semilogx,semilogy,text
params = {'legend.fontsize':14,'figure.facecolor':'white','figure.subplot.hspace':0.,
'axes.labelsize':16,'xtick.labelsize':14,'ytick.labelsize':14}
rcParams.update(params)
except (ImportError,RuntimeError):
can_plot = False
#end try
class ColorWheel(DevBase):
def __init__(self):
colors = 'Black Maroon DarkOrange Green DarkBlue Purple Gray Firebrick Orange MediumSeaGreen DodgerBlue MediumOrchid'.split()
lines = '- -- -. :'.split()
markers = '. v s o ^ d p'.split()
ls = []
for line in lines:
for color in colors:
ls.append((color,line))
#end for
#end for
ms = []
for i in range(len(markers)):
ms.append((colors[i],markers[i]))
#end for
mls = []
ic=-1
for line in lines:
for marker in markers:
ic = (ic+1)%len(colors)
mls.append((colors[ic],marker+line))
#end for
#end for
self.line_styles = ls
self.marker_styles = ms
self.marker_line_styles = mls
self.reset()
#end def __init__
def next_line(self):
self.iline = (self.iline+1)%len(self.line_styles)
return self.line_styles[self.iline]
#end def next_line
def next_marker(self):
self.imarker = (self.imarker+1)%len(self.marker_styles)
return self.marker_styles[self.imarker]
#end def next_marker
def next_marker_line(self):
self.imarker_line = (self.imarker_line+1)%len(self.marker_line_styles)
return self.marker_line_styles[self.imarker_line]
#end def next_marker_line
def reset(self):
self.iline = -1
self.imarker = -1
self.imarker_line = -1
#end def reset
def reset_line(self):
self.iline = -1
#end def reset_line
def reset_marker(self):
self.imarker = -1
#end def reset_marker
def reset_marker_line(self):
self.imarker_line = -1
#end def reset_marker_line
#end class ColorWheel
color_wheel = ColorWheel()
checkstats_settings = obj(
verbose = False,
)
def vlog(*args,**kwargs):
if checkstats_settings.verbose:
n = kwargs.get('n',0)
if n==0:
log(*args)
else:
log(*args,indent=n*' ')
#end if
#end if
#end def vlog
# standalone definition of error function from Abramowitz & Stegun
# credit: http://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/
# consider also: https://math.stackexchange.com/questions/42920/efficient-and-accurate-approximation-of-error-function
def erf(x):
# constants
a1 = 0.254829592
a2 = -0.284496736
a3 = 1.421413741
a4 = -1.453152027
a5 = 1.061405429
p = 0.3275911
# Save the sign of x
sign = 1
if x < 0:
sign = -1
x = abs(x)
# A & S 7.1.26
t = 1.0/(1.0 + p*x)
y = 1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*math.exp(-x*x)
return sign*y
#end def erf
# standalone inverse error function
# credit: https://stackoverflow.com/questions/42381244/pure-python-inverse-error-function
import math
def polevl(x, coefs, N):
ans = 0
power = len(coefs) - 1
for coef in coefs:
ans += coef * x**power
power -= 1
return ans
#end def polevl
def p1evl(x, coefs, N):
return polevl(x, [1] + coefs, N)
#end def p1evl
def erfinv(z):
if z < -1 or z > 1:
raise ValueError("'z' must be between -1 and 1 inclusive")
if z == 0:
return 0
if z == 1:
return float('inf')
if z == -1:
return -float('inf')
# From scipy special/cephes/ndrti.c
def ndtri(y):
# approximation for 0 <= abs(z - 0.5) <= 3/8
P0 = [
-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0,
]
Q0 = [
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0,
]
# Approximation for interval z = sqrt(-2 log y ) between 2 and 8
# i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
P1 = [
4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4,
]
Q1 = [
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4,
]
# Approximation for interval z = sqrt(-2 log y ) between 8 and 64
# i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
P2 = [
3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9,
]
Q2 = [
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9,
]
s2pi = 2.50662827463100050242
code = 1
if y > (1.0 - 0.13533528323661269189): # 0.135... = exp(-2)
y = 1.0 - y
code = 0
if y > 0.13533528323661269189:
y = y - 0.5
y2 = y * y
x = y + y * (y2 * polevl(y2, P0, 4) / p1evl(y2, Q0, 8))
x = x * s2pi
return x
x = math.sqrt(-2.0 * math.log(y))
x0 = x - math.log(x) / x
z = 1.0 / x
if x < 8.0: # y > exp(-32) = 1.2664165549e-14
x1 = z * polevl(z, P1, 8) / p1evl(z, Q1, 8)
else:
x1 = z * polevl(z, P2, 8) / p1evl(z, Q2, 8)
x = x0 - x1
if code != 0:
x = -x
return x
result = ndtri((z + 1) / 2.0) / math.sqrt(2)
return result
#end def inv_erf
def_atol = 0.0
def_rtol = 1e-6
# determine if two floats differ
def float_diff(v1,v2,atol=def_atol,rtol=def_rtol):
return abs(v1-v2)>atol+rtol*abs(v2)
#end def float_diff
# Returns failure error code to OS.
# Explicitly prints 'fail' after an optional message.
def exit_fail(msg=None):
if msg!=None:
print(msg)
#end if
print('Test status: fail')
exit(1)
#end def exit_fail
# Returns success error code to OS.
# Explicitly prints 'pass' after an optional message.
def exit_pass(msg=None):
if msg!=None:
print(msg)
#end if
print('Test status: pass')
exit(0)
#end def exit_pass
# Calculates the mean, variance, errorbar, and autocorrelation time
# for a N-d array of statistical data values.
# If 'exclude' is provided, the first 'exclude' values will be
# excluded from the analysis.
def simstats(x,dim=None,exclude=None):
if exclude!=None:
x = x[exclude:]
#end if
shape = x.shape
ndim = len(shape)
if dim==None:
dim=ndim-1
#end if
permute = dim!=ndim-1
reshape = ndim>2
nblocks = shape[dim]
if permute:
r = list(range(ndim))
r.pop(dim)
r.append(dim)
permutation = tuple(r)
r = list(range(ndim))
r.pop(ndim-1)
r.insert(dim,ndim-1)
invperm = tuple(r)
x=x.transpose(permutation)
shape = tuple(array(shape)[array(permutation)])
dim = ndim-1
#end if
if reshape:
nvars = prod(shape[0:dim])
x=x.reshape(nvars,nblocks)
rdim=dim
dim=1
else:
nvars = shape[0]
#end if
mean = x.mean(dim)
var = x.var(dim)
N=nblocks
if ndim==1:
i=0
tempC=0.5
kappa=0.0
mtmp=mean
if abs(var)<1e-15:
kappa = 1.0
else:
ovar=1.0/var
while (tempC>0 and i<(N-1)):
kappa=kappa+2.0*tempC
i=i+1
#tempC=corr(i,x,mean,var)
tempC = ovar/(N-i)*sum((x[0:N-i]-mtmp)*(x[i:N]-mtmp))
#end while
if kappa == 0.0:
kappa = 1.0
#end if
#end if
Neff=(N+0.0)/(kappa+0.0)
if (Neff == 0.0):
Neff = 1.0
#end if
error=sqrt(var/Neff)
else:
error = zeros(mean.shape)
kappa = zeros(mean.shape)
for v in range(nvars):
i=0
tempC=0.5
kap=0.0
vtmp = var[v]
mtmp = mean[v]
if abs(vtmp)<1e-15:
kap = 1.0
else:
ovar = 1.0/vtmp
while (tempC>0 and i<(N-1)):
i += 1
kap += 2.0*tempC
tempC = ovar/(N-i)*sum((x[v,0:N-i]-mtmp)*(x[v,i:N]-mtmp))
#end while
if kap == 0.0:
kap = 1.0
#end if
#end if
Neff=(N+0.0)/(kap+0.0)
if (Neff == 0.0):
Neff = 1.0
#end if
kappa[v]=kap
error[v]=sqrt(vtmp/Neff)
#end for
#end if
if reshape:
x = x.reshape(shape)
mean = mean.reshape(shape[0:rdim])
var = var.reshape(shape[0:rdim])
error = error.reshape(shape[0:rdim])
kappa = kappa.reshape(shape[0:rdim])
#end if
if permute:
x=x.transpose(invperm)
#end if
return (mean,var,error,kappa)
#end def simstats
def load_scalar_file(options,selector):
output_files = options.output_files
if selector=='auto':
if 'dmc' in output_files:
selector = 'dmc'
elif 'scalar' in output_files:
selector = 'scalar'
else:
exit_fail('could not load scalar file, no files present')
#end if
elif selector not in ('scalar','dmc'):
exit_fail('could not load scalar file, invalid selector\ninvalid selector: {0}\nvalid options: scalar, dmc'.format(selector))
#end if
if selector not in output_files:
exit_fail('could not load scalar file, file is not present\nfile type requested: {0}'.format(selector))
#end if
filepath = os.path.join(options.path,output_files[selector])
lt = loadtxt(filepath)
if len(lt.shape)==1:
lt.shape = (1,len(lt))
#end if
data = lt[:,1:].transpose()
fobj = open(filepath,'r')
variables = fobj.readline().split()[2:]
fobj.close()
scalars = obj(
file_type = selector,
data = obj(),
)
for i,var in enumerate(variables):
scalars.data[var] = data[i,:]
#end for
return scalars
#end def load_scalar_file
######################################################################
# IMPORTANT DEVELOPER INFO #
######################################################################
# Information for individual quantities in stat.h5 file.
# For each quantity, list a default label and data paths.
# To add new quantities to be checked, update this data structure.
stat_info = obj({
'density' : obj(
default_label = 'Density',
data_paths = obj(tot='value'),
),
'spindensity' : obj(
default_label = 'SpinDensity',
data_paths = obj(u='u/value',
d='d/value'),
),
'energydensity' : obj(
default_label = 'EnergyDensity',
data_paths = obj(W=('spacegrid1/value',0,3),
T=('spacegrid1/value',1,3),
V=('spacegrid1/value',2,3)),
),
'1rdm' : obj(
default_label = 'DensityMatrices',
data_paths = obj(u='number_matrix/u/value',
d='number_matrix/d/value'),
),
'1redm' : obj(
default_label = 'DensityMatrices',
data_paths = obj(u='energy_matrix/u/value',
d='energy_matrix/d/value'),
),
'obdm' : obj(
default_label = 'OneBodyDensityMatrices',
data_paths = obj(u='number_matrix/u/value',
d='number_matrix/d/value'),
),
'momentum' : obj(
default_label = 'nofk',
data_paths = obj(tot='value'),
),
'sh_coeff' : obj(
default_label = 'sh_coeff',
data_paths = obj(coeff='value'),
),
})
# Reads command line options.
def read_command_line():
try:
parser = OptionParser(
usage='usage: %prog [options]',
add_help_option=False,
version='%prog 0.1'
)
parser.add_option('-h','--help',dest='help',
action='store_true',default=False,
help='Print help information and exit (default=%default).'
)
parser.add_option('-p','--prefix',dest='prefix',
default='qmc',
help='Prefix for output files (default=%default). Can be a path including the file prefix.'
)
parser.add_option('-s','--series',dest='series',
default='0',
help='Output series to analyze (default=%default).'
)
parser.add_option('-e','--equilibration',dest='equilibration',
default='0',
help='Equilibration length in blocks (default=%default).'
)
parser.add_option('-n','--nsigma',dest='nsigma',
default='3',
help='Sigma requirement for pass/fail (default=%default).'
)
parser.add_option('-a','--abs_err',dest='abs_err',
default='none',
help='Absolute error requirement for pass/fail (default=%default).'
)
parser.add_option('-q','--quantity',dest='quantity',
default='none',
help = 'Quantity to check (required). If a non-default name for the quantity is used, pass in the quantity and name as a pair.'
)
parser.add_option('-c','--npartial_sums',dest='npartial_sums',
default='none',
help = 'Partial sum count for the reference data (required)'
)
parser.add_option('-r','--ref','--reference',dest='reference_file',
default='none',
help = 'Path to reference file containing full and partial sum reference information. The test fails if any full or partial sum exceeds nsigma deviation from the reference values. For cases like the density or spin density, the -f option should additionally be used (see below). For the energy density, a block by block check against relevant summed energy terms in scalar.dat or dmc.dat files is additionally made.'
)
parser.add_option('-f','--fixed','--fixed_sum',dest='fixed_sum',
action='store_true',default=False,
help = 'Full sum of data takes on a fixed, non-stochastic value. In this case, when checking against reference data, check that each block satisfies the fixed sum condition. This is appropriate, e.g. for the electron density where the full sum of each block must equal the number of electrons. Typically the appropriate value is inferred automatically and applied by default (in other cases default=%default).'
)
parser.add_option('-m','--make_ref','--make_reference',dest='make_reference',
default='none',
help='Used during test construction phase. Pass an integer list via -m corresponding to the number of partial sums to perform on the reference stat data followed by a series of MC step factors. The number of partial means must divide evenly into the number of stat field values for the quantity in question. The step factors relate the length of the test run (shorter) to the reference run (longer): #MC_test*factor=#MC_reference. Files containing the reference data will be produced, one for each step factor. For the partial sums, the reference sigma is increased so that the test fails with the expected probability specified by the inputted nsigma.'
)
parser.add_option('-t','--plot_trace',dest='plot_trace',
action='store_true',default=False,
help='Plot traces of full and partial sums (default=%default).'
)
parser.add_option('-v','--verbose',
action='store_true',default=False,
help='Print detailed information (default=%default).'
)
allowed_quantities = list(stat_info.keys())
opt,files_in = parser.parse_args()
options = obj()
options.transfer_from(opt.__dict__)
if options.help:
print('\n'+parser.format_help().strip())
print('\n\nExample usage:')
print('\n Making reference data to create a test:')
print(" check_stats.py -p qmc -s 0 -q spindensity -e 10 -c 8 -v -m '0 10 100'")
print('\n Using reference data to perform a test:')
print(' check_stats.py -p qmc -s 0 -q spindensity -e 10 -c 8 -n 3 -r qmc.s000.stat_ref_spindensity_10.dat')
print()
exit()
#end if
if len(files_in)>0:
exit_fail('check_stats does not accept file as input, only command line arguments\nfiles provided: {0}'.format(files_in))
#end if
checkstats_settings.verbose = options.verbose
vlog('\nreading command line inputs')
options.series = int(options.series)
options.equilibration = int(options.equilibration)
options.nsigma = float(options.nsigma)
options.path,options.prefix = os.path.split(options.prefix)
if options.plot_trace and not can_plot:
vlog('trace plots requested, but plotting libraries are not available\ndisabling plots',n=1)
options.plot_trace = False
#end if
if options.path=='':
options.path = './'
#end if
options.qlabel = None
if ' ' in options.quantity or ',' in options.quantity:
qlist = options.quantity.strip('"').strip("'").replace(',',' ').split()
if len(qlist)!=2:
exit_fail('quantity can accept only one or two values\nyou provided {0}: {1}'.format(len(qlist),qlist))
#end if
options.quantity,options.qlabel = qlist
#end if
if options.qlabel is None:
options.qlabel = stat_info[options.quantity].default_label
#end if
if options.quantity=='none':
exit_fail('must provide quantity')
elif options.quantity not in allowed_quantities:
exit_fail('unrecognized quantity provided\nallowed quantities: {0}\nquantity provided: {1}'.format(allowed_quantities,options.quantity))
#end if
if options.npartial_sums=='none':
exit_fail('-c option is required')
#end if
options.npartial_sums = int(options.npartial_sums)
if options.reference_file!='none':
if not os.path.exists(options.reference_file):
exit_fail('reference file does not exist\nreference file provided: {0}'.format(options.reference_file))
#end if
options.make_reference = False
elif options.make_reference!='none':
try:
mr = array(options.make_reference.split(),dtype=int)
except:
exit_fail('make_reference must be a list of integers\nyou provided: {0}'.format(options.make_reference))
#end try
if len(mr)<1:
exit_fail('make_reference must contain at least one MC length factor')
#end if
options.mc_factors = mr
options.make_reference = True
else:
exit_fail('must provide either reference_file or make_reference')
#end if
fixed_sum_quants = set(['density','spindensity','energydensity'])
if options.quantity in fixed_sum_quants:
options.fixed_sum = True
#end if
if options.abs_err!='none':
try:
options.abs_err = float(options.abs_err)
except:
exit_fail('abs_err must be a real number\nyou provided: {}'.format(options.abs_err))
#end try
if options.abs_err<0:
exit_fail('abs_err must be positive\nyou provided: {}'.format(options.abs_err))
#end if
else:
options.abs_err = None
#end if
vlog('inputted options:\n'+str(options),n=1)
except Exception as e:
import traceback
import io
ftmp = io.StringIO()
traceback.print_exc(file=ftmp)
exit_fail('error during command line read:\n'+str(ftmp.getvalue()))
#end try
return options
#end def read_command_line
def process_stat_file(options):
vlog('processing stat.h5 file')
values = obj()
try:
# find all output files matching prefix
vlog('searching for qmcpack output files',n=1)
vlog('search path:\n '+options.path,n=2)
prefix = options.prefix+'.s'+str(options.series).zfill(3)
files = os.listdir(options.path)
output_files = obj()
for file in files:
if file.startswith(prefix):
if file.endswith('.stat.h5'):
output_files.stat = file
elif file.endswith('.scalar.dat'):
output_files.scalar = file
elif file.endswith('.dmc.dat'):
output_files.dmc = file
#end if
#end if
#end for
options.output_files = output_files
vlog('files found:\n'+str(output_files).rstrip(),n=2)
if 'stat' not in output_files:
exit_fail('stat.h5 file matching prefix {0} was not found\nsearch path: {1}'.format(prefix,options.path))
#end if
# read data from the stat file
vlog('opening stat.h5 file',n=1)
stat = read_hdf(os.path.join(options.path,output_files.stat),verbose=options.verbose,view=True)
vlog('file contents:\n'+repr(stat).rstrip(),n=2)
vlog('extracting {0} data'.format(options.quantity),n=1)
vlog('searching for {0} with label {1}'.format(options.quantity,options.qlabel),n=2)
if options.qlabel in stat:
qstat = stat[options.qlabel]
vlog('{0} data contents:\n{1}'.format(options.quantity,repr(qstat).rstrip()),n=2)
else:
exit_fail('could not find {0} data with label {1}'.format(options.quantity,options.qlabel))
#end if
qpaths = stat_info[options.quantity].data_paths
vlog('search paths:\n{0}'.format(str(qpaths).rstrip()),n=2)
qdata = obj()
dfull = None
for dname,dpath in qpaths.items():
packed = isinstance(dpath,tuple)
if packed:
dpath,dindex,dcount = dpath
#end if
if not qstat.path_exists(dpath):
exit_fail('{0} data not found in file {1}\npath searched: {2}'.format(options.quantity,output_files.stat,dpath))
#end if
if not packed:
d = array(qstat.get_path(dpath),dtype=float)
else:
if dfull is None:
dfull = array(qstat.get_path(dpath),dtype=float)
dfull.shape = dfull.shape[0],dfull.shape[1]//dcount,dcount
#end if
d = dfull[:,:,dindex]
d.shape = dfull.shape[0],dfull.shape[1]
#end if
qdata[dname] = d
vlog('{0} data found with shape {1}'.format(dname,d.shape),n=2)
if len(d.shape)>2:
d.shape = d.shape[0],d.size//d.shape[0]
vlog('reshaped {0} data to {1}'.format(dname,d.shape),n=2)
#end if
options.nblocks = d.shape[0]
#end for
# process the data, taking full and partial sums
vlog('processing {0} data'.format(options.quantity),n=1)
for dname,d in qdata.items():
vlog('processing {0} data'.format(dname),n=2)
if d.shape[1]%options.npartial_sums!=0:
exit_fail('cannot make partial sums\nnumber of requested partial sums does not divide evenly into the number of values available\nrequested partial sums: {0}\nnumber of values present: {1}\nnvalue/npartial_sums: {2}'.format(options.npartial_sums,d.shape[1],float(d.shape[1])/options.npartial_sums))
#end if
data = obj()
data.full_sum = d.sum(1)
vlog('full sum data shape: {0}'.format(data.full_sum.shape),n=3)
data.partial_sums = zeros((d.shape[0],options.npartial_sums))
psize = d.shape[1]//options.npartial_sums
for p in range(options.npartial_sums):
data.partial_sums[:,p] = d[:,p*psize:(p+1)*psize].sum(1)
#end for
vlog('partial sum data shape: {0}'.format(data.partial_sums.shape),n=3)
fmean,var,ferror,kappa = simstats(data.full_sum,exclude=options.equilibration)
vlog('full sum mean : {0}'.format(fmean),n=3)
vlog('full sum error: {0}'.format(ferror),n=3)
pmean,var,perror,kappa = simstats(data.partial_sums,dim=0,exclude=options.equilibration)
vlog('partial sum mean : {0}'.format(pmean),n=3)
vlog('partial sum error: {0}'.format(perror),n=3)
values[dname] = obj(
full_mean = fmean,
full_error = ferror,
partial_mean = pmean,
partial_error = perror,
data = data,
)
#end for
# check that all values have been processed
missing = set(qpaths.keys())-set(values.keys())
if len(missing)>0:
exit_fail('some values not processed\nvalues missing: {0}'.format(sorted(missing)))
#end if
# plot quantity traces, if requested
if options.plot_trace:
vlog('creating trace plots of full and partial sums',n=1)
for dname,dvalues in values.items():
label = options.quantity
if len(values)>1:
label+=' '+dname
#end if
data = dvalues.data
figure()
plot(data.full_sum)
title('Trace of {0} full sum'.format(label))
xlabel('Block index')
figure()
plot(data.partial_sums)
title('Trace of {0} partial sums'.format(label))
xlabel('Block index')
#end for
show()
#end if
except Exception as e:
import traceback
import io
ftmp = io.StringIO()
traceback.print_exc(file=ftmp)
exit_fail('error during stat file processing:\n'+str(ftmp.getvalue()))
#end try
return values
#end def process_stat_file
def make_reference_files(options,values):
vlog('\nmaking reference files')
# create a reference file for each Monte Carlo sample factor
for mcfac in options.mc_factors:
errfac = sqrt(1.0+mcfac)
filename = '{0}.s{1}.stat_ref_{2}_{3}.dat'.format(options.prefix,str(options.series).zfill(3),options.quantity,mcfac)
filepath = os.path.join(options.path,filename)
vlog('writing reference file for {0}x shorter test runs'.format(mcfac),n=1)
vlog('reference file location: '+filepath,n=2)
f = open(filepath,'w')
# write descriptive header line
line = '# '
for dname in sorted(values.keys()):
line += ' {0:<16} {1:<16}'.format(dname,dname+'_err')
#end for
f.write(line+'\n')
# write means and errors of full sum
line = ''
for dname in sorted(values.keys()):
dvalues = values[dname]
fmean = dvalues.full_mean
ferror = dvalues.full_error
if options.abs_err is not None:
err = 0.0
else:
err = errfac*ferror
#end if
line += ' {0: 16.12e} {1: 16.12e}'.format(fmean,err)
#end for
f.write(line+'\n')
# write means and errors of partial sums
for p in range(options.npartial_sums):
line = ''
for dname in sorted(values.keys()):
dvalues = values[dname]
pmean = dvalues.partial_mean
perror = dvalues.partial_error
if options.abs_err is not None:
err = 0.0
else:
err = errfac*perror[p]
#end if
line += ' {0: 16.12e} {1: 16.12e}'.format(pmean[p],err)
#end for
f.write(line+'\n')
#end for
f.close()
#end for
# create a trace file containing full and partial sum data per block
filename = '{0}.s{1}.stat_trace_{2}.dat'.format(options.prefix,str(options.series).zfill(3),options.quantity)
filepath = os.path.join(options.path,filename)
vlog('writing trace file containing full and partial sums per block',n=1)
vlog('trace file location: '+filepath,n=2)
f = open(filepath,'w')
# write descriptive header line
line = '# index '
for dname in sorted(values.keys()):
line += ' {0:<16}'.format(dname+'_full')
for p in range(options.npartial_sums):
line += ' {0:<16}'.format(dname+'_partial_'+str(p))
#end for
#end for
f.write(line+'\n')
# write full and partial sum data per block
for b in range(options.nblocks):
line = ' {0:>6}'.format(b)
for dname in sorted(values.keys()):
dvalues = values[dname].data
fsum = dvalues.full_sum
psums = dvalues.partial_sums[b]
line += ' {0: 16.12e}'.format(fsum[b])
for psum in psums:
line += ' {0: 16.12e}'.format(psum)
#end for
#end for
f.write(line+'\n')
#end for
f.close()
vlog('\n')
#end def make_reference_files
def read_reference_file(filepath):
vlog('reading reference file',n=1)
vlog('reference file location: {0}'.format(options.reference_file),n=2)
f = open(options.reference_file,'r')
dnames = f.readline().split()[1::2]
vlog('sub-quantities found: {0}'.format(dnames),n=2)
if set(dnames)!=set(values.keys()):
missing = set(values.keys())-set(dnames)
extra = set(dnames)-set(values.keys())
if missing>0:
exit_fail('some sub-quantities are missing\npresent in test files: {0}\npresent in reference files: {1}\nmissing: {2}'.format(sorted(values.keys()),sorted(dnames),sorted(missing)))
elif extra>0:
exit_fail('some sub-quantities are extra\npresent in test files: {0}\npresent in reference files: {1}\nextra: {2}'.format(sorted(values.keys()),sorted(dnames),sorted(extra)))
else:
exit_fail('developer error, this point should be impossible to reach')
#end if
#end if
ref = array(f.read().split(),dtype=float)
ref.shape = len(ref)//(2*len(dnames)),2*len(dnames)
full = ref[0,:].ravel()
partial = ref[1:,:].T
if len(ref)-1!=options.npartial_sums:
exit_fail('test and reference partial sum counts do not match\ntest partial sum count: {0}\nreference partial sum count: {1}'.format(options.npartial_sums,len(ref)-1))
#end if
vlog('partial sum count found: {0}'.format(len(ref)-1),n=2)
ref_values = obj()
for dname in dnames:
ref_values[dname] = obj()
#end for
n=0
for dname in dnames:
ref_values[dname].set(
full_mean = full[n],
full_error = full[n+1],
partial_mean = partial[n,:].ravel(),
partial_error = partial[n+1,:].ravel(),
)
n+=2
#end for
npartial = len(ref)-1
f.close()
dnames = sorted(dnames)
vlog('reference file read successfully',n=2)
return ref_values,dnames
#end def read_reference_file
# Checks computed values from stat.h5 files against specified reference values.
passfail = {True:'pass',False:'fail'}
def check_values(options,values):
# check_values is for statistical tests
vlog('\nchecking against reference values')
success = True
msg = ''
try:
msg += '\nTests for series {0} quantity "{1}"\n'.format(options.series,options.quantity)
# find nsigma for each partial sum
# overall probability of partial sum failure is according to original nsigma
vlog('adjusting nsigma to account for partial sum count',n=1)
x = longdouble(options.nsigma/sqrt(2.))
N = options.npartial_sums
nsigma_partial = sqrt(2.)*erfinv(erf(x)**(1./N))
vlog('overall full/partial test nsigma: {0}'.format(options.nsigma),n=2)
vlog('adjusted per partial sum nsigma : {0}'.format(nsigma_partial),n=2)
# read in the reference file
ref_values,dnames = read_reference_file(options)
# for cases with fixed full sum, check the per block sum
if options.fixed_sum:
vlog('checking per block fixed sums',n=1)
msg+='\n Fixed sum per block tests:\n'
dnames_fixed = dnames
if options.quantity=='energydensity':
dnames_fixed = ['W']
#end if
fixed_sum_success = True
ftol = 1e-8
eq = options.equilibration
for dname in dnames_fixed:
ref_vals = ref_values[dname]
ref_mean = ref_vals.full_mean
ref_error = ref_vals.full_error
if abs(ref_error/ref_mean)>ftol:
exit_fail('reference fixed sum is not fixed as asserted\ncannot check per block fixed sums\nplease check reference data')
#end if
test_vals = values[dname].data.full_sum
for i,v in enumerate(test_vals[eq:]):
if abs((v-ref_mean)/ref_mean)>ftol:
fixed_sum_success = False
msg += ' {0} {1} {2}!={3}\n'.format(dname,i,v,ref_mean)
#end if
#end for
#end for
if fixed_sum_success:
fmsg = 'all per block sums match the reference'
else:
fmsg = 'some per block sums do not match the reference'
#end if
vlog(fmsg,n=2)
msg += ' '+fmsg+'\n'
msg += ' status of this test: {0}\n'.format(passfail[fixed_sum_success])
success &= fixed_sum_success
#end if
# for the energy density, check per block sums against the scalar file
if options.quantity=='energydensity':
vlog('checking energy density terms per block',n=1)
msg+='\n Energy density sums vs. scalar file per block tests:\n'
scalars = load_scalar_file(options,'auto')
ed_success = True
ftol = 1e-8
ed_values = obj(
T = values.T.data.full_sum,
V = values.V.data.full_sum,
)
ed_values.E = ed_values.T + ed_values.V
if scalars.file_type=='scalar':
comparisons = obj(
E='LocalEnergy',
T='Kinetic',
V='LocalPotential',
)
elif scalars.file_type=='dmc':
comparisons = obj(E='LocalEnergy')
else:
exit_fail('unrecognized scalar file type: {0}'.format(scalars.file_type))
#end if
for k in sorted(comparisons.keys()):
ed_vals = ed_values[k]
sc_vals = scalars.data[comparisons[k]]
if scalars.file_type=='dmc':
if len(sc_vals)%len(ed_vals)==0 and len(sc_vals)>len(ed_vals):
steps = len(sc_vals)//len(ed_vals)
sc_vals.shape = len(sc_vals)//steps,steps
sc_vals = sc_vals.mean(1)
#end if
#end if
if len(ed_vals)!=len(sc_vals):
exit_fail('energy density per block test cannot be completed\nnumber of energy density and scalar blocks do not match\nenergy density blocks: {0}\nscalar file blocks: {1}'.format(len(ed_vals),len(sc_vals)))
#end if
for i,(edv,scv) in enumerate(zip(ed_vals,sc_vals)):
if abs((edv-scv)/scv)>ftol:
ed_success = False
msg += ' {0} {1} {2}!={3}\n'.format(k,i,edv,scv)
#end if
#end for
#end for
if ed_success:
fmsg = 'all per block sums match the scalar file'
else:
fmsg = 'some per block sums do not match the scalar file'
#end if
vlog(fmsg,n=2)
msg += ' '+fmsg+'\n'
msg += ' status of this test: {0}\n'.format(passfail[ed_success])
success &= ed_success
#end if
# function used immediately below to test a mean value vs reference
def check_mean(label,mean_comp,error_comp,mean_ref,error_ref,nsigma):
msg='\n Testing quantity: {0}\n'.format(label)
# ensure error_ref is large enough for non-statistical quantities
ctol = 1e-12
if abs(error_ref/mean_ref)<ctol:
error_ref = ctol*mean_ref
#end if
quant_success = abs(mean_comp-mean_ref) <= nsigma*error_ref
delta = mean_comp-mean_ref
delta_err = sqrt(error_comp**2+error_ref**2)
msg+=' reference mean value : {0: 12.8f}\n'.format(mean_ref)
msg+=' reference error bar : {0: 12.8f}\n'.format(error_ref)
msg+=' computed mean value : {0: 12.8f}\n'.format(mean_comp)
msg+=' computed error bar : {0: 12.8f}\n'.format(error_comp)
msg+=' pass tolerance : {0: 12.8f} ({1: 12.8f} sigma)\n'.format(nsigma*error_ref,nsigma)
if error_ref > 0.0:
msg+=' deviation from reference : {0: 12.8f} ({1: 12.8f} sigma)\n'.format(delta,delta/error_ref)
#end if
msg+=' error bar of deviation : {0: 12.8f}\n'.format(delta_err)
if error_ref > 0.0:
msg+=' significance probability : {0: 12.8f} (gaussian statistics)\n'.format(erf(abs(delta/error_ref)/math.sqrt(2.0)))
#end if
msg+=' status of this test : {0}\n'.format(passfail[quant_success])
return quant_success,msg
#end def check_mean
# check full and partial sums vs the reference
vlog('checking full and partial sums',n=1)
for dname in dnames:
vals = values[dname]
ref_vals = ref_values[dname]
# check full sum
vlog('checking full sum mean for "{0}"'.format(dname),n=2)
qsuccess,qmsg = check_mean(
label = '{0} full sum'.format(dname),
mean_comp = vals.full_mean,
error_comp = vals.full_error,
mean_ref = ref_vals.full_mean,
error_ref = ref_vals.full_error,
nsigma = options.nsigma,
)
vlog('status for full sum: {0}'.format(passfail[qsuccess]),n=3)
msg += qmsg
success &= qsuccess
# check partial sums
vlog('checking partial sum means for "{0}"'.format(dname),n=2)
for p in range(len(ref_vals.partial_mean)):
qsuccess,qmsg = check_mean(
label = '{0} partial sum {1}'.format(dname,p),
mean_comp = vals.partial_mean[p],
error_comp = vals.partial_error[p],
mean_ref = ref_vals.partial_mean[p],
error_ref = ref_vals.partial_error[p],
nsigma = nsigma_partial,
)
vlog('status for partial sum {0}: {1}'.format(p,passfail[qsuccess]),n=3)
msg += qmsg
success &= qsuccess
#end for
#end for
except Exception as e:
import traceback
import io
ftmp = io.StringIO()
traceback.print_exc(file=ftmp)
exit_fail('error during value check:\n'+str(ftmp.getvalue()))
#end try
return success,msg
#end def check_values
def check_values_abs_err(options,values):
# check_values is for deterministic tests based on an absolute tolerance
vlog('\nchecking against reference values')
success = True
msg = ''
try:
msg += '\nTests for series {0} quantity "{1}"\n'.format(options.series,options.quantity)
N = options.npartial_sums
# read in the reference file
ref_values,dnames = read_reference_file(options)
# function used immediately below to test a value vs reference
def check_abs_err(label,value_comp,value_ref,abs_tol):
msg='\n Testing quantity: {0}\n'.format(label)
vdiff = float_diff(value_comp,value_ref,atol=abs_tol,rtol=0.0)
quant_success = not vdiff
msg+=' reference value : {0: 16.12f}\n'.format(value_ref)
msg+=' computed value : {0: 16.12f}\n'.format(value_comp)
msg+=' pass tolerance : {0: 16.12f}\n'.format(abs_tol)
msg+=' deviation from reference : {0: 16.12f}\n'.format(value_comp-value_ref)
#end if
msg+=' status of this test : {0}\n'.format(passfail[quant_success])
return quant_success,msg
#end def check_abs_err
# check full and partial sums vs the reference
vlog('checking full and partial sums',n=1)
for dname in dnames:
vals = values[dname]
ref_vals = ref_values[dname]
# check full sum
vlog('checking full sum absolute error for "{0}"'.format(dname),n=2)
qsuccess,qmsg = check_abs_err(
label = '{0} full sum'.format(dname),
value_comp = vals.full_mean,
value_ref = ref_vals.full_mean,
abs_tol = max(ref_vals.full_error,options.abs_err)
)
vlog('status for full sum: {0}'.format(passfail[qsuccess]),n=3)
msg += qmsg
success &= qsuccess
# check partial sums
vlog('checking partial sum absolute errors for "{0}"'.format(dname),n=2)
for p in range(len(ref_vals.partial_mean)):
qsuccess,qmsg = check_abs_err(
label = '{0} partial sum {1}'.format(dname,p),
value_comp = vals.partial_mean[p],
value_ref = ref_vals.partial_mean[p],
abs_tol = max(ref_vals.partial_error[p],options.abs_err),
)
vlog('status for partial sum {0}: {1}'.format(p,passfail[qsuccess]),n=3)
msg += qmsg
success &= qsuccess
#end for
#end for
except Exception as e:
import traceback
import io
ftmp = io.StringIO()
traceback.print_exc(file=ftmp)
exit_fail('error during value check:\n'+str(ftmp.getvalue()))
#end try
return success,msg
#end def check_values_abs_err
# Main execution
if __name__=='__main__':
# Read and interpret command line options.
options = read_command_line()
# Compute means of desired quantities from a stat.h5 file.
values = process_stat_file(options)
if options.make_reference:
# Make reference files to create tests from.
make_reference_files(options,values)
exit()
else:
# Check computed means against reference solutions.
if options.abs_err:
success,msg = check_values_abs_err(options,values)
else:
success,msg = check_values(options,values)
#end if
# Pass success/failure exit codes and strings to the OS.
if success:
exit_pass(msg)
else:
exit_fail(msg)
#end if
#end if
#end if