This commit is contained in:
ydshieh 2024-03-21 09:18:51 +01:00
parent 3dae978bec
commit 0145157b20
3 changed files with 9 additions and 3 deletions

View File

@ -15,6 +15,7 @@
Benchmark utilities - module containing base classes for benchmarking
"""
import json
import os.path
import timeit
@ -73,7 +74,7 @@ class BenchMark:
return report.__name__
return report
def _report(self, result, output_path=None, only_result=False):
def _report(self, result, output_path=None, only_result=False, overwrite=False):
self._run_buffer["config"]["report_kwargs"]["output_path"] = output_path
self._run_buffer["config"]["report_kwargs"]["only_result"] = only_result
@ -86,6 +87,13 @@ class BenchMark:
complete_report = self._convert_to_json(self._buffer)
if output_path is not None:
if not os.path.isdir(output_path):
os.makedirs(output_path)
output_path = os.path.join("benchmark_report.json")
if os.path.isfile(output_path) and not overwrite:
raise ValueError(f"TODO: add error")
with open(output_path, "w", encoding="UTF-8") as fp:
json.dump(complete_report, fp, ensure_ascii=False, indent=4)

View File

@ -214,7 +214,6 @@ if __name__ == "__main__":
"mode": compile,
},
"inputs_kwargs": {},
"report_kwargs": {"output_path": "benchmark_report.json"},
}
run_configs.append(run_kwargs)
else:

View File

@ -59,7 +59,6 @@ if __name__ == "__main__":
"measure_kwargs": {"number": 2, "repeat": 3},
"target_kwargs": {"model_class": "AutoModel", "repo_id": repo_id},
"inputs_kwargs": [{}],
"report_kwargs": {"output_path": "benchmark_report.json"},
}
run_configs = [run_kwargs]
else: