This commit is contained in:
ydshieh 2024-03-20 11:22:22 +01:00
parent ff3b978e5a
commit 5f98ee7b46
3 changed files with 47 additions and 27 deletions

View File

@ -27,19 +27,23 @@ class BenchMark:
def __init__(self, *arg, **kwargs):
self._buffer = {"init_kwargs": {}, "runs": []}
self._run_buffer = {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
"config": {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
},
"result": None,
}
def _reset_run_buffer(self):
self._run_buffer = {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
"config": {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
},
"result": None,
}
@ -119,8 +123,8 @@ class SpeedBenchMark(BenchMark):
"""A simple class used to benchmark the running time of a callable."""
def _measure(self, func, number=3, repeat=1):
self._run_buffer["measure_kwargs"]["number"] = number
self._run_buffer["measure_kwargs"]["repeat"] = repeat
self._run_buffer["config"]["measure_kwargs"]["number"] = number
self._run_buffer["config"]["measure_kwargs"]["repeat"] = repeat
def wrapper(*args, **kwargs):
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,7 @@
"""
Benchmark for models' `from_pretrained` method
"""
import argparse
import json
from benchmark_utils_generic import BenchMark, SpeedBenchMark
@ -21,8 +22,8 @@ from benchmark_utils_generic import BenchMark, SpeedBenchMark
class FromPretrainedBenchMark(BenchMark):
def _target(self, model_class, repo_id):
self._run_buffer["target_kwargs"]["model_class"] = model_class
self._run_buffer["target_kwargs"]["repo_id"] = repo_id
self._run_buffer["config"]["target_kwargs"]["model_class"] = model_class
self._run_buffer["config"]["target_kwargs"]["repo_id"] = repo_id
def target():
_ = model_class.from_pretrained(repo_id)
@ -37,15 +38,30 @@ class FromPretrainedSpeedBenchMark(SpeedBenchMark, FromPretrainedBenchMark):
if __name__ == "__main__":
from transformers import AutoModel
repo_id = "bert-base-uncased"
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
benchmark = FromPretrainedSpeedBenchMark()
if args.config_path is None:
init_kwargs = {}
run_kwargs = {
"measure_kwargs": {"number": 2, "repeat": 3},
"target_kwargs": {"model_class": AutoModel, "repo_id": repo_id},
"inputs_kwargs": [{}],
"report_kwargs": {"output_path": "benchmark_report.json"},
}
result = benchmark.run(**run_kwargs)
print(json.dumps(result, indent=4))
repo_id = "bert-base-uncased"
run_kwargs = {
"measure_kwargs": {"number": 2, "repeat": 3},
"target_kwargs": {"model_class": AutoModel, "repo_id": repo_id},
"inputs_kwargs": [{}],
"report_kwargs": {"output_path": "benchmark_report.json"},
}
run_configs = [run_kwargs]
else:
with open(args.coonfig_path) as fp:
config = json.load(fp)
init_kwargs = config["init_kwargs"]
run_configs = [run["config"] for run in config["runs"]]
benchmark = FromPretrainedSpeedBenchMark(**init_kwargs)
for run_config in run_configs:
result = benchmark.run(**run_config)