Don't reset the dataset type + plug for rm unused columns (#6683)
* Don't reset the type of the dataset * Formatting * Update trainer.py Co-authored-by: Teven <teven.lescao@gmail.com>
This commit is contained in:
parent
1a779ad7ec
commit
b30879fe0c
|
@ -244,6 +244,8 @@ class Trainer:
|
|||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
def _remove_unused_columns(self, dataset: "nlp.Dataset", description: Optional[str] = None):
|
||||
if not self.args.remove_unused_columns:
|
||||
return
|
||||
# Inspect model forward signature to keep only the arguments it accepts.
|
||||
signature = inspect.signature(self.model.forward)
|
||||
signature_columns = list(signature.parameters.keys())
|
||||
|
@ -255,7 +257,10 @@ class Trainer:
|
|||
logger.info(
|
||||
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
||||
)
|
||||
dataset.set_format(columns=columns)
|
||||
ds_type = dataset.format["type"]
|
||||
if ds_type == "python":
|
||||
ds_type = None
|
||||
dataset.set_format(type=ds_type, columns=columns)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
||||
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
||||
|
|
|
@ -114,6 +114,11 @@ class TrainingArguments:
|
|||
at the next training step under the keyword argument ``mems``.
|
||||
run_name (:obj:`str`, `optional`):
|
||||
A descriptor for the run. Notably used for wandb logging.
|
||||
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
|
||||
forward method.
|
||||
|
||||
(Note: this behavior is not implemented for :class:`~transformers.TFTrainer` yet.)
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
|
@ -234,6 +239,10 @@ class TrainingArguments:
|
|||
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
|
||||
)
|
||||
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
|
||||
)
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue