[trainer] remove `--model_parallel` (#9451)

* fix bad merge - dropped code

* remove --model_parallel

* Deal with TrainingArguments

* Use a private attr and fix batch sizes

* fix _n_gpu

* add is_parallel helper wrapper

* fix attribute

* introduce a new attribute is_model_parallel

* docs

* docs

* Put back init False and rearrange doc

* Ignore non-init args in HFArgumentParser

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
Stas Bekman 2021-01-11 06:39:28 -08:00 committed by GitHub
parent 6f63501383
commit 33b7422839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 35 deletions

View File

@ -53,6 +53,8 @@ class HfArgumentParser(ArgumentParser):
def _add_dataclass_arguments(self, dtype: DataClassType):
for field in dataclasses.fields(dtype):
if not field.init:
continue
field_name = f"--{field.name}"
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
@ -148,7 +150,7 @@ class HfArgumentParser(ArgumentParser):
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype)}
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: v for k, v in vars(namespace).items() if k in keys}
for k in keys:
delattr(namespace, k)

View File

@ -219,15 +219,16 @@ class Trainer:
:class:`~transformers.AdamW` on your model and a scheduler given by
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
Important accessors:
Important attributes:
``self.model`` - always points to the core model. If using a transformers model, it will be a
:class:`PreTrainedModel` subclass.
``self.model_wrapped`` - always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in ``DistributedDataParallel``. If the inner model
hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **model** -- Always points to the core model. If using a transformers model, it will be a
:class:`~transformers.PreTrainedModel` subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under ``DeepSpeed``,
the inner model is wrapped in ``DeepSpeed`` and then again in ``torch.nn.DistributedDataParallel``. If the
inner model hasn't been wrapped, then ``self.model_wrapped`` is the same as ``self.model``.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
"""
def __init__(
@ -267,6 +268,11 @@ class Trainer:
)
self.model_init = model_init
if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
self.is_model_parallel = True
else:
self.is_model_parallel = False
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
@ -274,8 +280,11 @@ class Trainer:
self.tokenizer = tokenizer
# Model parallel
if not self.args.model_parallel:
if not self.is_model_parallel:
model = model.to(args.device)
else:
# Force n_gpu to 1 to avoid DataParallel.
self.args._n_gpu = 1
# later use `self.model is self.model_wrapped` to check if it's wrapped or not
self.model_wrapped = model
@ -669,7 +678,7 @@ class Trainer:
set_seed(self.args.seed)
model = self.call_model_init(trial)
if not self.args.model_parallel:
if not self.is_model_parallel:
model = model.to(self.args.device)
self.model = model
@ -719,7 +728,7 @@ class Trainer:
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
# Multi-gpu training (should be after apex fp16 initialization)
if self.args.n_gpu > 1 and not self.args.model_parallel:
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
@ -930,7 +939,7 @@ class Trainer:
)
if isinstance(self.model, PreTrainedModel):
self.model = self.model.from_pretrained(self.state.best_model_checkpoint)
if not self.args.model_parallel:
if not self.is_model_parallel:
self.model = self.model.to(self.args.device)
else:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
@ -1481,7 +1490,7 @@ class Trainer:
model = self.model
# multi-gpu eval
if self.args.n_gpu > 1 and not self.args.model_parallel:
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.

View File

@ -210,9 +210,6 @@ class TrainingArguments:
- :obj:`True` if :obj:`metric_for_best_model` is set to a value that isn't :obj:`"loss"` or
:obj:`"eval_loss"`.
- :obj:`False` if :obj:`metric_for_best_model` is not set, or set to :obj:`"loss"` or :obj:`"eval_loss"`.
model_parallel (:obj:`bool`, `optional`, defaults to :obj:`False`):
If the model supports model parallelism and there is more than one device, whether to use model parallelism
to distribute the model's modules across devices or not.
ignore_skip_data (:obj:`bool`, `optional`, defaults to :obj:`False`):
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
@ -245,15 +242,6 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
model_parallel: bool = field(
default=False,
metadata={
"help": (
"If there are more than one devices, whether to use model parallelism to distribute the "
"model's modules across devices."
)
},
)
evaluation_strategy: EvaluationStrategy = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
@ -410,6 +398,7 @@ class TrainingArguments:
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace Adam by Adafactor."})
_n_gpu: int = field(init=False, repr=False, default=0)
def __post_init__(self):
if self.disable_tqdm is None:
@ -430,6 +419,7 @@ class TrainingArguments:
if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
self._n_gpu = torch.cuda.device_count()
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
@ -451,10 +441,7 @@ class TrainingArguments:
"version. Using `--per_device_train_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
if not self.model_parallel:
train_batch_size = per_device_batch_size * max(1, self.n_gpu)
else:
train_batch_size = per_device_batch_size
train_batch_size = per_device_batch_size * max(1, self.n_gpu)
return train_batch_size
@property
@ -468,10 +455,7 @@ class TrainingArguments:
"version. Using `--per_device_eval_batch_size` is preferred."
)
per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
if not self.model_parallel:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
else:
eval_batch_size = per_device_batch_size
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
return eval_batch_size
@cached_property
@ -492,7 +476,7 @@ class TrainingArguments:
# GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
# will use the first GPU in that env, i.e. GPU#1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
n_gpu = self._n_gpu
else:
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs