Import changes from pmg

This commit is contained in:
gmatteo 2019-09-04 11:26:59 +03:00
parent 38706526dd
commit 19c5748074
8 changed files with 89 additions and 67 deletions

View File

@ -111,6 +111,7 @@ class A2f(object):
if verbose:
for mustar in (0.1, 0.12, 0.2):
app("\tFor mustar %s: McMillan Tc: %s [K]" % (mustar, self.get_mcmillan_tc(mustar)))
if verbose > 1:
# $\int dw [a2F(w)/w] w^n$
for n in [0, 4]:
@ -911,13 +912,19 @@ class A2fFile(AbinitNcFile, Has_Structure, Has_ElectronBands, NotebookWriter):
yield self.plot(show=False)
#yield self.plot_eph_strength(show=False)
yield self.plot_with_a2f(show=False)
for qsamp in ["qcoarse", "qintp"]:
a2f = self.get_a2f_qsamp(qsamp)
yield a2f.plot_with_lambda(show=False)
yield a2f.plot_with_lambda(title="q-sampling: %s (%s)" % (str(a2f.ngqpt), qsamp), show=False)
#yield self.plot_nuterms(show=False)
#yield self.plot_a2(show=False)
#yield self.plot_tc_vs_mustar(show=False)
#if self.has_a2ftr:
# ncfile.a2ftr.plot();
def write_notebook(self, nbpath=None):
"""
Write a jupyter_ notebook to ``nbpath``. If nbpath is None, a temporay file in the current

View File

@ -433,7 +433,7 @@ class EventsParser(object):
event = yaml.load(doc.text) # Can't use ruamel safe_load!
#yaml.load(doc.text, Loader=ruamel.yaml.Loader)
#print(event.yaml_tag, type(event))
except:
except Exception:
#raise
# Wrong YAML doc. Check tha doc tag and instantiate the proper event.
message = "Malformatted YAML document at line: %d\n" % doc.lineno

View File

@ -20,7 +20,7 @@ from pprint import pprint
from tabulate import tabulate
from pydispatch import dispatcher
from collections import OrderedDict
from monty.collections import as_set, dict2namedtuple
from monty.collections import dict2namedtuple
from monty.string import list_strings, is_string, make_banner
from monty.operator import operator_from_str
from monty.io import FileLock
@ -61,6 +61,21 @@ __all__ = [
]
def as_set(obj):
"""
Convert obj into a set, returns None if obj is None.
>>> assert as_set(None) is None and as_set(1) == set([1]) and as_set(range(1,3)) == set([1, 2])
"""
if obj is None or isinstance(obj, collections.abc.Set):
return obj
if not isinstance(obj, collections.abc.Iterable):
return set((obj,))
else:
return set(obj)
class FlowResults(NodeResults):
JSON_SCHEMA = NodeResults.JSON_SCHEMA.copy()
@ -307,7 +322,7 @@ class Flow(Node, NodeContainer, MSONable):
if remove_lock and os.path.exists(filepath + ".lock"):
try:
os.remove(filepath + ".lock")
except:
except Exception:
pass
with FileLock(filepath):
@ -2954,7 +2969,7 @@ def phonon_flow(workdir, scf_input, ph_inputs, with_nscf=False, with_ddk=False,
# Parse the file to get the perturbations.
try:
irred_perts = yaml_read_irred_perts(fake_task.log_file.path)
except:
except Exception:
print("Error in %s" % fake_task.log_file.path)
raise

View File

@ -665,7 +665,7 @@ class PyFlowScheduler(object):
"""The function that will be executed by the scheduler."""
try:
return self._callback()
except:
except Exception:
# All exceptions raised here will trigger the shutdown!
s = straceback()
self.exceptions.append(s)
@ -855,7 +855,7 @@ class PyFlowScheduler(object):
"""
try:
return self._send_email(msg, tag)
except:
except Exception:
self.exceptions.append(straceback())
return -2

View File

@ -580,7 +580,7 @@ class Node(metaclass=abc.ABCMeta):
if self.is_task:
try:
return self.pos_str
except:
except Exception:
return os.path.basename(self.workdir)
else:
return os.path.basename(self.workdir)
@ -1146,7 +1146,7 @@ class HistoryRecord(object):
if self.args:
try:
msg = msg % self.args
except:
except Exception:
msg += str(self.args)
if asctime: msg = "[" + self.asctime + "] " + msg

View File

@ -312,7 +312,7 @@ def make_qadapter(**kwargs):
Return the concrete :class:`QueueAdapter` class from a string.
Note that one can register a customized version with:
.. code-block:: python
.. example::
from qadapters import SlurmAdapter
@ -352,7 +352,7 @@ class MaxNumLaunchesError(QueueAdapterError):
"""Raised by `submit_to_queue` if we try to submit more than `max_num_launches` times."""
class QueueAdapter(six.with_metaclass(abc.ABCMeta, MSONable)):
class QueueAdapter(MSONable, metaclass=abc.ABCMeta):
"""
The `QueueAdapter` is responsible for all interactions with a specific queue management system.
This includes handling all details of queue script format as well as queue submission and management.
@ -1265,28 +1265,28 @@ $${qverbatim}
"""
def set_qname(self, qname):
super(SlurmAdapter, self).set_qname(qname)
super().set_qname(qname)
if qname:
self.qparams["partition"] = qname
def set_mpi_procs(self, mpi_procs):
"""Set the number of CPUs used for MPI."""
super(SlurmAdapter, self).set_mpi_procs(mpi_procs)
super().set_mpi_procs(mpi_procs)
self.qparams["ntasks"] = mpi_procs
def set_omp_threads(self, omp_threads):
super(SlurmAdapter, self).set_omp_threads(omp_threads)
super().set_omp_threads(omp_threads)
self.qparams["cpus_per_task"] = omp_threads
def set_mem_per_proc(self, mem_mb):
"""Set the memory per process in megabytes"""
super(SlurmAdapter, self).set_mem_per_proc(mem_mb)
super().set_mem_per_proc(mem_mb)
self.qparams["mem_per_cpu"] = self.mem_per_proc
# Remove mem if it's defined.
#self.qparams.pop("mem", None)
def set_timelimit(self, timelimit):
super(SlurmAdapter, self).set_timelimit(timelimit)
super().set_timelimit(timelimit)
self.qparams["time"] = qu.time2slurm(timelimit)
def cancel(self, job_id):
@ -1338,7 +1338,7 @@ $${qverbatim}
# output should of the form '2561553.sdb' or '352353.jessup' - just grab the first part for job id
queue_id = int(out.split()[3])
logger.info('Job submission was successful and queue_id is {}'.format(queue_id))
except:
except Exception:
# probably error parsing job code
logger.critical('Could not parse job id following slurm...')
return SubmitResults(qid=queue_id, out=out, err=err, process=process)
@ -1404,17 +1404,17 @@ $${qverbatim}
"""
def set_qname(self, qname):
super(PbsProAdapter, self).set_qname(qname)
super().set_qname(qname)
if qname:
self.qparams["queue"] = qname
def set_timelimit(self, timelimit):
super(PbsProAdapter, self).set_timelimit(timelimit)
super().set_timelimit(timelimit)
self.qparams["walltime"] = qu.time2pbspro(timelimit)
def set_mem_per_proc(self, mem_mb):
"""Set the memory per process in megabytes"""
super(PbsProAdapter, self).set_mem_per_proc(mem_mb)
super().set_mem_per_proc(mem_mb)
#self.qparams["mem"] = self.mem_per_proc
def cancel(self, job_id):
@ -1699,7 +1699,7 @@ $${qverbatim}
try:
# output should of the form '2561553.sdb' or '352353.jessup' - just grab the first part for job id
queue_id = int(out.split('.')[0])
except:
except Exception:
# probably error parsing job code
logger.critical("Could not parse job id following qsub...")
return SubmitResults(qid=queue_id, out=out, err=err, process=process)
@ -1812,26 +1812,26 @@ class SGEAdapter(QueueAdapter):
$${qverbatim}
"""
def set_qname(self, qname):
super(SGEAdapter, self).set_qname(qname)
super().set_qname(qname)
if qname:
self.qparams["queue_name"] = qname
def set_mpi_procs(self, mpi_procs):
"""Set the number of CPUs used for MPI."""
super(SGEAdapter, self).set_mpi_procs(mpi_procs)
super().set_mpi_procs(mpi_procs)
self.qparams["ncpus"] = mpi_procs
def set_omp_threads(self, omp_threads):
super(SGEAdapter, self).set_omp_threads(omp_threads)
super().set_omp_threads(omp_threads)
logger.warning("Cannot use omp_threads with SGE")
def set_mem_per_proc(self, mem_mb):
"""Set the memory per process in megabytes"""
super(SGEAdapter, self).set_mem_per_proc(mem_mb)
super().set_mem_per_proc(mem_mb)
self.qparams["mem_per_slot"] = str(int(self.mem_per_proc)) + "M"
def set_timelimit(self, timelimit):
super(SGEAdapter, self).set_timelimit(timelimit)
super().set_timelimit(timelimit)
# Same convention as pbspro e.g. [hours:minutes:]seconds
self.qparams["walltime"] = qu.time2pbspro(timelimit)
@ -1854,7 +1854,7 @@ $${qverbatim}
# output should of the form
# Your job 1659048 ("NAME_OF_JOB") has been submitted
queue_id = int(out.split(' ')[2])
except:
except Exception:
# probably error parsing job code
logger.critical("Could not parse job id following qsub...")
return SubmitResults(qid=queue_id, out=out, err=err, process=process)
@ -1915,15 +1915,15 @@ $${qverbatim}
def set_mpi_procs(self, mpi_procs):
"""Set the number of CPUs used for MPI."""
super(MOABAdapter, self).set_mpi_procs(mpi_procs)
super().set_mpi_procs(mpi_procs)
self.qparams["procs"] = mpi_procs
def set_timelimit(self, timelimit):
super(MOABAdapter, self).set_timelimit(timelimit)
super().set_timelimit(timelimit)
self.qparams["walltime"] = qu.time2slurm(timelimit)
def set_mem_per_proc(self, mem_mb):
super(MOABAdapter, self).set_mem_per_proc(mem_mb)
super().set_mem_per_proc(mem_mb)
#TODO
#raise NotImplementedError("set_mem_per_cpu")
@ -1948,7 +1948,7 @@ $${qverbatim}
try:
# output should be the queue_id
queue_id = int(out.split()[0])
except:
except Exception:
# probably error parsing job code
logger.critical('Could not parse job id following msub...')
@ -2007,27 +2007,27 @@ $${qverbatim}
"""
def set_qname(self, qname):
super(BlueGeneAdapter, self).set_qname(qname)
super().set_qname(qname)
if qname:
self.qparams["class"] = qname
#def set_mpi_procs(self, mpi_procs):
# """Set the number of CPUs used for MPI."""
# super(BlueGeneAdapter, self).set_mpi_procs(mpi_procs)
# super().set_mpi_procs(mpi_procs)
# #self.qparams["ntasks"] = mpi_procs
#def set_omp_threads(self, omp_threads):
# super(BlueGeneAdapter, self).set_omp_threads(omp_threads)
# super().set_omp_threads(omp_threads)
# #self.qparams["cpus_per_task"] = omp_threads
#def set_mem_per_proc(self, mem_mb):
# """Set the memory per process in megabytes"""
# super(BlueGeneAdapter, self).set_mem_per_proc(mem_mb)
# super().set_mem_per_proc(mem_mb)
# #self.qparams["mem_per_cpu"] = self.mem_per_proc
def set_timelimit(self, timelimit):
"""Limits are specified with the format hh:mm:ss (hours:minutes:seconds)"""
super(BlueGeneAdapter, self).set_timelimit(timelimit)
super().set_timelimit(timelimit)
self.qparams["wall_clock_limit"] = qu.time2loadlever(timelimit)
def cancel(self, job_id):
@ -2068,7 +2068,7 @@ $${qverbatim}
token = out.split()[3]
s = token.split(".")[-1].replace('"', "")
queue_id = int(s)
except:
except Exception:
# probably error parsing job code
logger.critical("Could not parse job id following llsubmit...")
raise

View File

@ -108,7 +108,7 @@ class TaskResults(NodeResults):
@classmethod
def from_node(cls, task):
"""Initialize an instance from an :class:`AbinitTask` instance."""
new = super(TaskResults, cls).from_node(task)
new = super().from_node(task)
new.update(
executable=task.executable,
@ -168,7 +168,7 @@ class ParalConf(AttrDict):
}
def __init__(self, *args, **kwargs):
super(ParalConf, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
# Add default values if not already in self.
for k, v in self._DEFAULTS.items():
@ -232,7 +232,7 @@ class ParalHintsParser(object):
try:
d = yaml.safe_load(doc.text_notag)
return ParalHints(info=d["info"], confs=d["configurations"])
except:
except Exception:
import traceback
sexc = traceback.format_exc()
err_msg = "Wrong YAML doc:\n%s\n\nException:\n%s" % (doc.text, sexc)
@ -1221,7 +1221,7 @@ class MyTimedelta(datetime.timedelta):
def __str__(self):
"""Remove microseconds from timedelta default __str__"""
s = super(MyTimedelta, self).__str__()
s = super().__str__()
microsec = s.find(".")
if microsec != -1: s = s[:microsec]
return s
@ -1334,7 +1334,7 @@ class Task(Node, metaclass=abc.ABCMeta):
None means that this Task has no dependency.
"""
# Init the node
super(Task, self).__init__()
super().__init__()
self._input = input
@ -1672,7 +1672,7 @@ class Task(Node, metaclass=abc.ABCMeta):
self._returncode = self.process.wait()
try:
self.process.stderr.close()
except:
except Exception:
pass
self.set_status(self.S_DONE, "status set to Done")
@ -2014,7 +2014,7 @@ class Task(Node, metaclass=abc.ABCMeta):
self.history.info('Found unknown message in the queue qerr file: %s' % str(qerr_info))
#try:
# rt = self.datetimes.get_runtime().seconds
#except:
#except Exception:
# rt = -1.0
#tl = self.manager.qadapter.timelimit
#if rt > tl:
@ -3067,7 +3067,7 @@ class AbinitTask(Task):
self.manager.exclude_nodes(error.nodes)
self.reset_from_scratch()
self.set_status(self.S_READY, msg='excluding nodes')
except:
except Exception:
raise FixQueueCriticalError
else:
self.set_status(self.S_ERROR, msg='Node error but no node identified.')
@ -3286,7 +3286,7 @@ class ScfTask(GsTask):
return None
def get_results(self, **kwargs):
results = super(ScfTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
# Open the GSR file and add its data to results.out
with self.open_gsr() as gsr:
@ -3304,7 +3304,7 @@ class CollinearThenNonCollinearScfTask(ScfTask):
initialized from the previous WFK file.
"""
def __init__(self, input, workdir=None, manager=None, deps=None):
super(CollinearThenNonCollinearScfTask, self).__init__(input, workdir=workdir, manager=manager, deps=deps)
super().__init__(input, workdir=workdir, manager=manager, deps=deps)
# Enforce nspinor = 1, nsppol = 2 and prtwf = 1.
self._input = self.input.deepcopy()
self.input.set_spin_mode("polarized")
@ -3312,7 +3312,7 @@ class CollinearThenNonCollinearScfTask(ScfTask):
self.collinear_done = False
def _on_ok(self):
results = super(CollinearThenNonCollinearScfTask, self)._on_ok()
results = super()._on_ok()
if not self.collinear_done:
self.input.set_spin_mode("spinor")
self.collinear_done = True
@ -3355,7 +3355,7 @@ class NscfTask(GsTask):
else:
self.set_vars(ngfft=den_mesh)
super(NscfTask, self).setup()
super().setup()
def restart(self):
"""NSCF calculations can be restarted only if we have the WFK file."""
@ -3376,7 +3376,7 @@ class NscfTask(GsTask):
return self._restart()
def get_results(self, **kwargs):
results = super(NscfTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
# Read the GSR file.
with self.open_gsr() as gsr:
@ -3507,7 +3507,7 @@ class RelaxTask(GsTask, ProduceHist):
raise ValueError("Wrong value for what %s" % what)
def get_results(self, **kwargs):
results = super(RelaxTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
# Open the GSR file and add its data to results.out
with self.open_gsr() as gsr:
@ -3530,7 +3530,7 @@ class RelaxTask(GsTask, ProduceHist):
This change is needed so that we can specify dependencies with the syntax {node: "DEN"}
without having to know the number of iterations needed to converge the run in node!
"""
super(RelaxTask, self).fix_ofiles()
super().fix_ofiles()
# Find the last TIM?_DEN file.
last_timden = self.outdir.find_last_timden_file()
@ -3756,7 +3756,7 @@ class DdeTask(DfptTask):
color_rgb = np.array((61, 158, 255)) / 255
def get_results(self, **kwargs):
results = super(DdeTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
return results.register_gridfs_file(DDB=(self.outdir.has_abiext("DDE"), "t"))
@ -3767,10 +3767,10 @@ class DteTask(DfptTask):
# @check_spectator
def start(self, **kwargs):
kwargs['autoparal'] = False
return super(DteTask, self).start(**kwargs)
return super().start(**kwargs)
def get_results(self, **kwargs):
results = super(DteTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
return results.register_gridfs_file(DDB=(self.outdir.has_abiext("DDE"), "t"))
@ -3780,7 +3780,7 @@ class DdkTask(DfptTask):
#@check_spectator
def _on_ok(self):
super(DdkTask, self)._on_ok()
super()._on_ok()
# Client code expects to find du/dk in DDK file.
# Here I create a symbolic link out_1WF13 --> out_DDK
# so that we can use deps={ddk_task: "DDK"} in the high-level API.
@ -3789,7 +3789,7 @@ class DdkTask(DfptTask):
self.outdir.symlink_abiext('1WF', 'DDK')
def get_results(self, **kwargs):
results = super(DdkTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
return results.register_gridfs_file(DDK=(self.outdir.has_abiext("DDK"), "t"))
@ -3823,7 +3823,7 @@ class PhononTask(DfptTask):
return scf_cycle.plot(**kwargs)
def get_results(self, **kwargs):
results = super(PhononTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
return results.register_gridfs_files(DDB=(self.outdir.has_abiext("DDB"), "t"))
@ -3987,7 +3987,7 @@ class SigmaTask(ManyBodyTask):
raise RuntimeError("Cannot find SIGRES file!")
def get_results(self, **kwargs):
results = super(SigmaTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
# Open the SIGRES file and add its data to results.out
with self.open_sigres() as sigres:
@ -4109,7 +4109,7 @@ class BseTask(ManyBodyTask):
return None
def get_results(self, **kwargs):
results = super(BseTask, self).get_results(**kwargs)
results = super().get_results(**kwargs)
with self.open_mdf() as mdf:
#results["out"].update(mdf.as_dict())
@ -4152,11 +4152,11 @@ class OpticTask(Task):
deps.update({self.nscf_node: "WFK"})
super(OpticTask, self).__init__(optic_input, workdir=workdir, manager=manager, deps=deps)
super().__init__(optic_input, workdir=workdir, manager=manager, deps=deps)
def set_workdir(self, workdir, chroot=False):
"""Set the working directory of the task."""
super(OpticTask, self).set_workdir(workdir, chroot=chroot)
super().set_workdir(workdir, chroot=chroot)
# Small hack: the log file of optics is actually the main output file.
self.output_file = self.log_file
@ -4234,7 +4234,7 @@ class OpticTask(Task):
"""
def get_results(self, **kwargs):
return super(OpticTask, self).get_results(**kwargs)
return super().get_results(**kwargs)
def fix_abicritical(self):
"""
@ -4317,7 +4317,7 @@ class OpticTask(Task):
self.manager.exclude_nodes(error.nodes)
self.reset_from_scratch()
self.set_status(self.S_READY, msg='excluding nodes')
except:
except Exception:
raise FixQueueCriticalError
else:
self.set_status(self.S_ERROR, msg='Node error but no node identified.')
@ -4526,7 +4526,7 @@ class AnaddbTask(Task):
if self.ddk_node is not None:
deps.update({self.ddk_node: "DDK"})
super(AnaddbTask, self).__init__(input=anaddb_input, workdir=workdir, manager=manager, deps=deps)
super().__init__(input=anaddb_input, workdir=workdir, manager=manager, deps=deps)
@classmethod
def temp_shell_task(cls, inp, ddb_node, mpi_procs=1,

View File

@ -133,7 +133,7 @@ class File(object):
"""Remove the file."""
try:
os.remove(self.path)
except:
except Exception:
pass
def move(self, dst):
@ -227,7 +227,7 @@ class Directory(object):
for path in self.list_filepaths():
try:
os.remove(path)
except:
except Exception:
pass
def path_in(self, file_basename):