mirror of https://github.com/QMCPACK/qmcpack.git
2219 lines
69 KiB
Python
2219 lines
69 KiB
Python
##################################################################
|
|
## (c) Copyright 2015- by Jaron T. Krogel ##
|
|
##################################################################
|
|
|
|
|
|
import os
|
|
import types as pytypes
|
|
import keyword
|
|
from numpy import fromstring,empty,array,float64,\
|
|
loadtxt,ndarray,dtype,sqrt,pi,arange,exp,eye,ceil,mod
|
|
from StringIO import StringIO
|
|
from superstring import string2val
|
|
from generic import obj
|
|
from xmlreader import XMLreader,XMLelement
|
|
from periodic_table import pt as periodic_table
|
|
from project_base import Pobj
|
|
from structure import Structure
|
|
from physical_system import PhysicalSystem
|
|
from simulation import SimulationInput
|
|
|
|
|
|
|
|
yesno = {True:'yes' ,False:'no'}
|
|
truefalse = {True:'true',False:'false'}
|
|
onezero = {True:'1' ,False:'0'}
|
|
boolmap={'yes':True,'no':False,'true':True,'false':False,'1':True,'0':False}
|
|
|
|
def SQDbool(b):
|
|
return boolmap[b]
|
|
#end def SQDbool
|
|
|
|
def is_bool(var):
|
|
return var in boolmap
|
|
#end def is_bool
|
|
|
|
def is_int(var):
|
|
try:
|
|
int(var)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
#end def is_int
|
|
|
|
def is_float(var):
|
|
try:
|
|
float(var)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
#end def is_float
|
|
|
|
def is_array(var,type):
|
|
try:
|
|
if isinstance(var,str):
|
|
array(var.split(),type)
|
|
else:
|
|
array(var,type)
|
|
#end if
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
#end def is_float_array
|
|
|
|
def attribute_to_value(attr):
|
|
#if is_bool(attr):
|
|
# val = SQDbool(attr)
|
|
if is_int(attr):
|
|
val = int(attr)
|
|
elif is_float(attr):
|
|
val = float(attr)
|
|
elif is_array(attr,int):
|
|
val = array(attr.split(),int)
|
|
if val.size==9:
|
|
val.shape = 3,3
|
|
#end if
|
|
elif is_array(attr,float):
|
|
val = array(attr.split(),float)
|
|
else:
|
|
val = attr
|
|
#end if
|
|
return val
|
|
#end def attribute_to_value
|
|
|
|
|
|
|
|
|
|
def is_list(s):
|
|
try:
|
|
t = eval(s)
|
|
if isinstance(t,list):
|
|
return True
|
|
#end if
|
|
except (ValueError,NameError):
|
|
return False
|
|
#end def is_list
|
|
|
|
def string_to_val(s):
|
|
if is_int(s):
|
|
val = int(s)
|
|
elif is_float(s):
|
|
val = float(s)
|
|
elif is_list(s):
|
|
val = eval(s)
|
|
else:
|
|
val = s
|
|
#end if
|
|
return val
|
|
#end def string_to_val
|
|
|
|
|
|
class SQDobj(Pobj):
|
|
None
|
|
#end class SQDobj
|
|
|
|
|
|
class meta(obj):
|
|
None
|
|
#end class meta
|
|
|
|
|
|
class section(SQDobj):
|
|
def __init__(self,*args,**kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
#end def __init__
|
|
#end class section
|
|
|
|
|
|
class collection(SQDobj):
|
|
def __init__(self,*elements,**kwargs):
|
|
if len(elements)==1 and isinstance(elements[0],list):
|
|
elements = elements[0]
|
|
#end if
|
|
for element in elements:
|
|
identifier = element.identifier
|
|
if isinstance(identifier,str):
|
|
key = element[identifier]
|
|
else:
|
|
key = ''
|
|
if 'identifier_type' in element.__class__.__dict__:
|
|
identifier_type = element.identifier_type
|
|
else:
|
|
identifier_type = str
|
|
#end if
|
|
if identifier_type==str:
|
|
for ident in identifier:
|
|
if ident in element:
|
|
key+=element[ident]
|
|
#end if
|
|
#end for
|
|
elif identifier_type==tuple:
|
|
key = element.tuple(*identifier)
|
|
else:
|
|
self.error('identifier_type '+str(identifier_type)+' has not yet been implemented')
|
|
#end for
|
|
#end if
|
|
self[key] = element
|
|
#end for
|
|
for key,element in kwargs.iteritems():
|
|
self[key] = element
|
|
#end for
|
|
#end def __init__
|
|
|
|
def get_single(self,preference):
|
|
if len(self)>0:
|
|
if preference in self:
|
|
return self[preference]
|
|
else:
|
|
return self.list()[0]
|
|
#end if
|
|
else:
|
|
return self
|
|
#end if
|
|
#end def get_single
|
|
#end class collection
|
|
|
|
def make_collection(elements):
|
|
if len(elements)>0 and 'identifier' in elements[0].__class__.__dict__.keys():
|
|
return collection(*elements)
|
|
else:
|
|
coll = collection()
|
|
for i in range(len(elements)):
|
|
coll[i] = elements[i]
|
|
#end for
|
|
return coll
|
|
#end if
|
|
#end def make_collection
|
|
|
|
|
|
class classcollection(SQDobj):
|
|
def __init__(self,*classes):
|
|
if len(classes)==1 and isinstance(classes[0],list):
|
|
classes = classes[0]
|
|
#end if
|
|
self.classes = classes
|
|
#end def __init__
|
|
#end class classcollection
|
|
|
|
|
|
class Names(SQDobj):
|
|
names = set([
|
|
'atom','c','condition','eigensolve','grid','hamiltonian','id',
|
|
'l','m','n','name','npts','num_closed_shells','orbital',
|
|
'orbitalset','parameter','project','rf','ri','s','scale',
|
|
'series','simulation','type'])
|
|
|
|
bools = dict()
|
|
|
|
condensed_names = obj()
|
|
expanded_names = None
|
|
|
|
escape_names = set(keyword.kwlist)
|
|
escaped_names = list(escape_names)
|
|
for i in range(len(escaped_names)):
|
|
escaped_names[i]+='_'
|
|
#end for
|
|
escaped_names = set(escaped_names)
|
|
|
|
@staticmethod
|
|
def set_expanded_names(**kwargs):
|
|
Names.expanded_names = obj(**kwargs)
|
|
#end def set_expanded_names
|
|
|
|
|
|
def expand_name(self,condensed):
|
|
expanded = condensed
|
|
if expanded in self.escaped_names:
|
|
expanded = expanded[:-1]
|
|
#end if
|
|
if expanded in self.expanded_names:
|
|
expanded = self.expanded_names[expanded]
|
|
#end if
|
|
return expanded
|
|
#end def expand_name
|
|
|
|
def condense_name(self,expanded):
|
|
condensed = expanded
|
|
condensed = condensed.replace('___','_').replace('__','_')
|
|
condensed = condensed.replace('-','_').replace(' ','_')
|
|
condensed = condensed.lower()
|
|
if condensed in self.escape_names:
|
|
condensed += '_'
|
|
#end if
|
|
self.condensed_names[expanded]=condensed
|
|
return condensed
|
|
#end def condense_name
|
|
|
|
|
|
def condense_names(self,*namelists):
|
|
out = []
|
|
for namelist in namelists:
|
|
exp = obj()
|
|
for expanded in namelist:
|
|
condensed = self.condense_name(expanded)
|
|
exp[condensed]=expanded
|
|
#end for
|
|
out.append(exp)
|
|
#end for
|
|
return out
|
|
#end def condense_names
|
|
|
|
def condensed_name_report(self):
|
|
print
|
|
print 'Condensed Name Report:'
|
|
print '----------------------'
|
|
keylist = array(self.condensed_names.keys())
|
|
order = array(self.condensed_names.values()).argsort()
|
|
keylist = keylist[order]
|
|
for expanded in keylist:
|
|
condensed = self.condensed_names[expanded]
|
|
if expanded!=condensed:
|
|
print " {0:15} = '{1}'".format(condensed,expanded)
|
|
#end if
|
|
#end for
|
|
print
|
|
print
|
|
#end def condensed_name_report
|
|
#end class Names
|
|
|
|
|
|
|
|
|
|
class SQDxml(Names):
|
|
|
|
def init_from_args(self,args):
|
|
print
|
|
print args
|
|
print
|
|
self.not_implemented()
|
|
#end def init_from_args
|
|
|
|
|
|
|
|
@classmethod
|
|
def init_class(cls):
|
|
vars = cls.__dict__.keys()
|
|
init_vars = dict(tag = cls.__name__,
|
|
attributes = [],
|
|
elements = [],
|
|
text = None,
|
|
parameters = [],
|
|
attribs = [],
|
|
costs = [],
|
|
h5tags = [],
|
|
defaults = obj()
|
|
)
|
|
for var,val in init_vars.iteritems():
|
|
if not var in vars:
|
|
cls.__dict__[var] = val
|
|
#end if
|
|
#end for
|
|
for v in ['attributes','elements','parameters','attribs','costs','h5tags']:
|
|
names = cls.__dict__[v]
|
|
for i in range(len(names)):
|
|
if names[i] in cls.escape_names:
|
|
names[i]+='_'
|
|
#end if
|
|
#end for
|
|
#end for
|
|
cls.params = cls.parameters + cls.attribs + cls.costs + cls.h5tags
|
|
cls.plurals_inv = obj()
|
|
for e in cls.elements:
|
|
if e in plurals_inv:
|
|
cls.plurals_inv[e] = plurals_inv[e]
|
|
#end if
|
|
#end for
|
|
cls.plurals = cls.plurals_inv.inverse()
|
|
#end def init_class
|
|
|
|
|
|
def write(self,indent_level=0,pad=' ',first=False):
|
|
indent = indent_level*pad
|
|
ip = indent+pad
|
|
ipp= ip+pad
|
|
c = indent+'<'+self.tag
|
|
for a in self.attributes:
|
|
if a in self:
|
|
val = self[a]
|
|
if isinstance(val,str):
|
|
val = self.expand_name(val)
|
|
#end if
|
|
c += ' '+self.expand_name(a)+'='
|
|
if a in self.bools and (val==True or val==False):
|
|
c += '"'+self.bools[a][val]+'"'
|
|
else:
|
|
c += '"'+param.write(val)+'"'
|
|
#end if
|
|
#end if
|
|
#end for
|
|
if first:
|
|
None
|
|
#c+=' xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="http://www.mcc.uiuc.edu/qmc/schema/molecu.xsd"'
|
|
#end if
|
|
#no_contents = len(set(self.keys())-set(self.elements)-set(self.plurals.keys()))==0
|
|
no_contents = len(set(self.keys())-set(self.attributes))==0
|
|
if no_contents:
|
|
c += '/>\n'
|
|
else:
|
|
c += '>\n'
|
|
for v in self.h5tags:
|
|
if v in self:
|
|
c += param.write(self[v],name=self.expand_name(v),tag='h5tag',mode='elem',pad=ip)
|
|
#end if
|
|
#end for
|
|
for v in self.costs:
|
|
if v in self:
|
|
c += param.write(self[v],name=self.expand_name(v),tag='cost',mode='elem',pad=ip)
|
|
#end if
|
|
#end for
|
|
for p in self.parameters:
|
|
if p in self:
|
|
c += param.write(self[p],name=self.expand_name(p),mode='elem',pad=ip)
|
|
#end if
|
|
#end for
|
|
for a in self.attribs:
|
|
if a in self:
|
|
c += param.write(self[a],name=self.expand_name(a),tag='attrib',mode='elem',pad=ip)
|
|
#end if
|
|
#end for
|
|
for e in self.elements:
|
|
if e in self:
|
|
elem = self[e]
|
|
if isinstance(elem,SQDxml):
|
|
c += self[e].write(indent_level+1)
|
|
else:
|
|
begin = '<'+e+'>'
|
|
contents = param.write(elem)
|
|
end = '</'+e+'>'
|
|
if contents.strip()=='':
|
|
c += ip+begin+end+'\n'
|
|
else:
|
|
c += ip+begin+'\n'
|
|
c += ipp+contents+'\n'
|
|
c += ip+end+'\n'
|
|
#end if
|
|
#end if
|
|
elif e in plurals_inv and plurals_inv[e] in self:
|
|
coll = self[plurals_inv[e]]
|
|
coll_len = len(coll)
|
|
if 0 in coll and coll_len-1 in coll:
|
|
for i in range(coll_len):
|
|
instance = coll[i]
|
|
c += instance.write(indent_level+1)
|
|
#end for
|
|
else:
|
|
keys = coll.keys()
|
|
keys.sort()
|
|
for key in keys:
|
|
instance = coll[key]
|
|
c += instance.write(indent_level+1)
|
|
#end for
|
|
#end if
|
|
#end if
|
|
#end for
|
|
#for p in self.plurals.keys():
|
|
# if p in self:
|
|
# for instance in self[p]:
|
|
# c += instance.write(indent_level+1)
|
|
# #end for
|
|
# #end if
|
|
##end for
|
|
if self.text!=None:
|
|
#c+=ip+param.write(self[self.text])+'\n'
|
|
c+=param.write(self[self.text],mode='elem',pad=ip,tag=None)
|
|
#end if
|
|
c+=indent+'</'+self.tag+'>\n'
|
|
#end if
|
|
return c
|
|
#end def write
|
|
|
|
|
|
def __init__(self,*args,**kwargs):
|
|
if Param.metadata==None:
|
|
Param.metadata = meta()
|
|
#end if
|
|
if len(args)==1:
|
|
a = args[0]
|
|
if isinstance(a,XMLelement):
|
|
self.init_from_xml(a)
|
|
elif isinstance(a,section):
|
|
self.init_from_inputs(a.args,a.kwargs)
|
|
elif isinstance(a,self.__class__):
|
|
self.transfer_from(a)
|
|
else:
|
|
self.init_from_inputs(args,kwargs)
|
|
#end if
|
|
else:
|
|
self.init_from_inputs(args,kwargs)
|
|
#end if
|
|
#end def __init__
|
|
|
|
|
|
def init_from_xml(self,xml):
|
|
al,el = self.condense_names(xml._attributes.keys(),xml._elements.keys())
|
|
xa,sa = set(al.keys()) , set(self.attributes)
|
|
attr = xa & sa
|
|
junk = xa-attr
|
|
junk_elem = []
|
|
for e,ecap in el.iteritems():
|
|
value = xml._elements[ecap]
|
|
if (isinstance(value,list) or isinstance(value,tuple)) and e in self.plurals_inv.keys():
|
|
p = self.plurals_inv[e]
|
|
plist = []
|
|
for instance in value:
|
|
plist.append(types[e](instance))
|
|
#end for
|
|
self[p] = make_collection(plist)
|
|
elif e in self.elements:
|
|
self[e] = types[e](value)
|
|
elif e in ['parameter','attrib','cost','h5tag']:
|
|
if isinstance(value,XMLelement):
|
|
value = [value]
|
|
#end if
|
|
for p in value:
|
|
name = self.condense_name(p.name)
|
|
if name in self.params:
|
|
self[name] = param(p)
|
|
else:
|
|
junk_elem.append(name)
|
|
#end if
|
|
#end for
|
|
else:
|
|
junk_elem.append(e)
|
|
#end if
|
|
#end for
|
|
junk = junk | set(junk_elem)
|
|
self.check_junk(junk)
|
|
for a in attr:
|
|
if a in self.bools:
|
|
self[a] = boolmap[xml._attributes[al[a]]]
|
|
else:
|
|
self[a] = attribute_to_value(xml._attributes[al[a]])
|
|
#end if
|
|
#end for
|
|
if self.text!=None:
|
|
self[self.text] = param(xml)
|
|
#end if
|
|
#end def init_from_xml
|
|
|
|
|
|
def init_from_inputs(self,args,kwargs):
|
|
if len(args)>0:
|
|
if len(args)==1 and isinstance(args[0],self.__class__):
|
|
self.transfer_from(args[0])
|
|
else:
|
|
self.init_from_args(args)
|
|
#end if
|
|
#end if
|
|
self.init_from_kwargs(kwargs)
|
|
#end def init_from_inputs
|
|
|
|
|
|
def init_from_kwargs(self,kwargs):
|
|
ks=[]
|
|
kmap = dict()
|
|
for key,val in kwargs.iteritems():
|
|
ckey = self.condense_name(key)
|
|
ks.append(ckey)
|
|
kmap[ckey] = val
|
|
#end for
|
|
ks = set(ks)
|
|
kwargs = kmap
|
|
h5tags = ks & set(self.h5tags)
|
|
costs = ks & set(self.costs)
|
|
parameters = ks & set(self.parameters)
|
|
attribs = ks & set(self.attribs)
|
|
attr = ks & set(self.attributes)
|
|
elem = ks & set(self.elements)
|
|
plur = ks & set(self.plurals.keys())
|
|
if self.text!=None:
|
|
text = ks & set([self.text])
|
|
else:
|
|
text = set()
|
|
#end if
|
|
junk = ks -attr -elem -plur -h5tags -costs -parameters -attribs -text
|
|
self.check_junk(junk)
|
|
for v in h5tags:
|
|
self[v] = param(kwargs[v])
|
|
#end for
|
|
for v in costs:
|
|
self[v] = param(kwargs[v])
|
|
#end for
|
|
for v in parameters:
|
|
self[v] = param(kwargs[v])
|
|
#end for
|
|
for v in attribs:
|
|
self[v] = param(kwargs[v])
|
|
#end for
|
|
for a in attr:
|
|
self[a] = kwargs[a]
|
|
#end for
|
|
for e in elem:
|
|
self[e] = types[e](kwargs[e])
|
|
#end for
|
|
for p in plur:
|
|
plist = []
|
|
e = self.plurals[p]
|
|
kwcoll = kwargs[p]
|
|
if isinstance(kwcoll,collection):
|
|
cobj = collection()
|
|
for name,instance in kwcoll.iteritems():
|
|
iobj = types[e](instance)
|
|
if isinstance(iobj.identifier,str):
|
|
iobj[iobj.identifier]=name
|
|
#end if
|
|
cobj[name] = iobj
|
|
#end for
|
|
self[p] = cobj
|
|
else:
|
|
for instance in kwargs[p]:
|
|
plist.append(types[e](instance))
|
|
#end for
|
|
self[p] = make_collection(plist)
|
|
#end if
|
|
#end for
|
|
for t in text:
|
|
self[t] = kwargs[t]
|
|
#end for
|
|
#end def init_from_kwargs
|
|
|
|
|
|
def incorporate_defaults(self,elements=False,overwrite=False,propagate=True):
|
|
for name,value in self.defaults.iteritems():
|
|
valtype = type(value)
|
|
defval=None
|
|
if isinstance(value,classcollection):
|
|
if elements:
|
|
coll=[]
|
|
for cls in value.classes:
|
|
ins = cls()
|
|
ins.incorporate_defaults()
|
|
coll.append(ins)
|
|
#end for
|
|
defval = make_collection(coll)
|
|
#end if
|
|
elif valtype==pytypes.ClassType:
|
|
if elements:
|
|
defval = value()
|
|
#end if
|
|
elif valtype==pytypes.FunctionType:
|
|
defval = value()
|
|
else:
|
|
defval = value
|
|
#end if
|
|
if defval!=None:
|
|
if overwrite or not name in self:
|
|
self[name] = defval
|
|
#end if
|
|
#end if
|
|
#end for
|
|
if propagate:
|
|
for name,value in self.iteritems():
|
|
if isinstance(value,SQDxml):
|
|
value.incorporate_defaults(elements,overwrite)
|
|
elif isinstance(value,collection):
|
|
for v in value:
|
|
if isinstance(v,SQDxml):
|
|
v.incorporate_defaults(elements,overwrite)
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end def incorporate_defaults
|
|
|
|
|
|
def check_junk(self,junk):
|
|
if len(junk)>0:
|
|
msg = self.tag+' does not have the following attributes/elements:\n'
|
|
for jname in junk:
|
|
msg+=' '+jname+'\n'
|
|
#end for
|
|
self.error(msg,'SqdInput',exit=False,trace=False)
|
|
#self.error(msg,'SqdInput')
|
|
#end if
|
|
#end def check_junk
|
|
|
|
|
|
def get_single(self,preference):
|
|
return self
|
|
#end def get_single
|
|
|
|
|
|
def get(self,names,namedict=None,host=False,root=True):
|
|
if namedict is None:
|
|
namedict = {}
|
|
#end if
|
|
if isinstance(names,str):
|
|
names = [names]
|
|
#end if
|
|
if root and not host:
|
|
if 'identifier' in self.__class__.__dict__ and self.identifier in self:
|
|
identity = self[self.identifier]
|
|
else:
|
|
identity = None
|
|
#end if
|
|
for name in names:
|
|
if name==self.tag:
|
|
namedict[name]=self
|
|
elif name==identity:
|
|
namedict[name]=self
|
|
#end if
|
|
#end for
|
|
#end if
|
|
for name in names:
|
|
loc = None
|
|
if name in self:
|
|
loc = name
|
|
elif name in plurals_inv and plurals_inv[name] in self:
|
|
loc = plurals_inv[name]
|
|
#end if
|
|
name_absent = not name in namedict
|
|
not_element = False
|
|
if not name_absent:
|
|
not_xml = not isinstance(namedict[name],SQDxml)
|
|
not_coll = not isinstance(namedict[name],collection)
|
|
not_element = not_xml and not_coll
|
|
#end if
|
|
if loc!=None and (name_absent or not_element):
|
|
if host:
|
|
namedict[name] = self
|
|
else:
|
|
namedict[name] = self[loc]
|
|
#end if
|
|
#end if
|
|
#end for
|
|
for name,value in self.iteritems():
|
|
if isinstance(value,SQDxml):
|
|
value.get(names,namedict,host,root=False)
|
|
elif isinstance(value,collection):
|
|
for n,v in value.iteritems():
|
|
name_absent = not n in namedict
|
|
not_element = False
|
|
if not name_absent:
|
|
not_xml = not isinstance(namedict[n],SQDxml)
|
|
not_coll = not isinstance(namedict[n],collection)
|
|
not_element = not_xml and not_coll
|
|
#end if
|
|
if n in names and (name_absent or not_element):
|
|
if host:
|
|
namedict[n] = value
|
|
else:
|
|
namedict[n] = v
|
|
#end if
|
|
#end if
|
|
if isinstance(v,SQDxml):
|
|
v.get(names,namedict,host,root=False)
|
|
#end if
|
|
#end if
|
|
#end if
|
|
#end for
|
|
if root:
|
|
namelist = []
|
|
for name in names:
|
|
if name in namedict:
|
|
namelist.append(namedict[name])
|
|
else:
|
|
namelist.append(None)
|
|
#end if
|
|
#end for
|
|
if len(namelist)==1:
|
|
return namelist[0]
|
|
else:
|
|
return namelist
|
|
#end if
|
|
#end if
|
|
#end def get
|
|
|
|
def remove(self,*names):
|
|
if len(names)==1 and not isinstance(names[0],str):
|
|
names = names[0]
|
|
#end if
|
|
remove = []
|
|
for name in names:
|
|
attempt = True
|
|
if name in self:
|
|
rname = name
|
|
elif name in plurals_inv and plurals_inv[name] in self:
|
|
rname = plurals_inv[name]
|
|
else:
|
|
attempt = False
|
|
#end if
|
|
if attempt:
|
|
val = self[rname]
|
|
if isinstance(val,SQDxml) or isinstance(val,collection):
|
|
remove.append(rname)
|
|
#end if
|
|
#end if
|
|
#end for
|
|
for name in remove:
|
|
del self[name]
|
|
#end for
|
|
for name,value in self.iteritems():
|
|
if isinstance(value,SQDxml):
|
|
value.remove(*names)
|
|
elif isinstance(value,collection):
|
|
for element in value:
|
|
if isinstance(element,SQDxml):
|
|
element.remove(*names)
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end for
|
|
#end def remove
|
|
|
|
|
|
def replace(self,*args,**kwargs):
|
|
if len(args)==2 and isinstance(args[0],str) and isinstance(args[1],str):
|
|
vold,vnew = args
|
|
args = [(vold,vnew)]
|
|
#end for
|
|
for valpair in args:
|
|
vold,vnew = valpair
|
|
for var,val in self.iteritems():
|
|
not_coll = not isinstance(val,collection)
|
|
not_xml = not isinstance(val,SQDxml)
|
|
not_arr = not isinstance(val,ndarray)
|
|
if not_coll and not_xml and not_arr and val==vold:
|
|
self[var] = vnew
|
|
#end if
|
|
#end for
|
|
#end for
|
|
for var,valpair in kwargs.iteritems():
|
|
vold,vnew = valpair
|
|
if var in self:
|
|
val = self[var]
|
|
if vold==None:
|
|
self[var] = vnew
|
|
else:
|
|
not_coll = not isinstance(val,collection)
|
|
not_xml = not isinstance(val,SQDxml)
|
|
not_arr = not isinstance(val,ndarray)
|
|
if not_coll and not_xml and not_arr and val==vold:
|
|
self[var] = vnew
|
|
#end if
|
|
#end if
|
|
#end if
|
|
#end for
|
|
for vname,val in self.iteritems():
|
|
if isinstance(val,SQDxml):
|
|
val.replace(*args,**kwargs)
|
|
elif isinstance(val,collection):
|
|
for v in val:
|
|
if isinstance(v,SQDxml):
|
|
v.replace(*args,**kwargs)
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end for
|
|
#end def replace
|
|
|
|
|
|
def combine(self,other):
|
|
#elemental combine only
|
|
for name,element in other.iteritems():
|
|
plural = isinstance(element,collection)
|
|
single = isinstance(element,SQDxml)
|
|
if single or plural:
|
|
elem = []
|
|
single_name = None
|
|
plural_name = None
|
|
if single:
|
|
elem.append(element)
|
|
single_name = name
|
|
if name in plurals_inv:
|
|
plural_name = plurals_inv[name]
|
|
#end if
|
|
else:
|
|
elem.extend(element.values())
|
|
plural_name = name
|
|
single_name = plurals[name]
|
|
#end if
|
|
if single_name in self:
|
|
elem.append(self[single_name])
|
|
del self[single_name]
|
|
elif plural_name!=None and plural_name in self:
|
|
elem.append(self[plural_name])
|
|
del self[plural_name]
|
|
#end if
|
|
if len(elem)==1:
|
|
self[single_name]=elem[0]
|
|
elif plural_name==None:
|
|
self.error('attempting to combine non-plural elements: '+single_name)
|
|
else:
|
|
self[plural_name] = make_collection(elem)
|
|
#end if
|
|
#end if
|
|
#end for
|
|
#end def combine
|
|
|
|
|
|
def move(self,**elemdests):
|
|
names = elemdests.keys()
|
|
hosts = self.get_host(names)
|
|
dests = self.get(elemdests.values())
|
|
if len(names)==1:
|
|
hosts = [hosts]
|
|
dests = [dests]
|
|
#end if
|
|
for i in range(len(names)):
|
|
name = names[i]
|
|
host = hosts[i]
|
|
dest = dests[i]
|
|
if host!=None and dest!=None and id(host)!=id(dest):
|
|
if not name in host:
|
|
name = plurals_inv[name]
|
|
#end if
|
|
dest[name] = host[name]
|
|
del host[name]
|
|
#end if
|
|
#end for
|
|
#end def move
|
|
|
|
|
|
|
|
def pluralize(self):
|
|
make_plural = []
|
|
for name,value in self.iteritems():
|
|
if name in plurals_inv:
|
|
make_plural.append(name)
|
|
#end if
|
|
if isinstance(value,SQDxml):
|
|
value.pluralize()
|
|
elif isinstance(value,collection):
|
|
for v in value:
|
|
if isinstance(v,SQDxml):
|
|
v.pluralize()
|
|
#end if
|
|
#end for
|
|
#end if
|
|
#end for
|
|
for name in make_plural:
|
|
value = self[name]
|
|
del self[name]
|
|
plural_name = plurals_inv[name]
|
|
self[plural_name] = make_collection([value])
|
|
#end for
|
|
#end def pluralize
|
|
|
|
|
|
def difference(self,other,root=True):
|
|
if root:
|
|
q1 = self.copy()
|
|
q2 = other.copy()
|
|
else:
|
|
q1 = self
|
|
q2 = other
|
|
#end if
|
|
if q1.__class__!=q2.__class__:
|
|
different = True
|
|
diff = None
|
|
d1 = q1
|
|
d2 = q2
|
|
else:
|
|
cls = q1.__class__
|
|
s1 = set(q1.keys())
|
|
s2 = set(q2.keys())
|
|
shared = s1 & s2
|
|
unique1 = s1 - s2
|
|
unique2 = s2 - s1
|
|
different = len(unique1)>0 or len(unique2)>0
|
|
diff = cls()
|
|
d1 = cls()
|
|
d2 = cls()
|
|
d1.transfer_from(q1,unique1)
|
|
d2.transfer_from(q2,unique2)
|
|
for k in shared:
|
|
value1 = q1[k]
|
|
value2 = q2[k]
|
|
is_coll1 = isinstance(value1,collection)
|
|
is_coll2 = isinstance(value2,collection)
|
|
is_qxml1 = isinstance(value1,SQDxml)
|
|
is_qxml2 = isinstance(value2,SQDxml)
|
|
if is_coll1!=is_coll2 or is_qxml1!=is_qxml2:
|
|
self.error('values for '+k+' are of inconsistent types\n difference could not be taken')
|
|
#end if
|
|
if is_qxml1 and is_qxml2:
|
|
kdifferent,kdiff,kd1,kd2 = value1.difference(value2,root=False)
|
|
elif is_coll1 and is_coll2:
|
|
ks1 = set(value1.keys())
|
|
ks2 = set(value2.keys())
|
|
kshared = ks1 & ks2
|
|
kunique1 = ks1 - ks2
|
|
kunique2 = ks2 - ks1
|
|
kdifferent = len(kunique1)>0 or len(kunique2)>0
|
|
kd1 = collection()
|
|
kd2 = collection()
|
|
kd1.transfer_from(value1,kunique1)
|
|
kd2.transfer_from(value2,kunique2)
|
|
kdiff = collection()
|
|
for kk in kshared:
|
|
v1 = value1[kk]
|
|
v2 = value2[kk]
|
|
if isinstance(v1,SQDxml) and isinstance(v2,SQDxml):
|
|
kkdifferent,kkdiff,kkd1,kkd2 = v1.difference(v2,root=False)
|
|
kdifferent = kdifferent or kkdifferent
|
|
if kkdiff!=None:
|
|
kdiff[kk]=kkdiff
|
|
#end if
|
|
if kkd1!=None:
|
|
kd1[kk]=kkd1
|
|
#end if
|
|
if kkd2!=None:
|
|
kd2[kk]=kkd2
|
|
#end if
|
|
#end if
|
|
#end for
|
|
else:
|
|
if isinstance(value1,ndarray):
|
|
a1 = value1.ravel()
|
|
else:
|
|
a1 = array([value1])
|
|
#end if
|
|
if isinstance(value2,ndarray):
|
|
a2 = value2.ravel()
|
|
else:
|
|
a2 = array([value2])
|
|
#end if
|
|
if len(a1)!=len(a2):
|
|
kdifferent = True
|
|
elif len(a1)==0:
|
|
kdifferent = False
|
|
elif (isinstance(a1[0],float) or isinstance(a2[0],float)) and not (isinstance(a1[0],str) or isinstance(a2[0],str)):
|
|
kdifferent = abs(a1-a2).max()/max(1e-99,abs(a1).max(),abs(a2).max()) > 1e-6
|
|
else:
|
|
kdifferent = not (a1==a2).all()
|
|
#end if
|
|
if kdifferent:
|
|
kdiff = (value1,value2)
|
|
kd1 = value1
|
|
kd2 = value2
|
|
else:
|
|
kdiff = None
|
|
kd1 = None
|
|
kd2 = None
|
|
#end if
|
|
#end if
|
|
different = different or kdifferent
|
|
if kdiff!=None:
|
|
diff[k] = kdiff
|
|
#end if
|
|
if kd1!=None:
|
|
d1[k] = kd1
|
|
#end if
|
|
if kd2!=None:
|
|
d2[k] = kd2
|
|
#end if
|
|
#end for
|
|
#end if
|
|
if root:
|
|
if diff!=None:
|
|
diff.remove_empty()
|
|
#end if
|
|
d1.remove_empty()
|
|
d2.remove_empty()
|
|
#end if
|
|
return different,diff,d1,d2
|
|
#end def difference
|
|
|
|
def remove_empty(self):
|
|
names = list(self.keys())
|
|
for name in names:
|
|
value = self[name]
|
|
if isinstance(value,SQDxml):
|
|
value.remove_empty()
|
|
if len(value)==0:
|
|
del self[name]
|
|
#end if
|
|
elif isinstance(value,collection):
|
|
ns = list(value.keys())
|
|
for n in ns:
|
|
v = value[n]
|
|
if isinstance(v,SQDxml):
|
|
v.remove_empty()
|
|
if len(v)==0:
|
|
del value[n]
|
|
#end if
|
|
#end if
|
|
#end for
|
|
if len(value)==0:
|
|
del self[name]
|
|
#end if
|
|
#end if
|
|
#end for
|
|
#end def remove_empty
|
|
|
|
def get_host(self,names):
|
|
return self.get(names,host=True)
|
|
#end def get_host
|
|
#end class SQDxml
|
|
|
|
|
|
|
|
class SQDxmlFactory(Names):
|
|
def __init__(self,name,types,typekey='',typeindex=-1,typekey2='',default=None):
|
|
self.name = name
|
|
self.types = types
|
|
self.typekey = typekey
|
|
self.typekey2 = typekey2
|
|
self.typeindex = typeindex
|
|
self.default = default
|
|
#end def __init__
|
|
|
|
def __call__(self,*args,**kwargs):
|
|
#emulate SQDxml.__init__
|
|
#get the value of the typekey
|
|
a = args
|
|
kw = kwargs
|
|
found_type = False
|
|
if len(args)>0:
|
|
v = args[0]
|
|
if isinstance(v,XMLelement):
|
|
kw = v._attributes
|
|
elif isinstance(v,section):
|
|
a = v.args
|
|
kw = v.kwargs
|
|
elif isinstance(v,tuple(self.types.values())):
|
|
found_type = True
|
|
type = v.__class__.__name__
|
|
#end if
|
|
#end if
|
|
if not found_type:
|
|
if self.typekey in kw.keys():
|
|
type = kw[self.typekey]
|
|
elif self.typekey2 in kw.keys():
|
|
type = kw[self.typekey2]
|
|
elif self.default!=None:
|
|
type = self.default
|
|
else:
|
|
type = a[self.typeindex]
|
|
#end if
|
|
#end if
|
|
type = self.condense_name(type)
|
|
if type in self.types:
|
|
return self.types[type](*args,**kwargs)
|
|
else:
|
|
msg = self.name+' factory is not aware of the following subtype:\n'
|
|
msg+= ' '+type+'\n'
|
|
self.error(msg,exit=False,trace=False)
|
|
#end if
|
|
#end def __call__
|
|
|
|
def init_class(self):
|
|
None # this is for compatibility with SQDxml only (do not overwrite)
|
|
#end def init_class
|
|
#end class SQDxmlFactory
|
|
|
|
|
|
|
|
class Param(Names):
|
|
metadata = None
|
|
|
|
def __call__(self,*args,**kwargs):
|
|
if len(args)==0:
|
|
self.error('no arguments provided, should have recieved one XMLelement')
|
|
elif not isinstance(args[0],XMLelement):
|
|
return args[0]
|
|
#self.error('first argument is not an XMLelement')
|
|
#end if
|
|
return self.read(args[0])
|
|
#end def __call__
|
|
|
|
def read(self,xml):
|
|
val = ''
|
|
attr = set(xml._attributes.keys())
|
|
other_attr = attr-set(['name'])
|
|
if 'name' in attr and len(other_attr)>0:
|
|
oa = obj()
|
|
for a in other_attr:
|
|
oa[a] = xml._attributes[a]
|
|
#end for
|
|
self.metadata[xml.name] = oa
|
|
#end if
|
|
if 'text' in xml:
|
|
token = xml.text.split('\n',1)[0].split(None,1)[0]
|
|
if is_int(token):
|
|
val = loadtxt(StringIO(xml.text),int)
|
|
elif is_float(token):
|
|
val = loadtxt(StringIO(xml.text),float)
|
|
else:
|
|
val = array(xml.text.split())
|
|
#end if
|
|
if val.size==1:
|
|
val = val.ravel()[0]
|
|
#end if
|
|
#end if
|
|
return val
|
|
#end def read
|
|
|
|
|
|
def write(self,value,mode='attr',tag='parameter',name=None,pad=' '):
|
|
c = ''
|
|
attr_mode = mode=='attr'
|
|
elem_mode = mode=='elem'
|
|
if not attr_mode and not elem_mode:
|
|
self.error(mode+' is not a valid mode. Options are attr,elem.')
|
|
#end if
|
|
if isinstance(value,list) or isinstance(value,tuple):
|
|
value = array(value)
|
|
#end if
|
|
if attr_mode:
|
|
if isinstance(value,ndarray):
|
|
arr = value.ravel()
|
|
for v in arr:
|
|
c+=str(v)+' '
|
|
#end for
|
|
c=c[:-1]
|
|
else:
|
|
c = str(value)
|
|
#end if
|
|
elif elem_mode:
|
|
c+=pad
|
|
is_array = isinstance(value,ndarray)
|
|
is_single = not (is_array and value.size>1)
|
|
if tag!=None:
|
|
if is_single:
|
|
max_len = 20
|
|
rem_len = max(0,max_len-len(name))
|
|
else:
|
|
rem_len = 0
|
|
#end if
|
|
other=''
|
|
if name in self.metadata:
|
|
for a,v in self.metadata[name].iteritems():
|
|
other +=' '+self.expand_name(a)+'="'+str(v)+'"'
|
|
#end for
|
|
#end if
|
|
c+='<'+tag+' name="'+name+'"'+other+rem_len*' '+'>'
|
|
pp = pad+' '
|
|
else:
|
|
pp = pad
|
|
#end if
|
|
if is_array:
|
|
if tag!=None:
|
|
c+='\n'
|
|
#end if
|
|
ndim = len(value.shape)
|
|
if ndim==1:
|
|
if tag!=None:
|
|
c+=pp
|
|
#end if
|
|
for v in value:
|
|
c+=str(v)+' '
|
|
#end for
|
|
c=c[:-1]+'\n'
|
|
elif ndim==2:
|
|
nrows,ncols = value.shape
|
|
fmt=pp
|
|
if value.dtype == dtype(float):
|
|
vfmt = ':16.8e'
|
|
else:
|
|
vfmt = ''
|
|
#end if
|
|
for nc in range(ncols):
|
|
fmt+='{'+str(nc)+vfmt+'} '
|
|
#end for
|
|
fmt = fmt[:-2]+'\n'
|
|
for nr in range(nrows):
|
|
c+=fmt.format(*value[nr])
|
|
#end for
|
|
else:
|
|
self.error('only 1 and 2 dimensional arrays are supported for xml formatting.\n Received '+ndim+' dimensional array.')
|
|
#end if
|
|
else:
|
|
cname = self.condense_name(name)
|
|
if cname in self.bools and (value==0 or value==1):
|
|
val = self.bools[cname][value]
|
|
else:
|
|
val = value
|
|
#end if
|
|
c += ' '+str(val)
|
|
#end if
|
|
if tag!=None:
|
|
c+=pad+'</'+tag+'>\n'
|
|
#end if
|
|
#end if
|
|
return c
|
|
#end def write
|
|
|
|
|
|
def init_class(self):
|
|
None
|
|
#end def init_class
|
|
#end class Param
|
|
param = Param()
|
|
|
|
|
|
|
|
|
|
class simulation(SQDxml):
|
|
elements = ['project','atom','eigensolve']
|
|
#end class simulation
|
|
|
|
|
|
class project(SQDxml):
|
|
attributes = ['id','series']
|
|
#end class project
|
|
|
|
class atom(SQDxml):
|
|
attributes = ['name','num_closed_shells']
|
|
elements = ['grid','orbitalset','hamiltonian']
|
|
#identifier = 'name'
|
|
#end class atom
|
|
|
|
class grid(SQDxml):
|
|
attributes = ['type','ri','rf','npts','scale']
|
|
#end class grid
|
|
|
|
class orbitalset(SQDxml):
|
|
attributes = ['condition']
|
|
elements = ['orbital']
|
|
#end class orbitalset
|
|
|
|
class orbital(SQDxml):
|
|
attributes = ['n','l','m','s','c']
|
|
identifier = 'n','l','m','s'
|
|
identifier_type = tuple
|
|
#end class orbital
|
|
|
|
class hamiltonian(SQDxml):
|
|
attributes = ['type']
|
|
parameters = ['z']
|
|
#end class hamiltonian
|
|
|
|
class eigensolve(SQDxml):
|
|
parameters = ['max_iter','etot_tol','eig_tol','mix_ratio']
|
|
#end class eigensolve
|
|
|
|
|
|
|
|
|
|
classes = [ #standard classes
|
|
simulation,project,atom,grid,orbitalset,orbital,hamiltonian,eigensolve
|
|
]
|
|
types = dict( #simple types and factories
|
|
)
|
|
plurals = obj(
|
|
orbitals = 'orbital'
|
|
)
|
|
plurals_inv = plurals.inverse()
|
|
plural_names = set(plurals.keys())
|
|
single_names = set(plurals.values())
|
|
Names.set_expanded_names(
|
|
)
|
|
for c in classes:
|
|
c.init_class()
|
|
types[c.__name__] = c
|
|
#end for
|
|
|
|
|
|
#set default values
|
|
simulation.defaults.set(
|
|
project = project,
|
|
atom = atom,
|
|
eigensolve = eigensolve
|
|
)
|
|
project.defaults.set(
|
|
series = 0
|
|
)
|
|
atom.defaults.set(
|
|
grid = grid,
|
|
orbitalset = orbitalset,
|
|
hamiltonian = hamiltonian
|
|
)
|
|
grid.defaults.set(
|
|
type='log',ri=1e-6,rf=400,npts=2001
|
|
)
|
|
orbitalset.defaults.set(
|
|
condition = 'spin_space'
|
|
)
|
|
hamiltonian.defaults.set(
|
|
type = 'nuclear'
|
|
)
|
|
eigensolve.defaults.set(
|
|
max_iter = 1000,
|
|
etot_tol = 1e-8,
|
|
eig_tol = 1e-12,
|
|
mix_ratio = .3
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SqdInput(SimulationInput,Names):
|
|
|
|
opt_methods = set([])
|
|
|
|
simulation_type = simulation
|
|
|
|
default_metadata = meta(
|
|
)
|
|
|
|
def __init__(self,arg0=None,arg1=None):
|
|
Param.metadata = None
|
|
filepath = None
|
|
metadata = None
|
|
element = None
|
|
if arg0==None and arg1==None:
|
|
None
|
|
elif isinstance(arg0,str) and arg1==None:
|
|
filepath = arg0
|
|
elif isinstance(arg0,SQDxml) and arg1==None:
|
|
element = arg0
|
|
elif isinstance(arg0,meta) and isinstance(arg1,SQDxml):
|
|
metadata = arg0
|
|
element = arg1
|
|
else:
|
|
self.error('input arguments of types '+arg0.__class__.__name__+' and '+arg0.__class__.__name__+' cannot be used to initialize SqdInput')
|
|
#end if
|
|
if metadata!=None:
|
|
self._metadata = metadata
|
|
#end if
|
|
if filepath!=None:
|
|
self.read(filepath)
|
|
elif element!=None:
|
|
#simulation = arg0
|
|
#self.simulation = self.simulation_type(simulation)
|
|
elem_class = element.__class__
|
|
if 'identifier' in elem_class.__dict__:
|
|
name = elem_class.identifier
|
|
else:
|
|
name = elem_class.__name__
|
|
#end if
|
|
self[name] = elem_class(element)
|
|
#end if
|
|
Param.metadata = None
|
|
#end def __init__
|
|
|
|
def get_base(self):
|
|
elem_names = list(self.keys())
|
|
if '_metadata' in self:
|
|
elem_names.remove('_metadata')
|
|
#end if
|
|
if len(elem_names)>1:
|
|
self.error('sqd input cannot have more than one base element\n You have provided '+str(len(elem_names))+': '+str(elem_names))
|
|
#end if
|
|
return self[elem_names[0]]
|
|
#end def get_base
|
|
|
|
def get_basename(self):
|
|
elem_names = list(self.keys())
|
|
if '_metadata' in self:
|
|
elem_names.remove('_metadata')
|
|
#end if
|
|
if len(elem_names)>1:
|
|
self.error('sqd input cannot have more than one base element\n You have provided '+str(len(elem_names))+': '+str(elem_names))
|
|
#end if
|
|
return elem_names[0]
|
|
#end def get_basename
|
|
|
|
def read(self,filepath):
|
|
if os.path.exists(filepath):
|
|
element_joins=[]
|
|
element_aliases=dict()
|
|
xml = XMLreader(filepath,element_joins,element_aliases,warn=False).obj
|
|
xml.condense()
|
|
self._metadata = meta() #store parameter/attrib attribute metadata
|
|
Param.metadata = self._metadata
|
|
if 'simulation' in xml:
|
|
self.simulation = simulation(xml.simulation)
|
|
else:
|
|
#try to determine the type
|
|
elements = []
|
|
keys = []
|
|
error = False
|
|
for key,value in xml.iteritems():
|
|
if isinstance(key,str) and key[0]!='_':
|
|
if key in types:
|
|
elements.append(types[key](value))
|
|
keys.append(key)
|
|
else:
|
|
self.error('element '+key+' is not a recognized type',exit=False)
|
|
error = True
|
|
#end if
|
|
#end if
|
|
#end for
|
|
if error:
|
|
self.error('cannot read input xml file')
|
|
#end if
|
|
if len(elements)==0:
|
|
self.error('no valid elements were found for input xml file')
|
|
#end if
|
|
for i in range(len(elements)):
|
|
elem = elements[i]
|
|
key = keys[i]
|
|
if isinstance(elem,SQDxml):
|
|
if 'identifier' in elem.__class__.__dict__:
|
|
name = elem.identifier
|
|
else:
|
|
name = elem.tag
|
|
#end if
|
|
else:
|
|
name = key
|
|
#end if
|
|
self[name] = elem
|
|
#end for
|
|
#end if
|
|
Param.metadata = None
|
|
else:
|
|
self.error('the filepath you provided does not exist.\n Input filepath: '+filepath)
|
|
#end if
|
|
#end def read
|
|
|
|
|
|
def write_contents(self):
|
|
c = ''
|
|
header = '''<?xml version="1.0"?>
|
|
'''
|
|
c+= header
|
|
if '_metadata' in self:
|
|
Param.metadata = self._metadata
|
|
elif Param.metadata == None:
|
|
Param.metadata = self.default_metadata
|
|
#end if
|
|
base = self.get_base()
|
|
c+=base.write(first=True)
|
|
Param.metadata = None
|
|
return c
|
|
#end def write_contents
|
|
|
|
def unroll_calculations(self,modify=True):
|
|
qmc = []
|
|
sim = self.simulation
|
|
if 'calculations' in sim:
|
|
calcs = sim.calculations
|
|
elif 'qmc' in sim:
|
|
calcs = [sim.qmc]
|
|
else:
|
|
calcs = []
|
|
#end if
|
|
for i in range(len(calcs)):
|
|
c = calcs[i]
|
|
if isinstance(c,loop):
|
|
qmc.extend(c.unroll())
|
|
else:
|
|
qmc.append(c)
|
|
#end if
|
|
#end for
|
|
qmc = make_collection(qmc)
|
|
if modify:
|
|
self.simulation.calculations = qmc
|
|
#end if
|
|
return qmc
|
|
#end def unroll_calculations
|
|
|
|
def get(self,*names):
|
|
base = self.get_base()
|
|
return base.get(names)
|
|
#end def get
|
|
|
|
def remove(self,*names):
|
|
base = self.get_base()
|
|
base.remove(*names)
|
|
#end def remove
|
|
|
|
def replace(self,*args,**kwargs):# input is list of keyword=(oldval,newval)
|
|
base = self.get_base()
|
|
base.replace(*args,**kwargs)
|
|
#end def replace
|
|
|
|
def move(self,**elemdests):
|
|
base = self.get_base()
|
|
base.move(**elemdests)
|
|
#end def move
|
|
|
|
|
|
def get_host(self,names):
|
|
base = self.get_base()
|
|
return base.get_host(names)
|
|
#end if
|
|
|
|
def incorporate_defaults(self,elements=False,overwrite=False,propagate=False):
|
|
base = self.get_base()
|
|
base.incorporate_defaults(elements,overwrite,propagate)
|
|
#end def incorporate_defaults
|
|
|
|
def pluralize(self):
|
|
base = self.get_base()
|
|
base.pluralize()
|
|
#end def pluralize
|
|
|
|
def standard_placements(self):
|
|
self.move(particleset='qmcsystem',wavefunction='qmcsystem',hamiltonian='qmcsystem')
|
|
#end def standard_placements
|
|
|
|
def difference(self,other):
|
|
s1 = self.copy()
|
|
s2 = other.copy()
|
|
b1 = s1.get_basename()
|
|
b2 = s2.get_basename()
|
|
q1 = s1[b1]
|
|
q2 = s2[b2]
|
|
if b1!=b2:
|
|
different = True
|
|
d1 = q1
|
|
d2 = q2
|
|
diff = None
|
|
else:
|
|
s1.standard_placements()
|
|
s2.standard_placements()
|
|
s1.pluralize()
|
|
s2.pluralize()
|
|
different,diff,d1,d2 = q1.difference(q2,root=False)
|
|
#end if
|
|
if diff!=None:
|
|
diff.remove_empty()
|
|
#end if
|
|
d1.remove_empty()
|
|
d2.remove_empty()
|
|
return different,diff,d1,d2
|
|
#end def difference
|
|
|
|
def remove_empty(self):
|
|
base = self.get_base()
|
|
base.remove_empty()
|
|
#end def remove_empty
|
|
|
|
def read_xml(self,filepath):
|
|
if os.path.exists(filepath):
|
|
element_joins=['qmcsystem']
|
|
element_aliases=dict(loop='qmc')
|
|
xml = XMLreader(filepath,element_joins,element_aliases,warn=False).obj
|
|
xml.condense()
|
|
else:
|
|
self.error('the filepath you provided does not exist.\n Input filepath: '+filepath)
|
|
#end if
|
|
return xml
|
|
#end def read_xml
|
|
|
|
def include_xml(self,xmlfile,replace=True,exists=True):
|
|
xml = self.read_xml(xmlfile)
|
|
Param.metadata = self._metadata
|
|
for name,exml in xml.iteritems():
|
|
if not name.startswith('_'):
|
|
qxml = types[name](exml)
|
|
qname = qxml.tag
|
|
host = self.get_host(qname)
|
|
if host==None and exists:
|
|
self.error('host xml section for '+qname+' not found','SqdInput')
|
|
#end if
|
|
if qname in host:
|
|
section_name = qname
|
|
elif qname in plurals_inv and plurals_inv[qname] in host:
|
|
section_name = plurals_inv[qname]
|
|
else:
|
|
section_name = None
|
|
#end if
|
|
if replace:
|
|
if section_name!=None:
|
|
del host[section_name]
|
|
#end if
|
|
host[qname] = qxml
|
|
else:
|
|
if section_name==None:
|
|
host[qname] = qxml
|
|
else:
|
|
section = host[section_name]
|
|
if isinstance(section,collection):
|
|
section[qxml.identifier] = qxml
|
|
elif section_name in plurals_inv:
|
|
coll = collection()
|
|
coll[section.identifier] = section
|
|
coll[qxml.identifier] = qxml
|
|
del host[section_name]
|
|
host[plurals_inv[section_name]] = coll
|
|
else:
|
|
section.combine(qxml)
|
|
#end if
|
|
#end if
|
|
#end if
|
|
#end if
|
|
#end for
|
|
Param.metadata = None
|
|
#end def include_xml
|
|
|
|
|
|
def get_output_info(self,list=True):
|
|
project = self.simulation.project
|
|
prefix = project.id+'.s'+str(project.series).zfill(3)+'.'
|
|
|
|
outfiles = obj(
|
|
exchange = prefix+'exchange',
|
|
h5 = prefix+'h5',
|
|
hartree = prefix+'hartree',
|
|
log = prefix+'log',
|
|
orb = prefix+'orb.dat',
|
|
qmc = prefix+'qmc.xml',
|
|
Vext = prefix+'Vext.xml'
|
|
)
|
|
if list:
|
|
return outfiles.list()
|
|
else:
|
|
return outfiles
|
|
#end if
|
|
#end def get_output_info
|
|
|
|
|
|
|
|
def incorporate_system(self,system):
|
|
element = system.structure.elem[0]
|
|
Z = periodic_table[element].atomic_number
|
|
|
|
atom = self.simulation.atom
|
|
atom.name = element
|
|
atom.grid.scale = Z
|
|
atom.hamiltonian.z = Z
|
|
#end def incorporate_system
|
|
|
|
|
|
def return_system(self):
|
|
system = PhysicalSystem(
|
|
structure = Structure(
|
|
elem = [self.simulation.atom.name],
|
|
pos = [[0,0,0]]
|
|
)
|
|
)
|
|
return system
|
|
#end def return_system
|
|
#end class SqdInput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Shells(Pobj):
|
|
channel_names = tuple('spdfghik')
|
|
channels = obj()
|
|
for l in range(len(channel_names)):
|
|
channels[channel_names[l]] = range(l,-l-1,-1)
|
|
max_shells = 7
|
|
all_shells = obj()
|
|
for n in range(1,max_shells+1):
|
|
shell = obj()
|
|
for l in range(n):
|
|
shell[channel_names[l]] = range(l,-l-1,-1)
|
|
#end for
|
|
all_shells[n] = shell
|
|
#end for
|
|
|
|
channel_indices = obj()
|
|
for i in range(len(channel_names)):
|
|
channel_indices[channel_names[i]] = i
|
|
#end for
|
|
|
|
core_names = ['He','Ne','Ar','Kr','Xe','Rn']
|
|
core_orbitals = obj(
|
|
He = '1s',
|
|
Ne = '1s 2s 2p',
|
|
Ar = '1s 2s 2p 3s 3p',
|
|
Kr = '1s 2s 2p 3s 3p 4s 3d 4p',
|
|
Xe = '1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p',
|
|
Rn = '1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p'
|
|
)
|
|
ncore_shells = obj(
|
|
He = 1,
|
|
Ne = 2,
|
|
Ar = 3,
|
|
Kr = 4,
|
|
Xe = 5,
|
|
Rn = 6
|
|
)
|
|
for cname in core_orbitals.keys():
|
|
co = core_orbitals[cname].split()
|
|
olist = []
|
|
for o in co:
|
|
n = int(o[0])
|
|
ls = o[1]
|
|
olist.append((n,ls))
|
|
#end for
|
|
core_orbitals[cname] = olist
|
|
#end for
|
|
cores = obj()
|
|
for cname,orblist in core_orbitals.iteritems():
|
|
core = obj()
|
|
cores[cname] = core
|
|
for (n,ls) in orblist:
|
|
if not n in core:
|
|
core[n] = obj()
|
|
#end if
|
|
shell = core[n]
|
|
if not ls in shell:
|
|
shell[ls] = list(all_shells[n][ls])
|
|
#end if
|
|
#end for
|
|
#end for
|
|
|
|
orbital_fill_order_list = '1s 2s 2p 3s 3p 4s 3d 4p 5s 4d 5p 6s 4f 5d 6p 7s 5f 6d 7p'.split()
|
|
orbital_fill_order = []
|
|
for orbital in orbital_fill_order_list:
|
|
n = int(orbital[0])
|
|
ls = orbital[1]
|
|
l = channel_indices[ls]
|
|
mult = len(channels[ls])
|
|
mlist = list(channels[ls])
|
|
orbital_fill_order.append(obj(n=n,l=l,ls=ls,mult=mult,mlist=mlist,total_mult=2*mult))
|
|
#end for
|
|
#m_l fill order is max m_l to min m_l, i.e. l,l-1,l-2,...,-l+2,-l+1,-l
|
|
|
|
|
|
@classmethod
|
|
def hunds_rule_filling(cls,atom,net_charge=0,net_spin='ground',location='hunds_rule_filling'):
|
|
if isinstance(atom,str) and atom in periodic_table:
|
|
Z = periodic_table[atom].atomic_number
|
|
elif isinstance(atom,int) and atom>0 and atom<110:
|
|
Z = atom
|
|
else:
|
|
cls.class_error('expected atomic symbol or atomic number for atom\n you provided '+str(atom),location)
|
|
#end if
|
|
|
|
nelectrons = Z-net_charge
|
|
|
|
if isinstance(net_spin,int):
|
|
nup = float(nelectrons + net_spin)/2
|
|
ndown = float(nelectrons - net_spin)/2
|
|
if abs(nup-int(nup))>1e-3:
|
|
cls.class_error('requested spin state {0} incompatible with {1} electrons'.format(net_spin,nelectrons),location)
|
|
#end if
|
|
nup = int(nup)
|
|
ndown = int(ndown)
|
|
elif net_spin=='ground' or net_spin==None:
|
|
net_spin = None
|
|
nup = None
|
|
ndown = None
|
|
else:
|
|
cls.class_error("net_spin must be 'ground'/None or integer\n you provided "+str(net_spin))
|
|
#end if
|
|
|
|
if net_spin!=None and nup+ndown!=nelectrons:
|
|
cls.class_error('number of up and down electrons does add up to the total\n this may reflect an error in your input or the code\n please check: nel={0} nup={1} ndown={2} Z={3} net_charge={4} net_spin={5}'.format(nelectrons,nup,ndown,Z,net_charge,net_spin),location)
|
|
#end if
|
|
|
|
closed_orbitals = []
|
|
open_orbital = None
|
|
nel = 0
|
|
nud_closed = 0
|
|
for orbital in cls.orbital_fill_order:
|
|
nel_orb = orbital.total_mult
|
|
if nel<nelectrons:
|
|
if nel+nel_orb<nelectrons:
|
|
closed_orbitals.append(orbital)
|
|
nud_closed += orbital.mult
|
|
else:
|
|
open_orbital = orbital
|
|
break
|
|
#end if
|
|
#end if
|
|
nel+=nel_orb
|
|
#end if
|
|
|
|
|
|
up = ''
|
|
down = ''
|
|
updown = ''
|
|
|
|
o = open_orbital
|
|
if net_spin!=None:
|
|
nup_open = nup - nud_closed
|
|
ndown_open = ndown - nud_closed
|
|
|
|
if nup_open>open_orbital.mult or ndown_open>open_orbital.mult:
|
|
cls.class_error('more up or down electrons in open shell than will fit\n open_shell={0}{1}, up/down_size={2}, nup={3}, ndown={4}'.format(o.n,o.ls,o.mult,nup_open,ndown_open),'Developer')
|
|
#end if
|
|
else:
|
|
nopen = nelectrons - 2*nud_closed
|
|
nup_open = min(nopen,o.mult)
|
|
ndown_open = nopen - nup_open
|
|
#end if
|
|
|
|
for orbital in closed_orbitals:
|
|
updown += '{0}{1}'.format(orbital.n,orbital.ls)
|
|
#end for
|
|
if nup_open>0:
|
|
up = '{0}{1}{2}'.format(o.n,o.ls,str(tuple(o.mlist[:nup_open]))).replace(',)',')')
|
|
#end if
|
|
if ndown_open>0:
|
|
down = '{0}{1}{2}'.format(o.n,o.ls,str(tuple(o.mlist[:ndown_open]))).replace(',)',')')
|
|
#end if
|
|
|
|
if up=='':
|
|
up=None
|
|
#end if
|
|
if down=='':
|
|
down=None
|
|
#end if
|
|
if updown=='':
|
|
updown=None
|
|
#end if
|
|
|
|
net_spin_found = nup_open-ndown_open
|
|
if net_spin!=None and net_spin_found!=net_spin:
|
|
cls.error('spin state determined incorrectly\n net_spin requested: {0}\n net_spin found: {1}\n nup,ndown: {2},{3}'.format(net_spin,net_spin_found,nup_open,ndown_open))
|
|
#end if
|
|
|
|
return up,down,updown,net_spin_found
|
|
#end def hunds_rule_filling
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,shells=None,location='Shells'):
|
|
self.location = location
|
|
self.shells = obj()
|
|
self.core = None
|
|
self.nclosed = 0
|
|
if shells is None:
|
|
return
|
|
#end if
|
|
if isinstance(shells,str):
|
|
self.read_shell_string(shells)
|
|
elif isinstance(shells,obj) or isinstance(shells,dict):
|
|
if set(shells.keys()) <= set(self.all_shells.keys()):
|
|
self.shells.transfer_from(shells,copy=True)
|
|
else:
|
|
self.error('unexpected key values for shells\n expected values: '+str(shells.keys())+'\n you provided '+str(self.all_shells.keys()))
|
|
#end if
|
|
else:
|
|
self.error('must provide a string, dict, or obj describing atomic shells\n you provided '+str(shell_str))
|
|
#end if
|
|
self.check_shells()
|
|
#end def __init__
|
|
|
|
|
|
def read_shell_string(self,ss):
|
|
self.shell_string = str(ss)
|
|
ss = ss.replace('[',' ').replace(']',' ').replace('(',' [').replace(')','] ')
|
|
ss = ss.replace(', ',',').replace(', ',',').replace(', ',',')
|
|
ss = ss.replace(' ,',',').replace(' ,',',').replace(' ,',',')
|
|
ss = ss.replace('[ ','[').replace('[ ','[').replace('[ ','[')
|
|
ss = ss.replace(' ]',']').replace(' ]',']').replace(' ]',']')
|
|
for core_name in self.core_names:
|
|
ss = ss.replace(core_name,' '+core_name+' ')
|
|
#end for
|
|
for channel in self.channel_names:
|
|
ss = ss.replace(channel,' '+channel+' ')
|
|
#end for
|
|
|
|
shells = self.shells
|
|
sl = ss.split()
|
|
if len(sl)>0:
|
|
i = 0
|
|
if sl[i] in self.cores:
|
|
shells.transfer_from(self.cores[sl[i]],copy=True)
|
|
i+=1
|
|
#end if
|
|
n=None
|
|
l=None
|
|
m=None
|
|
shell = None
|
|
channel = None
|
|
while i<len(sl):
|
|
v = string_to_val(sl[i])
|
|
if isinstance(v,int):
|
|
if n!=None:
|
|
if n in shells:
|
|
shells[n].transfer_from(shell)
|
|
else:
|
|
shells[n] = shell
|
|
#end if
|
|
n=None
|
|
#end if
|
|
n = v
|
|
if n>self.max_shells:
|
|
self.error('maximum shell number is {0}\n you requested {1}'.format(self.max_shells,n))
|
|
#end if
|
|
shell = obj()
|
|
elif isinstance(v,str):
|
|
l = v.lower()
|
|
if not l in self.channel_names:
|
|
self.error('you requested an invalid channel: '+str(l)+'\n allowed channels are '+str(self.channel_names))
|
|
#end if
|
|
if i+1<len(sl):
|
|
v1 = string_to_val(sl[i+1])
|
|
if isinstance(v1,list):
|
|
m=v1
|
|
ma = abs(array(m))
|
|
if ma.max()>self.channel_indices[l]:
|
|
self.error('maximum |m| for {0} channel is {1}\n you requested {2}: {3}'.format(l,channel_indices[l],ma.max(),m))
|
|
#end if
|
|
channel = m
|
|
m=None
|
|
i+=1
|
|
else:
|
|
ln = self.channel_indices[l]
|
|
channel = range(-ln,ln+1)
|
|
#end if
|
|
else:
|
|
ln = self.channel_indices[l]
|
|
channel = range(-ln,ln+1)
|
|
#end if
|
|
if n!=None:
|
|
shell[l] = channel
|
|
l=None
|
|
#end if
|
|
#end if
|
|
i+=1
|
|
#end while
|
|
if n!=None:
|
|
if n in shells:
|
|
shells[n].transfer_from(shell)
|
|
else:
|
|
shells[n] = shell
|
|
#end if
|
|
n=None
|
|
#end if
|
|
#end if
|
|
#end def read_shell_string
|
|
|
|
|
|
def check_shells(self):
|
|
ref = self.all_shells
|
|
shells = self.shells
|
|
rkn = set(ref.keys())
|
|
skn = set(shells.keys())
|
|
errors = False
|
|
if not skn <= rkn:
|
|
self.error('shell indices (n) are invalid\n options for valid shell indices: '+str(list(rkn))+'\n shell indices of self: '+str(list(skn)),exit=False)
|
|
errors = True
|
|
#end if
|
|
for n,shell in shells.iteritems():
|
|
rshell = ref[n]
|
|
rkl = set(rshell.keys())
|
|
skl = set(shell.keys())
|
|
if not skl<=rkl:
|
|
self.error('channel keys (l) are invalid\n options for valid channel keys: '+str(list(rkl))+'\n channel keys of self: '+str(list(skl)),exit=False)
|
|
errors = True
|
|
#end if
|
|
for l,channel in shell.iteritems():
|
|
rchannel = rshell[l]
|
|
rkm = set(rchannel)
|
|
skm = set(channel)
|
|
if not skm<=rkm:
|
|
self.error('azimuthal indices (m) are invalid\n options for valid azimuthal indices: '+str(list(rkm))+'\n azimuthal indices of self: '+str(list(skm)),exit=False)
|
|
errors = True
|
|
#end if
|
|
#end for
|
|
#end for
|
|
if errors:
|
|
self.log('\nreference shells:\n'+str(ref))
|
|
self.log('\nself shells:\n'+str(shells))
|
|
self.error('encountered errors')
|
|
#end if
|
|
#end def check_shells
|
|
|
|
|
|
def partition(self):
|
|
shells = self.shells
|
|
#find all closed subshells
|
|
closed = []
|
|
for n,shell in shells.iteritems():
|
|
for l,channel in shell.iteritems():
|
|
if set(channel)==set(self.all_shells[n][l]):
|
|
closed.append((n,l))
|
|
#end if
|
|
#end for
|
|
#end for
|
|
closed = set(closed)
|
|
#find what the core is, He,Ne,Ar, etc
|
|
core_orb = set()
|
|
for i in range(len(self.core_names)-1,-1,-1):
|
|
core_name = self.core_names[i]
|
|
core_orbitals = set(self.core_orbitals[core_name])
|
|
if core_orbitals<=closed:
|
|
self.core = core_name
|
|
self.nclosed = self.ncore_shells[core_name]
|
|
core_orb = core_orbitals
|
|
break
|
|
#end if
|
|
#end for
|
|
#remove orbitals belonging to the core
|
|
for (n,l) in core_orb:
|
|
del shells[n][l]
|
|
#end for
|
|
for n in list(shells.keys()):
|
|
if len(shells[n])==0:
|
|
del shells[n]
|
|
#end if
|
|
#end for
|
|
#end def partition
|
|
|
|
|
|
def orbitals(self,spin):
|
|
if spin!='up' and spin!='down' and spin!=1 and spin!=-1:
|
|
self.error('spin must be up/1 or down/-1\n you provided: '+str(spin))
|
|
elif spin=='up':
|
|
s = 1
|
|
elif spin=='down':
|
|
s = -1
|
|
#end if
|
|
|
|
orbitals = []
|
|
for n,shell in self.shells.iteritems():
|
|
for lname,channel in shell.iteritems():
|
|
l = self.channel_indices[lname]
|
|
for m in channel:
|
|
orbitals.append(orbital(n=n,l=l,m=m,s=s,c=1.0))
|
|
#end for
|
|
#end for
|
|
#end for
|
|
return orbitals
|
|
#end def orbitals
|
|
|
|
|
|
def error(self,msg,exit=True):
|
|
Pobj.error(self,msg,self.location,exit=exit)
|
|
#end def error
|
|
#end class Shells
|
|
hunds_rule_filling = Shells.hunds_rule_filling
|
|
|
|
|
|
|
|
def generate_orbitalset(up=None,down=None,updown=None,location='generate_orbitalset'):
|
|
nclosed = 0
|
|
uorb = []
|
|
dorb = []
|
|
if isinstance(updown,str):
|
|
bshells = Shells(updown,location)
|
|
bshells.partition()
|
|
nclosed = bshells.nclosed
|
|
uorb = bshells.orbitals('up')
|
|
dorb = bshells.orbitals('down')
|
|
#end if
|
|
ushells = Shells(up,location)
|
|
dshells = Shells(down,location)
|
|
|
|
uorb += ushells.orbitals('up')
|
|
dorb += dshells.orbitals('down')
|
|
|
|
orbset = orbitalset(
|
|
orbitals = make_collection(uorb+dorb)
|
|
)
|
|
|
|
return orbset,nclosed
|
|
#end def generate_orbitalset
|
|
|
|
|
|
|
|
|
|
def generate_sqd_input(id = None,
|
|
series = 0,
|
|
system = None,
|
|
filling = None,
|
|
net_spin = 'none',
|
|
up = None,
|
|
down = None,
|
|
updown = None,
|
|
grid_type = 'log',
|
|
ri = 1e-6,
|
|
rf = 400,
|
|
npts = 10001,
|
|
max_iter = 1000,
|
|
etot_tol = 1e-8,
|
|
eig_tol = 1e-12,
|
|
mix_ratio = 0.7 ):
|
|
|
|
location = 'generate_sqd_input'
|
|
|
|
|
|
metadata = SqdInput.default_metadata.copy()
|
|
si = SqdInput(
|
|
metadata,
|
|
simulation(
|
|
project = section(
|
|
series = series
|
|
),
|
|
atom = section(
|
|
grid = section(type=grid_type,ri=ri,rf=rf,npts=npts),
|
|
orbitalset = section(),
|
|
hamiltonian = section()
|
|
),
|
|
eigensolve = section(
|
|
max_iter = max_iter,
|
|
etot_tol = etot_tol,
|
|
eig_tol = eig_tol,
|
|
mix_ratio = mix_ratio
|
|
)
|
|
)
|
|
)
|
|
|
|
#set the atomic system
|
|
if system==None:
|
|
SqdInput.class_error('system (atom) must be provided',location)
|
|
elif isinstance(system,str) and system in periodic_table:
|
|
Z = periodic_table[system].atomic_number
|
|
net_charge = Z
|
|
net_spin = 0
|
|
for orbital in orbset.orbitals:
|
|
net_charge -= 1
|
|
net_spin += orbital.s
|
|
#end for
|
|
system = PhysicalSystem(
|
|
structure = Structure(elem=[system],pos=[[0,0,0]]),
|
|
net_charge = net_charge,
|
|
net_spin = net_spin
|
|
)
|
|
si.incorporate_system(system)
|
|
elif isinstance(system,PhysicalSystem):
|
|
si.incorporate_system(system)
|
|
else:
|
|
SqdInput.class_error('system must be an atomic symbol or a PhysicalSystem object','generate_sqd_input')
|
|
#end if
|
|
|
|
#generate spin state via Hund's rule if requested
|
|
if filling is None and up is None and down is None and updown is None:
|
|
SqdInput.class_error('filling or up/down/updown must be provided',location)
|
|
#end if
|
|
if filling !=None and not isinstance(filling,str):
|
|
SqdInput.class_error('filling must be a string\n you provided '+str(filling),location)
|
|
elif isinstance(filling,str) and filling.lower()=='hund':
|
|
atom = system.structure.elem[0]
|
|
net_charge = system.net_charge
|
|
if net_spin is 'none':
|
|
net_spin = system.net_spin
|
|
#end if
|
|
up,down,updown,net_spin = hunds_rule_filling(
|
|
atom = atom,
|
|
net_charge = net_charge,
|
|
net_spin = net_spin,
|
|
location = location
|
|
)
|
|
elif filling!=None:
|
|
SqdInput.class_error("{0} is not a valid choice for filling\n valid options are: 'hund'".format(filling),location)
|
|
#end if
|
|
|
|
#generate the orbitalset
|
|
orbset,nclosed = generate_orbitalset(
|
|
up = up,
|
|
down = down,
|
|
updown = updown
|
|
)
|
|
|
|
|
|
if id is None:
|
|
id = system.structure.elem[0]
|
|
#end if
|
|
sim = si.simulation
|
|
sim.project.id = id
|
|
sim.atom.set(
|
|
num_closed_shells = nclosed,
|
|
orbitalset = orbset
|
|
)
|
|
|
|
si.incorporate_defaults(elements=False,overwrite=False,propagate=True)
|
|
|
|
return si
|
|
#end def generate_sqd_input
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|