[Wav2Vec2] Flax - Adapt wav2vec2 script (#12520)

* fix_torch_device_generate_test

* remove @

* adapt flax pretrain script
This commit is contained in:
Patrick von Platen 2021-07-05 23:49:47 +01:00 committed by GitHub
parent 4605b2b8ec
commit 7d6285a921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 11 deletions

View File

@ -64,6 +64,12 @@ class ModelArguments:
gumbel_temperature_decay: Optional[float] = field(
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
},
)
@flax.struct.dataclass
@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument
logger.setLevel(logging_level)
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
@ -342,9 +350,7 @@ def main():
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
)
model = FlaxWav2Vec2ForPreTraining(
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
)
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
@ -501,11 +507,11 @@ def main():
state = jax_utils.replicate(state)
train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
@ -516,7 +522,7 @@ def main():
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
model_inputs = shard(model_inputs.data)
@ -527,11 +533,20 @@ def main():
)
train_metrics.append(train_metric)
train_time += time.time() - train_start
cur_step = epoch * (num_train_samples // train_batch_size) + step
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
)
train_metrics = []
# ======================== Evaluating ==============================
num_eval_samples = len(vectorized_datasets["validation"])
@ -560,7 +575,7 @@ def main():
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
write_eval_metric(summary_writer, eval_metrics, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0: