nexus: scans of arbitrary parameters in qmcpack workflows

git-svn-id: https://subversion.assembla.com/svn/qmcdev/trunk@7040 e5b18d87-469d-4833-9cc0-8cdfa06e9491
This commit is contained in:
Jaron Krogel 2016-07-29 18:50:43 +00:00
parent dccbcf66c3
commit ae59ce1e66
2 changed files with 147 additions and 44 deletions

View File

@ -786,6 +786,7 @@ class Projwfc(PostProcessSimulation):
lowdin_file = self.identifier+'.lowdin'
filepath = os.path.join(self.locdir,lowdin_file)
analyzer.write_lowdin(filepath)
analyzer.write_lowdin(filepath+'_long',long=True)
except:
None
#end try

View File

@ -161,6 +161,31 @@ def prevent_invalid_input(invalid,loc):
#end def prevent_invalid_input
def render_parameter_key(v,loc='render_parameter_key'):
if isinstance(v,list):
vkey = tuple(v)
elif isinstance(v,ndarray):
vkey = tuple(v.ravel())
else:
vkey = v
#end if
if not hashable(vkey):
error('parameter value is not hashable\nvalue provided: {0}\nvalue type: {1}\nplease restrict parameter values to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
#end if
return vkey
#end def render_parameter_key
def render_parameter_label(vkey,loc='render_parameter_label'):
if isinstance(vkey,(int,float,str)):
vlabel = str(vkey)
elif isinstance(vkey,tuple):
vlabel = str(vkey).replace('(','').replace(')','').replace(' ','').replace(',','_')
else:
error('cannot transform parameter value key into a directory name\nvalue key: {0}\ntype: {1}\nplease restrict parameter values to basic types such as str,int,float,tuple and combinations of these'.format(vkey,vkey.__class__.__name__),loc)
#end if
return vlabel
#end def render_parameter_label
@ -1394,8 +1419,10 @@ def system_scan(
systems = missing,
sysdirs = missing,
syskeys = None,
same_jastrow = False,
same_jastrow = False, # deprecated
jastrow_key = missing,
J2_source = None,
J3_source = None,
loc = 'system_scan',
**kwargs
):
@ -1419,10 +1446,9 @@ def system_scan(
error('must provide one key per system via syskeys keyword\nnumber of keys provided: {0}\nnumber of systems: {1}'.format(len(syskeys),len(systems)),loc)
#end if
same_jastrow = not missing(jastrow_key)
if same_jastrow:
if missing(jastrow_key):
error('requested same jastrow across scan but no system key was provided via jastrow_key keyword',loc)
elif jastrow_key not in set(syskeys):
if jastrow_key not in set(syskeys):
error('key used to identify jastrow for use across scan was not found\njastrow key provided: {0}\nsystem keys present: {1}'.format(jastrow_key,sorted(system_keys)),loc)
#end if
sys = obj()
@ -1442,8 +1468,6 @@ def system_scan(
#end for
#end if
J2_source = None
J3_source = None
sims = obj()
for n in xrange(len(systems)):
qckw = process_qmcpack_chain_kwargs(
@ -1452,15 +1476,11 @@ def system_scan(
loc = loc,
**kwargs
)
qckw.basepath = os.path.join(basepath,dirname,sysdirs[n])
if J2_source is not None:
qckw.J2_source = J2_source
#end if
if J3_source is not None:
qckw.J3_source = J3_source
#end if
qckw.basepath = os.path.join(basepath,dirname,sysdirs[n])
qckw.J2_source = J2_source
qckw.J3_source = J3_source
qcsims = qmcpack_chain(**qckw)
if same_jastrow:
if same_jastrow and n==0:
J2_source = qcsims.get_optional('optJ2',None)
J3_source = qcsims.get_optional('optJ3',None)
#end if
@ -1472,51 +1492,41 @@ def system_scan(
def system_parameter_scan(
basepath = missing,
dirname = 'system_param_scan',
sysfunc = missing,
variable = missing,
values = missing,
fixed = None,
loc = 'system_parameter_scan',
basepath = missing,
dirname = 'system_param_scan',
sysfunc = missing,
parameter = missing,
variable = missing, # same as parameter, to be deprecated
values = missing,
fixed = None,
loc = 'system_parameter_scan',
**kwargs
):
set_loc(loc)
require('basepath',basepath)
require('sysfunc' ,sysfunc )
require('variable',variable)
require('values' ,values )
if missing(parameter):
parameter = variable
#end if
require('basepath' ,basepath )
require('sysfunc' ,sysfunc )
require('parameter',parameter)
require('values' ,values )
systems = []
sysdirs = []
syskeys = []
for v in values:
params = obj()
params[variable] = v
params[parameter] = v
if fixed!=None:
params.set(**fixed)
#end if
system = sysfunc(**params)
if isinstance(v,list):
vkey = tuple(v)
elif isinstance(v,ndarray):
vkey = tuple(v.ravel())
else:
vkey = v
#end if
if not hashable(vkey):
error('inputted system generation variable value is not hashable\nvalue provided: {0}\nvalue type: {1}\nplease restrict system generation variables to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
#end if
if isinstance(vkey,(int,float,str)):
vstr = str(vkey)
elif isinstance(vkey,tuple):
vstr = str(vkey).replace('(','').replace(')','').replace(' ','').replace(',','_')
else:
error('cannot convert system generation variable value into a directory name\nvalue provided: {0}\nvalue type: {1}\nplease restrict system generation variables to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
#end if
sysdir = '{0}_{1}'.format(variable,vstr)
vkey = render_parameter_key(v,loc)
vlabel = render_parameter_label(vkey,loc)
sysdir = '{0}_{1}'.format(parameter,vlabel)
systems.append(system)
sysdirs.append(sysdir)
syskeys.append(vkey)
@ -1539,6 +1549,98 @@ def system_parameter_scan(
def input_parameter_scan(
basepath = missing,
dirname = 'input_param_scan',
section = missing,
parameter = missing,
variable = missing, # same as parameter, to be deprecated
values = missing,
tags = missing,
jastrow_key = missing,
J2_source = None,
J3_source = None,
loc = 'input_parameter_scan',
**kwargs
):
set_loc(loc)
if missing(parameter):
parameter = variable
#end if
require('basepath' ,basepath )
require('section' ,section )
require('parameter',parameter)
require('values' ,values )
if not missing(tags) and len(tags)!=len(values):
error('must provide one tag (directory label) per parameter value\nnumber of values: {0}\nnumber of tags: {1}\nvalues provided: {2}\ntags provided: {3}'.format(len(values),len(tags),values,tags),loc)
#end if
paramdirs = []
paramkeys = []
n = 0
for v in values:
if not missing(tags):
vkey = tags[n]
else:
vkey = render_parameter_key(v,loc)
#end if
vlabel = render_parameter_label(vkey,loc)
paramdir = '{0}_{1}'.format(parameter,vlabel)
paramdirs.append(paramdir)
paramkeys.append(vkey)
n+=1
#end for
same_jastrow = not missing(jastrow_key)
if same_jastrow:
jastrow_key = render_key(jastrow_key,loc)
if not jastrow_key in set(paramkeys):
error('input parameter value for fixed Jastrow not found\njastrow key provided: {0}\nparameter keys present: {1}'.format(jastrow_key,paramkeys),loc)
#end if
values = list(values)
index = paramkeys.index(jastrow_key)
paramdirs.insert(0,paramdirs.pop(index))
paramkeys.insert(0,paramkeys.pop(index))
values.insert(0,values.pop(index))
#end if
ip_sims = obj()
for n in xrange(len(paramkeys)):
qckw = process_qmcpack_chain_kwargs(
defaults = qmcpack_chain_defaults,
loc = loc,
**kwargs
)
# just do the simplest thing for now: duplicate full workflows
# later should branch workflow only downstream from varying sim
if 'inputs' not in section:
error('section must be a name of the form *_inputs\nsection provided: {0}\ninput sections present: {1}'.format(section,[s for s in sorted(qckw.keys()) if 'inputs' in s]),loc)
elif section not in qckw:
error('input section not found\ninput section provided: {0}\ninput sections present: {1}'.format(section,[s for s in sorted(qckw.keys()) if 'inputs' in s]),loc)
#end if
qckw[section][parameter] = values[n]
qckw.basepath = os.path.join(basepath,dirname,paramdirs[n])
qckw.J2_source = J2_source
qckw.J3_source = J3_source
qcsims = qmcpack_chain(**qckw)
if same_jastrow and n==0:
J2_source = qcsims.get_optional('optJ2',None)
J3_source = qcsims.get_optional('optJ3',None)
#end if
ip_sims[paramkeys[n]] = qcsims
#end for
return ip_sims
#end def input_parameter_scan
if __name__=='__main__':
print 'simple driver for qmcpack_workflows functions'