Clean up old Accelerate checks (#24279)

* Clean up old Accelerate checks

* Put back imports
This commit is contained in:
Sylvain Gugger 2023-06-14 12:44:09 -04:00 committed by GitHub
parent 860d11ff7c
commit 26a2ec56d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 45 deletions

View File

@ -98,7 +98,7 @@ if stale_egg_info.exists():
# 2. once modified, run: `make deps_table_update` to update src/transformers/dependency_versions_table.py
_deps = [
"Pillow",
"accelerate>=0.20.2",
"accelerate>=0.20.3",
"av==9.2.0", # Latest version of PyAV (10.0.0) has issues with audio stream.
"beautifulsoup4",
"black~=23.1",

View File

@ -3,7 +3,7 @@
# 2. run `make deps_table_update``
deps = {
"Pillow": "Pillow",
"accelerate": "accelerate>=0.20.2",
"accelerate": "accelerate>=0.20.3",
"av": "av==9.2.0",
"beautifulsoup4": "beautifulsoup4",
"black": "black~=23.1",

View File

@ -82,27 +82,17 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
from accelerate.utils import (
check_tied_parameters_on_same_device,
find_tied_parameters,
get_balanced_memory,
load_offloaded_weights,
offload_weight,
save_offload_index,
set_module_tensor_to_device,
)
if version.parse(accelerate_version) > version.parse("0.11.0"):
from accelerate.utils import get_balanced_memory
else:
get_balanced_memory = None
if version.parse(accelerate_version) > version.parse("0.19.0"):
from accelerate.utils import check_tied_parameters_on_same_device
else:
check_tied_parameters_on_same_device = None
else:
find_tied_parameters = None
if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
@ -2792,8 +2782,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
"'sequential'."
)
elif device_map in ["balanced", "balanced_low_0"] and get_balanced_memory is None:
raise ValueError(f"`device_map={device_map}` requires a source install of Accelerate.")
kwargs = {"no_split_module_classes": no_split_modules}
if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
@ -2803,7 +2791,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"This model has some weights that should be kept in higher precision, you need to upgrade "
"`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
)
if device_map != "sequential" and get_balanced_memory is not None:
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
dtype=target_dtype,
@ -2838,8 +2826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model.tie_weights()
tied_params = find_tied_parameters(model)
# check if we don't have tied param in different devices
if check_tied_parameters_on_same_device is not None:
check_tied_parameters_on_same_device(tied_params, device_map)
check_tied_parameters_on_same_device(tied_params, device_map)
if from_tf:
if resolved_archive_file.endswith(".index"):
@ -3031,7 +3018,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
if find_tied_parameters is not None:
if is_accelerate_available():
tied_params = find_tied_parameters(model)
else:
tied_params = []

View File

@ -32,8 +32,6 @@ from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from tqdm.auto import tqdm
# Integrations must be imported before ML frameworks:
# isort: off
@ -206,14 +204,9 @@ if is_peft_available():
from peft import PeftModel
skip_first_batches = None
if is_accelerate_available():
from accelerate import Accelerator, skip_first_batches
from accelerate import __version__ as accelerate_version
if version.parse(accelerate_version) >= version.parse("0.16"):
from accelerate import skip_first_batches
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
if version.parse(accelerate_version) > version.parse("0.20.3"):
@ -322,6 +315,7 @@ class Trainer:
"""
# Those are used as methods of the Trainer in examples.
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
def __init__(
@ -1714,22 +1708,10 @@ class Trainer:
logger.info(f" Continuing training from epoch {epochs_trained}")
logger.info(f" Continuing training from global step {self.state.global_step}")
if not args.ignore_data_skip:
if skip_first_batches is None:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
" you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
" also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
" training on data already seen by your model."
)
else:
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
steps_trained_progress_bar.set_description("Skipping the first batches")
logger.info(
f" Will skip the first {epochs_trained} epochs then the first"
f" {steps_trained_in_current_epoch} batches in the first epoch."
)
# Update the references
self.callback_handler.model = self.model
@ -1787,7 +1769,7 @@ class Trainer:
rng_to_sync = False
steps_skipped = 0
if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
if steps_trained_in_current_epoch > 0:
epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
steps_skipped = steps_trained_in_current_epoch
steps_trained_in_current_epoch = 0