update
This commit is contained in:
parent
f9a86a78f1
commit
d556c587fa
|
@ -34,13 +34,13 @@ class BenchMark:
|
|||
self._buffer["measure_kwargs"] = {}
|
||||
self._buffer["report_kwargs"] = {}
|
||||
|
||||
def measure(self, func, **measure_kwargs):
|
||||
def _measure(self, func, **measure_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def target(self, **target_kwargs):
|
||||
def _target(self, **target_kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def inputs(self, **inputs_kwargs):
|
||||
def _inputs(self, **inputs_kwargs):
|
||||
return {}
|
||||
|
||||
def _convert_to_json(self, report):
|
||||
|
@ -52,7 +52,7 @@ class BenchMark:
|
|||
return report.__name__
|
||||
return report
|
||||
|
||||
def report(self, result, only_result=False, output_path=None):
|
||||
def _report(self, result, only_result=False, output_path=None):
|
||||
report = {"result": result}
|
||||
if not only_result:
|
||||
report["configuration"] = self._buffer
|
||||
|
@ -78,23 +78,23 @@ class BenchMark:
|
|||
if measure_kwargs is None:
|
||||
measure_kwargs = {}
|
||||
|
||||
target = self.target(**target_kwargs)
|
||||
target = self._target(**target_kwargs)
|
||||
|
||||
all_inputs_kwargs = [inputs_kwargs] if isinstance(inputs_kwargs, dict) else inputs_kwargs
|
||||
results = []
|
||||
for _inputs_kwargs in all_inputs_kwargs:
|
||||
inputs = self.inputs(**_inputs_kwargs)
|
||||
result = self.measure(target, **measure_kwargs)(**inputs)
|
||||
inputs = self._inputs(**_inputs_kwargs)
|
||||
result = self._measure(target, **measure_kwargs)(**inputs)
|
||||
results.append(result)
|
||||
|
||||
if isinstance(inputs_kwargs, dict):
|
||||
results = results[0]
|
||||
|
||||
return self.report(results, **report_kwargs)
|
||||
return self._report(results, **report_kwargs)
|
||||
|
||||
|
||||
class SpeedBenchMark(BenchMark):
|
||||
def measure(self, func, number=3, repeat=1):
|
||||
def _measure(self, func, number=3, repeat=1):
|
||||
self._buffer["measure_kwargs"]["number"] = number
|
||||
self._buffer["measure_kwargs"]["repeat"] = repeat
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -4,7 +4,7 @@ from benchmark_utils_generic import BenchMark, SpeedBenchMark
|
|||
|
||||
|
||||
class FromPretrainedBenchMark(BenchMark):
|
||||
def target(self, model_class, repo_id):
|
||||
def _target(self, model_class, repo_id):
|
||||
self._buffer["target_kwargs"]["model_class"] = model_class
|
||||
self._buffer["target_kwargs"]["repo_id"] = repo_id
|
||||
|
||||
|
|
Loading…
Reference in New Issue