mirror of https://github.com/QMCPACK/qmcpack.git
nexus: scans of arbitrary parameters in qmcpack workflows
git-svn-id: https://subversion.assembla.com/svn/qmcdev/trunk@7040 e5b18d87-469d-4833-9cc0-8cdfa06e9491
This commit is contained in:
parent
dccbcf66c3
commit
ae59ce1e66
|
@ -786,6 +786,7 @@ class Projwfc(PostProcessSimulation):
|
|||
lowdin_file = self.identifier+'.lowdin'
|
||||
filepath = os.path.join(self.locdir,lowdin_file)
|
||||
analyzer.write_lowdin(filepath)
|
||||
analyzer.write_lowdin(filepath+'_long',long=True)
|
||||
except:
|
||||
None
|
||||
#end try
|
||||
|
|
|
@ -161,6 +161,31 @@ def prevent_invalid_input(invalid,loc):
|
|||
#end def prevent_invalid_input
|
||||
|
||||
|
||||
def render_parameter_key(v,loc='render_parameter_key'):
|
||||
if isinstance(v,list):
|
||||
vkey = tuple(v)
|
||||
elif isinstance(v,ndarray):
|
||||
vkey = tuple(v.ravel())
|
||||
else:
|
||||
vkey = v
|
||||
#end if
|
||||
if not hashable(vkey):
|
||||
error('parameter value is not hashable\nvalue provided: {0}\nvalue type: {1}\nplease restrict parameter values to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
|
||||
#end if
|
||||
return vkey
|
||||
#end def render_parameter_key
|
||||
|
||||
|
||||
def render_parameter_label(vkey,loc='render_parameter_label'):
|
||||
if isinstance(vkey,(int,float,str)):
|
||||
vlabel = str(vkey)
|
||||
elif isinstance(vkey,tuple):
|
||||
vlabel = str(vkey).replace('(','').replace(')','').replace(' ','').replace(',','_')
|
||||
else:
|
||||
error('cannot transform parameter value key into a directory name\nvalue key: {0}\ntype: {1}\nplease restrict parameter values to basic types such as str,int,float,tuple and combinations of these'.format(vkey,vkey.__class__.__name__),loc)
|
||||
#end if
|
||||
return vlabel
|
||||
#end def render_parameter_label
|
||||
|
||||
|
||||
|
||||
|
@ -1394,8 +1419,10 @@ def system_scan(
|
|||
systems = missing,
|
||||
sysdirs = missing,
|
||||
syskeys = None,
|
||||
same_jastrow = False,
|
||||
same_jastrow = False, # deprecated
|
||||
jastrow_key = missing,
|
||||
J2_source = None,
|
||||
J3_source = None,
|
||||
loc = 'system_scan',
|
||||
**kwargs
|
||||
):
|
||||
|
@ -1419,10 +1446,9 @@ def system_scan(
|
|||
error('must provide one key per system via syskeys keyword\nnumber of keys provided: {0}\nnumber of systems: {1}'.format(len(syskeys),len(systems)),loc)
|
||||
#end if
|
||||
|
||||
same_jastrow = not missing(jastrow_key)
|
||||
if same_jastrow:
|
||||
if missing(jastrow_key):
|
||||
error('requested same jastrow across scan but no system key was provided via jastrow_key keyword',loc)
|
||||
elif jastrow_key not in set(syskeys):
|
||||
if jastrow_key not in set(syskeys):
|
||||
error('key used to identify jastrow for use across scan was not found\njastrow key provided: {0}\nsystem keys present: {1}'.format(jastrow_key,sorted(system_keys)),loc)
|
||||
#end if
|
||||
sys = obj()
|
||||
|
@ -1442,8 +1468,6 @@ def system_scan(
|
|||
#end for
|
||||
#end if
|
||||
|
||||
J2_source = None
|
||||
J3_source = None
|
||||
sims = obj()
|
||||
for n in xrange(len(systems)):
|
||||
qckw = process_qmcpack_chain_kwargs(
|
||||
|
@ -1452,15 +1476,11 @@ def system_scan(
|
|||
loc = loc,
|
||||
**kwargs
|
||||
)
|
||||
qckw.basepath = os.path.join(basepath,dirname,sysdirs[n])
|
||||
if J2_source is not None:
|
||||
qckw.J2_source = J2_source
|
||||
#end if
|
||||
if J3_source is not None:
|
||||
qckw.J3_source = J3_source
|
||||
#end if
|
||||
qckw.basepath = os.path.join(basepath,dirname,sysdirs[n])
|
||||
qckw.J2_source = J2_source
|
||||
qckw.J3_source = J3_source
|
||||
qcsims = qmcpack_chain(**qckw)
|
||||
if same_jastrow:
|
||||
if same_jastrow and n==0:
|
||||
J2_source = qcsims.get_optional('optJ2',None)
|
||||
J3_source = qcsims.get_optional('optJ3',None)
|
||||
#end if
|
||||
|
@ -1472,51 +1492,41 @@ def system_scan(
|
|||
|
||||
|
||||
def system_parameter_scan(
|
||||
basepath = missing,
|
||||
dirname = 'system_param_scan',
|
||||
sysfunc = missing,
|
||||
variable = missing,
|
||||
values = missing,
|
||||
fixed = None,
|
||||
loc = 'system_parameter_scan',
|
||||
basepath = missing,
|
||||
dirname = 'system_param_scan',
|
||||
sysfunc = missing,
|
||||
parameter = missing,
|
||||
variable = missing, # same as parameter, to be deprecated
|
||||
values = missing,
|
||||
fixed = None,
|
||||
loc = 'system_parameter_scan',
|
||||
**kwargs
|
||||
):
|
||||
|
||||
set_loc(loc)
|
||||
|
||||
require('basepath',basepath)
|
||||
require('sysfunc' ,sysfunc )
|
||||
require('variable',variable)
|
||||
require('values' ,values )
|
||||
if missing(parameter):
|
||||
parameter = variable
|
||||
#end if
|
||||
|
||||
require('basepath' ,basepath )
|
||||
require('sysfunc' ,sysfunc )
|
||||
require('parameter',parameter)
|
||||
require('values' ,values )
|
||||
|
||||
systems = []
|
||||
sysdirs = []
|
||||
syskeys = []
|
||||
for v in values:
|
||||
params = obj()
|
||||
params[variable] = v
|
||||
params[parameter] = v
|
||||
if fixed!=None:
|
||||
params.set(**fixed)
|
||||
#end if
|
||||
system = sysfunc(**params)
|
||||
if isinstance(v,list):
|
||||
vkey = tuple(v)
|
||||
elif isinstance(v,ndarray):
|
||||
vkey = tuple(v.ravel())
|
||||
else:
|
||||
vkey = v
|
||||
#end if
|
||||
if not hashable(vkey):
|
||||
error('inputted system generation variable value is not hashable\nvalue provided: {0}\nvalue type: {1}\nplease restrict system generation variables to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
|
||||
#end if
|
||||
if isinstance(vkey,(int,float,str)):
|
||||
vstr = str(vkey)
|
||||
elif isinstance(vkey,tuple):
|
||||
vstr = str(vkey).replace('(','').replace(')','').replace(' ','').replace(',','_')
|
||||
else:
|
||||
error('cannot convert system generation variable value into a directory name\nvalue provided: {0}\nvalue type: {1}\nplease restrict system generation variables to basic types such as str,int,float,tuple and combinations of these'.format(v,v.__class__.__name__),loc)
|
||||
#end if
|
||||
sysdir = '{0}_{1}'.format(variable,vstr)
|
||||
vkey = render_parameter_key(v,loc)
|
||||
vlabel = render_parameter_label(vkey,loc)
|
||||
sysdir = '{0}_{1}'.format(parameter,vlabel)
|
||||
systems.append(system)
|
||||
sysdirs.append(sysdir)
|
||||
syskeys.append(vkey)
|
||||
|
@ -1539,6 +1549,98 @@ def system_parameter_scan(
|
|||
|
||||
|
||||
|
||||
def input_parameter_scan(
|
||||
basepath = missing,
|
||||
dirname = 'input_param_scan',
|
||||
section = missing,
|
||||
parameter = missing,
|
||||
variable = missing, # same as parameter, to be deprecated
|
||||
values = missing,
|
||||
tags = missing,
|
||||
jastrow_key = missing,
|
||||
J2_source = None,
|
||||
J3_source = None,
|
||||
loc = 'input_parameter_scan',
|
||||
**kwargs
|
||||
):
|
||||
|
||||
set_loc(loc)
|
||||
|
||||
if missing(parameter):
|
||||
parameter = variable
|
||||
#end if
|
||||
|
||||
require('basepath' ,basepath )
|
||||
require('section' ,section )
|
||||
require('parameter',parameter)
|
||||
require('values' ,values )
|
||||
|
||||
if not missing(tags) and len(tags)!=len(values):
|
||||
error('must provide one tag (directory label) per parameter value\nnumber of values: {0}\nnumber of tags: {1}\nvalues provided: {2}\ntags provided: {3}'.format(len(values),len(tags),values,tags),loc)
|
||||
#end if
|
||||
|
||||
paramdirs = []
|
||||
paramkeys = []
|
||||
n = 0
|
||||
for v in values:
|
||||
if not missing(tags):
|
||||
vkey = tags[n]
|
||||
else:
|
||||
vkey = render_parameter_key(v,loc)
|
||||
#end if
|
||||
vlabel = render_parameter_label(vkey,loc)
|
||||
paramdir = '{0}_{1}'.format(parameter,vlabel)
|
||||
paramdirs.append(paramdir)
|
||||
paramkeys.append(vkey)
|
||||
n+=1
|
||||
#end for
|
||||
|
||||
same_jastrow = not missing(jastrow_key)
|
||||
if same_jastrow:
|
||||
jastrow_key = render_key(jastrow_key,loc)
|
||||
if not jastrow_key in set(paramkeys):
|
||||
error('input parameter value for fixed Jastrow not found\njastrow key provided: {0}\nparameter keys present: {1}'.format(jastrow_key,paramkeys),loc)
|
||||
#end if
|
||||
values = list(values)
|
||||
index = paramkeys.index(jastrow_key)
|
||||
paramdirs.insert(0,paramdirs.pop(index))
|
||||
paramkeys.insert(0,paramkeys.pop(index))
|
||||
values.insert(0,values.pop(index))
|
||||
#end if
|
||||
|
||||
ip_sims = obj()
|
||||
for n in xrange(len(paramkeys)):
|
||||
qckw = process_qmcpack_chain_kwargs(
|
||||
defaults = qmcpack_chain_defaults,
|
||||
loc = loc,
|
||||
**kwargs
|
||||
)
|
||||
# just do the simplest thing for now: duplicate full workflows
|
||||
# later should branch workflow only downstream from varying sim
|
||||
if 'inputs' not in section:
|
||||
error('section must be a name of the form *_inputs\nsection provided: {0}\ninput sections present: {1}'.format(section,[s for s in sorted(qckw.keys()) if 'inputs' in s]),loc)
|
||||
elif section not in qckw:
|
||||
error('input section not found\ninput section provided: {0}\ninput sections present: {1}'.format(section,[s for s in sorted(qckw.keys()) if 'inputs' in s]),loc)
|
||||
#end if
|
||||
qckw[section][parameter] = values[n]
|
||||
|
||||
qckw.basepath = os.path.join(basepath,dirname,paramdirs[n])
|
||||
qckw.J2_source = J2_source
|
||||
qckw.J3_source = J3_source
|
||||
qcsims = qmcpack_chain(**qckw)
|
||||
if same_jastrow and n==0:
|
||||
J2_source = qcsims.get_optional('optJ2',None)
|
||||
J3_source = qcsims.get_optional('optJ3',None)
|
||||
#end if
|
||||
ip_sims[paramkeys[n]] = qcsims
|
||||
#end for
|
||||
|
||||
return ip_sims
|
||||
#end def input_parameter_scan
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
print 'simple driver for qmcpack_workflows functions'
|
||||
|
||||
|
|
Loading…
Reference in New Issue