[Flax] Allow retraining from save checkpoint (#12559)
* fix_torch_device_generate_test * remove @ * finish
This commit is contained in:
parent
1d6623c6a2
commit
7d321b7689
|
@ -478,7 +478,14 @@ if __name__ == "__main__":
|
|||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
model = FlaxAutoModelForMaskedLM.from_config(
|
||||
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
|
|
|
@ -588,7 +588,12 @@ if __name__ == "__main__":
|
|||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
|
||||
# Data collator
|
||||
# This one will take care of randomly masking the tokens.
|
||||
|
|
|
@ -427,7 +427,14 @@ if __name__ == "__main__":
|
|||
rng = jax.random.PRNGKey(training_args.seed)
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
model = FlaxAutoModelForMaskedLM.from_config(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
if model_args.model_name_or_path:
|
||||
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
else:
|
||||
model = FlaxAutoModelForMaskedLM.from_config(
|
||||
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
|
|
Loading…
Reference in New Issue