Fix type issue in using bucketing with Trainer (#18051)

* Fix type issue in using bucketing with Trainer

- Fix type issues in LengthGrouperSampler,
  DistributedLengthGroupedSampler

refs: #18003

* Change logging type in LengthGroupedSampler

- Change `logger.warning` to `logger.info`

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Change logging type in DistributedLengthGroupedSampler

- Change `logger.warning` to `logger.info`

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove adundant clause in LengthGroupedSampler

- Use `elif`

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove adundant clause in DistributedLengthGroupedSampler

- Use `elif`

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply black, isort to modified codes in the script

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
BOSEOP KIM 2022-07-09 00:06:00 +09:00 committed by GitHub
parent 9bd3968509
commit 94ca7d2faa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 0 deletions

View File

@ -558,6 +558,12 @@ class LengthGroupedSampler(Sampler):
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
)
lengths = lengths.tolist()
self.lengths = lengths
self.generator = generator
@ -614,6 +620,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
" List[int]..."
)
lengths = lengths.tolist()
self.lengths = lengths
# If the dataset length is evenly divisible by # of replicas, then there