Faster list concat for trainer_pt_utils.get_length_grouped_indices() (#11825)
get_length_grouped_indices() in LengthGroupedSampler and DistributedLengthGroupedSampler is prohibitively slow for large number of megabatches (in test case takes hours for ~270k megabatches with 100 items each) due to slow list concatenation with sum(megabatches, []). Resolves: #11795 Co-authored-by: ctheodoris <cvtheodo@ds.dfci.harvard.edu>
This commit is contained in:
parent
da22245ed9
commit
73fde1defe
|
@ -495,7 +495,7 @@ def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, genera
|
|||
# Switch to put the longest element in first position
|
||||
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
|
||||
|
||||
return sum(megabatches, [])
|
||||
return [i for megabatch in megabatches for i in megabatch]
|
||||
|
||||
|
||||
class LengthGroupedSampler(Sampler):
|
||||
|
|
Loading…
Reference in New Issue