Compare commits

...

55 Commits

Author SHA1 Message Date
ydshieh d3393528d4 update 2024-03-21 09:47:39 +01:00
ydshieh 093e376c46 update 2024-03-21 09:46:44 +01:00
ydshieh 6c97e80b0e empty 2024-03-21 09:37:47 +01:00
ydshieh 62408e13b8 update 2024-03-21 09:31:08 +01:00
ydshieh ba476ae57e empty 2024-03-21 09:27:18 +01:00
ydshieh d440260cdd update 2024-03-21 09:27:05 +01:00
ydshieh 559fd7e852 empty 2024-03-21 09:21:38 +01:00
ydshieh a8d66fbb7b update 2024-03-21 09:21:32 +01:00
ydshieh 20c82e9b4e empty 2024-03-21 09:20:45 +01:00
ydshieh 5d5e5a5c1d update 2024-03-21 09:20:41 +01:00
ydshieh 0a2efe085f empty 2024-03-21 09:19:20 +01:00
ydshieh 0145157b20 update 2024-03-21 09:18:51 +01:00
ydshieh 3dae978bec update 2024-03-20 17:49:52 +01:00
ydshieh d60ffc4b4a update 2024-03-20 17:25:18 +01:00
ydshieh 2977be1bdc update 2024-03-20 17:23:27 +01:00
ydshieh 47b0550816 update 2024-03-20 17:21:47 +01:00
ydshieh aae1fbf76e update 2024-03-20 17:17:25 +01:00
ydshieh e81191c2da update 2024-03-20 17:08:35 +01:00
ydshieh d08fd0b7bf update 2024-03-20 16:05:24 +01:00
ydshieh fed1dd9990 update 2024-03-20 15:58:00 +01:00
ydshieh 29a478007d update 2024-03-20 15:55:34 +01:00
ydshieh 6ad1e309b5 update 2024-03-20 15:47:57 +01:00
ydshieh 0670a0110b update 2024-03-20 15:03:41 +01:00
ydshieh 185435c607 update 2024-03-20 11:47:28 +01:00
ydshieh 5f98ee7b46 update 2024-03-20 11:22:22 +01:00
ydshieh ff3b978e5a update 2024-03-20 10:38:13 +01:00
ydshieh 472ce78baf update 2024-03-18 10:48:09 +01:00
ydshieh 2a04bd3386 update 2024-03-18 10:27:09 +01:00
ydshieh 93d6ccaae0 update 2024-03-18 10:26:27 +01:00
ydshieh f25b457534 update 2024-03-18 10:18:11 +01:00
ydshieh a5070e90af update 2024-03-18 10:17:29 +01:00
ydshieh 751c4a4f86 update 2024-03-18 10:17:02 +01:00
ydshieh ec2a34a5af update 2024-03-18 10:16:23 +01:00
ydshieh 35ff6456c2 update 2024-03-18 10:11:20 +01:00
ydshieh cd2db2fdec update 2024-03-15 18:42:58 +01:00
ydshieh bb5da0421d update 2024-03-15 18:26:19 +01:00
ydshieh 01f89c9256 update 2024-03-15 17:07:50 +01:00
ydshieh 65224b5357 update 2024-03-15 16:54:38 +01:00
Yih-Dar 37e178847a
Update src/transformers/benchmark/benchmark_utils_generic.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-03-13 19:38:48 +01:00
Yih-Dar 47837ddd28
Update src/transformers/benchmark/from_pretrained_benchmark.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-03-13 19:38:36 +01:00
Yih-Dar e825a4169e
Update src/transformers/benchmark/from_pretrained_benchmark.py
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
2024-03-13 19:38:12 +01:00
ydshieh d556c587fa update 2024-03-13 19:23:09 +01:00
ydshieh f9a86a78f1 update 2024-03-13 14:38:59 +01:00
ydshieh 324ef81649 update 2024-03-13 14:38:59 +01:00
ydshieh 2eb4a9d45b save 2024-03-13 14:38:59 +01:00
ydshieh d7dfdf7281 save 2024-03-13 14:38:59 +01:00
ydshieh 335d241be1 save 2024-03-13 14:38:59 +01:00
ydshieh 101b639651 style 2024-03-13 14:38:59 +01:00
ydshieh b00a094787 fix 2024-03-13 14:38:59 +01:00
ydshieh c27bdbe8e3 rename classes 2024-03-13 14:38:59 +01:00
ydshieh 5a7bf5c937 rename 2024-03-13 14:38:59 +01:00
ydshieh b4fe856387 update 2024-03-13 14:38:58 +01:00
ydshieh 4c4010b136 update 2024-03-13 14:38:58 +01:00
ydshieh dc47f69d3f update 2024-03-13 14:38:58 +01:00
ydshieh 28aedd00b7 benchmark 2024-03-13 14:38:58 +01:00
5 changed files with 557 additions and 0 deletions

View File

@ -0,0 +1,151 @@
# Copyright 2024 The HuggingFace Team and the AllenNLP authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmark utilities - module containing base classes for benchmarking
"""
import json
import os.path
import timeit
class BenchMark:
"""Base class specifying the methods to be implemented in benchmark subclasses.
All the methods except `run` are designed to be private: only the `run` method should be used by an end user.
"""
def __init__(self, *arg, **kwargs):
self._buffer = {"init_kwargs": {}, "runs": []}
self._run_buffer = {
"config": {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
},
"result": None,
}
def _reset_run_buffer(self):
self._run_buffer = {
"config": {
"inputs_kwargs": {},
"target_kwargs": {},
"measure_kwargs": {},
"report_kwargs": {},
},
"result": None,
}
def _measure(self, func, **measure_kwargs):
"""Return a callable that, when called, will return some measurement results for the argument `func`.
See `SpeedBenchMark` for an example implementation.
"""
raise NotImplementedError
def _target(self, **target_kwargs):
"""Return a callable against which we would like to perform benchmark.
See `FromPretrainedBenchMark` and `CacheBenchMark` for example implementations.
"""
raise NotImplementedError
def _inputs(self, **inputs_kwargs):
return {}
def _convert_to_json(self, report):
if isinstance(report, list):
return [self._convert_to_json(x) for x in report]
if isinstance(report, dict):
return {k: self._convert_to_json(v) for k, v in report.items()}
if isinstance(report, type):
return report.__name__
return report
def _report(self, result, output_path=None, only_result=False):
self._run_buffer["config"]["report_kwargs"]["output_path"] = output_path
self._run_buffer["config"]["report_kwargs"]["only_result"] = only_result
self._run_buffer["result"] = result
self._buffer["runs"].append(self._run_buffer)
report = {"result": result}
if not only_result:
report = self._buffer["runs"][-1]
complete_report = self._convert_to_json(self._buffer)
if output_path is not None:
if not os.path.isdir(output_path):
os.makedirs(output_path)
output_path = os.path.join(output_path, "benchmark_report.json")
with open(output_path, "w", encoding="UTF-8") as fp:
json.dump(complete_report, fp, ensure_ascii=False, indent=4)
report = self._convert_to_json(report)
return report
def run(self, measure_kwargs=None, target_kwargs=None, inputs_kwargs=None, report_kwargs=None):
self._reset_run_buffer()
if measure_kwargs is None:
measure_kwargs = {}
if target_kwargs is None:
target_kwargs = {}
if inputs_kwargs is None:
inputs_kwargs = {}
if report_kwargs is None:
report_kwargs = {}
if measure_kwargs is None:
measure_kwargs = {}
target = self._target(**target_kwargs)
all_inputs_kwargs = [inputs_kwargs] if isinstance(inputs_kwargs, dict) else inputs_kwargs
results = []
for _inputs_kwargs in all_inputs_kwargs:
inputs = self._inputs(**_inputs_kwargs)
result = self._measure(target, **measure_kwargs)(**inputs)
results.append(result)
if isinstance(inputs_kwargs, dict):
results = results[0]
return self._report(results, **report_kwargs)
class SpeedBenchMark(BenchMark):
"""A simple class used to benchmark the running time of a callable."""
def _measure(self, func, number=3, repeat=1):
self._run_buffer["config"]["measure_kwargs"]["number"] = number
self._run_buffer["config"]["measure_kwargs"]["repeat"] = repeat
def wrapper(*args, **kwargs):
# as written in https://docs.python.org/2/library/timeit.html#timeit.Timer.repeat, min should be taken rather than the average
def _func():
func(*args, **kwargs)
runtimes = timeit.repeat(
_func,
repeat=repeat,
number=number,
)
return {"time": min(runtimes) / number}
return wrapper

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,75 @@
# Copyright 2024 The HuggingFace Team and the AllenNLP authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmark for models' `from_pretrained` method
"""
import argparse
import json
from benchmark_utils_generic import BenchMark, SpeedBenchMark
import transformers
class FromPretrainedBenchMark(BenchMark):
def _target(self, model_class, repo_id):
self._run_buffer["config"]["target_kwargs"]["model_class"] = model_class
self._run_buffer["config"]["target_kwargs"]["repo_id"] = repo_id
def target():
_ = getattr(transformers, model_class).from_pretrained(repo_id)
return target
class FromPretrainedSpeedBenchMark(SpeedBenchMark, FromPretrainedBenchMark):
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
default=None,
type=str,
required=False,
help="Path to a prepared run file or a previously run output file.",
)
parser.add_argument(
"--output_path", type=str, required=True, help="Path to the output file where the run's info. will be saved."
)
args = parser.parse_args()
if args.config_path is None:
init_kwargs = {}
repo_id = "bert-base-uncased"
run_kwargs = {
"measure_kwargs": {"number": 2, "repeat": 3},
"target_kwargs": {"model_class": "AutoModel", "repo_id": repo_id},
"inputs_kwargs": [{}],
"report_kwargs": {},
}
run_configs = [run_kwargs]
else:
with open(args.config_path) as fp:
config = json.load(fp)
init_kwargs = config["init_kwargs"]
run_configs = [run["config"] for run in config["runs"]]
benchmark = FromPretrainedSpeedBenchMark(**init_kwargs)
for run_config in run_configs:
run_config["report_kwargs"]["output_path"] = args.output_path
result = benchmark.run(**run_config)

98
utils/benchmark.py Normal file
View File

@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team and the AllenNLP authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A handy script to run benchmark script(s) against a list of git commits.
Example: python utils/benchmark.py --benchmark_path src/transformers/benchmark/from_pretrained_benchmark.py --base_output_path "./bench_reports" --commits "62408e13,6c97e80b0"
"""
import argparse
import os
from contextlib import contextmanager
from pathlib import Path
from git import Repo
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
@contextmanager
def checkout_commit(repo: Repo, commit_id: str):
"""
Context manager that checks out a given commit when entered, but gets back to the reference it was at on exit.
Args:
repo (`git.Repo`): A git repository (for instance the Transformers repo).
commit_id (`str`): The commit reference to checkout inside the context manager.
"""
current_head = repo.head.commit if repo.head.is_detached else repo.head.ref
try:
repo.git.checkout(commit_id)
yield
finally:
repo.git.checkout(current_head)
if __name__ == "__main__":
def list_str(values):
return values.split(",")
parser = argparse.ArgumentParser()
parser.add_argument(
"--benchmark_path",
type=str,
required=True,
help="Path to the benchmark script to run.",
)
parser.add_argument(
"--base_output_path",
type=str,
required=True,
help="Base path to the output file where the run's info. will be saved.",
)
parser.add_argument(
"--config_path",
default=None,
type=str,
required=False,
help="Path to a prepared run file or a previously run output file.",
)
parser.add_argument(
"--commits",
type=list_str,
required=True,
help="Comma-separated list of commit SHA values against which the benchmark will be run",
)
args = parser.parse_args()
repo = Repo(PATH_TO_REPO)
for commit in args.commits:
with checkout_commit(repo, commit):
print(f"benchmark against commit: {repo.head.commit}")
output_path = os.path.join(args.base_output_path, f"{commit}")
commandline_args = f"--output_path {output_path}"
if args.config_path is not None:
commandline_args += " --config_path {args.config_path}"
# TODO: use `subprocess`
os.system(f"python {args.benchmark_path} {commandline_args}")

View File

@ -332,6 +332,9 @@ src/transformers/benchmark/benchmark_args_tf.py
src/transformers/benchmark/benchmark_args_utils.py
src/transformers/benchmark/benchmark_tf.py
src/transformers/benchmark/benchmark_utils.py
src/transformers/benchmark/benchmark_utils_generic.py
src/transformers/benchmark/cache_benchmark.py
src/transformers/benchmark/from_pretrained_benchmark.py
src/transformers/commands/add_new_model.py
src/transformers/commands/add_new_model_like.py
src/transformers/commands/convert.py