diff --git a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py index 67de00c00e..9f602e1c85 100644 --- a/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py +++ b/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py @@ -439,6 +439,7 @@ def main(): data_args.dataset_config_name, split=data_args.train_split_name, cache_dir=data_args.dataset_cache_dir, + num_proc=data_args.preprocessing_num_workers, token=True if model_args.use_auth_token else None, ) @@ -448,6 +449,7 @@ def main(): data_args.dataset_config_name, split=data_args.eval_split_name, cache_dir=data_args.dataset_cache_dir, + num_proc=data_args.preprocessing_num_workers, token=True if model_args.use_auth_token else None, ) @@ -551,7 +553,7 @@ def main(): prepare_dataset, remove_columns=next(iter(raw_datasets.values())).column_names, num_proc=num_workers, - desc="preprocess train dataset", + desc="preprocess train and eval dataset", ) # filter training data with inputs longer than max_input_length