Faster list concat for trainer_pt_utils.get_length_grouped_indices() (#11825)

get_length_grouped_indices() in LengthGroupedSampler and DistributedLengthGroupedSampler
is prohibitively slow for large number of megabatches (in test case takes hours for ~270k
megabatches with 100 items each) due to slow list concatenation with sum(megabatches, []).

Resolves: #11795

Co-authored-by: ctheodoris <cvtheodo@ds.dfci.harvard.edu>
This commit is contained in:
ctheodoris 2021-05-22 10:27:20 -04:00 committed by GitHub
parent da22245ed9
commit 73fde1defe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -495,7 +495,7 @@ def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, genera
# Switch to put the longest element in first position
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
return sum(megabatches, [])
return [i for megabatch in megabatches for i in megabatch]
class LengthGroupedSampler(Sampler):