This commit is contained in:
ydshieh 2024-03-13 19:23:09 +01:00
parent f9a86a78f1
commit d556c587fa
3 changed files with 11 additions and 14 deletions

View File

@ -34,13 +34,13 @@ class BenchMark:
self._buffer["measure_kwargs"] = {}
self._buffer["report_kwargs"] = {}
def measure(self, func, **measure_kwargs):
def _measure(self, func, **measure_kwargs):
raise NotImplementedError
def target(self, **target_kwargs):
def _target(self, **target_kwargs):
raise NotImplementedError
def inputs(self, **inputs_kwargs):
def _inputs(self, **inputs_kwargs):
return {}
def _convert_to_json(self, report):
@ -52,7 +52,7 @@ class BenchMark:
return report.__name__
return report
def report(self, result, only_result=False, output_path=None):
def _report(self, result, only_result=False, output_path=None):
report = {"result": result}
if not only_result:
report["configuration"] = self._buffer
@ -78,23 +78,23 @@ class BenchMark:
if measure_kwargs is None:
measure_kwargs = {}
target = self.target(**target_kwargs)
target = self._target(**target_kwargs)
all_inputs_kwargs = [inputs_kwargs] if isinstance(inputs_kwargs, dict) else inputs_kwargs
results = []
for _inputs_kwargs in all_inputs_kwargs:
inputs = self.inputs(**_inputs_kwargs)
result = self.measure(target, **measure_kwargs)(**inputs)
inputs = self._inputs(**_inputs_kwargs)
result = self._measure(target, **measure_kwargs)(**inputs)
results.append(result)
if isinstance(inputs_kwargs, dict):
results = results[0]
return self.report(results, **report_kwargs)
return self._report(results, **report_kwargs)
class SpeedBenchMark(BenchMark):
def measure(self, func, number=3, repeat=1):
def _measure(self, func, number=3, repeat=1):
self._buffer["measure_kwargs"]["number"] = number
self._buffer["measure_kwargs"]["repeat"] = repeat

File diff suppressed because one or more lines are too long

View File

@ -4,7 +4,7 @@ from benchmark_utils_generic import BenchMark, SpeedBenchMark
class FromPretrainedBenchMark(BenchMark):
def target(self, model_class, repo_id):
def _target(self, model_class, repo_id):
self._buffer["target_kwargs"]["model_class"] = model_class
self._buffer["target_kwargs"]["repo_id"] = repo_id