From 805db1fe13b3155d61ac5571439f5d619e47022f Mon Sep 17 00:00:00 2001 From: Alex Punnen Date: Tue, 2 May 2023 22:37:30 +0530 Subject: [PATCH] num_noise_spans should be <= num_items #22246 (#22938) --- examples/flax/language-modeling/run_t5_mlm_flax.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 152760f4bf..f3cec97b2e 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -418,13 +418,14 @@ class FlaxDataCollatorForT5MLM: orig_length = length num_noise_tokens = int(np.round(length * self.noise_density)) + num_nonnoise_tokens = length - num_noise_tokens # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) - num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) + # num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens + num_noise_spans = int(np.round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length)) # avoid degeneracy by ensuring positive number of noise spans num_noise_spans = max(num_noise_spans, 1) - num_nonnoise_tokens = length - num_noise_tokens # pick the lengths of the noise spans and the non-noise spans def _random_segmentation(num_items, num_segments):