mirror of https://gitlab.com/QEF/q-e.git
272 lines
11 KiB
Python
272 lines
11 KiB
Python
'''
|
|
testcode2.validation
|
|
--------------------
|
|
|
|
Classes and functions for comparing data.
|
|
|
|
:copyright: (c) 2012 James Spencer.
|
|
:license: modified BSD; see LICENSE for more details.
|
|
'''
|
|
|
|
from __future__ import print_function
|
|
import re
|
|
import sys
|
|
import warnings
|
|
|
|
import testcode2.ansi as ansi
|
|
import testcode2.compatibility as compat
|
|
import testcode2.exceptions as exceptions
|
|
|
|
class Status:
|
|
'''Enum-esque object for storing whether an object passed a comparison.
|
|
|
|
bools: iterable of boolean objects. If all booleans are True (False) then the
|
|
status is set to pass (fail) and if only some booleans are True, the
|
|
status is set to warning (partial pass).
|
|
status: existing status to use. bools is ignored if status is supplied.
|
|
name: name of status (unknown, skipped, passed, partial, failed) to use.
|
|
Setting name overrides bools and status.
|
|
'''
|
|
def __init__(self, bools=None, status=None, name=None):
|
|
(self._unknown, self._skipped) = (-2, -1)
|
|
(self._passed, self._partial, self._failed) = (0, 1, 2)
|
|
if name is not None:
|
|
setattr(self, 'status', getattr(self, '_'+name))
|
|
elif status is not None:
|
|
self.status = status
|
|
elif bools:
|
|
if compat.compat_all(bools):
|
|
self.status = self._passed
|
|
elif compat.compat_any(bools):
|
|
self.status = self._partial
|
|
else:
|
|
self.status = self._failed
|
|
else:
|
|
self.status = self._unknown
|
|
def unknown(self):
|
|
'''Return true if stored status is unknown.'''
|
|
return self.status == self._unknown
|
|
def skipped(self):
|
|
'''Return true if stored status is skipped.'''
|
|
return self.status == self._skipped
|
|
def passed(self):
|
|
'''Return true if stored status is passed.'''
|
|
return self.status == self._passed
|
|
def warning(self):
|
|
'''Return true if stored status is a partial pass.'''
|
|
return self.status == self._partial
|
|
def failed(self):
|
|
'''Return true if stored status is failed.'''
|
|
return self.status == self._failed
|
|
def print_status(self, msg=None, verbose=1, vspace=True):
|
|
'''Print status.
|
|
|
|
msg: optional message to print out after status.
|
|
verbose: 0: suppress all output except for . (for pass), U (for unknown),
|
|
W (for warning/partial pass) and F (for fail) without a newline.
|
|
1: print 'Passed', 'Unknown', 'WARNING' or '**FAILED**'.
|
|
2: as for 1 plus print msg (if supplied).
|
|
3: as for 2 plus print a blank line.
|
|
vspace: print out extra new line afterwards if verbose > 1.
|
|
'''
|
|
if verbose > 0:
|
|
if self.status == self._unknown:
|
|
print('Unknown.')
|
|
elif self.status == self._passed:
|
|
print('Passed.')
|
|
elif self.status == self._skipped:
|
|
print(('%s.' % ansi.ansi_format('SKIPPED', 'blue')))
|
|
elif self.status == self._partial:
|
|
print(('%s.' % ansi.ansi_format('WARNING', 'blue')))
|
|
else:
|
|
print(('%s.' % ansi.ansi_format('**FAILED**', 'red', 'normal', 'bold')))
|
|
if msg and verbose > 1:
|
|
print(msg)
|
|
if vspace and verbose > 1:
|
|
print('')
|
|
else:
|
|
if self.status == self._unknown:
|
|
sys.stdout.write('U')
|
|
elif self.status == self._skipped:
|
|
sys.stdout.write('S')
|
|
elif self.status == self._passed:
|
|
sys.stdout.write('.')
|
|
elif self.status == self._partial:
|
|
sys.stdout.write('W')
|
|
else:
|
|
sys.stdout.write('F')
|
|
sys.stdout.flush()
|
|
def __add__(self, other):
|
|
'''Add two status objects.
|
|
|
|
Return the maximum level (ie most "failed") status.'''
|
|
return Status(status=max(self.status, other.status))
|
|
|
|
class Tolerance:
|
|
'''Store absolute and relative tolerances
|
|
|
|
Given are regarded as equal if they are within these tolerances.
|
|
|
|
name: name of tolerance object.
|
|
absolute: threshold for absolute difference between two numbers.
|
|
relative: threshold for relative difference between two numbers.
|
|
strict: if true, then require numbers to be within both thresholds.
|
|
'''
|
|
def __init__(self, name='', absolute=None, relative=None, strict=True):
|
|
self.name = name
|
|
self.absolute = absolute
|
|
self.relative = relative
|
|
if not self.absolute and not self.relative:
|
|
err = 'Neither absolute nor relative tolerance given.'
|
|
raise exceptions.TestCodeError(err)
|
|
self.strict = strict
|
|
def __repr__(self):
|
|
return (self.absolute, self.relative, self.strict).__repr__()
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
def __eq__(self, other):
|
|
return (isinstance(other, self.__class__) and
|
|
self.__dict__ == other.__dict__)
|
|
def validate(self, test_val, benchmark_val, key=''):
|
|
'''Compare test and benchmark values to within the tolerances.'''
|
|
status = Status([True])
|
|
msg = ['values are within tolerance.']
|
|
compare = '(Test: %s. Benchmark: %s.)' % (test_val, benchmark_val)
|
|
try:
|
|
# Check float is not NaN (which we can't compare).
|
|
if compat.isnan(test_val) or compat.isnan(benchmark_val):
|
|
status = Status([False])
|
|
msg = ['cannot compare NaNs.']
|
|
else:
|
|
# Check if values are within tolerances.
|
|
(status_absolute, msg_absolute) = \
|
|
self.validate_absolute(benchmark_val, test_val)
|
|
(status_relative, msg_relative) = \
|
|
self.validate_relative(benchmark_val, test_val)
|
|
if self.absolute and self.relative and not self.strict:
|
|
# Require only one of thresholds to be met.
|
|
status = Status([status_relative.passed(),
|
|
status_absolute.passed()])
|
|
else:
|
|
# Only have one or other of thresholds (require active one
|
|
# to be met) or have both and strict mode is on (require
|
|
# both to be met).
|
|
status = status_relative + status_absolute
|
|
err_stat = ''
|
|
if status.warning():
|
|
err_stat = 'Warning: '
|
|
elif status.failed():
|
|
err_stat = 'ERROR: '
|
|
msg = []
|
|
if self.absolute and msg_absolute:
|
|
msg.append('%s%s %s' % (err_stat, msg_absolute, compare))
|
|
if self.relative and msg_relative:
|
|
msg.append('%s%s %s' % (err_stat, msg_relative, compare))
|
|
except TypeError:
|
|
if test_val != benchmark_val:
|
|
# require test and benchmark values to be equal (within python's
|
|
# definition of equality).
|
|
status = Status([False])
|
|
msg = ['values are different. ' + compare]
|
|
if key and msg:
|
|
msg.insert(0, key)
|
|
msg = '\n '.join(msg)
|
|
else:
|
|
msg = '\n'.join(msg)
|
|
return (status, msg)
|
|
|
|
def validate_absolute(self, benchmark_val, test_val):
|
|
'''Compare test and benchmark values to the absolute tolerance.'''
|
|
if self.absolute:
|
|
diff = test_val - benchmark_val
|
|
err = abs(diff)
|
|
passed = err < self.absolute
|
|
msg = ''
|
|
if not passed:
|
|
msg = ('absolute error %.2e greater than %.2e.' %
|
|
(err, self.absolute))
|
|
else:
|
|
passed = True
|
|
msg = 'No absolute tolerance set. Passing without checking.'
|
|
return (Status([passed]), msg)
|
|
|
|
def validate_relative(self, benchmark_val, test_val):
|
|
'''Compare test and benchmark values to the relative tolerance.'''
|
|
if self.relative:
|
|
diff = test_val - benchmark_val
|
|
if benchmark_val == 0 and diff == 0:
|
|
err = 0
|
|
elif benchmark_val == 0:
|
|
err = float("Inf")
|
|
else:
|
|
err = abs(diff/benchmark_val)
|
|
passed = err < self.relative
|
|
msg = ''
|
|
if not passed:
|
|
msg = ('relative error %.2e greater than %.2e.' %
|
|
(err, self.relative))
|
|
else:
|
|
passed = True
|
|
msg = 'No relative tolerance set. Passing without checking.'
|
|
return (Status([passed]), msg)
|
|
|
|
|
|
def compare_data(benchmark, test, default_tolerance, tolerances,
|
|
ignore_fields=None):
|
|
'''Compare two data dictionaries.'''
|
|
ignored_params = compat.compat_set(ignore_fields or tuple())
|
|
bench_params = compat.compat_set(benchmark) - ignored_params
|
|
test_params = compat.compat_set(test) - ignored_params
|
|
# Check both the key names and the number of keys in case there are
|
|
# different numbers of duplicate keys.
|
|
comparable = (bench_params == test_params)
|
|
key_counts = dict((key,0) for key in bench_params | test_params)
|
|
for (key, val) in list(benchmark.items()):
|
|
if key not in ignored_params:
|
|
key_counts[key] += len(val)
|
|
for (key, val) in list(test.items()):
|
|
if key not in ignored_params:
|
|
key_counts[key] -= len(val)
|
|
comparable = comparable and compat.compat_all(kc == 0 for kc in list(key_counts.values()))
|
|
status = Status()
|
|
msg = []
|
|
|
|
if not comparable:
|
|
status = Status([False])
|
|
bench_only = bench_params - test_params
|
|
test_only = test_params - bench_params
|
|
msg.append('Different sets of data extracted from benchmark and test.')
|
|
if bench_only:
|
|
msg.append(" Data only in benchmark: %s." % ", ".join(bench_only))
|
|
if test_only:
|
|
msg.append(" Data only in test: %s." % ", ".join(test_only))
|
|
bench_more = [key for key in key_counts
|
|
if key_counts[key] > 0 and key not in bench_only]
|
|
test_more = [key for key in key_counts
|
|
if key_counts[key] < 0 and key not in test_only]
|
|
if bench_more:
|
|
msg.append(" More data in benchmark than in test: %s." %
|
|
", ".join(bench_more))
|
|
if test_more:
|
|
msg.append(" More data in test than in benchmark: %s." %
|
|
", ".join(test_more))
|
|
|
|
for param in (bench_params & test_params):
|
|
param_tol = tolerances.get(param, default_tolerance)
|
|
if param_tol == default_tolerance:
|
|
# See if there's a regex that matches.
|
|
tol_matches = [tol for tol in list(tolerances.values())
|
|
if tol.name and re.match(tol.name, param)]
|
|
if tol_matches:
|
|
param_tol = tol_matches[0]
|
|
if len(tol_matches) > 1:
|
|
warnings.warn('Multiple tolerance regexes match. '
|
|
'Using %s.' % (param_tol.name))
|
|
for bench_value, test_value in zip(benchmark[param], test[param]):
|
|
key_status, err = param_tol.validate(test_value, bench_value, param)
|
|
status += key_status
|
|
if not key_status.passed() and err:
|
|
msg.append(err)
|
|
|
|
return (comparable, status, "\n".join(msg))
|