[Wav2Vec2] Flax - Adapt wav2vec2 script (#12520)
* fix_torch_device_generate_test * remove @ * adapt flax pretrain script
This commit is contained in:
parent
4605b2b8ec
commit
7d6285a921
|
@ -64,6 +64,12 @@ class ModelArguments:
|
|||
gumbel_temperature_decay: Optional[float] = field(
|
||||
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
|
@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument
|
|||
logger.setLevel(logging_level)
|
||||
|
||||
|
||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
summary_writer.scalar("train_time", train_time, step)
|
||||
|
||||
train_metrics = get_metrics(train_metrics)
|
||||
|
@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||
for i, val in enumerate(vals):
|
||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||
|
||||
|
||||
def write_eval_metric(summary_writer, eval_metrics, step):
|
||||
for metric_name, value in eval_metrics.items():
|
||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||
|
||||
|
@ -342,9 +350,7 @@ def main():
|
|||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||
)
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(
|
||||
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
|
||||
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
|
||||
|
@ -501,11 +507,11 @@ def main():
|
|||
state = jax_utils.replicate(state)
|
||||
|
||||
train_time = 0
|
||||
train_metrics = []
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
# ======================== Training ================================
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
# Create sampling rng
|
||||
rng, input_rng = jax.random.split(rng)
|
||||
|
@ -516,7 +522,7 @@ def main():
|
|||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||
model_inputs = data_collator(samples)
|
||||
model_inputs = shard(model_inputs.data)
|
||||
|
@ -527,11 +533,20 @@ def main():
|
|||
)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_time += time.time() - train_start
|
||||
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
||||
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
train_time += time.time() - train_start
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
|
||||
train_metrics = []
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
num_eval_samples = len(vectorized_datasets["validation"])
|
||||
|
@ -560,7 +575,7 @@ def main():
|
|||
# Save metrics
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
|
||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
|
|
Loading…
Reference in New Issue