[Flax] Fix another bug in logging steps (#12516)
* fix_torch_device_generate_test * remove @ * up
This commit is contained in:
parent
d0f7508abe
commit
4605b2b8ec
|
@ -606,7 +606,7 @@ if __name__ == "__main__":
|
|||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
cur_step = epoch * num_train_samples + step
|
||||
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
||||
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
|
|
|
@ -722,7 +722,7 @@ if __name__ == "__main__":
|
|||
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
cur_step = epoch * num_train_samples + step
|
||||
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
||||
|
||||
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||
# Save metrics
|
||||
|
|
Loading…
Reference in New Issue