quantum-espresso/test-suite/testcode/lib/testcode2/validation.py

271 lines
11 KiB
Python

'''
testcode2.validation
--------------------
Classes and functions for comparing data.
:copyright: (c) 2012 James Spencer.
:license: modified BSD; see LICENSE for more details.
'''
import re
import sys
import warnings
import testcode2.ansi as ansi
import testcode2.compatibility as compat
import testcode2.exceptions as exceptions
class Status:
'''Enum-esque object for storing whether an object passed a comparison.
bools: iterable of boolean objects. If all booleans are True (False) then the
status is set to pass (fail) and if only some booleans are True, the
status is set to warning (partial pass).
status: existing status to use. bools is ignored if status is supplied.
name: name of status (unknown, skipped, passed, partial, failed) to use.
Setting name overrides bools and status.
'''
def __init__(self, bools=None, status=None, name=None):
(self._unknown, self._skipped) = (-2, -1)
(self._passed, self._partial, self._failed) = (0, 1, 2)
if name is not None:
setattr(self, 'status', getattr(self, '_'+name))
elif status is not None:
self.status = status
elif bools:
if compat.compat_all(bools):
self.status = self._passed
elif compat.compat_any(bools):
self.status = self._partial
else:
self.status = self._failed
else:
self.status = self._unknown
def unknown(self):
'''Return true if stored status is unknown.'''
return self.status == self._unknown
def skipped(self):
'''Return true if stored status is skipped.'''
return self.status == self._skipped
def passed(self):
'''Return true if stored status is passed.'''
return self.status == self._passed
def warning(self):
'''Return true if stored status is a partial pass.'''
return self.status == self._partial
def failed(self):
'''Return true if stored status is failed.'''
return self.status == self._failed
def print_status(self, msg=None, verbose=1, vspace=True):
'''Print status.
msg: optional message to print out after status.
verbose: 0: suppress all output except for . (for pass), U (for unknown),
W (for warning/partial pass) and F (for fail) without a newline.
1: print 'Passed', 'Unknown', 'WARNING' or '**FAILED**'.
2: as for 1 plus print msg (if supplied).
3: as for 2 plus print a blank line.
vspace: print out extra new line afterwards if verbose > 1.
'''
if verbose > 0:
if self.status == self._unknown:
print('Unknown.')
elif self.status == self._passed:
print('Passed.')
elif self.status == self._skipped:
print('%s.' % ansi.ansi_format('SKIPPED', 'blue'))
elif self.status == self._partial:
print('%s.' % ansi.ansi_format('WARNING', 'blue'))
else:
print('%s.' % ansi.ansi_format('**FAILED**', 'red', 'normal', 'bold'))
if msg and verbose > 1:
print(msg)
if vspace and verbose > 1:
print('')
else:
if self.status == self._unknown:
sys.stdout.write('U')
elif self.status == self._skipped:
sys.stdout.write('S')
elif self.status == self._passed:
sys.stdout.write('.')
elif self.status == self._partial:
sys.stdout.write('W')
else:
sys.stdout.write('F')
sys.stdout.flush()
def __add__(self, other):
'''Add two status objects.
Return the maximum level (ie most "failed") status.'''
return Status(status=max(self.status, other.status))
class Tolerance:
'''Store absolute and relative tolerances
Given are regarded as equal if they are within these tolerances.
name: name of tolerance object.
absolute: threshold for absolute difference between two numbers.
relative: threshold for relative difference between two numbers.
strict: if true, then require numbers to be within both thresholds.
'''
def __init__(self, name='', absolute=None, relative=None, strict=True):
self.name = name
self.absolute = absolute
self.relative = relative
if not self.absolute and not self.relative:
err = 'Neither absolute nor relative tolerance given.'
raise exceptions.TestCodeError(err)
self.strict = strict
def __repr__(self):
return (self.absolute, self.relative, self.strict).__repr__()
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def validate(self, test_val, benchmark_val, key=''):
'''Compare test and benchmark values to within the tolerances.'''
status = Status([True])
msg = ['values are within tolerance.']
compare = '(Test: %s. Benchmark: %s.)' % (test_val, benchmark_val)
try:
# Check float is not NaN (which we can't compare).
if compat.isnan(test_val) or compat.isnan(benchmark_val):
status = Status([False])
msg = ['cannot compare NaNs.']
else:
# Check if values are within tolerances.
(status_absolute, msg_absolute) = \
self.validate_absolute(benchmark_val, test_val)
(status_relative, msg_relative) = \
self.validate_relative(benchmark_val, test_val)
if self.absolute and self.relative and not self.strict:
# Require only one of thresholds to be met.
status = Status([status_relative.passed(),
status_absolute.passed()])
else:
# Only have one or other of thresholds (require active one
# to be met) or have both and strict mode is on (require
# both to be met).
status = status_relative + status_absolute
err_stat = ''
if status.warning():
err_stat = 'Warning: '
elif status.failed():
err_stat = 'ERROR: '
msg = []
if self.absolute and msg_absolute:
msg.append('%s%s %s' % (err_stat, msg_absolute, compare))
if self.relative and msg_relative:
msg.append('%s%s %s' % (err_stat, msg_relative, compare))
except TypeError:
if test_val != benchmark_val:
# require test and benchmark values to be equal (within python's
# definition of equality).
status = Status([False])
msg = ['values are different. ' + compare]
if key and msg:
msg.insert(0, key)
msg = '\n '.join(msg)
else:
msg = '\n'.join(msg)
return (status, msg)
def validate_absolute(self, benchmark_val, test_val):
'''Compare test and benchmark values to the absolute tolerance.'''
if self.absolute:
diff = test_val - benchmark_val
err = abs(diff)
passed = err < self.absolute
msg = ''
if not passed:
msg = ('absolute error %.2e greater than %.2e.' %
(err, self.absolute))
else:
passed = True
msg = 'No absolute tolerance set. Passing without checking.'
return (Status([passed]), msg)
def validate_relative(self, benchmark_val, test_val):
'''Compare test and benchmark values to the relative tolerance.'''
if self.relative:
diff = test_val - benchmark_val
if benchmark_val == 0 and diff == 0:
err = 0
elif benchmark_val == 0:
err = float("Inf")
else:
err = abs(diff/benchmark_val)
passed = err < self.relative
msg = ''
if not passed:
msg = ('relative error %.2e greater than %.2e.' %
(err, self.relative))
else:
passed = True
msg = 'No relative tolerance set. Passing without checking.'
return (Status([passed]), msg)
def compare_data(benchmark, test, default_tolerance, tolerances,
ignore_fields=None):
'''Compare two data dictionaries.'''
ignored_params = compat.compat_set(ignore_fields or tuple())
bench_params = compat.compat_set(benchmark) - ignored_params
test_params = compat.compat_set(test) - ignored_params
# Check both the key names and the number of keys in case there are
# different numbers of duplicate keys.
comparable = (bench_params == test_params)
key_counts = dict((key,0) for key in bench_params | test_params)
for (key, val) in benchmark.items():
if key not in ignored_params:
key_counts[key] += len(val)
for (key, val) in test.items():
if key not in ignored_params:
key_counts[key] -= len(val)
comparable = comparable and compat.compat_all(kc == 0 for kc in key_counts.values())
status = Status()
msg = []
if not comparable:
status = Status([False])
bench_only = bench_params - test_params
test_only = test_params - bench_params
msg.append('Different sets of data extracted from benchmark and test.')
if bench_only:
msg.append(" Data only in benchmark: %s." % ", ".join(bench_only))
if test_only:
msg.append(" Data only in test: %s." % ", ".join(test_only))
bench_more = [key for key in key_counts
if key_counts[key] > 0 and key not in bench_only]
test_more = [key for key in key_counts
if key_counts[key] < 0 and key not in test_only]
if bench_more:
msg.append(" More data in benchmark than in test: %s." %
", ".join(bench_more))
if test_more:
msg.append(" More data in test than in benchmark: %s." %
", ".join(test_more))
for param in (bench_params & test_params):
param_tol = tolerances.get(param, default_tolerance)
if param_tol == default_tolerance:
# See if there's a regex that matches.
tol_matches = [tol for tol in tolerances.values()
if tol.name and re.match(tol.name, param)]
if tol_matches:
param_tol = tol_matches[0]
if len(tol_matches) > 1:
warnings.warn('Multiple tolerance regexes match. '
'Using %s.' % (param_tol.name))
for bench_value, test_value in zip(benchmark[param], test[param]):
key_status, err = param_tol.validate(test_value, bench_value, param)
status += key_status
if not key_status.passed() and err:
msg.append(err)
return (comparable, status, "\n".join(msg))